checkpoint

This commit is contained in:
Will Miao
2025-04-10 09:08:36 +08:00
parent 64c9e4aeca
commit 8fdfb68741
12 changed files with 1397 additions and 484 deletions

View File

@@ -2,12 +2,12 @@ import logging
import os
import hashlib
import json
from typing import Dict, Optional
import time
from typing import Dict, Optional, Type
from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata
from .models import LoraMetadata
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str:
"""Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
for byte_block in iter(lambda: f.read(128 * 1024), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
@@ -42,8 +42,8 @@ def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
"""Get basic file information as LoraMetadata object"""
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Get basic file information as a model metadata object"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
@@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
try:
# If we didn't get SHA256 from the .json file, calculate it
if not sha256:
start_time = time.time()
sha256 = await calculate_sha256(real_path)
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
# Create default metadata based on model class
if model_class == CheckpointMetadata:
metadata = CheckpointMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint"
)
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="",
notes="",
from_civitai=True,
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract checkpoint-specific metadata
# model_info = await extract_checkpoint_metadata(real_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
else: # Default to LoraMetadata
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="{}",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract lora-specific metadata
model_info = await extract_lora_metadata(real_path)
metadata.base_model = model_info['base_model']
# create metadata file
base_model_info = await extract_lora_metadata(real_path)
metadata.base_model = base_model_info['base_model']
# Save metadata to file
await save_metadata(file_path, metadata)
return metadata
@@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
logger.error(f"Error getting file info for {file_path}: {e}")
return None
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
"""Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -162,12 +187,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
if 'modelDescription' not in data:
data['modelDescription'] = ""
needs_update = True
# For checkpoint metadata
if model_class == CheckpointMetadata and 'model_type' not in data:
data['model_type'] = "checkpoint"
needs_update = True
# For lora metadata
if model_class == LoraMetadata and 'usage_tips' not in data:
data['usage_tips'] = "{}"
needs_update = True
if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return LoraMetadata.from_dict(data)
return model_class.from_dict(data)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}")