test: fix duplicate pytest import

This commit is contained in:
pixelpaws
2025-09-25 16:44:43 +08:00
parent 66abf736c9
commit fc9db4510f
6 changed files with 76 additions and 122 deletions

View File

@@ -1,13 +1,45 @@
import asyncio
import importlib.util
import inspect
import sys
import types import types
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, List, Optional, Sequence
import asyncio
import inspect
from unittest import mock from unittest import mock
import sys
import pytest import pytest
REPO_ROOT = Path(__file__).resolve().parents[1]
PY_INIT = REPO_ROOT / "py" / "__init__.py"
def _load_repo_package(name: str) -> types.ModuleType:
"""Ensure the repository's ``py`` package is importable under *name*."""
module = sys.modules.get(name)
if module and getattr(module, "__file__", None) == str(PY_INIT):
return module
spec = importlib.util.spec_from_file_location(
name,
PY_INIT,
submodule_search_locations=[str(PY_INIT.parent)],
)
if spec is None or spec.loader is None: # pragma: no cover - initialization guard
raise ImportError(f"Unable to load repository package for alias '{name}'")
package = importlib.util.module_from_spec(spec)
spec.loader.exec_module(package) # type: ignore[attr-defined]
package.__path__ = [str(PY_INIT.parent)] # type: ignore[attr-defined]
sys.modules[name] = package
return package
_repo_package = _load_repo_package("py")
sys.modules.setdefault("py_local", _repo_package)
# Mock ComfyUI modules before any imports from the main project # Mock ComfyUI modules before any imports from the main project
server_mock = types.SimpleNamespace() server_mock = types.SimpleNamespace()
server_mock.PromptServer = mock.MagicMock() server_mock.PromptServer = mock.MagicMock()

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import importlib.util
import json import json
import os import os
import sys import sys
@@ -14,25 +13,13 @@ import pytest
from aiohttp import FormData, web from aiohttp import FormData, web
from aiohttp.test_utils import TestClient, TestServer from aiohttp.test_utils import TestClient, TestServer
REPO_ROOT = Path(__file__).resolve().parents[2] from py.config import config
PY_PACKAGE_PATH = REPO_ROOT / "py" from py.routes.base_model_routes import BaseModelRoutes
from py.services import model_file_service
spec = importlib.util.spec_from_file_location( from py.services.model_file_service import AutoOrganizeResult
"py_local", from py.services.service_registry import ServiceRegistry
PY_PACKAGE_PATH / "__init__.py", from py.services.websocket_manager import ws_manager
submodule_search_locations=[str(PY_PACKAGE_PATH)], from py.utils.exif_utils import ExifUtils
)
py_local = importlib.util.module_from_spec(spec)
assert spec.loader is not None # for mypy/static analyzers
spec.loader.exec_module(py_local)
sys.modules.setdefault("py_local", py_local)
from py_local.routes.base_model_routes import BaseModelRoutes
from py_local.services.model_file_service import AutoOrganizeResult
from py_local.services.service_registry import ServiceRegistry
from py_local.services.websocket_manager import ws_manager
from py_local.utils.exif_utils import ExifUtils
from py_local.config import config
class DummyRoutes(BaseModelRoutes): class DummyRoutes(BaseModelRoutes):
@@ -345,7 +332,7 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo
return result return result
monkeypatch.setattr( monkeypatch.setattr(
py_local.services.model_file_service.ModelFileService, model_file_service.ModelFileService,
"auto_organize_models", "auto_organize_models",
fake_auto_organize, fake_auto_organize,
) )

View File

