Improve segmented vision matching quality

This commit is contained in:
Melbar
2026-05-02 13:49:16 +02:00
parent 884a0d4232
commit e6bd0faf03
4 changed files with 114 additions and 22 deletions
+49 -14
View File
@@ -623,6 +623,47 @@ def _attach_visual_segments(results: list, beats: list, cfg) -> list:
return expanded
def _fast_vision_match_cfg(cfg):
"""Return a vision-seed prepass config that still keeps quality settings."""
from dataclasses import replace
return replace(
cfg,
cv=replace(
cfg.cv,
deep_scan=replace(cfg.cv.deep_scan, skip_coarse_scan_with_weighted_seeds=True),
),
vision=replace(
cfg.vision,
fullscan_fallback=False,
),
)
def _run_segment_match(segment_beat, continuity, cfg, allow_fullscan: bool = True):
"""Match one visual island with the same generic staged strategy as a beat."""
from src.pipeline.matcher import run_matching
if cfg.vision.enabled:
fast_cfg = _fast_vision_match_cfg(cfg)
fast_matches = run_matching(
fast_cfg,
[segment_beat],
seed_in_points=continuity,
)
if fast_matches:
return fast_matches
if not allow_fullscan:
return []
return run_matching(
cfg,
[segment_beat],
seed_in_points=continuity,
)
def _match_unmatched_visual_segments(
results: list,
beats: list,
@@ -634,7 +675,6 @@ def _match_unmatched_visual_segments(
from dataclasses import replace
from src.core.models import MatchResult, MatchSegment
from src.cv.frame_extractor import get_video_info
from src.cv.global_scan import run_global_scan
matched_ids = {r.beat_id for r in results}
expanded = list(results)
@@ -667,11 +707,7 @@ def _match_unmatched_visual_segments(
)
segment_matches = []
if beat.beat_id not in skip_global_segment_scan_for:
segment_matches = run_global_scan(
[segment_beat],
cfg,
seed_in_points=continuity,
)
segment_matches = _run_segment_match(segment_beat, continuity, cfg, allow_fullscan=True)
if not segment_matches:
local_segment = _local_same_scene_segment_match(
segment_beat,
@@ -799,7 +835,13 @@ def cmd_match(args: argparse.Namespace, cfg) -> list:
all_beats = _load_beats(cfg)
beats = _select_beats(all_beats, getattr(args, "beat", None))
cached = _normalize_cached_results(all_beats, _load_results(cfg), cfg) if _results_cache_path(cfg).exists() else []
multi_island_beat_ids = {
beat.beat_id
for beat in beats
if len(_reference_scoreable_segments(beat, cfg)) > 1
}
scan_beats, single_island_trims = _trim_beats_to_single_visual_island(beats, cfg)
scan_beats = [b for b in scan_beats if b.beat_id not in multi_island_beat_ids]
seed_in_points = (
_continuity_seed_in_points(args.beat, all_beats, cached, cfg)
if getattr(args, "beat", None) is not None
@@ -807,14 +849,7 @@ def cmd_match(args: argparse.Namespace, cfg) -> list:
)
results = []
if cfg.vision.enabled:
fast_cfg = replace(
cfg,
cv=replace(
cfg.cv,
deep_scan=replace(cfg.cv.deep_scan, skip_coarse_scan_with_weighted_seeds=True),
),
vision=replace(cfg.vision, fullscan_fallback=False),
)
fast_cfg = _fast_vision_match_cfg(cfg)
results = run_matching(
fast_cfg,
scan_beats,