feat(download): auto-route diffusion models to unet folder based on baseModel, see #770

CivitAI does not distinguish between checkpoint and diffusion model types -
both are labeled as "checkpoint". For certain base model types like
"ZImageTurbo", all models are actually diffusion models and should be
saved to the unet/diffusion model folder instead of the checkpoint folder.

- Add DIFFUSION_MODEL_BASE_MODELS constant for known diffusion model types
- Add default_unet_root setting with auto-set logic
- Route downloads to unet folder when baseModel matches known diffusion types

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Will Miao
2026-01-18 11:57:23 +08:00
parent aab1797269
commit dad549f65f
3 changed files with 64 additions and 4 deletions

View File

@@ -10,7 +10,7 @@ import uuid
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES from ..utils.constants import CARD_PREVIEW_WIDTH, DIFFUSION_MODEL_BASE_MODELS, VALID_LORA_TYPES
from ..utils.civitai_utils import rewrite_preview_url from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media from ..utils.preview_selection import select_preview_media
from ..utils.utils import sanitize_folder_name from ..utils.utils import sanitize_folder_name
@@ -343,6 +343,14 @@ class DownloadManager:
"error": f'Model type "{model_type_from_info}" is not supported for download', "error": f'Model type "{model_type_from_info}" is not supported for download',
} }
# Check if this checkpoint should be treated as a diffusion model based on baseModel
is_diffusion_model = False
if model_type == "checkpoint":
base_model_value = version_info.get('baseModel', '')
if base_model_value in DIFFUSION_MODEL_BASE_MODELS:
is_diffusion_model = True
logger.info(f"baseModel '{base_model_value}' is a known diffusion model, routing to unet folder")
# Case 2: model_version_id was None, check after getting version_info # Case 2: model_version_id was None, check after getting version_info
if model_version_id is None: if model_version_id is None:
version_id = version_info.get("id") version_id = version_info.get("id")
@@ -377,11 +385,16 @@ class DownloadManager:
settings_manager = get_settings_manager() settings_manager = get_settings_manager()
# Set save_dir based on model type # Set save_dir based on model type
if model_type == "checkpoint": if model_type == "checkpoint":
default_path = settings_manager.get("default_checkpoint_root") if is_diffusion_model:
default_path = settings_manager.get("default_unet_root")
error_msg = "Default unet root path not set in settings"
else:
default_path = settings_manager.get("default_checkpoint_root")
error_msg = "Default checkpoint root path not set in settings"
if not default_path: if not default_path:
return { return {
"success": False, "success": False,
"error": "Default checkpoint root path not set in settings", "error": error_msg,
} }
save_dir = default_path save_dir = default_path
elif model_type == "lora": elif model_type == "lora":

View File

