388 lines
15 KiB
Python
388 lines
15 KiB
Python
"""
|
|
src/core/config.py — Configuration loader for AI Trailer Generator v2
|
|
|
|
Loads config.toml and exposes typed, nested dataclasses.
|
|
All CV thresholds, paths, and model settings are sourced exclusively here.
|
|
API keys are NEVER stored in config.toml; they are loaded from .env.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tomllib
|
|
|
|
try:
|
|
from dotenv import load_dotenv as _load_dotenv
|
|
_HAS_DOTENV = True
|
|
except ImportError: # dotenv optional — falls back to existing env vars
|
|
_HAS_DOTENV = False
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Leaf sections
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass(frozen=True)
|
|
class PathsConfig:
|
|
source_movie: Path
|
|
reference_trailer: Path
|
|
output_dir: Path
|
|
cache_dir: Path
|
|
proxy_dir: Path
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VideoConfig:
|
|
extract_fps: float
|
|
proxy_width: int
|
|
proxy_height: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VibeCheckConfig:
|
|
top_k_candidates: int
|
|
hist_compare_method: int
|
|
hist_bins_hue: int
|
|
hist_bins_saturation: int
|
|
phash_max_distance: int
|
|
crop_top_fraction: float
|
|
crop_bottom_fraction: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DeepScanConfig:
|
|
coarse_step_seconds: float
|
|
match_threshold: float
|
|
provisional_match_threshold: float
|
|
coarse_candidate_threshold: float
|
|
sequence_score_weight: float
|
|
span_score_weight: float
|
|
coarse_score_weight: float
|
|
duration_score_weight: float
|
|
duration_tie_break_score_delta: float
|
|
min_duration_coverage: float
|
|
continuity_seed_offsets_s: tuple[float, ...]
|
|
scene_seed_top_k: int
|
|
scene_seed_points_per_scene: int
|
|
content_rerank_candidate_count: int
|
|
skip_coarse_scan_with_weighted_seeds: bool
|
|
max_refine_candidates: int
|
|
match_method: int
|
|
refine_window_seconds: float
|
|
refine_step_seconds: float
|
|
content_align_window_seconds: float
|
|
content_align_sample_step_s: float
|
|
content_validation_weight: float
|
|
provisional_content_threshold: float
|
|
start_tie_break_score_delta: float
|
|
start_preroll_frames: int
|
|
sequence_candidate_count: int
|
|
sequence_min_distance_s: float
|
|
span_sample_step_s: float
|
|
trim_tail_frames: int
|
|
scene_boundary_epsilon_s: float
|
|
scoreable_luma_mean_min: float
|
|
scoreable_luma_p90_min: float
|
|
scoreable_contrast_min: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CVConfig:
|
|
vibe_check: VibeCheckConfig
|
|
deep_scan: DeepScanConfig
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SceneDetectionConfig:
|
|
content_threshold: float
|
|
min_scene_duration_s: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WhisperConfig:
|
|
model: str
|
|
language: str
|
|
device: Literal["cuda", "cpu"]
|
|
compute_type: Literal["float16", "int8", "float32"]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMConfig:
|
|
provider: Literal["ollama", "openai", "openrouter"]
|
|
base_url: str
|
|
model: str
|
|
timeout_seconds: int
|
|
temperature: float
|
|
max_tokens: int
|
|
# Loaded from .env — NEVER committed to version control
|
|
api_key: str = ""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VisionConfig:
|
|
enabled: bool
|
|
provider: Literal["openai", "openrouter"]
|
|
base_url: str
|
|
model: str
|
|
timeout_seconds: int
|
|
temperature: float
|
|
max_tokens: int
|
|
scene_candidate_top_k: int
|
|
max_new_descriptions_per_run: int
|
|
max_seed_scenes: int
|
|
seed_points_per_scene: int
|
|
seed_score: float
|
|
max_refine_candidates: int
|
|
local_scan_step_s: float
|
|
local_scan_max_points_per_scene: int
|
|
local_scan_top_candidates: int
|
|
local_scan_tie_break_score_delta: float
|
|
multi_shot_cut_corr_threshold: float
|
|
multi_shot_boundary_tolerance_s: float
|
|
fullscan_fallback: bool
|
|
content_threshold: float
|
|
similarity_threshold: float
|
|
api_key: str = ""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExportConfig:
|
|
fcpxml_version: str
|
|
edl_frame_rate: float
|
|
output_format: Literal["fcpxml", "edl", "both"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Root config — single object passed through the entire application
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass(frozen=True)
|
|
class AppConfig:
|
|
project_name: str
|
|
version: str
|
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"]
|
|
|
|
paths: PathsConfig
|
|
video: VideoConfig
|
|
cv: CVConfig
|
|
scene_detection: SceneDetectionConfig
|
|
whisper: WhisperConfig
|
|
llm: LLMConfig
|
|
vision: VisionConfig
|
|
export: ExportConfig
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Loader
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_DEFAULT_CONFIG_PATH = Path(__file__).parents[2] / "config.toml"
|
|
_DEFAULT_ENV_PATH = Path(__file__).parents[2] / ".env"
|
|
|
|
|
|
def load_config(
|
|
config_path: Path = _DEFAULT_CONFIG_PATH,
|
|
env_path: Path = _DEFAULT_ENV_PATH,
|
|
) -> AppConfig:
|
|
"""
|
|
Parse config.toml and return a fully-typed, immutable AppConfig.
|
|
|
|
API keys are read from the .env file (or existing environment variables);
|
|
they are never stored in config.toml.
|
|
|
|
Args:
|
|
config_path: Absolute or relative path to the TOML file.
|
|
Defaults to <project_root>/config.toml.
|
|
env_path: Path to the .env file.
|
|
Defaults to <project_root>/.env.
|
|
|
|
Raises:
|
|
FileNotFoundError: If the TOML file does not exist.
|
|
KeyError / TypeError: If a required key is missing or has the wrong type.
|
|
"""
|
|
# Load .env first so os.environ is populated before we read it below.
|
|
if _HAS_DOTENV:
|
|
_load_dotenv(dotenv_path=env_path, override=False)
|
|
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(
|
|
f"Config file not found: {config_path}\n"
|
|
"Copy config.toml.example to config.toml and adjust your paths."
|
|
)
|
|
|
|
with config_path.open("rb") as fh:
|
|
raw: dict = tomllib.load(fh)
|
|
|
|
project = raw["project"]
|
|
paths_raw = raw["paths"]
|
|
video_raw = raw["video"]
|
|
cv_raw = raw["cv"]
|
|
sd_raw = raw["scene_detection"]
|
|
whisper_raw = raw["whisper"]
|
|
llm_raw = raw["llm"]
|
|
vision_raw = raw.get("vision", {})
|
|
export_raw = raw["export"]
|
|
|
|
# Resolve paths relative to the config file's parent directory so the
|
|
# project is relocatable, but keep absolute paths as-is.
|
|
def _resolve(p: str) -> Path:
|
|
path = Path(p)
|
|
return path if path.is_absolute() else (config_path.parent / path).resolve()
|
|
|
|
paths = PathsConfig(
|
|
source_movie=_resolve(paths_raw["source_movie"]),
|
|
reference_trailer=_resolve(paths_raw["reference_trailer"]),
|
|
output_dir=_resolve(paths_raw["output_dir"]),
|
|
cache_dir=_resolve(paths_raw["cache_dir"]),
|
|
proxy_dir=_resolve(paths_raw["proxy_dir"]),
|
|
)
|
|
|
|
video = VideoConfig(
|
|
extract_fps=float(video_raw["extract_fps"]),
|
|
proxy_width=int(video_raw["proxy_width"]),
|
|
proxy_height=int(video_raw["proxy_height"]),
|
|
)
|
|
|
|
vibe_check = VibeCheckConfig(
|
|
top_k_candidates=int(cv_raw["vibe_check"]["top_k_candidates"]),
|
|
hist_compare_method=int(cv_raw["vibe_check"]["hist_compare_method"]),
|
|
hist_bins_hue=int(cv_raw["vibe_check"]["hist_bins_hue"]),
|
|
hist_bins_saturation=int(cv_raw["vibe_check"]["hist_bins_saturation"]),
|
|
phash_max_distance=int(cv_raw["vibe_check"]["phash_max_distance"]),
|
|
crop_top_fraction=float(cv_raw["vibe_check"]["crop_top_fraction"]),
|
|
crop_bottom_fraction=float(cv_raw["vibe_check"]["crop_bottom_fraction"]),
|
|
)
|
|
|
|
deep_scan = DeepScanConfig(
|
|
coarse_step_seconds=float(cv_raw["deep_scan"]["coarse_step_seconds"]),
|
|
match_threshold=float(cv_raw["deep_scan"]["match_threshold"]),
|
|
provisional_match_threshold=float(cv_raw["deep_scan"].get("provisional_match_threshold", 0.43)),
|
|
coarse_candidate_threshold=float(cv_raw["deep_scan"].get("coarse_candidate_threshold", cv_raw["deep_scan"]["match_threshold"])),
|
|
sequence_score_weight=float(cv_raw["deep_scan"].get("sequence_score_weight", 0.55)),
|
|
span_score_weight=float(cv_raw["deep_scan"].get("span_score_weight", 0.15)),
|
|
coarse_score_weight=float(cv_raw["deep_scan"].get("coarse_score_weight", 0.10)),
|
|
duration_score_weight=float(cv_raw["deep_scan"].get("duration_score_weight", 0.20)),
|
|
duration_tie_break_score_delta=float(cv_raw["deep_scan"].get("duration_tie_break_score_delta", 0.03)),
|
|
min_duration_coverage=float(cv_raw["deep_scan"].get("min_duration_coverage", 0.65)),
|
|
continuity_seed_offsets_s=tuple(
|
|
float(v) for v in cv_raw["deep_scan"].get(
|
|
"continuity_seed_offsets_s",
|
|
[-1.0, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
|
)
|
|
),
|
|
scene_seed_top_k=int(cv_raw["deep_scan"].get("scene_seed_top_k", 30)),
|
|
scene_seed_points_per_scene=int(cv_raw["deep_scan"].get("scene_seed_points_per_scene", 6)),
|
|
content_rerank_candidate_count=int(cv_raw["deep_scan"].get("content_rerank_candidate_count", 100)),
|
|
skip_coarse_scan_with_weighted_seeds=bool(cv_raw["deep_scan"].get("skip_coarse_scan_with_weighted_seeds", False)),
|
|
max_refine_candidates=int(cv_raw["deep_scan"].get("max_refine_candidates", 6)),
|
|
match_method=int(cv_raw["deep_scan"]["match_method"]),
|
|
refine_window_seconds=float(cv_raw["deep_scan"].get("refine_window_seconds", 0.6)),
|
|
refine_step_seconds=float(cv_raw["deep_scan"]["refine_step_seconds"]),
|
|
content_align_window_seconds=float(cv_raw["deep_scan"].get("content_align_window_seconds", 0.48)),
|
|
content_align_sample_step_s=float(cv_raw["deep_scan"].get("content_align_sample_step_s", 0.28)),
|
|
content_validation_weight=float(cv_raw["deep_scan"].get("content_validation_weight", 0.35)),
|
|
provisional_content_threshold=float(cv_raw["deep_scan"].get("provisional_content_threshold", 0.42)),
|
|
start_tie_break_score_delta=float(cv_raw["deep_scan"].get("start_tie_break_score_delta", 0.015)),
|
|
start_preroll_frames=int(cv_raw["deep_scan"].get("start_preroll_frames", 0)),
|
|
sequence_candidate_count=int(cv_raw["deep_scan"].get("sequence_candidate_count", 240)),
|
|
sequence_min_distance_s=float(cv_raw["deep_scan"].get("sequence_min_distance_s", 1.0)),
|
|
span_sample_step_s=float(cv_raw["deep_scan"].get("span_sample_step_s", 0.08)),
|
|
trim_tail_frames=int(cv_raw["deep_scan"].get("trim_tail_frames", 2)),
|
|
scene_boundary_epsilon_s=float(cv_raw["deep_scan"].get("scene_boundary_epsilon_s", 0.12)),
|
|
scoreable_luma_mean_min=float(cv_raw["deep_scan"].get("scoreable_luma_mean_min", 24.0)),
|
|
scoreable_luma_p90_min=float(cv_raw["deep_scan"].get("scoreable_luma_p90_min", 58.0)),
|
|
scoreable_contrast_min=float(cv_raw["deep_scan"].get("scoreable_contrast_min", 24.0)),
|
|
)
|
|
|
|
scene_detection = SceneDetectionConfig(
|
|
content_threshold=float(sd_raw["content_threshold"]),
|
|
min_scene_duration_s=float(sd_raw["min_scene_duration_s"]),
|
|
)
|
|
|
|
whisper = WhisperConfig(
|
|
model=whisper_raw["model"],
|
|
language=whisper_raw["language"],
|
|
device=whisper_raw["device"],
|
|
compute_type=whisper_raw["compute_type"],
|
|
)
|
|
|
|
# Resolve API key: env var takes precedence over config (which shouldn't have it).
|
|
# Supported env vars (in priority order):
|
|
# OPENROUTER_API_KEY → for provider = openrouter
|
|
# OPENAI_API_KEY → for provider = openai
|
|
# LLM_API_KEY → universal fallback
|
|
_provider = llm_raw["provider"]
|
|
_api_key = (
|
|
os.environ.get("OPENROUTER_API_KEY", "")
|
|
if _provider == "openrouter"
|
|
else os.environ.get("OPENAI_API_KEY", "")
|
|
if _provider == "openai"
|
|
else ""
|
|
) or os.environ.get("LLM_API_KEY", "")
|
|
|
|
llm = LLMConfig(
|
|
provider=_provider,
|
|
base_url=llm_raw["base_url"],
|
|
model=llm_raw["model"],
|
|
timeout_seconds=int(llm_raw["timeout_seconds"]),
|
|
temperature=float(llm_raw["temperature"]),
|
|
max_tokens=int(llm_raw["max_tokens"]),
|
|
api_key=_api_key,
|
|
)
|
|
|
|
vision_provider = vision_raw.get("provider", _provider if _provider in ("openai", "openrouter") else "openrouter")
|
|
vision_api_key = (
|
|
os.environ.get("OPENROUTER_API_KEY", "")
|
|
if vision_provider == "openrouter"
|
|
else os.environ.get("OPENAI_API_KEY", "")
|
|
) or os.environ.get("VISION_API_KEY", "") or os.environ.get("LLM_API_KEY", "")
|
|
|
|
vision = VisionConfig(
|
|
enabled=bool(vision_raw.get("enabled", False)),
|
|
provider=vision_provider,
|
|
base_url=str(vision_raw.get("base_url", llm.base_url)),
|
|
model=str(vision_raw.get("model", llm.model)),
|
|
timeout_seconds=int(vision_raw.get("timeout_seconds", llm.timeout_seconds)),
|
|
temperature=float(vision_raw.get("temperature", 0.0)),
|
|
max_tokens=int(vision_raw.get("max_tokens", 350)),
|
|
scene_candidate_top_k=int(vision_raw.get("scene_candidate_top_k", 8)),
|
|
max_new_descriptions_per_run=int(vision_raw.get("max_new_descriptions_per_run", 12)),
|
|
max_seed_scenes=int(vision_raw.get("max_seed_scenes", 3)),
|
|
seed_points_per_scene=int(vision_raw.get("seed_points_per_scene", 12)),
|
|
seed_score=float(vision_raw.get("seed_score", 0.88)),
|
|
max_refine_candidates=int(vision_raw.get("max_refine_candidates", 6)),
|
|
local_scan_step_s=float(vision_raw.get("local_scan_step_s", 0.12)),
|
|
local_scan_max_points_per_scene=int(vision_raw.get("local_scan_max_points_per_scene", 180)),
|
|
local_scan_top_candidates=int(vision_raw.get("local_scan_top_candidates", 18)),
|
|
local_scan_tie_break_score_delta=float(vision_raw.get("local_scan_tie_break_score_delta", 0.08)),
|
|
multi_shot_cut_corr_threshold=float(vision_raw.get("multi_shot_cut_corr_threshold", 0.20)),
|
|
multi_shot_boundary_tolerance_s=float(vision_raw.get("multi_shot_boundary_tolerance_s", 0.20)),
|
|
fullscan_fallback=bool(vision_raw.get("fullscan_fallback", False)),
|
|
content_threshold=float(vision_raw.get("content_threshold", 0.22)),
|
|
similarity_threshold=float(vision_raw.get("similarity_threshold", 0.18)),
|
|
api_key=vision_api_key,
|
|
)
|
|
|
|
export = ExportConfig(
|
|
fcpxml_version=str(export_raw["fcpxml_version"]),
|
|
edl_frame_rate=float(export_raw["edl_frame_rate"]),
|
|
output_format=export_raw["output_format"],
|
|
)
|
|
|
|
return AppConfig(
|
|
project_name=project["name"],
|
|
version=project["version"],
|
|
log_level=project["log_level"],
|
|
paths=paths,
|
|
video=video,
|
|
cv=CVConfig(vibe_check=vibe_check, deep_scan=deep_scan),
|
|
scene_detection=scene_detection,
|
|
whisper=whisper,
|
|
llm=llm,
|
|
vision=vision,
|
|
export=export,
|
|
)
|