Add vision prepass for targeted matches

This commit is contained in:
Melbar
2026-05-02 13:03:15 +02:00
parent f1173eacee
commit 884a0d4232
2 changed files with 125 additions and 11 deletions
+115 -11
View File
@@ -485,6 +485,71 @@ def _reference_scoreable_segments(beat, cfg) -> list[tuple[float, float]]:
return raw
def _trim_beats_to_single_visual_island(beats: list, cfg) -> tuple[list, dict[int, tuple[float, float]]]:
"""Use a single visible island as the primary match target for faded beats."""
from dataclasses import replace
trimmed = []
trims: dict[int, tuple[float, float]] = {}
frame_s = 1.0 / max(1.0, float(cfg.export.edl_frame_rate))
for beat in beats:
islands = _reference_scoreable_segments(beat, cfg)
if len(islands) == 1:
start_s, end_s = islands[0]
island_duration_s = max(0.0, end_s - start_s)
has_real_trim = (
start_s > frame_s * 1.5
or beat.duration_s - end_s > frame_s * 1.5
)
if island_duration_s > 0.0 and has_real_trim:
trimmed.append(
replace(
beat,
start_s=beat.start_s + start_s,
end_s=beat.start_s + end_s,
)
)
trims[beat.beat_id] = (start_s, island_duration_s)
continue
trimmed.append(beat)
return trimmed, trims
def _apply_single_island_segments(results: list, trims: dict[int, tuple[float, float]]) -> list:
"""Restore beat-relative segment metadata after matching a trimmed island."""
if not trims:
return results
from dataclasses import replace
from src.core.models import MatchSegment
expanded = []
for result in results:
trim = trims.get(result.beat_id)
if trim is None or getattr(result, "segments", ()):
expanded.append(result)
continue
trailer_offset_s, island_duration_s = trim
duration_s = min(max(0.0, island_duration_s), max(0.0, result.duration_s))
segment = MatchSegment(
trailer_offset_s=trailer_offset_s,
duration_s=duration_s,
scene_id=result.scene_id,
in_point_s=result.in_point_s,
out_point_s=result.in_point_s + duration_s,
match_score=result.match_score,
is_confirmed=result.is_confirmed,
)
expanded.append(
replace(
result,
out_point_s=result.in_point_s + duration_s,
segments=(segment,),
)
)
return expanded
def _attach_visual_segments(results: list, beats: list, cfg) -> list:
"""Attach automatic sub-shot matches for multi-island trailer beats."""
from dataclasses import replace
@@ -558,7 +623,13 @@ def _attach_visual_segments(results: list, beats: list, cfg) -> list:
return expanded
def _match_unmatched_visual_segments(results: list, beats: list, cached: list, cfg) -> list:
def _match_unmatched_visual_segments(
results: list,
beats: list,
cached: list,
cfg,
skip_global_segment_scan_for: set[int] | None = None,
) -> list:
"""Create segmented provisional matches when a whole beat has no single match."""
from dataclasses import replace
from src.core.models import MatchResult, MatchSegment
@@ -567,6 +638,7 @@ def _match_unmatched_visual_segments(results: list, beats: list, cached: list, c
matched_ids = {r.beat_id for r in results}
expanded = list(results)
skip_global_segment_scan_for = skip_global_segment_scan_for or set()
try:
fps = float(get_video_info(cfg.paths.source_movie)["fps"]) or cfg.export.edl_frame_rate
except Exception:
@@ -593,11 +665,13 @@ def _match_unmatched_visual_segments(results: list, beats: list, cached: list, c
cached + expanded,
cfg,
)
segment_matches = run_global_scan(
[segment_beat],
cfg,
seed_in_points=continuity,
)
segment_matches = []
if beat.beat_id not in skip_global_segment_scan_for:
segment_matches = run_global_scan(
[segment_beat],
cfg,
seed_in_points=continuity,
)
if not segment_matches:
local_segment = _local_same_scene_segment_match(
segment_beat,
@@ -725,18 +799,48 @@ def cmd_match(args: argparse.Namespace, cfg) -> list:
all_beats = _load_beats(cfg)
beats = _select_beats(all_beats, getattr(args, "beat", None))
cached = _normalize_cached_results(all_beats, _load_results(cfg), cfg) if _results_cache_path(cfg).exists() else []
scan_beats, single_island_trims = _trim_beats_to_single_visual_island(beats, cfg)
seed_in_points = (
_continuity_seed_in_points(args.beat, all_beats, cached, cfg)
if getattr(args, "beat", None) is not None
else None
)
results = run_matching(
cfg,
results = []
if cfg.vision.enabled:
fast_cfg = replace(
cfg,
cv=replace(
cfg.cv,
deep_scan=replace(cfg.cv.deep_scan, skip_coarse_scan_with_weighted_seeds=True),
),
vision=replace(cfg.vision, fullscan_fallback=False),
)
results = run_matching(
fast_cfg,
scan_beats,
force_reindex=args.force_reindex,
seed_in_points=seed_in_points,
)
if len(results) < len(scan_beats):
matched_ids = {r.beat_id for r in results}
remaining_beats = [b for b in scan_beats if b.beat_id not in matched_ids]
if remaining_beats:
full_results = run_matching(
cfg,
remaining_beats,
force_reindex=args.force_reindex,
seed_in_points=seed_in_points,
)
results = sorted([*results, *full_results], key=lambda r: r.beat_id)
results = _apply_single_island_segments(results, single_island_trims)
results = _match_unmatched_visual_segments(
results,
beats,
force_reindex=args.force_reindex,
seed_in_points=seed_in_points,
cached,
cfg,
skip_global_segment_scan_for=set(single_island_trims),
)
results = _match_unmatched_visual_segments(results, beats, cached, cfg)
results = _attach_visual_segments(results, beats, cfg)
# A targeted one-beat match should improve the cache without deleting