@@ -8,41 +8,23 @@ logic.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import importlib.util
import sys
import types import types
from collections import Counter from collections import Counter
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict from typing import Any, Awaitable, Callable, Dict
import pytest import pytest
from aiohttp import web from aiohttp import web
from py.routes import base_recipe_routes, recipe_route_registrar, recipe_routes
REPO_ROOT = Path(__file__).resolve().parents[2] from py.services import service_registry
PY_PACKAGE_PATH = REPO_ROOT / "py" from py.services.server_i18n import server_i18n
spec = importlib.util.spec_from_file_location(
"py_local",
PY_PACKAGE_PATH / "__init__.py",
submodule_search_locations=[str(PY_PACKAGE_PATH)],
)
py_local = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(py_local)
sys.modules.setdefault("py_local", py_local)
base_routes_module = importlib.import_module("py_local.routes.base_recipe_routes")
recipe_routes_module = importlib.import_module("py_local.routes.recipe_routes")
registrar_module = importlib.import_module("py_local.routes.recipe_route_registrar")
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_service_registry(monkeypatch: pytest.MonkeyPatch): def reset_service_registry(monkeypatch: pytest.MonkeyPatch):
"""Ensure each test starts from a clean registry state.""" """Ensure each test starts from a clean registry state."""
services_module = importlib.import_module("py_local.services.service_registry") registry = service_registry.ServiceRegistry
registry = services_module.ServiceRegistry
previous_services = dict(registry._services) previous_services = dict(registry._services)
previous_locks = dict(registry._locks) previous_locks = dict(registry._locks)
registry._services.clear() registry._services.clear()
@@ -74,10 +56,7 @@ def _make_stub_scanner():
def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPatch): def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPatch):
base_module = base_routes_module registry = service_registry.ServiceRegistry
services_module = importlib.import_module("py_local.services.service_registry")
registry = services_module.ServiceRegistry
server_i18n = importlib.import_module("py_local.services.server_i18n").server_i18n
scanner = _make_stub_scanner() scanner = _make_stub_scanner()
civitai_client = object() civitai_client = object()
@@ -98,7 +77,7 @@ def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPa
monkeypatch.setattr(server_i18n, "create_template_filter", fake_create_filter) monkeypatch.setattr(server_i18n, "create_template_filter", fake_create_filter)
async def scenario(): async def scenario():
routes = base_module.BaseRecipeRoutes() routes = base_recipe_routes.BaseRecipeRoutes()
await routes.attach_dependencies() await routes.attach_dependencies()
await routes.attach_dependencies() # idempotent await routes.attach_dependencies() # idempotent
@@ -113,7 +92,7 @@ def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPa
def test_register_startup_hooks_appends_once(): def test_register_startup_hooks_appends_once():
routes = base_routes_module.BaseRecipeRoutes() routes = base_recipe_routes.BaseRecipeRoutes()
app = web.Application() app = web.Application()
routes.register_startup_hooks(app) routes.register_startup_hooks(app)
@@ -141,7 +120,7 @@ def test_to_route_mapping_uses_handler_set():
return {"render_page": render_page} return {"render_page": render_page}
class DummyRoutes(base_routes_module.BaseRecipeRoutes): class DummyRoutes(base_recipe_routes.BaseRecipeRoutes):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.created = 0 self.created = 0
@@ -184,11 +163,11 @@ def test_recipe_route_registrar_binds_every_route():
self.router = FakeRouter() self.router = FakeRouter()
app = FakeApp() app = FakeApp()
registrar = registrar_module.RecipeRouteRegistrar(app) registrar = recipe_route_registrar.RecipeRouteRegistrar(app)
handler_mapping = { handler_mapping = {
definition.handler_name: object() definition.handler_name: object()
for definition in registrar_module.ROUTE_DEFINITIONS for definition in recipe_route_registrar.ROUTE_DEFINITIONS
} }
registrar.register_routes(handler_mapping) registrar.register_routes(handler_mapping)
@@ -196,7 +175,7 @@ def test_recipe_route_registrar_binds_every_route():
assert { assert {
(method, path) (method, path)
for method, path, _ in app.router.calls for method, path, _ in app.router.calls
} == {(d.method, d.path) for d in registrar_module.ROUTE_DEFINITIONS} } == {(d.method, d.path) for d in recipe_route_registrar.ROUTE_DEFINITIONS}
def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPatch): def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPatch):
@@ -209,28 +188,28 @@ def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPat
def register_routes(self, mapping): def register_routes(self, mapping):
registered_mappings.append(mapping) registered_mappings.append(mapping)
monkeypatch.setattr(recipe_routes_module, "RecipeRouteRegistrar", DummyRegistrar) monkeypatch.setattr(recipe_routes, "RecipeRouteRegistrar", DummyRegistrar)
expected_mapping = {name: object() for name in ("render_page", "list_recipes")} expected_mapping = {name: object() for name in ("render_page", "list_recipes")}
def fake_to_route_mapping(self): def fake_to_route_mapping(self):
return expected_mapping return expected_mapping
monkeypatch.setattr(base_routes_module.BaseRecipeRoutes, "to_route_mapping", fake_to_route_mapping) monkeypatch.setattr(base_recipe_routes.BaseRecipeRoutes, "to_route_mapping", fake_to_route_mapping)
monkeypatch.setattr( monkeypatch.setattr(
base_routes_module.BaseRecipeRoutes, base_recipe_routes.BaseRecipeRoutes,
"_HANDLER_NAMES", "_HANDLER_NAMES",
tuple(expected_mapping.keys()), tuple(expected_mapping.keys()),
) )
app = web.Application() app = web.Application()
recipe_routes_module.RecipeRoutes.setup_routes(app) recipe_routes.RecipeRoutes.setup_routes(app)
assert registered_mappings == [expected_mapping] assert registered_mappings == [expected_mapping]
recipe_callbacks = { recipe_callbacks = {
cb cb
for cb in app.on_startup for cb in app.on_startup
if isinstance(getattr(cb, "__self__", None), recipe_routes_module.RecipeRoutes) if isinstance(getattr(cb, "__self__", None), recipe_routes.RecipeRoutes)
} }
assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes_module.RecipeRoutes} assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes.RecipeRoutes}
assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"} assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"}

