mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-09 20:39:25 -03:00
feat(utils): add AutoV2 and AutoV3 hash calculation functions
This commit is contained in:
@@ -1,7 +1,10 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import struct
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .constants import (
|
from .constants import (
|
||||||
CARD_PREVIEW_WIDTH,
|
CARD_PREVIEW_WIDTH,
|
||||||
@@ -31,7 +34,7 @@ def _get_hash_chunk_size_bytes() -> int:
|
|||||||
|
|
||||||
|
|
||||||
async def calculate_sha256(file_path: str) -> str:
|
async def calculate_sha256(file_path: str) -> str:
|
||||||
"""Calculate SHA256 hash of a file"""
|
"""Calculate SHA256 hash of a file (full file content)."""
|
||||||
sha256_hash = hashlib.sha256()
|
sha256_hash = hashlib.sha256()
|
||||||
chunk_size = _get_hash_chunk_size_bytes()
|
chunk_size = _get_hash_chunk_size_bytes()
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
@@ -39,6 +42,79 @@ async def calculate_sha256(file_path: str) -> str:
|
|||||||
sha256_hash.update(byte_block)
|
sha256_hash.update(byte_block)
|
||||||
return sha256_hash.hexdigest()
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_autov2(file_path: str) -> str:
|
||||||
|
"""Calculate CivitAI AutoV2 hash.
|
||||||
|
|
||||||
|
AutoV2 is the first 10 characters of the full file SHA256.
|
||||||
|
Used by CivitAI as a shortened file identifier.
|
||||||
|
|
||||||
|
Reference: https://developer.civitai.com/site/reference/model-versions
|
||||||
|
"""
|
||||||
|
full_hash = hashlib.sha256()
|
||||||
|
chunk_size = _get_hash_chunk_size_bytes()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for byte_block in iter(lambda: f.read(chunk_size), b""):
|
||||||
|
full_hash.update(byte_block)
|
||||||
|
return full_hash.hexdigest()[:10]
|
||||||
|
|
||||||
|
|
||||||
|
def read_safetensors_metadata(file_path: str) -> dict[str, Any]:
|
||||||
|
"""Read the ``__metadata__`` dict from a safetensors file header.
|
||||||
|
|
||||||
|
Safetensors file format:
|
||||||
|
- 8 bytes: header length (little-endian 64-bit)
|
||||||
|
- N bytes: UTF-8 JSON header
|
||||||
|
- The header JSON contains a ``__metadata__`` key holding arbitrary metadata.
|
||||||
|
|
||||||
|
Returns an empty dict if the file is not a valid safetensors file or has no
|
||||||
|
metadata.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
header_len_bytes = f.read(8)
|
||||||
|
if len(header_len_bytes) < 8:
|
||||||
|
return {}
|
||||||
|
header_len = struct.unpack("<Q", header_len_bytes)[0]
|
||||||
|
header_bytes = f.read(header_len)
|
||||||
|
if len(header_bytes) < header_len:
|
||||||
|
return {}
|
||||||
|
header = json.loads(header_bytes.decode("utf-8"))
|
||||||
|
return header.get("__metadata__", {})
|
||||||
|
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_autov3(file_path: str) -> str | None:
|
||||||
|
"""Calculate CivitAI AutoV3 hash from a safetensors file.
|
||||||
|
|
||||||
|
AutoV3 is extracted from the safetensors file's embedded metadata, not
|
||||||
|
computed from the file bytes directly. The orchestrator reads the
|
||||||
|
``sshs_model_hash`` (kohya-ss format) or ``modelspec.hash_sha256`` field
|
||||||
|
from the safetensors header and stores the first 12 characters.
|
||||||
|
|
||||||
|
The embedded hash itself is the SHA256 of the file after skipping the
|
||||||
|
8-byte header length + JSON header (a.k.a. the addnet hash / tensor-only
|
||||||
|
hash).
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- CivitAI DB trigger: ``SUBSTRING(NEW.hash FROM 1 FOR 12)``
|
||||||
|
- https://developer.civitai.com/site/reference/model-versions
|
||||||
|
|
||||||
|
Returns ``None`` when no AutoV3 hash can be determined (e.g. the file is
|
||||||
|
not safetensors, or the metadata doesn't contain a recognised hash field).
|
||||||
|
"""
|
||||||
|
metadata = read_safetensors_metadata(file_path)
|
||||||
|
if not metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
embedded_hash = metadata.get("sshs_model_hash") or metadata.get("modelspec.hash_sha256")
|
||||||
|
if embedded_hash and isinstance(embedded_hash, str) and len(embedded_hash) >= 12:
|
||||||
|
return embedded_hash[:12]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def find_preview_file(base_name: str, dir_path: str) -> str:
|
def find_preview_file(base_name: str, dir_path: str) -> str:
|
||||||
"""Find preview file for given base name in directory.
|
"""Find preview file for given base name in directory.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user