mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #585 from willmiao/codex/fix-model_type-not-updating-for-checkpoints
fix: apply adjust_cached_entry during model reconciliation
This commit is contained in:
@@ -643,10 +643,10 @@ class DownloadManager:
|
|||||||
# 4. Update file information (size and modified time)
|
# 4. Update file information (size and modified time)
|
||||||
metadata.update_file_info(save_path)
|
metadata.update_file_info(save_path)
|
||||||
|
|
||||||
# 5. Final metadata update
|
scanner = None
|
||||||
await MetadataManager.save_metadata(save_path, metadata)
|
adjust_root: Optional[str] = None
|
||||||
|
|
||||||
# 6. Update cache based on model type
|
# 5. Determine scanner and adjust metadata for cache consistency
|
||||||
if model_type == "checkpoint":
|
if model_type == "checkpoint":
|
||||||
scanner = await self._get_checkpoint_scanner()
|
scanner = await self._get_checkpoint_scanner()
|
||||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||||
@@ -656,9 +656,33 @@ class DownloadManager:
|
|||||||
elif model_type == "embedding":
|
elif model_type == "embedding":
|
||||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
logger.info(f"Updating embedding cache for {save_path}")
|
logger.info(f"Updating embedding cache for {save_path}")
|
||||||
|
|
||||||
|
if scanner is not None:
|
||||||
|
file_path_for_adjust = getattr(metadata, "file_path", save_path)
|
||||||
|
if isinstance(file_path_for_adjust, str):
|
||||||
|
normalized_file_path = file_path_for_adjust.replace(os.sep, "/")
|
||||||
|
else:
|
||||||
|
normalized_file_path = str(file_path_for_adjust)
|
||||||
|
|
||||||
|
find_root = getattr(scanner, "_find_root_for_file", None)
|
||||||
|
if callable(find_root):
|
||||||
|
try:
|
||||||
|
adjust_root = find_root(normalized_file_path)
|
||||||
|
except TypeError:
|
||||||
|
adjust_root = None
|
||||||
|
|
||||||
|
adjust_metadata = getattr(scanner, "adjust_metadata", None)
|
||||||
|
if callable(adjust_metadata):
|
||||||
|
metadata = adjust_metadata(metadata, normalized_file_path, adjust_root)
|
||||||
|
|
||||||
|
# 6. Persist metadata with any adjustments
|
||||||
|
await MetadataManager.save_metadata(save_path, metadata)
|
||||||
|
|
||||||
# Convert metadata to dictionary
|
# Convert metadata to dictionary
|
||||||
metadata_dict = metadata.to_dict()
|
metadata_dict = metadata.to_dict()
|
||||||
|
adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
|
||||||
|
if callable(adjust_cached_entry):
|
||||||
|
metadata_dict = adjust_cached_entry(metadata_dict)
|
||||||
|
|
||||||
# Add model to cache and save to disk in a single operation
|
# Add model to cache and save to disk in a single operation
|
||||||
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
||||||
|
|||||||
@@ -678,6 +678,9 @@ class ModelScanner:
|
|||||||
if root_path:
|
if root_path:
|
||||||
model_data = await self._process_model_file(path, root_path)
|
model_data = await self._process_model_file(path, root_path)
|
||||||
if model_data:
|
if model_data:
|
||||||
|
model_data = self.adjust_cached_entry(dict(model_data))
|
||||||
|
if not model_data:
|
||||||
|
continue
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(model_data)
|
self._cache.raw_data.append(model_data)
|
||||||
self._cache.add_to_version_index(model_data)
|
self._cache.add_to_version_index(model_data)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
@@ -431,6 +432,97 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
|||||||
assert dummy_scanner.calls # ensure cache updated
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|
||||||
|
|
||||||
|
async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
root_dir = tmp_path / "checkpoints"
|
||||||
|
root_dir.mkdir()
|
||||||
|
save_dir = root_dir
|
||||||
|
target_path = save_dir / "model.safetensors"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = path.as_posix()
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
self.preview_nsfw_level = 0
|
||||||
|
self.model_type = "checkpoint"
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, updated_path):
|
||||||
|
self.file_path = Path(updated_path).as_posix()
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"model_type": self.model_type,
|
||||||
|
"sha256": self.sha256,
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(target_path)
|
||||||
|
version_info = {"images": []}
|
||||||
|
download_urls = ["https://example.invalid/model.safetensors"]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
async def download_file(self, _url, path, progress_callback=None, use_auth=None):
|
||||||
|
Path(path).write_text("content")
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_manager,
|
||||||
|
"get_downloader",
|
||||||
|
AsyncMock(return_value=DummyDownloader()),
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyCheckpointScanner:
|
||||||
|
def __init__(self, root: Path):
|
||||||
|
self.root = root.as_posix()
|
||||||
|
self.add_calls = []
|
||||||
|
|
||||||
|
def _find_root_for_file(self, file_path: str):
|
||||||
|
return self.root if file_path.startswith(self.root) else None
|
||||||
|
|
||||||
|
def adjust_metadata(self, metadata_obj, _file_path: str, root_path: Optional[str]):
|
||||||
|
if root_path:
|
||||||
|
metadata_obj.model_type = "diffusion_model"
|
||||||
|
return metadata_obj
|
||||||
|
|
||||||
|
def adjust_cached_entry(self, entry):
|
||||||
|
if entry.get("file_path", "").startswith(self.root):
|
||||||
|
entry["model_type"] = "diffusion_model"
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||||
|
self.add_calls.append((metadata_dict, relative_path))
|
||||||
|
return True
|
||||||
|
|
||||||
|
dummy_scanner = DummyCheckpointScanner(root_dir)
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="checkpoint",
|
||||||
|
download_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert metadata.model_type == "diffusion_model"
|
||||||
|
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
|
||||||
|
assert saved_metadata.model_type == "diffusion_model"
|
||||||
|
assert dummy_scanner.add_calls
|
||||||
|
cached_entry, _ = dummy_scanner.add_calls[0]
|
||||||
|
assert cached_entry["model_type"] == "diffusion_model"
|
||||||
|
|
||||||
|
|
||||||
async def test_pause_download_updates_state():
|
async def test_pause_download_updates_state():
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import os
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -470,6 +471,37 @@ async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: P
|
|||||||
assert cache.folders == [""]
|
assert cache.folders == [""]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
||||||
|
existing = tmp_path / "one.txt"
|
||||||
|
existing.write_text("one", encoding="utf-8")
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
|
||||||
|
applied: List[str] = []
|
||||||
|
|
||||||
|
def _adjust(self, entry: dict) -> dict:
|
||||||
|
applied.append(entry["file_path"])
|
||||||
|
entry["model_type"] = "adjusted"
|
||||||
|
return entry
|
||||||
|
|
||||||
|
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
|
||||||
|
|
||||||
|
await scanner._initialize_cache()
|
||||||
|
applied.clear()
|
||||||
|
|
||||||
|
new_file = tmp_path / "two.txt"
|
||||||
|
new_file.write_text("two", encoding="utf-8")
|
||||||
|
|
||||||
|
await scanner._reconcile_cache()
|
||||||
|
|
||||||
|
normalized_new = _normalize_path(new_file)
|
||||||
|
assert normalized_new in applied
|
||||||
|
|
||||||
|
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
|
||||||
|
assert new_entry["model_type"] == "adjusted"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_count_model_files_handles_symlink_loops(tmp_path: Path):
|
async def test_count_model_files_handles_symlink_loops(tmp_path: Path):
|
||||||
scanner = DummyScanner(tmp_path)
|
scanner = DummyScanner(tmp_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user