Files
aitrailer/src/core/config.py
T
2026-05-02 10:09:30 +02:00

388 lines
15 KiB
Python

"""
src/core/config.py — Configuration loader for AI Trailer Generator v2
Loads config.toml and exposes typed, nested dataclasses.
All CV thresholds, paths, and model settings are sourced exclusively here.
API keys are NEVER stored in config.toml; they are loaded from .env.
"""
from __future__ import annotations
import os
import tomllib
try:
from dotenv import load_dotenv as _load_dotenv
_HAS_DOTENV = True
except ImportError: # dotenv optional — falls back to existing env vars
_HAS_DOTENV = False
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
# ---------------------------------------------------------------------------
# Leaf sections
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class PathsConfig:
source_movie: Path
reference_trailer: Path
output_dir: Path
cache_dir: Path
proxy_dir: Path
@dataclass(frozen=True)
class VideoConfig:
extract_fps: float
proxy_width: int
proxy_height: int
@dataclass(frozen=True)
class VibeCheckConfig:
top_k_candidates: int
hist_compare_method: int
hist_bins_hue: int
hist_bins_saturation: int
phash_max_distance: int
crop_top_fraction: float
crop_bottom_fraction: float
@dataclass(frozen=True)
class DeepScanConfig:
coarse_step_seconds: float
match_threshold: float
provisional_match_threshold: float
coarse_candidate_threshold: float
sequence_score_weight: float
span_score_weight: float
coarse_score_weight: float
duration_score_weight: float
duration_tie_break_score_delta: float
min_duration_coverage: float
continuity_seed_offsets_s: tuple[float, ...]
scene_seed_top_k: int
scene_seed_points_per_scene: int
content_rerank_candidate_count: int
skip_coarse_scan_with_weighted_seeds: bool
max_refine_candidates: int
match_method: int
refine_window_seconds: float
refine_step_seconds: float
content_align_window_seconds: float
content_align_sample_step_s: float
content_validation_weight: float
provisional_content_threshold: float
start_tie_break_score_delta: float
start_preroll_frames: int
sequence_candidate_count: int
sequence_min_distance_s: float
span_sample_step_s: float
trim_tail_frames: int
scene_boundary_epsilon_s: float
scoreable_luma_mean_min: float
scoreable_luma_p90_min: float
scoreable_contrast_min: float
@dataclass(frozen=True)
class CVConfig:
vibe_check: VibeCheckConfig
deep_scan: DeepScanConfig
@dataclass(frozen=True)
class SceneDetectionConfig:
content_threshold: float
min_scene_duration_s: float
@dataclass(frozen=True)
class WhisperConfig:
model: str
language: str
device: Literal["cuda", "cpu"]
compute_type: Literal["float16", "int8", "float32"]
@dataclass(frozen=True)
class LLMConfig:
provider: Literal["ollama", "openai", "openrouter"]
base_url: str
model: str
timeout_seconds: int
temperature: float
max_tokens: int
# Loaded from .env — NEVER committed to version control
api_key: str = ""
@dataclass(frozen=True)
class VisionConfig:
enabled: bool
provider: Literal["openai", "openrouter"]
base_url: str
model: str
timeout_seconds: int
temperature: float
max_tokens: int
scene_candidate_top_k: int
max_new_descriptions_per_run: int
max_seed_scenes: int
seed_points_per_scene: int
seed_score: float
max_refine_candidates: int
local_scan_step_s: float
local_scan_max_points_per_scene: int
local_scan_top_candidates: int
local_scan_tie_break_score_delta: float
multi_shot_cut_corr_threshold: float
multi_shot_boundary_tolerance_s: float
fullscan_fallback: bool
content_threshold: float
similarity_threshold: float
api_key: str = ""
@dataclass(frozen=True)
class ExportConfig:
fcpxml_version: str
edl_frame_rate: float
output_format: Literal["fcpxml", "edl", "both"]
# ---------------------------------------------------------------------------
# Root config — single object passed through the entire application
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class AppConfig:
project_name: str
version: str
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"]
paths: PathsConfig
video: VideoConfig
cv: CVConfig
scene_detection: SceneDetectionConfig
whisper: WhisperConfig
llm: LLMConfig
vision: VisionConfig
export: ExportConfig
# ---------------------------------------------------------------------------
# Loader
# ---------------------------------------------------------------------------
_DEFAULT_CONFIG_PATH = Path(__file__).parents[2] / "config.toml"
_DEFAULT_ENV_PATH = Path(__file__).parents[2] / ".env"
def load_config(
config_path: Path = _DEFAULT_CONFIG_PATH,
env_path: Path = _DEFAULT_ENV_PATH,
) -> AppConfig:
"""
Parse config.toml and return a fully-typed, immutable AppConfig.
API keys are read from the .env file (or existing environment variables);
they are never stored in config.toml.
Args:
config_path: Absolute or relative path to the TOML file.
Defaults to <project_root>/config.toml.
env_path: Path to the .env file.
Defaults to <project_root>/.env.
Raises:
FileNotFoundError: If the TOML file does not exist.
KeyError / TypeError: If a required key is missing or has the wrong type.
"""
# Load .env first so os.environ is populated before we read it below.
if _HAS_DOTENV:
_load_dotenv(dotenv_path=env_path, override=False)
if not config_path.exists():
raise FileNotFoundError(
f"Config file not found: {config_path}\n"
"Copy config.toml.example to config.toml and adjust your paths."
)
with config_path.open("rb") as fh:
raw: dict = tomllib.load(fh)
project = raw["project"]
paths_raw = raw["paths"]
video_raw = raw["video"]
cv_raw = raw["cv"]
sd_raw = raw["scene_detection"]
whisper_raw = raw["whisper"]
llm_raw = raw["llm"]
vision_raw = raw.get("vision", {})
export_raw = raw["export"]
# Resolve paths relative to the config file's parent directory so the
# project is relocatable, but keep absolute paths as-is.
def _resolve(p: str) -> Path:
path = Path(p)
return path if path.is_absolute() else (config_path.parent / path).resolve()
paths = PathsConfig(
source_movie=_resolve(paths_raw["source_movie"]),
reference_trailer=_resolve(paths_raw["reference_trailer"]),
output_dir=_resolve(paths_raw["output_dir"]),
cache_dir=_resolve(paths_raw["cache_dir"]),
proxy_dir=_resolve(paths_raw["proxy_dir"]),
)
video = VideoConfig(
extract_fps=float(video_raw["extract_fps"]),
proxy_width=int(video_raw["proxy_width"]),
proxy_height=int(video_raw["proxy_height"]),
)
vibe_check = VibeCheckConfig(
top_k_candidates=int(cv_raw["vibe_check"]["top_k_candidates"]),
hist_compare_method=int(cv_raw["vibe_check"]["hist_compare_method"]),
hist_bins_hue=int(cv_raw["vibe_check"]["hist_bins_hue"]),
hist_bins_saturation=int(cv_raw["vibe_check"]["hist_bins_saturation"]),
phash_max_distance=int(cv_raw["vibe_check"]["phash_max_distance"]),
crop_top_fraction=float(cv_raw["vibe_check"]["crop_top_fraction"]),
crop_bottom_fraction=float(cv_raw["vibe_check"]["crop_bottom_fraction"]),
)
deep_scan = DeepScanConfig(
coarse_step_seconds=float(cv_raw["deep_scan"]["coarse_step_seconds"]),
match_threshold=float(cv_raw["deep_scan"]["match_threshold"]),
provisional_match_threshold=float(cv_raw["deep_scan"].get("provisional_match_threshold", 0.43)),
coarse_candidate_threshold=float(cv_raw["deep_scan"].get("coarse_candidate_threshold", cv_raw["deep_scan"]["match_threshold"])),
sequence_score_weight=float(cv_raw["deep_scan"].get("sequence_score_weight", 0.55)),
span_score_weight=float(cv_raw["deep_scan"].get("span_score_weight", 0.15)),
coarse_score_weight=float(cv_raw["deep_scan"].get("coarse_score_weight", 0.10)),
duration_score_weight=float(cv_raw["deep_scan"].get("duration_score_weight", 0.20)),
duration_tie_break_score_delta=float(cv_raw["deep_scan"].get("duration_tie_break_score_delta", 0.03)),
min_duration_coverage=float(cv_raw["deep_scan"].get("min_duration_coverage", 0.65)),
continuity_seed_offsets_s=tuple(
float(v) for v in cv_raw["deep_scan"].get(
"continuity_seed_offsets_s",
[-1.0, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
)
),
scene_seed_top_k=int(cv_raw["deep_scan"].get("scene_seed_top_k", 30)),
scene_seed_points_per_scene=int(cv_raw["deep_scan"].get("scene_seed_points_per_scene", 6)),
content_rerank_candidate_count=int(cv_raw["deep_scan"].get("content_rerank_candidate_count", 100)),
skip_coarse_scan_with_weighted_seeds=bool(cv_raw["deep_scan"].get("skip_coarse_scan_with_weighted_seeds", False)),
max_refine_candidates=int(cv_raw["deep_scan"].get("max_refine_candidates", 6)),
match_method=int(cv_raw["deep_scan"]["match_method"]),
refine_window_seconds=float(cv_raw["deep_scan"].get("refine_window_seconds", 0.6)),
refine_step_seconds=float(cv_raw["deep_scan"]["refine_step_seconds"]),
content_align_window_seconds=float(cv_raw["deep_scan"].get("content_align_window_seconds", 0.48)),
content_align_sample_step_s=float(cv_raw["deep_scan"].get("content_align_sample_step_s", 0.28)),
content_validation_weight=float(cv_raw["deep_scan"].get("content_validation_weight", 0.35)),
provisional_content_threshold=float(cv_raw["deep_scan"].get("provisional_content_threshold", 0.42)),
start_tie_break_score_delta=float(cv_raw["deep_scan"].get("start_tie_break_score_delta", 0.015)),
start_preroll_frames=int(cv_raw["deep_scan"].get("start_preroll_frames", 0)),
sequence_candidate_count=int(cv_raw["deep_scan"].get("sequence_candidate_count", 240)),
sequence_min_distance_s=float(cv_raw["deep_scan"].get("sequence_min_distance_s", 1.0)),
span_sample_step_s=float(cv_raw["deep_scan"].get("span_sample_step_s", 0.08)),
trim_tail_frames=int(cv_raw["deep_scan"].get("trim_tail_frames", 2)),
scene_boundary_epsilon_s=float(cv_raw["deep_scan"].get("scene_boundary_epsilon_s", 0.12)),
scoreable_luma_mean_min=float(cv_raw["deep_scan"].get("scoreable_luma_mean_min", 24.0)),
scoreable_luma_p90_min=float(cv_raw["deep_scan"].get("scoreable_luma_p90_min", 58.0)),
scoreable_contrast_min=float(cv_raw["deep_scan"].get("scoreable_contrast_min", 24.0)),
)
scene_detection = SceneDetectionConfig(
content_threshold=float(sd_raw["content_threshold"]),
min_scene_duration_s=float(sd_raw["min_scene_duration_s"]),
)
whisper = WhisperConfig(
model=whisper_raw["model"],
language=whisper_raw["language"],
device=whisper_raw["device"],
compute_type=whisper_raw["compute_type"],
)
# Resolve API key: env var takes precedence over config (which shouldn't have it).
# Supported env vars (in priority order):
# OPENROUTER_API_KEY → for provider = openrouter
# OPENAI_API_KEY → for provider = openai
# LLM_API_KEY → universal fallback
_provider = llm_raw["provider"]
_api_key = (
os.environ.get("OPENROUTER_API_KEY", "")
if _provider == "openrouter"
else os.environ.get("OPENAI_API_KEY", "")
if _provider == "openai"
else ""
) or os.environ.get("LLM_API_KEY", "")
llm = LLMConfig(
provider=_provider,
base_url=llm_raw["base_url"],
model=llm_raw["model"],
timeout_seconds=int(llm_raw["timeout_seconds"]),
temperature=float(llm_raw["temperature"]),
max_tokens=int(llm_raw["max_tokens"]),
api_key=_api_key,
)
vision_provider = vision_raw.get("provider", _provider if _provider in ("openai", "openrouter") else "openrouter")
vision_api_key = (
os.environ.get("OPENROUTER_API_KEY", "")
if vision_provider == "openrouter"
else os.environ.get("OPENAI_API_KEY", "")
) or os.environ.get("VISION_API_KEY", "") or os.environ.get("LLM_API_KEY", "")
vision = VisionConfig(
enabled=bool(vision_raw.get("enabled", False)),
provider=vision_provider,
base_url=str(vision_raw.get("base_url", llm.base_url)),
model=str(vision_raw.get("model", llm.model)),
timeout_seconds=int(vision_raw.get("timeout_seconds", llm.timeout_seconds)),
temperature=float(vision_raw.get("temperature", 0.0)),
max_tokens=int(vision_raw.get("max_tokens", 350)),
scene_candidate_top_k=int(vision_raw.get("scene_candidate_top_k", 8)),
max_new_descriptions_per_run=int(vision_raw.get("max_new_descriptions_per_run", 12)),
max_seed_scenes=int(vision_raw.get("max_seed_scenes", 3)),
seed_points_per_scene=int(vision_raw.get("seed_points_per_scene", 12)),
seed_score=float(vision_raw.get("seed_score", 0.88)),
max_refine_candidates=int(vision_raw.get("max_refine_candidates", 6)),
local_scan_step_s=float(vision_raw.get("local_scan_step_s", 0.12)),
local_scan_max_points_per_scene=int(vision_raw.get("local_scan_max_points_per_scene", 180)),
local_scan_top_candidates=int(vision_raw.get("local_scan_top_candidates", 18)),
local_scan_tie_break_score_delta=float(vision_raw.get("local_scan_tie_break_score_delta", 0.08)),
multi_shot_cut_corr_threshold=float(vision_raw.get("multi_shot_cut_corr_threshold", 0.20)),
multi_shot_boundary_tolerance_s=float(vision_raw.get("multi_shot_boundary_tolerance_s", 0.20)),
fullscan_fallback=bool(vision_raw.get("fullscan_fallback", False)),
content_threshold=float(vision_raw.get("content_threshold", 0.22)),
similarity_threshold=float(vision_raw.get("similarity_threshold", 0.18)),
api_key=vision_api_key,
)
export = ExportConfig(
fcpxml_version=str(export_raw["fcpxml_version"]),
edl_frame_rate=float(export_raw["edl_frame_rate"]),
output_format=export_raw["output_format"],
)
return AppConfig(
project_name=project["name"],
version=project["version"],
log_level=project["log_level"],
paths=paths,
video=video,
cv=CVConfig(vibe_check=vibe_check, deep_scan=deep_scan),
scene_detection=scene_detection,
whisper=whisper,
llm=llm,
vision=vision,
export=export,
)