From d8e13de096181782f4518bba929ea9382c175330 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 4 Aug 2025 17:06:46 +0800 Subject: [PATCH] feat: enhance metadata adjustment in CheckpointScanner and ModelScanner for improved model type handling --- py/services/checkpoint_scanner.py | 8 ++++++++ py/services/model_scanner.py | 8 +++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index d4696631..c81be7ef 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -21,6 +21,14 @@ class CheckpointScanner(ModelScanner): hash_index=ModelHashIndex() ) + def adjust_metadata(self, metadata, file_path, root_path): + if hasattr(metadata, "model_type"): + if root_path in config.checkpoints_roots: + metadata.model_type = "checkpoint" + elif root_path in config.unet_roots: + metadata.model_type = "diffusion_model" + return metadata + def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" return config.base_models_roots \ No newline at end of file diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 66961a62..bf58edde 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -604,7 +604,10 @@ class ModelScanner: return os.path.dirname(rel_path).replace(os.path.sep, '/') return '' - # Common methods shared between scanners + def adjust_metadata(self, metadata, file_path, root_path): + """Hook for subclasses: adjust metadata during scanning""" + return metadata + async def _process_model_file(self, file_path: str, root_path: str) -> Dict: """Process a single model file and return its metadata""" metadata = await MetadataManager.load_metadata(file_path, self.model_class) @@ -658,6 +661,9 @@ class ModelScanner: if metadata is None: metadata = await self._create_default_metadata(file_path) + # Hook: allow subclasses to adjust metadata + metadata = self.adjust_metadata(metadata, file_path, root_path) + model_data = metadata.to_dict() # Skip excluded models