mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
Reorganize python files
This commit is contained in:
1
py/utils/__init__.py
Normal file
1
py/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty file to mark directory as Python package
|
||||
137
py/utils/file_utils.py
Normal file
137
py/utils/file_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .lora_metadata import extract_lora_metadata
|
||||
from .models import LoraMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def calculate_sha256(file_path: str) -> str:
|
||||
"""Calculate SHA256 hash of a file"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
def _find_preview_file(base_name: str, dir_path: str) -> str:
|
||||
"""Find preview file for given base name in directory"""
|
||||
preview_patterns = [
|
||||
f"{base_name}.preview.png",
|
||||
f"{base_name}.preview.jpg",
|
||||
f"{base_name}.preview.jpeg",
|
||||
f"{base_name}.preview.mp4",
|
||||
f"{base_name}.png",
|
||||
f"{base_name}.jpg",
|
||||
f"{base_name}.jpeg",
|
||||
f"{base_name}.mp4"
|
||||
]
|
||||
|
||||
for pattern in preview_patterns:
|
||||
full_pattern = os.path.join(dir_path, pattern)
|
||||
if os.path.exists(full_pattern):
|
||||
return full_pattern.replace(os.sep, "/")
|
||||
return ""
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
"""Normalize file path to use forward slashes"""
|
||||
return path.replace(os.sep, "/") if path else path
|
||||
|
||||
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
||||
"""Get basic file information as LoraMetadata object"""
|
||||
# First check if file actually exists and resolve symlinks
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking file existence for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
dir_path = os.path.dirname(file_path)
|
||||
|
||||
preview_url = _find_preview_file(base_name, dir_path)
|
||||
|
||||
try:
|
||||
metadata = LoraMetadata(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
sha256=await calculate_sha256(real_path),
|
||||
base_model="Unknown", # Will be updated later
|
||||
usage_tips="",
|
||||
notes="",
|
||||
from_civitai=True,
|
||||
preview_url=normalize_path(preview_url),
|
||||
)
|
||||
|
||||
# create metadata file
|
||||
base_model_info = await extract_lora_metadata(real_path)
|
||||
metadata.base_model = base_model_info['base_model']
|
||||
await save_metadata(file_path, metadata)
|
||||
|
||||
return metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file info for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
||||
"""Save metadata to .metadata.json file"""
|
||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
try:
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict['file_path'] = normalize_path(metadata_dict['file_path'])
|
||||
metadata_dict['preview_url'] = normalize_path(metadata_dict['preview_url'])
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata_dict, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
print(f"Error saving metadata to {metadata_path}: {str(e)}")
|
||||
|
||||
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
||||
"""Load metadata from .metadata.json file"""
|
||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
try:
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
needs_update = False
|
||||
|
||||
if data['file_path'] != normalize_path(data['file_path']):
|
||||
data['file_path'] = normalize_path(data['file_path'])
|
||||
needs_update = True
|
||||
|
||||
preview_url = data.get('preview_url', '')
|
||||
if not preview_url or not os.path.exists(preview_url):
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
dir_path = os.path.dirname(file_path)
|
||||
new_preview_url = normalize_path(_find_preview_file(base_name, dir_path))
|
||||
if new_preview_url != preview_url:
|
||||
data['preview_url'] = new_preview_url
|
||||
needs_update = True
|
||||
elif preview_url != normalize_path(preview_url):
|
||||
data['preview_url'] = normalize_path(preview_url)
|
||||
needs_update = True
|
||||
|
||||
if needs_update:
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return LoraMetadata.from_dict(data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading metadata from {metadata_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def update_civitai_metadata(file_path: str, civitai_data: Dict) -> None:
|
||||
"""Update metadata file with Civitai data"""
|
||||
metadata = await load_metadata(file_path)
|
||||
metadata['civitai'] = civitai_data
|
||||
await save_metadata(file_path, metadata)
|
||||
16
py/utils/lora_metadata.py
Normal file
16
py/utils/lora_metadata.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from safetensors import safe_open
|
||||
from typing import Dict
|
||||
from .model_utils import determine_base_model
|
||||
|
||||
async def extract_lora_metadata(file_path: str) -> Dict:
|
||||
"""Extract essential metadata from safetensors file"""
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata:
|
||||
# Only extract base_model from ss_base_model_version
|
||||
base_model = determine_base_model(metadata.get("ss_base_model_version"))
|
||||
return {"base_model": base_model}
|
||||
except Exception as e:
|
||||
print(f"Error reading metadata from {file_path}: {str(e)}")
|
||||
return {"base_model": "Unknown"}
|
||||
25
py/utils/model_utils.py
Normal file
25
py/utils/model_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
# Base model mapping based on version string
|
||||
BASE_MODEL_MAPPING = {
|
||||
"sd-v1-5": "SD1.5",
|
||||
"sd-v2-1": "SD2.1",
|
||||
"sdxl": "SDXL",
|
||||
"sd-v2": "SD2.0",
|
||||
"flux1": "Flux.1 D",
|
||||
"flux.1 d": "Flux.1 D",
|
||||
"illustrious": "IL",
|
||||
"pony": "Pony"
|
||||
}
|
||||
|
||||
def determine_base_model(version_string: Optional[str]) -> str:
|
||||
"""Determine base model from version string in safetensors metadata"""
|
||||
if not version_string:
|
||||
return "Unknown"
|
||||
|
||||
version_lower = version_string.lower()
|
||||
for key, value in BASE_MODEL_MAPPING.items():
|
||||
if key in version_lower:
|
||||
return value
|
||||
|
||||
return "Unknown"
|
||||
68
py/utils/models.py
Normal file
68
py/utils/models.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
from .model_utils import determine_base_model
|
||||
|
||||
@dataclass
|
||||
class LoraMetadata:
|
||||
"""Represents the metadata structure for a Lora model"""
|
||||
file_name: str # The filename without extension of the lora
|
||||
model_name: str # The lora's name defined by the creator, initially same as file_name
|
||||
file_path: str # Full path to the safetensors file
|
||||
size: int # File size in bytes
|
||||
modified: float # Last modified timestamp
|
||||
sha256: str # SHA256 hash of the file
|
||||
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
|
||||
preview_url: str # Preview image URL
|
||||
usage_tips: str = "{}" # Usage tips for the model, json string
|
||||
notes: str = "" # Additional notes
|
||||
from_civitai: bool = True # Whether the lora is from Civitai
|
||||
civitai: Optional[Dict] = None # Civitai API data if available
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'LoraMetadata':
|
||||
"""Create LoraMetadata instance from dictionary"""
|
||||
# Create a copy of the data to avoid modifying the input
|
||||
data_copy = data.copy()
|
||||
return cls(**data_copy)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
|
||||
"""Create LoraMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', ''),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
from_civitai=True,
|
||||
civitai=version_info
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for JSON serialization"""
|
||||
return asdict(self)
|
||||
|
||||
@property
|
||||
def modified_datetime(self) -> datetime:
|
||||
"""Convert modified timestamp to datetime object"""
|
||||
return datetime.fromtimestamp(self.modified)
|
||||
|
||||
def update_civitai_info(self, civitai_data: Dict) -> None:
|
||||
"""Update Civitai information"""
|
||||
self.civitai = civitai_data
|
||||
|
||||
def update_file_info(self, file_path: str) -> None:
|
||||
"""Update metadata with actual file information"""
|
||||
if os.path.exists(file_path):
|
||||
self.size = os.path.getsize(file_path)
|
||||
self.modified = os.path.getmtime(file_path)
|
||||
self.file_path = file_path.replace(os.sep, '/')
|
||||
|
||||
Reference in New Issue
Block a user