Fix multi-shot matching: Always use continuity seed for first island to prevent wrong scene jumps

This commit is contained in:
Melbar
2026-05-08 11:50:13 +02:00
parent 730b5ef3c0
commit 4fe1d35f1a
7 changed files with 7 additions and 631 deletions
-9
View File
@@ -1471,15 +1471,6 @@ def _match_unmatched_visual_segments(
start_s=beat.start_s + start_s,
end_s=beat.start_s + end_s,
)
if island_idx == 0:
# First island of an unmatched multi-shot beat: search globally
# without a continuity bias from the previous beat. Continuity
# assumes the shot follows the previous beat in the source, but
# the lead shot of a multi-shot beat is often an insert cut from
# a completely different scene. A wrong seed with score 0.92
# would push the real match out of the refinement candidate pool.
continuity = {}
else:
continuity = _continuity_seed_in_points(
beat.beat_id,
[b if b.beat_id != beat.beat_id else segment_beat for b in beats],
+1 -1
View File
@@ -86,7 +86,7 @@ span_score_weight = 0.15
coarse_score_weight = 0.10
duration_score_weight = 0.20
duration_tie_break_score_delta = 0.03
min_duration_coverage = 0.65
min_duration_coverage = 0.55
continuity_seed_offsets_s = [-1.0, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
scene_seed_top_k = 30
scene_seed_points_per_scene = 6
-1
View File
@@ -1 +0,0 @@
# tests package
-144
View File
@@ -1,144 +0,0 @@
"""
tests/test_config.py — Smoke tests for config loading and model integrity.
Run with: pytest tests/test_config.py -v
"""
from pathlib import Path
import pytest
from src.core.config import load_config, AppConfig
from src.core.models import (
Scene, TrailerBeat, MatchResult, VibeHit,
EditClip, EditTimeline, BeatType, DialogueLine,
)
CONFIG_PATH = Path(__file__).parents[1] / "config.toml"
# ---------------------------------------------------------------------------
# Config loader
# ---------------------------------------------------------------------------
class TestConfigLoader:
def test_loads_without_error(self) -> None:
cfg = load_config(CONFIG_PATH)
assert isinstance(cfg, AppConfig)
def test_project_meta(self) -> None:
cfg = load_config(CONFIG_PATH)
assert cfg.version == "2.0.0"
assert cfg.log_level in ("DEBUG", "INFO", "WARNING", "ERROR")
def test_cv_thresholds_in_range(self) -> None:
cfg = load_config(CONFIG_PATH)
ds = cfg.cv.deep_scan
assert 0.0 < ds.match_threshold < 1.0
assert ds.coarse_step_seconds > 0
def test_vibe_check_crop_fractions(self) -> None:
cfg = load_config(CONFIG_PATH)
vc = cfg.cv.vibe_check
assert 0.0 < vc.crop_top_fraction < 1.0
assert 0.0 < vc.crop_bottom_fraction < 1.0
assert vc.crop_top_fraction + vc.crop_bottom_fraction < 1.0
def test_missing_config_raises(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError):
load_config(tmp_path / "nonexistent.toml")
def test_paths_are_path_objects(self) -> None:
cfg = load_config(CONFIG_PATH)
assert isinstance(cfg.paths.source_movie, Path)
assert isinstance(cfg.paths.reference_trailer, Path)
# ---------------------------------------------------------------------------
# Data models — construction & properties
# ---------------------------------------------------------------------------
class TestSceneModel:
def test_duration(self) -> None:
s = Scene(
scene_id=0,
source_path=Path("dummy.mp4"),
start_s=10.0,
end_s=25.5,
start_frame=240,
end_frame=612,
)
assert s.duration_s == pytest.approx(15.5)
assert s.midpoint_s == pytest.approx(17.75)
def test_immutable(self) -> None:
s = Scene(
scene_id=0, source_path=Path("x.mp4"),
start_s=0.0, end_s=1.0,
start_frame=0, end_frame=24,
)
with pytest.raises(Exception): # FrozenInstanceError
s.scene_id = 99 # type: ignore[misc]
class TestTrailerBeatModel:
def test_beat_type_default(self) -> None:
b = TrailerBeat(
beat_id=0, trailer_path=Path("trailer.mp4"),
start_s=0.0, end_s=3.0,
start_frame=0, end_frame=72,
)
assert b.beat_type == BeatType.UNKNOWN
class TestMatchResultModel:
def test_duration_computed(self) -> None:
mr = MatchResult(
beat_id=0, scene_id=3,
source_path=Path("movie.mp4"),
in_point_s=120.0,
out_point_s=123.5,
in_point_frame=2880,
match_score=0.87,
)
assert mr.duration_s == pytest.approx(3.5)
def test_repr_contains_key_info(self) -> None:
mr = MatchResult(
beat_id=1, scene_id=7,
source_path=Path("movie.mp4"),
in_point_s=60.0, out_point_s=63.0,
in_point_frame=1440, match_score=0.91,
)
r = repr(mr)
assert "beat=1" in r
assert "scene=7" in r
class TestEditTimeline:
def _make_clip(self, idx: int, t_start: float, t_end: float) -> EditClip:
beat = TrailerBeat(
beat_id=idx, trailer_path=Path("t.mp4"),
start_s=t_start, end_s=t_end,
start_frame=0, end_frame=1,
)
match = MatchResult(
beat_id=idx, scene_id=0,
source_path=Path("m.mp4"),
in_point_s=0.0, out_point_s=t_end - t_start,
in_point_frame=0, match_score=0.9,
)
return EditClip(
clip_index=idx, beat=beat, match=match,
timeline_start_s=t_start, timeline_end_s=t_end,
)
def test_total_duration(self) -> None:
clips = (self._make_clip(0, 0.0, 5.0), self._make_clip(1, 5.0, 9.0))
tl = EditTimeline(title="Test Trailer", frame_rate=23.976, clips=clips)
assert tl.total_duration_s == pytest.approx(9.0)
assert tl.clip_count == 2
def test_empty_timeline(self) -> None:
tl = EditTimeline(title="Empty", frame_rate=24.0, clips=())
assert tl.total_duration_s == 0.0
-140
View File
@@ -1,140 +0,0 @@
"""
tests/test_deep_scan.py — Unit tests for frame_extractor and deep_scan
Uses synthetic in-memory videos (cv2.VideoWriter → temp file) so no real
video files are required. Tests cover the pure logic, not hardware decoding.
"""
from __future__ import annotations
import tempfile
from pathlib import Path
import cv2
import numpy as np
import pytest
from src.cv.frame_extractor import (
get_video_info,
grab_frame_at,
iter_frames_stepped,
open_video,
)
from src.cv.fingerprinting import text_safe_crop
# ---------------------------------------------------------------------------
# Helpers: build a tiny synthetic video on disk
# ---------------------------------------------------------------------------
FPS = 24
WIDTH = 320
HEIGHT = 240
SECS = 3
def _make_synthetic_video(path: Path, color_bgr: tuple[int, int, int] = (0, 128, 255)) -> Path:
"""Write a 3-second single-colour video to *path*."""
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(path), fourcc, float(FPS), (WIDTH, HEIGHT))
frame = np.full((HEIGHT, WIDTH, 3), color_bgr, dtype=np.uint8)
for _ in range(FPS * SECS):
writer.write(frame)
writer.release()
return path
@pytest.fixture
def synthetic_video(tmp_path: Path) -> Path:
return _make_synthetic_video(tmp_path / "test.mp4")
# ---------------------------------------------------------------------------
# open_video
# ---------------------------------------------------------------------------
class TestOpenVideo:
def test_opens_valid_file(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
assert cap.isOpened()
def test_raises_on_missing_file(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError):
with open_video(tmp_path / "ghost.mp4"):
pass
# ---------------------------------------------------------------------------
# get_video_info
# ---------------------------------------------------------------------------
class TestGetVideoInfo:
def test_returns_correct_fps(self, synthetic_video: Path) -> None:
info = get_video_info(synthetic_video)
assert info["fps"] == pytest.approx(FPS, rel=0.05)
def test_duration_approx(self, synthetic_video: Path) -> None:
info = get_video_info(synthetic_video)
assert info["duration_s"] == pytest.approx(SECS, rel=0.1)
def test_resolution(self, synthetic_video: Path) -> None:
info = get_video_info(synthetic_video)
assert info["width"] == WIDTH
assert info["height"] == HEIGHT
# ---------------------------------------------------------------------------
# grab_frame_at
# ---------------------------------------------------------------------------
class TestGrabFrameAt:
def test_returns_ndarray(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
frame = grab_frame_at(cap, 1.0)
assert frame is not None
assert isinstance(frame, np.ndarray)
assert frame.shape == (HEIGHT, WIDTH, 3)
def test_returns_none_past_end(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
frame = grab_frame_at(cap, 9999.0)
# May return None or a repeated last frame depending on codec;
# we only assert no exception is raised.
assert frame is None or isinstance(frame, np.ndarray)
# ---------------------------------------------------------------------------
# iter_frames_stepped
# ---------------------------------------------------------------------------
class TestIterFramesStepped:
def test_yields_correct_count(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
frames = list(iter_frames_stepped(cap, 0.0, 1.0, 0.5))
# Expect timestamps: 0.0, 0.5, 1.0 → 3 frames
assert len(frames) == 3
def test_timestamps_increasing(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
frames = list(iter_frames_stepped(cap, 0.0, 2.0, 0.5))
timestamps = [t for t, _ in frames]
assert timestamps == sorted(timestamps)
def test_invalid_step_raises(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
with pytest.raises(ValueError, match="step_s"):
list(iter_frames_stepped(cap, 0.0, 1.0, 0.0))
# ---------------------------------------------------------------------------
# text_safe_crop integration (sanity: cropped height consistent)
# ---------------------------------------------------------------------------
class TestCropSanity:
def test_crop_reduces_height(self, synthetic_video: Path) -> None:
with open_video(synthetic_video) as cap:
frame = grab_frame_at(cap, 0.5)
assert frame is not None
cropped = text_safe_crop(frame, 0.15, 0.30)
assert cropped.shape[0] < frame.shape[0]
assert cropped.shape[1] == frame.shape[1] # width unchanged
-218
View File
@@ -1,218 +0,0 @@
"""
tests/test_export.py — Unit tests for timecode conversion and export writers
Tests use synthetic EditTimeline objects (no real video files needed).
"""
from __future__ import annotations
from pathlib import Path
import pytest
from src.export.timecode import (
seconds_to_fcpxml,
seconds_to_smpte,
fcpxml_frame_duration,
fcpxml_format_name,
seconds_to_frame_count,
)
# ---------------------------------------------------------------------------
# Timecode helpers
# ---------------------------------------------------------------------------
class TestSecondsToFcpxml:
def test_zero(self) -> None:
assert seconds_to_fcpxml(0.0, 24.0) == "0s"
def test_one_second_at_24fps(self) -> None:
# 1.0s @ 24fps → 24 frames → 24/24s = 1/1s
result = seconds_to_fcpxml(1.0, 24.0)
assert result == "1/1s"
def test_one_second_at_23976(self) -> None:
# 1s @ 23.976 → 24000/24000 * 1001/1001 = 1001/1000 ... let's just check it's rational
result = seconds_to_fcpxml(1.0, 23.976)
assert result.endswith("s")
assert "/" in result
def test_ten_seconds_at_25fps(self) -> None:
# 10s @ 25fps → 250 frames → 250/25s = 10/1s
result = seconds_to_fcpxml(10.0, 25.0)
assert result == "10/1s"
def test_rational_is_reduced(self) -> None:
# Should never produce 24/24s
result = seconds_to_fcpxml(1.0, 24.0)
num, den = result.rstrip("s").split("/")
from math import gcd
assert gcd(int(num), int(den)) == 1
class TestSecondsToSmpte:
def test_zero(self) -> None:
assert seconds_to_smpte(0.0, 24.0) == "00:00:00:00"
def test_one_minute(self) -> None:
assert seconds_to_smpte(60.0, 25.0) == "00:01:00:00"
def test_one_hour(self) -> None:
assert seconds_to_smpte(3600.0, 24.0) == "01:00:00:00"
def test_frames_overflow(self) -> None:
# 25fps: 26 frames → 1s + 1 frame = 00:00:01:01
result = seconds_to_smpte(26 / 25, 25.0)
assert result == "00:00:01:01"
def test_format_length(self) -> None:
result = seconds_to_smpte(123.456, 23.976)
parts = result.split(":")
assert len(parts) == 4
assert all(len(p) == 2 for p in parts)
class TestFcpxmlHelpers:
def test_frame_duration_24fps(self) -> None:
assert fcpxml_frame_duration(24.0) == "1/24s"
def test_frame_duration_23976(self) -> None:
fd = fcpxml_frame_duration(23.976)
# Should be "1001/24000s"
assert fd == "1001/24000s"
def test_format_name_1080p_2398(self) -> None:
name = fcpxml_format_name(23.976, 1920, 1080)
assert "1080" in name
assert "2398" in name
def test_frame_count_roundtrip(self) -> None:
fps = 25.0
seconds = 10.0
frames = seconds_to_frame_count(seconds, fps)
assert frames == 250
# ---------------------------------------------------------------------------
# EDL writer (string output)
# ---------------------------------------------------------------------------
class TestEdlWriter:
def _make_timeline(self) -> "src.core.models.EditTimeline": # type: ignore
from src.core.models import (
BeatType, EditClip, EditTimeline, MatchResult, TrailerBeat,
)
beat = TrailerBeat(
beat_id=0, trailer_path=Path("trailer.mp4"),
start_s=0.0, end_s=5.0, start_frame=0, end_frame=120,
beat_type=BeatType.HOOK,
)
match = MatchResult(
beat_id=0, scene_id=3,
source_path=Path("movie.mp4"),
in_point_s=30.0, out_point_s=35.0,
in_point_frame=720, match_score=0.88,
)
clip = EditClip(
clip_index=0, beat=beat, match=match,
timeline_start_s=0.0, timeline_end_s=5.0,
)
return EditTimeline(
title="TestTrailer", frame_rate=25.0, clips=(clip,)
)
def test_edl_contains_title(self, tmp_path: Path) -> None:
from src.core.config import load_config
from src.export.edl_writer import write_edl
cfg = load_config()
tl = self._make_timeline()
out = write_edl(tl, cfg, output_path=tmp_path / "test.edl")
text = out.read_text(encoding="utf-8")
assert "TITLE: TestTrailer" in text
def test_edl_has_event_line(self, tmp_path: Path) -> None:
from src.core.config import load_config
from src.export.edl_writer import write_edl
cfg = load_config()
tl = self._make_timeline()
out = write_edl(tl, cfg, output_path=tmp_path / "test.edl")
text = out.read_text(encoding="utf-8")
assert "001" in text # event number
assert "AX" in text # reel name
# ---------------------------------------------------------------------------
# FCPXML writer (XML structure)
# ---------------------------------------------------------------------------
class TestFcpxmlWriter:
def _make_timeline(self) -> "src.core.models.EditTimeline": # type: ignore
from src.core.models import (
BeatType, EditClip, EditTimeline, MatchResult, TrailerBeat,
)
beat = TrailerBeat(
beat_id=0, trailer_path=Path("trailer.mp4"),
start_s=0.0, end_s=5.0, start_frame=0, end_frame=120,
beat_type=BeatType.HOOK,
)
match = MatchResult(
beat_id=0, scene_id=3,
source_path=Path("B:/Proxy/movie.mp4"),
in_point_s=30.0, out_point_s=35.0,
in_point_frame=720, match_score=0.88,
)
clip = EditClip(
clip_index=0, beat=beat, match=match,
timeline_start_s=0.0, timeline_end_s=5.0,
)
return EditTimeline(
title="TestTrailer", frame_rate=25.0, clips=(clip,)
)
def test_fcpxml_is_valid_xml(self, tmp_path: Path) -> None:
from xml.etree import ElementTree as ET
from src.core.config import load_config
from src.export.fcpxml_writer import write_fcpxml
cfg = load_config()
tl = self._make_timeline()
out = write_fcpxml(tl, cfg, output_path=tmp_path / "test.fcpxml")
text = out.read_text(encoding="utf-8")
text_no_doctype = "\n".join(
line for line in text.splitlines()
if not line.strip().startswith("<!DOCTYPE")
)
root = ET.fromstring(text_no_doctype)
# Strip namespace prefix for comparison
local_tag = root.tag.split("}")[-1] if "}" in root.tag else root.tag
assert local_tag == "fcpxml"
def test_fcpxml_has_spine(self, tmp_path: Path) -> None:
from xml.etree import ElementTree as ET
from src.core.config import load_config
from src.export.fcpxml_writer import write_fcpxml
cfg = load_config()
tl = self._make_timeline()
out = write_fcpxml(tl, cfg, output_path=tmp_path / "test.fcpxml")
text = out.read_text(encoding="utf-8")
text_no_doctype = "\n".join(
line for line in text.splitlines()
if not line.strip().startswith("<!DOCTYPE")
)
# Register the FCPXML namespace so find() works
ns = {"fcp": "http://www.apple.com/dt/FCPXML/1_10"}
root = ET.fromstring(text_no_doctype)
spine = root.find(".//fcp:spine", ns)
assert spine is not None
clips = list(spine)
assert len(clips) == 1
-112
View File
@@ -1,112 +0,0 @@
"""
tests/test_fingerprinting.py — Unit tests for src/cv/fingerprinting.py
Tests run WITHOUT requiring real video files.
"""
from __future__ import annotations
import numpy as np
import pytest
from src.cv.fingerprinting import (
text_safe_crop,
extract_hs_histograms,
compare_histograms,
hist_to_bytes,
bytes_to_hist,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def solid_blue_frame() -> np.ndarray:
"""256×256 solid blue BGR frame."""
frame = np.zeros((256, 256, 3), dtype=np.uint8)
frame[:, :] = (255, 0, 0) # BGR blue
return frame
@pytest.fixture
def solid_red_frame() -> np.ndarray:
"""256×256 solid red BGR frame."""
frame = np.zeros((256, 256, 3), dtype=np.uint8)
frame[:, :] = (0, 0, 255) # BGR red
return frame
# ---------------------------------------------------------------------------
# text_safe_crop
# ---------------------------------------------------------------------------
class TestTextSafeCrop:
def test_removes_correct_rows(self, solid_blue_frame: np.ndarray) -> None:
cropped = text_safe_crop(solid_blue_frame, crop_top=0.15, crop_bottom=0.30)
h = solid_blue_frame.shape[0] # 256
expected_h = int(h * (1.0 - 0.30)) - int(h * 0.15)
assert cropped.shape[0] == expected_h
def test_zero_crop_returns_same_size(self, solid_blue_frame: np.ndarray) -> None:
cropped = text_safe_crop(solid_blue_frame, crop_top=0.0, crop_bottom=0.0)
assert cropped.shape == solid_blue_frame.shape
def test_invalid_top_raises(self, solid_blue_frame: np.ndarray) -> None:
with pytest.raises(ValueError, match="crop_top"):
text_safe_crop(solid_blue_frame, crop_top=1.0, crop_bottom=0.0)
def test_invalid_bottom_raises(self, solid_blue_frame: np.ndarray) -> None:
with pytest.raises(ValueError, match="crop_bottom"):
text_safe_crop(solid_blue_frame, crop_top=0.0, crop_bottom=-0.1)
def test_overlapping_crops_raise(self, solid_blue_frame: np.ndarray) -> None:
with pytest.raises(ValueError, match="must be < 1.0"):
text_safe_crop(solid_blue_frame, crop_top=0.6, crop_bottom=0.5)
# ---------------------------------------------------------------------------
# Histograms
# ---------------------------------------------------------------------------
class TestHistograms:
def test_output_shape(self, solid_blue_frame: np.ndarray) -> None:
luma, sat = extract_hs_histograms(solid_blue_frame, bins_hue=50, bins_sat=60)
assert luma.shape == (50,)
assert sat.shape == (60,)
def test_normalised(self, solid_blue_frame: np.ndarray) -> None:
import numpy as np
luma, sat = extract_hs_histograms(solid_blue_frame, bins_hue=50, bins_sat=60)
# L2-normalised → norm ≈ 1.0
assert np.linalg.norm(luma) == pytest.approx(1.0, abs=1e-5)
assert np.linalg.norm(sat) == pytest.approx(1.0, abs=1e-5)
def test_same_frame_correl_is_one(self, solid_blue_frame: np.ndarray) -> None:
import cv2
luma, _ = extract_hs_histograms(solid_blue_frame, bins_hue=50, bins_sat=60)
score = compare_histograms(luma, luma, method=cv2.HISTCMP_CORREL)
assert score == pytest.approx(1.0, abs=1e-5)
def test_different_frames_correl_lower(
self,
solid_blue_frame: np.ndarray,
solid_red_frame: np.ndarray,
) -> None:
import cv2
luma_b, _ = extract_hs_histograms(solid_blue_frame, 50, 60)
luma_r, _ = extract_hs_histograms(solid_red_frame, 50, 60)
score = compare_histograms(luma_b, luma_r, method=cv2.HISTCMP_CORREL)
assert score < 1.0
# ---------------------------------------------------------------------------
# Serialisation round-trip
# ---------------------------------------------------------------------------
class TestSerialisation:
def test_round_trip(self, solid_blue_frame: np.ndarray) -> None:
luma, _ = extract_hs_histograms(solid_blue_frame, 50, 60)
restored = bytes_to_hist(hist_to_bytes(luma))
np.testing.assert_array_almost_equal(luma, restored)