Merge pull request #585 from willmiao/codex/fix-model_type-not-updating-for-checkpoints

fix: apply adjust_cached_entry during model reconciliation
This commit is contained in:
pixelpaws
2025-10-23 08:33:33 +08:00
committed by GitHub
4 changed files with 155 additions and 4 deletions

View File

@@ -643,10 +643,10 @@ class DownloadManager:
# 4. Update file information (size and modified time) # 4. Update file information (size and modified time)
metadata.update_file_info(save_path) metadata.update_file_info(save_path)
# 5. Final metadata update scanner = None
await MetadataManager.save_metadata(save_path, metadata) adjust_root: Optional[str] = None
# 6. Update cache based on model type # 5. Determine scanner and adjust metadata for cache consistency
if model_type == "checkpoint": if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner() scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}") logger.info(f"Updating checkpoint cache for {save_path}")
@@ -656,9 +656,33 @@ class DownloadManager:
elif model_type == "embedding": elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner() scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {save_path}") logger.info(f"Updating embedding cache for {save_path}")
if scanner is not None:
file_path_for_adjust = getattr(metadata, "file_path", save_path)
if isinstance(file_path_for_adjust, str):
normalized_file_path = file_path_for_adjust.replace(os.sep, "/")
else:
normalized_file_path = str(file_path_for_adjust)
find_root = getattr(scanner, "_find_root_for_file", None)
if callable(find_root):
try:
adjust_root = find_root(normalized_file_path)
except TypeError:
adjust_root = None
adjust_metadata = getattr(scanner, "adjust_metadata", None)
if callable(adjust_metadata):
metadata = adjust_metadata(metadata, normalized_file_path, adjust_root)
# 6. Persist metadata with any adjustments
await MetadataManager.save_metadata(save_path, metadata)
# Convert metadata to dictionary # Convert metadata to dictionary
metadata_dict = metadata.to_dict() metadata_dict = metadata.to_dict()
adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
if callable(adjust_cached_entry):
metadata_dict = adjust_cached_entry(metadata_dict)
# Add model to cache and save to disk in a single operation # Add model to cache and save to disk in a single operation
await scanner.add_model_to_cache(metadata_dict, relative_path) await scanner.add_model_to_cache(metadata_dict, relative_path)

View File

@@ -678,6 +678,9 @@ class ModelScanner:
if root_path: if root_path:
model_data = await self._process_model_file(path, root_path) model_data = await self._process_model_file(path, root_path)
if model_data: if model_data:
model_data = self.adjust_cached_entry(dict(model_data))
if not model_data:
continue
# Add to cache # Add to cache
self._cache.raw_data.append(model_data) self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data) self._cache.add_to_version_index(model_data)

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
@@ -431,6 +432,97 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert dummy_scanner.calls # ensure cache updated assert dummy_scanner.calls # ensure cache updated
async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path):
manager = DownloadManager()
root_dir = tmp_path / "checkpoints"
root_dir.mkdir()
save_dir = root_dir
target_path = save_dir / "model.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = path.as_posix()
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = 0
self.model_type = "checkpoint"
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = Path(updated_path).as_posix()
def to_dict(self):
return {
"file_path": self.file_path,
"model_type": self.model_type,
"sha256": self.sha256,
}
metadata = DummyMetadata(target_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.safetensors"]
class DummyDownloader:
async def download_file(self, _url, path, progress_callback=None, use_auth=None):
Path(path).write_text("content")
return True, "ok"
monkeypatch.setattr(
download_manager,
"get_downloader",
AsyncMock(return_value=DummyDownloader()),
)
class DummyCheckpointScanner:
def __init__(self, root: Path):
self.root = root.as_posix()
self.add_calls = []
def _find_root_for_file(self, file_path: str):
return self.root if file_path.startswith(self.root) else None
def adjust_metadata(self, metadata_obj, _file_path: str, root_path: Optional[str]):
if root_path:
metadata_obj.model_type = "diffusion_model"
return metadata_obj
def adjust_cached_entry(self, entry):
if entry.get("file_path", "").startswith(self.root):
entry["model_type"] = "diffusion_model"
return entry
async def add_model_to_cache(self, metadata_dict, relative_path):
self.add_calls.append((metadata_dict, relative_path))
return True
dummy_scanner = DummyCheckpointScanner(root_dir)
monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="checkpoint",
download_id=None,
)
assert result == {"success": True}
assert metadata.model_type == "diffusion_model"
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
assert saved_metadata.model_type == "diffusion_model"
assert dummy_scanner.add_calls
cached_entry, _ = dummy_scanner.add_calls[0]
assert cached_entry["model_type"] == "diffusion_model"
async def test_pause_download_updates_state(): async def test_pause_download_updates_state():
manager = DownloadManager() manager = DownloadManager()

View File

@@ -3,6 +3,7 @@ import os
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from types import MethodType
import pytest import pytest
@@ -470,6 +471,37 @@ async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: P
assert cache.folders == [""] assert cache.folders == [""]
@pytest.mark.asyncio
async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
existing = tmp_path / "one.txt"
existing.write_text("one", encoding="utf-8")
scanner = DummyScanner(tmp_path)
applied: List[str] = []
def _adjust(self, entry: dict) -> dict:
applied.append(entry["file_path"])
entry["model_type"] = "adjusted"
return entry
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
await scanner._initialize_cache()
applied.clear()
new_file = tmp_path / "two.txt"
new_file.write_text("two", encoding="utf-8")
await scanner._reconcile_cache()
normalized_new = _normalize_path(new_file)
assert normalized_new in applied
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
assert new_entry["model_type"] == "adjusted"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_count_model_files_handles_symlink_loops(tmp_path: Path): async def test_count_model_files_handles_symlink_loops(tmp_path: Path):
scanner = DummyScanner(tmp_path) scanner = DummyScanner(tmp_path)