feat(api): enhance model API creation with validation and default fallback

Refactored `createModelApiClient` to pass the specific model type as a parameter to each client constructor. Introduced `isValidModelType` for validation and added logic to set a default model type if provided type is invalid or not specified. Updated `getModelApiClient` function to utilize these improvements, ensuring robust model API instantiation.
This commit is contained in:
Will Miao
2025-10-12 15:28:30 +08:00
parent 0040863a03
commit a9a6f66035

View File

@@ -1,17 +1,17 @@
import { LoraApiClient } from './loraApi.js'; import { LoraApiClient } from './loraApi.js';
import { CheckpointApiClient } from './checkpointApi.js'; import { CheckpointApiClient } from './checkpointApi.js';
import { EmbeddingApiClient } from './embeddingApi.js'; import { EmbeddingApiClient } from './embeddingApi.js';
import { MODEL_TYPES } from './apiConfig.js'; import { MODEL_TYPES, isValidModelType } from './apiConfig.js';
import { state } from '../state/index.js'; import { state } from '../state/index.js';
export function createModelApiClient(modelType) { export function createModelApiClient(modelType) {
switch (modelType) { switch (modelType) {
case MODEL_TYPES.LORA: case MODEL_TYPES.LORA:
return new LoraApiClient(); return new LoraApiClient(MODEL_TYPES.LORA);
case MODEL_TYPES.CHECKPOINT: case MODEL_TYPES.CHECKPOINT:
return new CheckpointApiClient(); return new CheckpointApiClient(MODEL_TYPES.CHECKPOINT);
case MODEL_TYPES.EMBEDDING: case MODEL_TYPES.EMBEDDING:
return new EmbeddingApiClient(); return new EmbeddingApiClient(MODEL_TYPES.EMBEDDING);
default: default:
throw new Error(`Unsupported model type: ${modelType}`); throw new Error(`Unsupported model type: ${modelType}`);
} }
@@ -20,7 +20,13 @@ export function createModelApiClient(modelType) {
let _singletonClients = new Map(); let _singletonClients = new Map();
export function getModelApiClient(modelType = null) { export function getModelApiClient(modelType = null) {
const targetType = modelType || state.currentPageType; let targetType = modelType;
if (!isValidModelType(targetType)) {
targetType = isValidModelType(state.currentPageType)
? state.currentPageType
: MODEL_TYPES.LORA;
}
if (!_singletonClients.has(targetType)) { if (!_singletonClients.has(targetType)) {
_singletonClients.set(targetType, createModelApiClient(targetType)); _singletonClients.set(targetType, createModelApiClient(targetType));
@@ -32,4 +38,4 @@ export function getModelApiClient(modelType = null) {
export function resetAndReload(updateFolders = false) { export function resetAndReload(updateFolders = false) {
const client = getModelApiClient(); const client = getModelApiClient();
return client.loadMoreWithVirtualScroll(true, updateFolders); return client.loadMoreWithVirtualScroll(true, updateFolders);
} }