View File

@@ -1,35 +1,13 @@
import pytest import pytest
import importlib from py.services.base_model_service import BaseModelService
import importlib.util from py.services.model_query import (
import sys ModelCacheRepository,
from pathlib import Path ModelFilterSet,
SearchStrategy,
ROOT = Path(__file__).resolve().parents[2] SortParams,
if str(ROOT) not in sys.path: )
sys.path.insert(0, str(ROOT)) from py.utils.models import BaseModelMetadata
def import_from(module_name: str):
existing = sys.modules.get("py")
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
sys.modules.pop("py", None)
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module) # type: ignore[union-attr]
module.__path__ = [str(ROOT / "py")]
sys.modules["py"] = module
return importlib.import_module(module_name)
BaseModelService = import_from("py.services.base_model_service").BaseModelService
model_query_module = import_from("py.services.model_query")
ModelCacheRepository = model_query_module.ModelCacheRepository
ModelFilterSet = model_query_module.ModelFilterSet
SearchStrategy = model_query_module.SearchStrategy
SortParams = model_query_module.SortParams
BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata
class StubSettings: class StubSettings:

View File

@@ -1,37 +1,15 @@
import asyncio import asyncio
import json import json
import os import os
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import importlib
import importlib.util
import pytest import pytest
from py.services.download_coordinator import DownloadCoordinator
def import_from(module_name: str): from py.services.metadata_sync_service import MetadataSyncService
existing = sys.modules.get("py") from py.services.preview_asset_service import PreviewAssetService
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): from py.services.tag_update_service import TagUpdateService
sys.modules.pop("py", None)
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module) # type: ignore[union-attr]
module.__path__ = [str(ROOT / "py")]
sys.modules["py"] = module
return importlib.import_module(module_name)
DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator
MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService
PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService
TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService
class DummySettings: class DummySettings:

View File

@@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional
import pytest import pytest
from py_local.services.model_file_service import AutoOrganizeResult from py.services.model_file_service import AutoOrganizeResult
from py_local.services.use_cases import ( from py.services.use_cases import (
AutoOrganizeInProgressError, AutoOrganizeInProgressError,
AutoOrganizeUseCase, AutoOrganizeUseCase,
BulkMetadataRefreshUseCase, BulkMetadataRefreshUseCase,
@@ -19,12 +19,12 @@ from py_local.services.use_cases import (
ImportExampleImagesUseCase, ImportExampleImagesUseCase,
ImportExampleImagesValidationError, ImportExampleImagesValidationError,
) )
from py_local.utils.example_images_download_manager import ( from py.utils.example_images_download_manager import (
DownloadConfigurationError, DownloadConfigurationError,
DownloadInProgressError, DownloadInProgressError,
ExampleImagesDownloadError, ExampleImagesDownloadError,
) )
from py_local.utils.example_images_processor import ( from py.utils.example_images_processor import (
ExampleImagesImportError, ExampleImagesImportError,
ExampleImagesValidationError, ExampleImagesValidationError,
) )