fix(cache): sync persistent metadata updates

This commit is contained in:
pixelpaws
2025-10-03 14:57:44 +08:00
parent 28bc966b76
commit fe9fbdb93c
2 changed files with 90 additions and 6 deletions

View File

@@ -657,9 +657,11 @@ class ModelScanner:
# Update folders list
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await self._cache.resort()
await self._persist_current_cache()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
except Exception as e:
@@ -1087,7 +1089,7 @@ class ModelScanner:
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
"""Update cache after a model has been moved or modified"""
cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
@@ -1099,10 +1101,12 @@ class ModelScanner:
self._hash_index.remove_by_path(original_path)
cache.raw_data = [
item for item in cache.raw_data
item for item in cache.raw_data
if item['file_path'] != original_path
]
cache_modified = bool(existing_item) or bool(metadata)
if metadata:
normalized_new_path = new_path.replace(os.sep, '/')
if original_path == new_path and existing_item:
@@ -1129,7 +1133,10 @@ class ModelScanner:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
await cache.resort()
if cache_modified:
await self._persist_current_cache()
return True
def has_hash(self, sha256: str) -> bool:
@@ -1363,7 +1370,9 @@ class ModelScanner:
# Resort cache
await self._cache.resort()
await self._persist_current_cache()
return True
except Exception as e:

View File

@@ -1,5 +1,6 @@
import asyncio
import os
import sqlite3
from pathlib import Path
from typing import List
@@ -304,3 +305,77 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
assert scanner._tags_count == {'alpha': 1}
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
assert ws_stub.payloads[-1]['progress'] == 1
@pytest.mark.asyncio
async def test_update_single_model_cache_persists_changes(tmp_path: Path, monkeypatch):
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
db_path = tmp_path / 'cache.sqlite'
monkeypatch.setenv('LORA_MANAGER_CACHE_DB', str(db_path))
monkeypatch.setattr(PersistentModelCache, '_instance', None, raising=False)
_create_files(tmp_path)
scanner = DummyScanner(tmp_path)
await scanner._initialize_cache()
normalized = _normalize_path(tmp_path / 'one.txt')
updated_metadata = {
'file_path': normalized,
'file_name': 'one',
'model_name': 'renamed',
'sha256': 'hash-one',
'tags': ['gamma', 'delta'],
'size': 42,
'modified': 456.0,
'base_model': 'base',
'from_civitai': True,
}
await scanner.update_single_model_cache(normalized, normalized, updated_metadata)
with sqlite3.connect(db_path) as conn:
conn.row_factory = sqlite3.Row
row = conn.execute(
"SELECT model_name FROM models WHERE file_path = ?",
(normalized,),
).fetchone()
assert row is not None
assert row['model_name'] == 'renamed'
tags = {
record['tag']
for record in conn.execute(
"SELECT tag FROM model_tags WHERE file_path = ?",
(normalized,),
)
}
assert tags == {'gamma', 'delta'}
@pytest.mark.asyncio
async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
db_path = tmp_path / 'cache.sqlite'
monkeypatch.setenv('LORA_MANAGER_CACHE_DB', str(db_path))
monkeypatch.setattr(PersistentModelCache, '_instance', None, raising=False)
first, _, _ = _create_files(tmp_path)
scanner = DummyScanner(tmp_path)
await scanner._initialize_cache()
normalized = _normalize_path(first)
removed = await scanner._batch_update_cache_for_deleted_models([normalized])
assert removed is True
with sqlite3.connect(db_path) as conn:
remaining = conn.execute(
"SELECT COUNT(*) FROM models WHERE file_path = ?",
(normalized,),
).fetchone()[0]
assert remaining == 0