360 lines
12 KiB
Python
360 lines
12 KiB
Python
"""
|
|
Cached vision descriptions for ambiguous trailer/source matching.
|
|
|
|
This module is deliberately conservative: it never writes a final match and it
|
|
does not replace CV. It describes a small number of 3-frame beat/scene samples,
|
|
caches those descriptions, and returns extra source in-point seeds for the CV
|
|
scanner to verify.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
from dataclasses import asdict
|
|
from pathlib import Path
|
|
from typing import Sequence
|
|
|
|
import cv2
|
|
|
|
from src.core.config import AppConfig
|
|
from src.core.models import Scene, TrailerBeat
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_CACHE_VERSION = 1
|
|
_STOPWORDS = {
|
|
"the", "and", "with", "from", "that", "this", "there", "their", "into",
|
|
"scene", "frame", "image", "shot", "video", "visible", "looks", "appears",
|
|
"eine", "einer", "einem", "einen", "und", "oder", "mit", "der", "die", "das",
|
|
}
|
|
|
|
_SYSTEM_PROMPT = """You describe film shots for automatic matching.
|
|
Return only compact JSON with these keys:
|
|
subject, setting, composition, action_phase, distinctive_objects, lighting_color, negatives.
|
|
Focus on stable visual facts and spatial layout. Ignore timecode overlays, subtitles, logos, compression, aspect ratio, and color grading differences."""
|
|
|
|
_RETRYABLE_HTTP_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
|
_CREDIT_ERROR_PATTERNS = (
|
|
"insufficient credit",
|
|
"insufficient credits",
|
|
"no credits",
|
|
"out of credits",
|
|
"billing",
|
|
"quota exceeded",
|
|
"payment required",
|
|
)
|
|
|
|
|
|
def _cache_path(cfg: AppConfig) -> Path:
|
|
return cfg.paths.cache_dir / "vision_descriptions.json"
|
|
|
|
|
|
def _load_cache(cfg: AppConfig) -> dict:
|
|
path = _cache_path(cfg)
|
|
if not path.exists():
|
|
return {"version": _CACHE_VERSION, "items": {}}
|
|
try:
|
|
data = json.loads(path.read_text(encoding="utf-8"))
|
|
except json.JSONDecodeError:
|
|
logger.warning("Vision cache is unreadable; rebuilding: %s", path)
|
|
return {"version": _CACHE_VERSION, "items": {}}
|
|
if data.get("version") != _CACHE_VERSION or not isinstance(data.get("items"), dict):
|
|
return {"version": _CACHE_VERSION, "items": {}}
|
|
return data
|
|
|
|
|
|
def _save_cache(cfg: AppConfig, cache: dict) -> None:
|
|
path = _cache_path(cfg)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path.write_text(json.dumps(cache, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
|
|
|
|
def _sample_times(start_s: float, end_s: float) -> list[float]:
|
|
duration_s = max(0.04, end_s - start_s)
|
|
return [
|
|
start_s + min(duration_s * 0.12, max(0.0, duration_s - 0.04)),
|
|
start_s + duration_s * 0.50,
|
|
start_s + max(0.0, duration_s - min(duration_s * 0.12, 0.20)),
|
|
]
|
|
|
|
|
|
def _frame_data_url(video_path: Path, t_s: float) -> str | None:
|
|
cap = cv2.VideoCapture(str(video_path))
|
|
try:
|
|
if not cap.isOpened():
|
|
return None
|
|
cap.set(cv2.CAP_PROP_POS_MSEC, max(0.0, t_s) * 1000.0)
|
|
ok, frame = cap.read()
|
|
if not ok or frame is None:
|
|
return None
|
|
h, w = frame.shape[:2]
|
|
if w > 640:
|
|
frame = cv2.resize(frame, (640, int(h * (640 / w))), interpolation=cv2.INTER_AREA)
|
|
ok, encoded = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 72])
|
|
if not ok:
|
|
return None
|
|
payload = base64.b64encode(encoded.tobytes()).decode("ascii")
|
|
return f"data:image/jpeg;base64,{payload}"
|
|
finally:
|
|
cap.release()
|
|
|
|
|
|
def _call_vision_model(label: str, image_urls: list[str], cfg: AppConfig) -> str:
|
|
vision = cfg.vision
|
|
if vision.provider in ("openai", "openrouter") and not vision.api_key:
|
|
raise RuntimeError(
|
|
"Vision is enabled but no API key is available. Set VISION_API_KEY, "
|
|
"OPENROUTER_API_KEY, OPENAI_API_KEY, or LLM_API_KEY."
|
|
)
|
|
|
|
content: list[dict] = [{
|
|
"type": "text",
|
|
"text": (
|
|
f"Describe this 3-frame sample for matching. Label: {label}. "
|
|
"The frames are start, middle, and end of the same beat/scene."
|
|
),
|
|
}]
|
|
content.extend({
|
|
"type": "image_url",
|
|
"image_url": {"url": url, "detail": "low"},
|
|
} for url in image_urls)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {vision.api_key}",
|
|
}
|
|
if vision.provider == "openrouter":
|
|
headers["HTTP-Referer"] = "https://github.com/ai-trailer-2026"
|
|
headers["X-Title"] = "AI Trailer Generator v2"
|
|
|
|
body = json.dumps({
|
|
"model": vision.model,
|
|
"messages": [
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{"role": "user", "content": content},
|
|
],
|
|
"temperature": vision.temperature,
|
|
"max_tokens": vision.max_tokens,
|
|
}).encode("utf-8")
|
|
|
|
url = f"{vision.base_url.rstrip('/')}/chat/completions"
|
|
req = urllib.request.Request(url, data=body, headers=headers, method="POST")
|
|
delays_s = (8.0, 20.0, 45.0, 90.0)
|
|
for attempt in range(len(delays_s) + 1):
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=vision.timeout_seconds) as resp:
|
|
data = json.loads(resp.read().decode("utf-8"))
|
|
return str(data["choices"][0]["message"]["content"]).strip()
|
|
except urllib.error.HTTPError as exc:
|
|
body_text = exc.read().decode(errors="replace")
|
|
lowered = body_text.lower()
|
|
if exc.code == 402 or any(pattern in lowered for pattern in _CREDIT_ERROR_PATTERNS):
|
|
raise RuntimeError(f"Vision HTTP {exc.code} from {url}:\n{body_text}") from exc
|
|
if exc.code not in _RETRYABLE_HTTP_CODES or attempt >= len(delays_s):
|
|
raise RuntimeError(f"Vision HTTP {exc.code} from {url}:\n{body_text}") from exc
|
|
delay_s = delays_s[attempt]
|
|
logger.warning(
|
|
"Vision HTTP %d for %s; waiting %.0fs before retry %d/%d.",
|
|
exc.code,
|
|
label,
|
|
delay_s,
|
|
attempt + 1,
|
|
len(delays_s),
|
|
)
|
|
time.sleep(delay_s)
|
|
except urllib.error.URLError as exc:
|
|
if attempt >= len(delays_s):
|
|
raise RuntimeError(f"Vision request failed for {url}: {exc}") from exc
|
|
delay_s = delays_s[attempt]
|
|
logger.warning(
|
|
"Vision request failed for %s (%s); waiting %.0fs before retry %d/%d.",
|
|
label,
|
|
exc.reason,
|
|
delay_s,
|
|
attempt + 1,
|
|
len(delays_s),
|
|
)
|
|
time.sleep(delay_s)
|
|
|
|
raise RuntimeError(f"Vision request failed unexpectedly for {url}")
|
|
|
|
|
|
def _description_key(kind: str, item_id: int, start_s: float, end_s: float, cfg: AppConfig) -> str:
|
|
path = cfg.paths.reference_trailer if kind == "beat" else cfg.paths.source_movie
|
|
try:
|
|
stamp = int(path.stat().st_mtime)
|
|
except OSError:
|
|
stamp = 0
|
|
return (
|
|
f"{kind}:{item_id}:"
|
|
f"{start_s:.3f}:{end_s:.3f}:"
|
|
f"{cfg.vision.provider}:{cfg.vision.model}:{stamp}"
|
|
)
|
|
|
|
|
|
def _describe_sample(
|
|
*,
|
|
kind: str,
|
|
item_id: int,
|
|
label: str,
|
|
video_path: Path,
|
|
start_s: float,
|
|
end_s: float,
|
|
cfg: AppConfig,
|
|
cache: dict,
|
|
budget: list[int],
|
|
) -> str | None:
|
|
key = _description_key(kind, item_id, start_s, end_s, cfg)
|
|
cached = cache["items"].get(key)
|
|
if cached:
|
|
return str(cached.get("description", ""))
|
|
if budget[0] <= 0:
|
|
return None
|
|
|
|
image_urls = [
|
|
url for url in (_frame_data_url(video_path, t) for t in _sample_times(start_s, end_s))
|
|
if url is not None
|
|
]
|
|
if len(image_urls) < 2:
|
|
return None
|
|
|
|
description = _call_vision_model(label, image_urls, cfg)
|
|
cache["items"][key] = {
|
|
"kind": kind,
|
|
"item_id": item_id,
|
|
"start_s": start_s,
|
|
"end_s": end_s,
|
|
"label": label,
|
|
"description": description,
|
|
}
|
|
budget[0] -= 1
|
|
return description
|
|
|
|
|
|
def _terms(text: str) -> set[str]:
|
|
words = re.findall(r"[a-zA-Z][a-zA-Z0-9_'-]{2,}", text.lower())
|
|
return {w for w in words if w not in _STOPWORDS}
|
|
|
|
|
|
def _text_similarity(a: str, b: str) -> float:
|
|
ta = _terms(a)
|
|
tb = _terms(b)
|
|
if not ta or not tb:
|
|
return 0.0
|
|
overlap = len(ta & tb)
|
|
return float(overlap / max(8, min(len(ta), len(tb))))
|
|
|
|
|
|
def _scene_seed_points(scene: Scene, max_points: int) -> list[float]:
|
|
if max_points <= 1 or scene.duration_s <= 0:
|
|
return [scene.start_s]
|
|
usable_end = max(scene.start_s, scene.end_s - 0.2)
|
|
if usable_end <= scene.start_s:
|
|
return [scene.start_s]
|
|
step = (usable_end - scene.start_s) / max(1, max_points - 1)
|
|
return [scene.start_s + step * idx for idx in range(max_points)]
|
|
|
|
|
|
def build_vision_seed_in_points(
|
|
beats: Sequence[TrailerBeat],
|
|
scenes: Sequence[Scene],
|
|
cfg: AppConfig,
|
|
) -> dict[int, list[tuple[float, float]]]:
|
|
"""
|
|
Return extra in-point seeds from cached vision descriptions.
|
|
|
|
The function is intentionally small-budget: for each beat it describes the
|
|
beat once and only a few top scene-level candidates. Existing descriptions
|
|
are read from cache and cost nothing.
|
|
"""
|
|
if not cfg.vision.enabled:
|
|
return {}
|
|
if not beats or not scenes:
|
|
return {}
|
|
|
|
from src.cv.vibe_check import run_vibe_check
|
|
|
|
cache = _load_cache(cfg)
|
|
budget = [cfg.vision.max_new_descriptions_per_run]
|
|
scenes_by_id = {scene.scene_id: scene for scene in scenes}
|
|
seeds: dict[int, list[tuple[float, float]]] = {}
|
|
|
|
for beat in beats:
|
|
beat_desc = _describe_sample(
|
|
kind="beat",
|
|
item_id=beat.beat_id,
|
|
label=f"trailer beat {beat.beat_id}",
|
|
video_path=beat.trailer_path,
|
|
start_s=beat.start_s,
|
|
end_s=beat.end_s,
|
|
cfg=cfg,
|
|
cache=cache,
|
|
budget=budget,
|
|
)
|
|
if not beat_desc:
|
|
continue
|
|
|
|
hits = run_vibe_check(
|
|
beat,
|
|
scenes,
|
|
top_k=cfg.vision.scene_candidate_top_k,
|
|
hist_method=cfg.cv.vibe_check.hist_compare_method,
|
|
phash_max_distance=64,
|
|
)
|
|
|
|
ranked: list[tuple[float, Scene]] = []
|
|
for hit in hits:
|
|
scene = scenes_by_id.get(hit.scene_id)
|
|
if scene is None:
|
|
continue
|
|
scene_desc = _describe_sample(
|
|
kind="scene",
|
|
item_id=scene.scene_id,
|
|
label=f"source scene {scene.scene_id}",
|
|
video_path=scene.source_path,
|
|
start_s=scene.start_s,
|
|
end_s=scene.end_s,
|
|
cfg=cfg,
|
|
cache=cache,
|
|
budget=budget,
|
|
)
|
|
if not scene_desc:
|
|
continue
|
|
score = _text_similarity(beat_desc, scene_desc)
|
|
if score >= cfg.vision.similarity_threshold:
|
|
ranked.append((score, scene))
|
|
|
|
ranked.sort(key=lambda item: item[0], reverse=True)
|
|
points: list[tuple[float, float]] = []
|
|
for score, scene in ranked[:cfg.vision.max_seed_scenes]:
|
|
logger.info(
|
|
"Beat %d: vision seed scene=%d score=%.3f",
|
|
beat.beat_id,
|
|
scene.scene_id,
|
|
score,
|
|
)
|
|
weighted_score = max(
|
|
cfg.cv.deep_scan.coarse_candidate_threshold,
|
|
min(0.98, cfg.vision.seed_score * (0.75 + min(1.0, score) * 0.25)),
|
|
)
|
|
points.extend(
|
|
(point, weighted_score)
|
|
for point in _scene_seed_points(scene, cfg.vision.seed_points_per_scene)
|
|
)
|
|
|
|
if points:
|
|
merged: dict[float, float] = {}
|
|
for point, weighted_score in points:
|
|
key = round(max(0.0, point), 3)
|
|
merged[key] = max(weighted_score, merged.get(key, 0.0))
|
|
seeds[beat.beat_id] = sorted((point, score) for point, score in merged.items())
|
|
|
|
_save_cache(cfg, cache)
|
|
return seeds
|