Files
aitrailer/src/llm/vision_cache.py
T
2026-05-02 13:49:16 +02:00

360 lines
12 KiB
Python

"""
Cached vision descriptions for ambiguous trailer/source matching.
This module is deliberately conservative: it never writes a final match and it
does not replace CV. It describes a small number of 3-frame beat/scene samples,
caches those descriptions, and returns extra source in-point seeds for the CV
scanner to verify.
"""
from __future__ import annotations
import base64
import json
import logging
import re
import time
import urllib.error
import urllib.request
from dataclasses import asdict
from pathlib import Path
from typing import Sequence
import cv2
from src.core.config import AppConfig
from src.core.models import Scene, TrailerBeat
logger = logging.getLogger(__name__)
_CACHE_VERSION = 1
_STOPWORDS = {
"the", "and", "with", "from", "that", "this", "there", "their", "into",
"scene", "frame", "image", "shot", "video", "visible", "looks", "appears",
"eine", "einer", "einem", "einen", "und", "oder", "mit", "der", "die", "das",
}
_SYSTEM_PROMPT = """You describe film shots for automatic matching.
Return only compact JSON with these keys:
subject, setting, composition, action_phase, distinctive_objects, lighting_color, negatives.
Focus on stable visual facts and spatial layout. Ignore timecode overlays, subtitles, logos, compression, aspect ratio, and color grading differences."""
_RETRYABLE_HTTP_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_CREDIT_ERROR_PATTERNS = (
"insufficient credit",
"insufficient credits",
"no credits",
"out of credits",
"billing",
"quota exceeded",
"payment required",
)
def _cache_path(cfg: AppConfig) -> Path:
return cfg.paths.cache_dir / "vision_descriptions.json"
def _load_cache(cfg: AppConfig) -> dict:
path = _cache_path(cfg)
if not path.exists():
return {"version": _CACHE_VERSION, "items": {}}
try:
data = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
logger.warning("Vision cache is unreadable; rebuilding: %s", path)
return {"version": _CACHE_VERSION, "items": {}}
if data.get("version") != _CACHE_VERSION or not isinstance(data.get("items"), dict):
return {"version": _CACHE_VERSION, "items": {}}
return data
def _save_cache(cfg: AppConfig, cache: dict) -> None:
path = _cache_path(cfg)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(cache, indent=2, ensure_ascii=False), encoding="utf-8")
def _sample_times(start_s: float, end_s: float) -> list[float]:
duration_s = max(0.04, end_s - start_s)
return [
start_s + min(duration_s * 0.12, max(0.0, duration_s - 0.04)),
start_s + duration_s * 0.50,
start_s + max(0.0, duration_s - min(duration_s * 0.12, 0.20)),
]
def _frame_data_url(video_path: Path, t_s: float) -> str | None:
cap = cv2.VideoCapture(str(video_path))
try:
if not cap.isOpened():
return None
cap.set(cv2.CAP_PROP_POS_MSEC, max(0.0, t_s) * 1000.0)
ok, frame = cap.read()
if not ok or frame is None:
return None
h, w = frame.shape[:2]
if w > 640:
frame = cv2.resize(frame, (640, int(h * (640 / w))), interpolation=cv2.INTER_AREA)
ok, encoded = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 72])
if not ok:
return None
payload = base64.b64encode(encoded.tobytes()).decode("ascii")
return f"data:image/jpeg;base64,{payload}"
finally:
cap.release()
def _call_vision_model(label: str, image_urls: list[str], cfg: AppConfig) -> str:
vision = cfg.vision
if vision.provider in ("openai", "openrouter") and not vision.api_key:
raise RuntimeError(
"Vision is enabled but no API key is available. Set VISION_API_KEY, "
"OPENROUTER_API_KEY, OPENAI_API_KEY, or LLM_API_KEY."
)
content: list[dict] = [{
"type": "text",
"text": (
f"Describe this 3-frame sample for matching. Label: {label}. "
"The frames are start, middle, and end of the same beat/scene."
),
}]
content.extend({
"type": "image_url",
"image_url": {"url": url, "detail": "low"},
} for url in image_urls)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {vision.api_key}",
}
if vision.provider == "openrouter":
headers["HTTP-Referer"] = "https://github.com/ai-trailer-2026"
headers["X-Title"] = "AI Trailer Generator v2"
body = json.dumps({
"model": vision.model,
"messages": [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": content},
],
"temperature": vision.temperature,
"max_tokens": vision.max_tokens,
}).encode("utf-8")
url = f"{vision.base_url.rstrip('/')}/chat/completions"
req = urllib.request.Request(url, data=body, headers=headers, method="POST")
delays_s = (8.0, 20.0, 45.0, 90.0)
for attempt in range(len(delays_s) + 1):
try:
with urllib.request.urlopen(req, timeout=vision.timeout_seconds) as resp:
data = json.loads(resp.read().decode("utf-8"))
return str(data["choices"][0]["message"]["content"]).strip()
except urllib.error.HTTPError as exc:
body_text = exc.read().decode(errors="replace")
lowered = body_text.lower()
if exc.code == 402 or any(pattern in lowered for pattern in _CREDIT_ERROR_PATTERNS):
raise RuntimeError(f"Vision HTTP {exc.code} from {url}:\n{body_text}") from exc
if exc.code not in _RETRYABLE_HTTP_CODES or attempt >= len(delays_s):
raise RuntimeError(f"Vision HTTP {exc.code} from {url}:\n{body_text}") from exc
delay_s = delays_s[attempt]
logger.warning(
"Vision HTTP %d for %s; waiting %.0fs before retry %d/%d.",
exc.code,
label,
delay_s,
attempt + 1,
len(delays_s),
)
time.sleep(delay_s)
except urllib.error.URLError as exc:
if attempt >= len(delays_s):
raise RuntimeError(f"Vision request failed for {url}: {exc}") from exc
delay_s = delays_s[attempt]
logger.warning(
"Vision request failed for %s (%s); waiting %.0fs before retry %d/%d.",
label,
exc.reason,
delay_s,
attempt + 1,
len(delays_s),
)
time.sleep(delay_s)
raise RuntimeError(f"Vision request failed unexpectedly for {url}")
def _description_key(kind: str, item_id: int, start_s: float, end_s: float, cfg: AppConfig) -> str:
path = cfg.paths.reference_trailer if kind == "beat" else cfg.paths.source_movie
try:
stamp = int(path.stat().st_mtime)
except OSError:
stamp = 0
return (
f"{kind}:{item_id}:"
f"{start_s:.3f}:{end_s:.3f}:"
f"{cfg.vision.provider}:{cfg.vision.model}:{stamp}"
)
def _describe_sample(
*,
kind: str,
item_id: int,
label: str,
video_path: Path,
start_s: float,
end_s: float,
cfg: AppConfig,
cache: dict,
budget: list[int],
) -> str | None:
key = _description_key(kind, item_id, start_s, end_s, cfg)
cached = cache["items"].get(key)
if cached:
return str(cached.get("description", ""))
if budget[0] <= 0:
return None
image_urls = [
url for url in (_frame_data_url(video_path, t) for t in _sample_times(start_s, end_s))
if url is not None
]
if len(image_urls) < 2:
return None
description = _call_vision_model(label, image_urls, cfg)
cache["items"][key] = {
"kind": kind,
"item_id": item_id,
"start_s": start_s,
"end_s": end_s,
"label": label,
"description": description,
}
budget[0] -= 1
return description
def _terms(text: str) -> set[str]:
words = re.findall(r"[a-zA-Z][a-zA-Z0-9_'-]{2,}", text.lower())
return {w for w in words if w not in _STOPWORDS}
def _text_similarity(a: str, b: str) -> float:
ta = _terms(a)
tb = _terms(b)
if not ta or not tb:
return 0.0
overlap = len(ta & tb)
return float(overlap / max(8, min(len(ta), len(tb))))
def _scene_seed_points(scene: Scene, max_points: int) -> list[float]:
if max_points <= 1 or scene.duration_s <= 0:
return [scene.start_s]
usable_end = max(scene.start_s, scene.end_s - 0.2)
if usable_end <= scene.start_s:
return [scene.start_s]
step = (usable_end - scene.start_s) / max(1, max_points - 1)
return [scene.start_s + step * idx for idx in range(max_points)]
def build_vision_seed_in_points(
beats: Sequence[TrailerBeat],
scenes: Sequence[Scene],
cfg: AppConfig,
) -> dict[int, list[tuple[float, float]]]:
"""
Return extra in-point seeds from cached vision descriptions.
The function is intentionally small-budget: for each beat it describes the
beat once and only a few top scene-level candidates. Existing descriptions
are read from cache and cost nothing.
"""
if not cfg.vision.enabled:
return {}
if not beats or not scenes:
return {}
from src.cv.vibe_check import run_vibe_check
cache = _load_cache(cfg)
budget = [cfg.vision.max_new_descriptions_per_run]
scenes_by_id = {scene.scene_id: scene for scene in scenes}
seeds: dict[int, list[tuple[float, float]]] = {}
for beat in beats:
beat_desc = _describe_sample(
kind="beat",
item_id=beat.beat_id,
label=f"trailer beat {beat.beat_id}",
video_path=beat.trailer_path,
start_s=beat.start_s,
end_s=beat.end_s,
cfg=cfg,
cache=cache,
budget=budget,
)
if not beat_desc:
continue
hits = run_vibe_check(
beat,
scenes,
top_k=cfg.vision.scene_candidate_top_k,
hist_method=cfg.cv.vibe_check.hist_compare_method,
phash_max_distance=64,
)
ranked: list[tuple[float, Scene]] = []
for hit in hits:
scene = scenes_by_id.get(hit.scene_id)
if scene is None:
continue
scene_desc = _describe_sample(
kind="scene",
item_id=scene.scene_id,
label=f"source scene {scene.scene_id}",
video_path=scene.source_path,
start_s=scene.start_s,
end_s=scene.end_s,
cfg=cfg,
cache=cache,
budget=budget,
)
if not scene_desc:
continue
score = _text_similarity(beat_desc, scene_desc)
if score >= cfg.vision.similarity_threshold:
ranked.append((score, scene))
ranked.sort(key=lambda item: item[0], reverse=True)
points: list[tuple[float, float]] = []
for score, scene in ranked[:cfg.vision.max_seed_scenes]:
logger.info(
"Beat %d: vision seed scene=%d score=%.3f",
beat.beat_id,
scene.scene_id,
score,
)
weighted_score = max(
cfg.cv.deep_scan.coarse_candidate_threshold,
min(0.98, cfg.vision.seed_score * (0.75 + min(1.0, score) * 0.25)),
)
points.extend(
(point, weighted_score)
for point in _scene_seed_points(scene, cfg.vision.seed_points_per_scene)
)
if points:
merged: dict[float, float] = {}
for point, weighted_score in points:
key = round(max(0.0, point), 3)
merged[key] = max(weighted_score, merged.get(key, 0.0))
seeds[beat.beat_id] = sorted((point, score) for point, score in merged.items())
_save_cache(cfg, cache)
return seeds