checkpoint

This commit is contained in:
Will Miao
2025-04-10 09:08:36 +08:00
parent 64c9e4aeca
commit 8fdfb68741
12 changed files with 1397 additions and 484 deletions

View File

@@ -2,12 +2,12 @@ import logging
import os
import hashlib
import json
from typing import Dict, Optional
import time
from typing import Dict, Optional, Type
from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata
from .models import LoraMetadata
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str:
"""Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
for byte_block in iter(lambda: f.read(128 * 1024), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
@@ -42,8 +42,8 @@ def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
"""Get basic file information as LoraMetadata object"""
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Get basic file information as a model metadata object"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
@@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
try:
# If we didn't get SHA256 from the .json file, calculate it
if not sha256:
start_time = time.time()
sha256 = await calculate_sha256(real_path)
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
# Create default metadata based on model class
if model_class == CheckpointMetadata:
metadata = CheckpointMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint"
)
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="",
notes="",
from_civitai=True,
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract checkpoint-specific metadata
# model_info = await extract_checkpoint_metadata(real_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
else: # Default to LoraMetadata
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="{}",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract lora-specific metadata
model_info = await extract_lora_metadata(real_path)
metadata.base_model = model_info['base_model']
# create metadata file
base_model_info = await extract_lora_metadata(real_path)
metadata.base_model = base_model_info['base_model']
# Save metadata to file
await save_metadata(file_path, metadata)
return metadata
@@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
logger.error(f"Error getting file info for {file_path}: {e}")
return None
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
"""Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -162,12 +187,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
if 'modelDescription' not in data:
data['modelDescription'] = ""
needs_update = True
# For checkpoint metadata
if model_class == CheckpointMetadata and 'model_type' not in data:
data['model_type'] = "checkpoint"
needs_update = True
# For lora metadata
if model_class == LoraMetadata and 'usage_tips' not in data:
data['usage_tips'] = "{}"
needs_update = True
if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return LoraMetadata.from_dict(data)
return model_class.from_dict(data)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}")

View File

@@ -1,6 +1,7 @@
from safetensors import safe_open
from typing import Dict
from .model_utils import determine_base_model
import os
async def extract_lora_metadata(file_path: str) -> Dict:
"""Extract essential metadata from safetensors file"""
@@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict:
return {"base_model": base_model}
except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}")
return {"base_model": "Unknown"}
return {"base_model": "Unknown"}
async def extract_checkpoint_metadata(file_path: str) -> dict:
"""Extract metadata from a checkpoint file to determine model type and base model"""
try:
# Analyze filename for clues about the model
filename = os.path.basename(file_path).lower()
model_info = {
'base_model': 'Unknown',
'model_type': 'checkpoint'
}
# Detect base model from filename
if 'xl' in filename or 'sdxl' in filename:
model_info['base_model'] = 'SDXL'
elif 'sd3' in filename:
model_info['base_model'] = 'SD3'
elif 'sd2' in filename or 'v2' in filename:
model_info['base_model'] = 'SD2.x'
elif 'sd1' in filename or 'v1' in filename:
model_info['base_model'] = 'SD1.5'
# Detect model type from filename
if 'inpaint' in filename:
model_info['model_type'] = 'inpainting'
elif 'anime' in filename:
model_info['model_type'] = 'anime'
elif 'realistic' in filename:
model_info['model_type'] = 'realistic'
# Try to peek at the safetensors file structure if available
if file_path.endswith('.safetensors'):
import json
import struct
with open(file_path, 'rb') as f:
header_size = struct.unpack('<Q', f.read(8))[0]
header_json = f.read(header_size)
header = json.loads(header_json)
# Look for specific keys to identify model type
metadata = header.get('__metadata__', {})
if metadata:
# Try to determine if it's SDXL
if any(key.startswith('conditioner.embedders.1') for key in header):
model_info['base_model'] = 'SDXL'
# Look for model type info
if metadata.get('modelspec.architecture') == 'SD-XL':
model_info['base_model'] = 'SDXL'
elif metadata.get('modelspec.architecture') == 'SD-3':
model_info['base_model'] = 'SD3'
# Check for specific use case
if metadata.get('modelspec.purpose') == 'inpainting':
model_info['model_type'] = 'inpainting'
return model_info
except Exception as e:
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}

View File

@@ -5,20 +5,19 @@ import os
from .model_utils import determine_base_model
@dataclass
class LoraMetadata:
"""Represents the metadata structure for a Lora model"""
file_name: str # The filename without extension of the lora
model_name: str # The lora's name defined by the creator, initially same as file_name
file_path: str # Full path to the safetensors file
class BaseModelMetadata:
"""Base class for all model metadata structures"""
file_name: str # The filename without extension
model_name: str # The model's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image
usage_tips: str = "{}" # Usage tips for the model, json string
notes: str = "" # Additional notes
from_civitai: bool = True # Whether the lora is from Civitai
from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
@@ -29,32 +28,11 @@ class LoraMetadata:
self.tags = []
@classmethod
def from_dict(cls, data: Dict) -> 'LoraMetadata':
"""Create LoraMetadata instance from dictionary"""
# Create a copy of the data to avoid modifying the input
def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
"""Create instance from dictionary"""
data_copy = data.copy()
return cls(**data_copy)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
from_civitai=True,
civitai=version_info
)
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization"""
return asdict(self)
@@ -76,30 +54,54 @@ class LoraMetadata:
self.file_path = file_path.replace(os.sep, '/')
@dataclass
class CheckpointMetadata:
"""Represents the metadata structure for a Checkpoint model"""
file_name: str # The filename without extension
model_name: str # The checkpoint's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
notes: str = "" # Additional notes
from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
# Additional checkpoint-specific fields
resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024)
vae_included: bool = False # Whether VAE is included in the checkpoint
architecture: str = "" # Model architecture (if known)
def __post_init__(self):
if self.tags is None:
self.tags = []
class LoraMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Lora model"""
usage_tips: str = "{}" # Usage tips for the model, json string
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download
from_civitai=True,
civitai=version_info
)
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
"""Create CheckpointMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type
)