mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
checkpoint
This commit is contained in:
@@ -2,12 +2,12 @@ import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
import time
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from .model_utils import determine_base_model
|
||||
|
||||
from .lora_metadata import extract_lora_metadata
|
||||
from .models import LoraMetadata
|
||||
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
|
||||
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str:
|
||||
"""Calculate SHA256 hash of a file"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
for byte_block in iter(lambda: f.read(128 * 1024), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@@ -42,8 +42,8 @@ def normalize_path(path: str) -> str:
|
||||
"""Normalize file path to use forward slashes"""
|
||||
return path.replace(os.sep, "/") if path else path
|
||||
|
||||
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
||||
"""Get basic file information as LoraMetadata object"""
|
||||
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
|
||||
"""Get basic file information as a model metadata object"""
|
||||
# First check if file actually exists and resolve symlinks
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
@@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
||||
try:
|
||||
# If we didn't get SHA256 from the .json file, calculate it
|
||||
if not sha256:
|
||||
start_time = time.time()
|
||||
sha256 = await calculate_sha256(real_path)
|
||||
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
# Create default metadata based on model class
|
||||
if model_class == CheckpointMetadata:
|
||||
metadata = CheckpointMetadata(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
sha256=sha256,
|
||||
base_model="Unknown", # Will be updated later
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription="",
|
||||
model_type="checkpoint"
|
||||
)
|
||||
|
||||
metadata = LoraMetadata(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
sha256=sha256,
|
||||
base_model="Unknown", # Will be updated later
|
||||
usage_tips="",
|
||||
notes="",
|
||||
from_civitai=True,
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription=""
|
||||
)
|
||||
# Extract checkpoint-specific metadata
|
||||
# model_info = await extract_checkpoint_metadata(real_path)
|
||||
# metadata.base_model = model_info['base_model']
|
||||
# if 'model_type' in model_info:
|
||||
# metadata.model_type = model_info['model_type']
|
||||
|
||||
else: # Default to LoraMetadata
|
||||
metadata = LoraMetadata(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
sha256=sha256,
|
||||
base_model="Unknown", # Will be updated later
|
||||
usage_tips="{}",
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription=""
|
||||
)
|
||||
|
||||
# Extract lora-specific metadata
|
||||
model_info = await extract_lora_metadata(real_path)
|
||||
metadata.base_model = model_info['base_model']
|
||||
|
||||
# create metadata file
|
||||
base_model_info = await extract_lora_metadata(real_path)
|
||||
metadata.base_model = base_model_info['base_model']
|
||||
# Save metadata to file
|
||||
await save_metadata(file_path, metadata)
|
||||
|
||||
return metadata
|
||||
@@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
||||
logger.error(f"Error getting file info for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
||||
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
|
||||
"""Save metadata to .metadata.json file"""
|
||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
try:
|
||||
@@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
||||
except Exception as e:
|
||||
print(f"Error saving metadata to {metadata_path}: {str(e)}")
|
||||
|
||||
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
||||
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
|
||||
"""Load metadata from .metadata.json file"""
|
||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
try:
|
||||
@@ -162,12 +187,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
||||
if 'modelDescription' not in data:
|
||||
data['modelDescription'] = ""
|
||||
needs_update = True
|
||||
|
||||
# For checkpoint metadata
|
||||
if model_class == CheckpointMetadata and 'model_type' not in data:
|
||||
data['model_type'] = "checkpoint"
|
||||
needs_update = True
|
||||
|
||||
# For lora metadata
|
||||
if model_class == LoraMetadata and 'usage_tips' not in data:
|
||||
data['usage_tips'] = "{}"
|
||||
needs_update = True
|
||||
|
||||
if needs_update:
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return LoraMetadata.from_dict(data)
|
||||
return model_class.from_dict(data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading metadata from {metadata_path}: {str(e)}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from safetensors import safe_open
|
||||
from typing import Dict
|
||||
from .model_utils import determine_base_model
|
||||
import os
|
||||
|
||||
async def extract_lora_metadata(file_path: str) -> Dict:
|
||||
"""Extract essential metadata from safetensors file"""
|
||||
@@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict:
|
||||
return {"base_model": base_model}
|
||||
except Exception as e:
|
||||
print(f"Error reading metadata from {file_path}: {str(e)}")
|
||||
return {"base_model": "Unknown"}
|
||||
return {"base_model": "Unknown"}
|
||||
|
||||
async def extract_checkpoint_metadata(file_path: str) -> dict:
|
||||
"""Extract metadata from a checkpoint file to determine model type and base model"""
|
||||
try:
|
||||
# Analyze filename for clues about the model
|
||||
filename = os.path.basename(file_path).lower()
|
||||
|
||||
model_info = {
|
||||
'base_model': 'Unknown',
|
||||
'model_type': 'checkpoint'
|
||||
}
|
||||
|
||||
# Detect base model from filename
|
||||
if 'xl' in filename or 'sdxl' in filename:
|
||||
model_info['base_model'] = 'SDXL'
|
||||
elif 'sd3' in filename:
|
||||
model_info['base_model'] = 'SD3'
|
||||
elif 'sd2' in filename or 'v2' in filename:
|
||||
model_info['base_model'] = 'SD2.x'
|
||||
elif 'sd1' in filename or 'v1' in filename:
|
||||
model_info['base_model'] = 'SD1.5'
|
||||
|
||||
# Detect model type from filename
|
||||
if 'inpaint' in filename:
|
||||
model_info['model_type'] = 'inpainting'
|
||||
elif 'anime' in filename:
|
||||
model_info['model_type'] = 'anime'
|
||||
elif 'realistic' in filename:
|
||||
model_info['model_type'] = 'realistic'
|
||||
|
||||
# Try to peek at the safetensors file structure if available
|
||||
if file_path.endswith('.safetensors'):
|
||||
import json
|
||||
import struct
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
header_size = struct.unpack('<Q', f.read(8))[0]
|
||||
header_json = f.read(header_size)
|
||||
header = json.loads(header_json)
|
||||
|
||||
# Look for specific keys to identify model type
|
||||
metadata = header.get('__metadata__', {})
|
||||
if metadata:
|
||||
# Try to determine if it's SDXL
|
||||
if any(key.startswith('conditioner.embedders.1') for key in header):
|
||||
model_info['base_model'] = 'SDXL'
|
||||
|
||||
# Look for model type info
|
||||
if metadata.get('modelspec.architecture') == 'SD-XL':
|
||||
model_info['base_model'] = 'SDXL'
|
||||
elif metadata.get('modelspec.architecture') == 'SD-3':
|
||||
model_info['base_model'] = 'SD3'
|
||||
|
||||
# Check for specific use case
|
||||
if metadata.get('modelspec.purpose') == 'inpainting':
|
||||
model_info['model_type'] = 'inpainting'
|
||||
|
||||
return model_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
|
||||
# Return default values
|
||||
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
||||
@@ -5,20 +5,19 @@ import os
|
||||
from .model_utils import determine_base_model
|
||||
|
||||
@dataclass
|
||||
class LoraMetadata:
|
||||
"""Represents the metadata structure for a Lora model"""
|
||||
file_name: str # The filename without extension of the lora
|
||||
model_name: str # The lora's name defined by the creator, initially same as file_name
|
||||
file_path: str # Full path to the safetensors file
|
||||
class BaseModelMetadata:
|
||||
"""Base class for all model metadata structures"""
|
||||
file_name: str # The filename without extension
|
||||
model_name: str # The model's name defined by the creator
|
||||
file_path: str # Full path to the model file
|
||||
size: int # File size in bytes
|
||||
modified: float # Last modified timestamp
|
||||
sha256: str # SHA256 hash of the file
|
||||
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
|
||||
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
|
||||
preview_url: str # Preview image URL
|
||||
preview_nsfw_level: int = 0 # NSFW level of the preview image
|
||||
usage_tips: str = "{}" # Usage tips for the model, json string
|
||||
notes: str = "" # Additional notes
|
||||
from_civitai: bool = True # Whether the lora is from Civitai
|
||||
from_civitai: bool = True # Whether from Civitai
|
||||
civitai: Optional[Dict] = None # Civitai API data if available
|
||||
tags: List[str] = None # Model tags
|
||||
modelDescription: str = "" # Full model description
|
||||
@@ -29,32 +28,11 @@ class LoraMetadata:
|
||||
self.tags = []
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'LoraMetadata':
|
||||
"""Create LoraMetadata instance from dictionary"""
|
||||
# Create a copy of the data to avoid modifying the input
|
||||
def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
|
||||
"""Create instance from dictionary"""
|
||||
data_copy = data.copy()
|
||||
return cls(**data_copy)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
|
||||
"""Create LoraMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
|
||||
from_civitai=True,
|
||||
civitai=version_info
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return asdict(self)
|
||||
@@ -76,30 +54,54 @@ class LoraMetadata:
|
||||
self.file_path = file_path.replace(os.sep, '/')
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
"""Represents the metadata structure for a Checkpoint model"""
|
||||
file_name: str # The filename without extension
|
||||
model_name: str # The checkpoint's name defined by the creator
|
||||
file_path: str # Full path to the model file
|
||||
size: int # File size in bytes
|
||||
modified: float # Last modified timestamp
|
||||
sha256: str # SHA256 hash of the file
|
||||
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
|
||||
preview_url: str # Preview image URL
|
||||
preview_nsfw_level: int = 0 # NSFW level of the preview image
|
||||
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
|
||||
notes: str = "" # Additional notes
|
||||
from_civitai: bool = True # Whether from Civitai
|
||||
civitai: Optional[Dict] = None # Civitai API data if available
|
||||
tags: List[str] = None # Model tags
|
||||
modelDescription: str = "" # Full model description
|
||||
|
||||
# Additional checkpoint-specific fields
|
||||
resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024)
|
||||
vae_included: bool = False # Whether VAE is included in the checkpoint
|
||||
architecture: str = "" # Model architecture (if known)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tags is None:
|
||||
self.tags = []
|
||||
class LoraMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for a Lora model"""
|
||||
usage_tips: str = "{}" # Usage tips for the model, json string
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
|
||||
"""Create LoraMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0, # Will be updated after preview download
|
||||
from_civitai=True,
|
||||
civitai=version_info
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for a Checkpoint model"""
|
||||
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
|
||||
"""Create CheckpointMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
model_type = version_info.get('type', 'checkpoint')
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0,
|
||||
from_civitai=True,
|
||||
civitai=version_info,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user