mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 06:02:11 -03:00
39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
import logging
|
|
from typing import List
|
|
|
|
from ..utils.models import CheckpointMetadata
|
|
from ..config import config
|
|
from .model_scanner import ModelScanner
|
|
from .model_hash_index import ModelHashIndex
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CheckpointScanner(ModelScanner):
|
|
"""Service for scanning and managing checkpoint files"""
|
|
|
|
def __init__(self):
|
|
# Define supported file extensions
|
|
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
|
super().__init__(
|
|
model_type="checkpoint",
|
|
model_class=CheckpointMetadata,
|
|
file_extensions=file_extensions,
|
|
hash_index=ModelHashIndex()
|
|
)
|
|
|
|
def get_model_roots(self) -> List[str]:
|
|
"""Get checkpoint root directories"""
|
|
return config.base_models_roots
|
|
|
|
# Checkpoint-specific hash index functionality
|
|
def has_checkpoint_hash(self, sha256: str) -> bool:
|
|
"""Check if a checkpoint with given hash exists"""
|
|
return self.has_hash(sha256)
|
|
|
|
def get_checkpoint_path_by_hash(self, sha256: str) -> str:
|
|
"""Get file path for a checkpoint by its hash"""
|
|
return self.get_path_by_hash(sha256)
|
|
|
|
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
|
|
"""Get hash for a checkpoint by its file path"""
|
|
return self.get_hash_by_path(file_path) |