diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 6ed99cce..25750f78 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -657,9 +657,11 @@ class ModelScanner: # Update folders list all_folders = set(item.get('folder', '') for item in self._cache.raw_data) self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - + # Resort cache await self._cache.resort() + + await self._persist_current_cache() logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.") except Exception as e: @@ -1087,7 +1089,7 @@ class ModelScanner: async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: """Update cache after a model has been moved or modified""" cache = await self.get_cached_data() - + existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) if existing_item and 'tags' in existing_item: for tag in existing_item.get('tags', []): @@ -1099,10 +1101,12 @@ class ModelScanner: self._hash_index.remove_by_path(original_path) cache.raw_data = [ - item for item in cache.raw_data + item for item in cache.raw_data if item['file_path'] != original_path ] - + + cache_modified = bool(existing_item) or bool(metadata) + if metadata: normalized_new_path = new_path.replace(os.sep, '/') if original_path == new_path and existing_item: @@ -1129,7 +1133,10 @@ class ModelScanner: self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 await cache.resort() - + + if cache_modified: + await self._persist_current_cache() + return True def has_hash(self, sha256: str) -> bool: @@ -1363,7 +1370,9 @@ class ModelScanner: # Resort cache await self._cache.resort() - + + await self._persist_current_cache() + return True except Exception as e: diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index 0dbfaf01..6d87d56e 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -1,5 +1,6 @@ import asyncio import os +import sqlite3 from pathlib import Path from typing import List @@ -304,3 +305,77 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch) assert scanner._tags_count == {'alpha': 1} assert ws_stub.payloads[-1]['stage'] == 'loading_cache' assert ws_stub.payloads[-1]['progress'] == 1 + + +@pytest.mark.asyncio +async def test_update_single_model_cache_persists_changes(tmp_path: Path, monkeypatch): + monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') + db_path = tmp_path / 'cache.sqlite' + monkeypatch.setenv('LORA_MANAGER_CACHE_DB', str(db_path)) + monkeypatch.setattr(PersistentModelCache, '_instance', None, raising=False) + + _create_files(tmp_path) + scanner = DummyScanner(tmp_path) + + await scanner._initialize_cache() + + normalized = _normalize_path(tmp_path / 'one.txt') + updated_metadata = { + 'file_path': normalized, + 'file_name': 'one', + 'model_name': 'renamed', + 'sha256': 'hash-one', + 'tags': ['gamma', 'delta'], + 'size': 42, + 'modified': 456.0, + 'base_model': 'base', + 'from_civitai': True, + } + + await scanner.update_single_model_cache(normalized, normalized, updated_metadata) + + with sqlite3.connect(db_path) as conn: + conn.row_factory = sqlite3.Row + row = conn.execute( + "SELECT model_name FROM models WHERE file_path = ?", + (normalized,), + ).fetchone() + + assert row is not None + assert row['model_name'] == 'renamed' + + tags = { + record['tag'] + for record in conn.execute( + "SELECT tag FROM model_tags WHERE file_path = ?", + (normalized,), + ) + } + + assert tags == {'gamma', 'delta'} + + +@pytest.mark.asyncio +async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch): + monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') + db_path = tmp_path / 'cache.sqlite' + monkeypatch.setenv('LORA_MANAGER_CACHE_DB', str(db_path)) + monkeypatch.setattr(PersistentModelCache, '_instance', None, raising=False) + + first, _, _ = _create_files(tmp_path) + scanner = DummyScanner(tmp_path) + + await scanner._initialize_cache() + + normalized = _normalize_path(first) + removed = await scanner._batch_update_cache_for_deleted_models([normalized]) + + assert removed is True + + with sqlite3.connect(db_path) as conn: + remaining = conn.execute( + "SELECT COUNT(*) FROM models WHERE file_path = ?", + (normalized,), + ).fetchone()[0] + + assert remaining == 0