diff --git a/static/js/api/modelApiFactory.js b/static/js/api/modelApiFactory.js index 51c0c055..154b103b 100644 --- a/static/js/api/modelApiFactory.js +++ b/static/js/api/modelApiFactory.js @@ -1,17 +1,17 @@ import { LoraApiClient } from './loraApi.js'; import { CheckpointApiClient } from './checkpointApi.js'; import { EmbeddingApiClient } from './embeddingApi.js'; -import { MODEL_TYPES } from './apiConfig.js'; +import { MODEL_TYPES, isValidModelType } from './apiConfig.js'; import { state } from '../state/index.js'; export function createModelApiClient(modelType) { switch (modelType) { case MODEL_TYPES.LORA: - return new LoraApiClient(); + return new LoraApiClient(MODEL_TYPES.LORA); case MODEL_TYPES.CHECKPOINT: - return new CheckpointApiClient(); + return new CheckpointApiClient(MODEL_TYPES.CHECKPOINT); case MODEL_TYPES.EMBEDDING: - return new EmbeddingApiClient(); + return new EmbeddingApiClient(MODEL_TYPES.EMBEDDING); default: throw new Error(`Unsupported model type: ${modelType}`); } @@ -20,7 +20,13 @@ export function createModelApiClient(modelType) { let _singletonClients = new Map(); export function getModelApiClient(modelType = null) { - const targetType = modelType || state.currentPageType; + let targetType = modelType; + + if (!isValidModelType(targetType)) { + targetType = isValidModelType(state.currentPageType) + ? state.currentPageType + : MODEL_TYPES.LORA; + } if (!_singletonClients.has(targetType)) { _singletonClients.set(targetType, createModelApiClient(targetType)); @@ -32,4 +38,4 @@ export function getModelApiClient(modelType = null) { export function resetAndReload(updateFolders = false) { const client = getModelApiClient(); return client.loadMoreWithVirtualScroll(true, updateFolders); -} \ No newline at end of file +}