@@ -44,6 +44,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
"proxy_type": "http", "proxy_type": "http",
"default_lora_root": "", "default_lora_root": "",
"default_checkpoint_root": "", "default_checkpoint_root": "",
"default_unet_root": "",
"default_embedding_root": "", "default_embedding_root": "",
"base_model_path_mappings": {}, "base_model_path_mappings": {},
"download_path_templates": {}, "download_path_templates": {},
@@ -215,6 +216,7 @@ class SettingsManager:
folder_paths=merged.get("folder_paths", {}), folder_paths=merged.get("folder_paths", {}),
default_lora_root=merged.get("default_lora_root"), default_lora_root=merged.get("default_lora_root"),
default_checkpoint_root=merged.get("default_checkpoint_root"), default_checkpoint_root=merged.get("default_checkpoint_root"),
default_unet_root=merged.get("default_unet_root"),
default_embedding_root=merged.get("default_embedding_root"), default_embedding_root=merged.get("default_embedding_root"),
) )
} }
@@ -300,6 +302,7 @@ class SettingsManager:
folder_paths=normalized_top_level_paths, folder_paths=normalized_top_level_paths,
default_lora_root=self.settings.get("default_lora_root", ""), default_lora_root=self.settings.get("default_lora_root", ""),
default_checkpoint_root=self.settings.get("default_checkpoint_root", ""), default_checkpoint_root=self.settings.get("default_checkpoint_root", ""),
default_unet_root=self.settings.get("default_unet_root", ""),
default_embedding_root=self.settings.get("default_embedding_root", ""), default_embedding_root=self.settings.get("default_embedding_root", ""),
) )
libraries = {library_name: library_payload} libraries = {library_name: library_payload}
@@ -342,6 +345,7 @@ class SettingsManager:
folder_paths=candidate_folder_paths, folder_paths=candidate_folder_paths,
default_lora_root=data.get("default_lora_root"), default_lora_root=data.get("default_lora_root"),
default_checkpoint_root=data.get("default_checkpoint_root"), default_checkpoint_root=data.get("default_checkpoint_root"),
default_unet_root=data.get("default_unet_root"),
default_embedding_root=data.get("default_embedding_root"), default_embedding_root=data.get("default_embedding_root"),
metadata=data.get("metadata"), metadata=data.get("metadata"),
base=data, base=data,
@@ -380,6 +384,7 @@ class SettingsManager:
self.settings["folder_paths"] = folder_paths self.settings["folder_paths"] = folder_paths
self.settings["default_lora_root"] = active_library.get("default_lora_root", "") self.settings["default_lora_root"] = active_library.get("default_lora_root", "")
self.settings["default_checkpoint_root"] = active_library.get("default_checkpoint_root", "") self.settings["default_checkpoint_root"] = active_library.get("default_checkpoint_root", "")
self.settings["default_unet_root"] = active_library.get("default_unet_root", "")
self.settings["default_embedding_root"] = active_library.get("default_embedding_root", "") self.settings["default_embedding_root"] = active_library.get("default_embedding_root", "")
if save: if save:
@@ -394,6 +399,7 @@ class SettingsManager:
folder_paths: Optional[Mapping[str, Iterable[str]]] = None, folder_paths: Optional[Mapping[str, Iterable[str]]] = None,
default_lora_root: Optional[str] = None, default_lora_root: Optional[str] = None,
default_checkpoint_root: Optional[str] = None, default_checkpoint_root: Optional[str] = None,
default_unet_root: Optional[str] = None,
default_embedding_root: Optional[str] = None, default_embedding_root: Optional[str] = None,
metadata: Optional[Mapping[str, Any]] = None, metadata: Optional[Mapping[str, Any]] = None,
base: Optional[Mapping[str, Any]] = None, base: Optional[Mapping[str, Any]] = None,
@@ -416,6 +422,11 @@ class SettingsManager:
else: else:
payload.setdefault("default_checkpoint_root", "") payload.setdefault("default_checkpoint_root", "")
if default_unet_root is not None:
payload["default_unet_root"] = default_unet_root
else:
payload.setdefault("default_unet_root", "")
if default_embedding_root is not None: if default_embedding_root is not None:
payload["default_embedding_root"] = default_embedding_root payload["default_embedding_root"] = default_embedding_root
else: else:
@@ -517,6 +528,7 @@ class SettingsManager:
folder_paths: Optional[Mapping[str, Iterable[str]]] = None, folder_paths: Optional[Mapping[str, Iterable[str]]] = None,
default_lora_root: Optional[str] = None, default_lora_root: Optional[str] = None,
default_checkpoint_root: Optional[str] = None, default_checkpoint_root: Optional[str] = None,
default_unet_root: Optional[str] = None,
default_embedding_root: Optional[str] = None, default_embedding_root: Optional[str] = None,
) -> bool: ) -> bool:
libraries = self.settings.get("libraries", {}) libraries = self.settings.get("libraries", {})
@@ -541,6 +553,10 @@ class SettingsManager:
library["default_checkpoint_root"] = default_checkpoint_root library["default_checkpoint_root"] = default_checkpoint_root
changed = True changed = True
if default_unet_root is not None and library.get("default_unet_root") != default_unet_root:
library["default_unet_root"] = default_unet_root
changed = True
if default_embedding_root is not None and library.get("default_embedding_root") != default_embedding_root: if default_embedding_root is not None and library.get("default_embedding_root") != default_embedding_root:
library["default_embedding_root"] = default_embedding_root library["default_embedding_root"] = default_embedding_root
changed = True changed = True
@@ -596,7 +612,11 @@ class SettingsManager:
logger.info("Migration completed") logger.info("Migration completed")
def _auto_set_default_roots(self): def _auto_set_default_roots(self):
"""Auto set default root paths when only one folder is present and the current default is unset or not among the options.""" """Auto set default root paths when the current default is unset or not among the options.
For single-path cases, always use that path.
For multi-path cases, only set if current default is empty or invalid.
"""
folder_paths = self.settings.get('folder_paths', {}) folder_paths = self.settings.get('folder_paths', {})
updated = False updated = False
# loras # loras
@@ -613,6 +633,14 @@ class SettingsManager:
if current_checkpoint_root not in checkpoints: if current_checkpoint_root not in checkpoints:
self.settings['default_checkpoint_root'] = checkpoints[0] self.settings['default_checkpoint_root'] = checkpoints[0]
updated = True updated = True
# unet (diffusion models) - auto-set if empty or invalid
unet_paths = folder_paths.get('unet', [])
if isinstance(unet_paths, list) and len(unet_paths) >= 1:
current_unet_root = self.settings.get('default_unet_root')
# Set to first path if current is empty or not in the valid paths
if not current_unet_root or current_unet_root not in unet_paths:
self.settings['default_unet_root'] = unet_paths[0]
updated = True
# embeddings # embeddings
embeddings = folder_paths.get('embeddings', []) embeddings = folder_paths.get('embeddings', [])
if isinstance(embeddings, list) and len(embeddings) == 1: if isinstance(embeddings, list) and len(embeddings) == 1:
@@ -624,6 +652,7 @@ class SettingsManager:
self._update_active_library_entry( self._update_active_library_entry(
default_lora_root=self.settings.get('default_lora_root'), default_lora_root=self.settings.get('default_lora_root'),
default_checkpoint_root=self.settings.get('default_checkpoint_root'), default_checkpoint_root=self.settings.get('default_checkpoint_root'),
default_unet_root=self.settings.get('default_unet_root'),
default_embedding_root=self.settings.get('default_embedding_root'), default_embedding_root=self.settings.get('default_embedding_root'),
) )
if self._bootstrap_reason == "missing": if self._bootstrap_reason == "missing":
@@ -851,6 +880,8 @@ class SettingsManager:
self._update_active_library_entry(default_lora_root=str(value)) self._update_active_library_entry(default_lora_root=str(value))
elif key == 'default_checkpoint_root': elif key == 'default_checkpoint_root':
self._update_active_library_entry(default_checkpoint_root=str(value)) self._update_active_library_entry(default_checkpoint_root=str(value))
elif key == 'default_unet_root':
self._update_active_library_entry(default_unet_root=str(value))
elif key == 'default_embedding_root': elif key == 'default_embedding_root':
self._update_active_library_entry(default_embedding_root=str(value)) self._update_active_library_entry(default_embedding_root=str(value))
elif key == 'model_name_display': elif key == 'model_name_display':
@@ -1131,6 +1162,7 @@ class SettingsManager:
folder_paths: Optional[Mapping[str, Iterable[str]]] = None, folder_paths: Optional[Mapping[str, Iterable[str]]] = None,
default_lora_root: Optional[str] = None, default_lora_root: Optional[str] = None,
default_checkpoint_root: Optional[str] = None, default_checkpoint_root: Optional[str] = None,
default_unet_root: Optional[str] = None,
default_embedding_root: Optional[str] = None, default_embedding_root: Optional[str] = None,
metadata: Optional[Mapping[str, Any]] = None, metadata: Optional[Mapping[str, Any]] = None,
activate: bool = False, activate: bool = False,
@@ -1155,6 +1187,11 @@ class SettingsManager:
if default_checkpoint_root is not None if default_checkpoint_root is not None
else existing.get("default_checkpoint_root") else existing.get("default_checkpoint_root")
), ),
default_unet_root=(
default_unet_root
if default_unet_root is not None
else existing.get("default_unet_root")
),
default_embedding_root=( default_embedding_root=(
default_embedding_root default_embedding_root
if default_embedding_root is not None if default_embedding_root is not None
@@ -1184,6 +1221,7 @@ class SettingsManager:
folder_paths: Mapping[str, Iterable[str]], folder_paths: Mapping[str, Iterable[str]],
default_lora_root: str = "", default_lora_root: str = "",
default_checkpoint_root: str = "", default_checkpoint_root: str = "",
default_unet_root: str = "",
default_embedding_root: str = "", default_embedding_root: str = "",
metadata: Optional[Mapping[str, Any]] = None, metadata: Optional[Mapping[str, Any]] = None,
activate: bool = False, activate: bool = False,
@@ -1199,6 +1237,7 @@ class SettingsManager:
folder_paths=folder_paths, folder_paths=folder_paths,
default_lora_root=default_lora_root, default_lora_root=default_lora_root,
default_checkpoint_root=default_checkpoint_root, default_checkpoint_root=default_checkpoint_root,
default_unet_root=default_unet_root,
default_embedding_root=default_embedding_root, default_embedding_root=default_embedding_root,
metadata=metadata, metadata=metadata,
activate=activate, activate=activate,
@@ -1256,6 +1295,7 @@ class SettingsManager:
*, *,
default_lora_root: Optional[str] = None, default_lora_root: Optional[str] = None,
default_checkpoint_root: Optional[str] = None, default_checkpoint_root: Optional[str] = None,
default_unet_root: Optional[str] = None,
default_embedding_root: Optional[str] = None, default_embedding_root: Optional[str] = None,
) -> None: ) -> None:
"""Update folder paths for the active library.""" """Update folder paths for the active library."""
@@ -1266,6 +1306,7 @@ class SettingsManager:
folder_paths=folder_paths, folder_paths=folder_paths,
default_lora_root=default_lora_root, default_lora_root=default_lora_root,
default_checkpoint_root=default_checkpoint_root, default_checkpoint_root=default_checkpoint_root,
default_unet_root=default_unet_root,
default_embedding_root=default_embedding_root, default_embedding_root=default_embedding_root,
activate=True, activate=True,
) )

View File

@@ -75,3 +75,9 @@ DEFAULT_PRIORITY_TAG_CONFIG = {
'checkpoint': ', '.join(CIVITAI_MODEL_TAGS), 'checkpoint': ', '.join(CIVITAI_MODEL_TAGS),
'embedding': ', '.join(CIVITAI_MODEL_TAGS), 'embedding': ', '.join(CIVITAI_MODEL_TAGS),
} }
# baseModel values from CivitAI that should be treated as diffusion models (unet)
# These model types are incorrectly labeled as "checkpoint" by CivitAI but are actually diffusion models
DIFFUSION_MODEL_BASE_MODELS = frozenset([
"ZImageTurbo",
])