Add optical flow Auto Segment button

Annotators can now press Auto Segment to replace the current mask with
an automatic river segmentation based on dense optical flow magnitude
and frame brightness. The result is pushed onto the undo stack, so it
can be refined or reverted like any other mask operation.

Parameters (norm_squared_threshold, gaussian_kernel, brightness_range)
live in a separate config/optical_flow_config.yaml; the button is only
enabled when optical_flow_config_file is set in config.yaml.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-20 15:13:10 +02:00
parent 47432cec4f
commit 67c9a1152c
6 changed files with 130 additions and 13 deletions

View File

@@ -20,7 +20,8 @@ from PySide6.QtWidgets import (
)
from .clip_selector import ClipSelector
from .config import AppConfig
from .compute_optical_flow import compute_optical_flow_mask
from .config import AppConfig, load_optical_flow_config
from .mask_canvas import MaskCanvas
from .video_loader import load_frames
@@ -38,6 +39,11 @@ class Annotator(QMainWindow):
self.cfg = config
self.out_dir = Path(config.out_dir)
self.extras = extras
self.of_cfg = (
load_optical_flow_config(Path(config.optical_flow_config_file))
if config.optical_flow_config_file
else None
)
self.selector = ClipSelector(
data_dir=Path(config.data_dir),
@@ -114,9 +120,17 @@ class Annotator(QMainWindow):
btn_undo10 = QPushButton("Undo×10")
btn_redo = QPushButton("Redo")
btn_load_prev_mask = QPushButton("Load Prev Mask")
btn_auto_segment = QPushButton("Auto Segment")
btn_auto_segment.setEnabled(self.of_cfg is not None and self.of_cfg.enabled)
row1 = QHBoxLayout()
for b in [self.btn_prev, btn_next, btn_skip, btn_load_prev_mask]:
for b in [
self.btn_prev,
btn_next,
btn_skip,
btn_load_prev_mask,
btn_auto_segment,
]:
row1.addWidget(b)
row_tools = QHBoxLayout()
@@ -198,6 +212,7 @@ class Annotator(QMainWindow):
btn_undo10.clicked.connect(self.mc.undo10)
btn_redo.clicked.connect(self.mc.redo)
btn_load_prev_mask.clicked.connect(self.load_prev_mask)
btn_auto_segment.clicked.connect(self.run_optical_flow)
if self._pending_answers:
self._set_answers(self._pending_answers)
@@ -406,3 +421,13 @@ class Annotator(QMainWindow):
interpolation=cv2.INTER_NEAREST,
)
self.mc.set_mask(mask)
def run_optical_flow(self):
mask = compute_optical_flow_mask(
self.frames,
self.fps,
self.of_cfg.norm_squared_threshold,
self.of_cfg.gaussian_kernel,
self.of_cfg.brightness_range,
)
self.mc.set_mask(mask)

View File

@@ -0,0 +1,49 @@
import cv2
import numpy as np
def compute_optical_flow_mask(
frames: list[np.ndarray],
fps: float,
norm_squared_threshold: float,
gaussian_kernel: tuple[int, int],
brightness_range: tuple[int, int],
) -> np.ndarray:
"""Return a binary mask (uint8, values 0/1) from optical flow + brightness."""
if len(frames) < 2:
return np.zeros(frames[0].shape[:2], dtype=np.uint8)
frames_arr = np.stack(frames).astype(np.float64)
frames_sub_mean = frames_arr - np.mean(frames_arr, axis=0)
mn, mx = frames_sub_mean.min(), frames_sub_mean.max()
if mx > mn:
standardized = ((frames_sub_mean - mn) / (mx - mn) * 255).astype(np.uint8)
else:
standardized = np.zeros_like(frames_arr, dtype=np.uint8)
N = len(standardized)
gray = np.stack([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in standardized])
flow_data = np.zeros((N - 1,) + gray.shape[1:] + (2,))
for i in range(N - 1):
flow_data[i] = fps * cv2.optflow.calcOpticalFlowSparseToDense(
gray[i], gray[i + 1]
)
optical_flow = np.median(flow_data, axis=0)
flow_norm_sq = np.sum(optical_flow**2, axis=-1)
max_norm = np.max(flow_norm_sq)
if max_norm > 0:
flow_mask = flow_norm_sq >= max_norm * norm_squared_threshold**2
else:
flow_mask = np.zeros(flow_norm_sq.shape, dtype=bool)
reference_frame = frames[len(frames) // 2]
smoothed = cv2.GaussianBlur(reference_frame, gaussian_kernel, 0)
gray_ref = cv2.cvtColor(smoothed, cv2.COLOR_RGB2GRAY)
brightness_mask = (gray_ref > brightness_range[0]) & (
gray_ref < brightness_range[1]
)
return np.logical_and(brightness_mask, flow_mask).astype(np.uint8)

View File

@@ -28,6 +28,7 @@ class AppConfig:
data_dir: str = "data/clips"
out_dir: str = "data/annotation_results"
clips_file: str = "config/clips.txt"
optical_flow_config_file: str = ""
questions: list = field(default_factory=list)
filenames: FilenameConfig = field(default_factory=FilenameConfig)
@@ -51,6 +52,22 @@ class AppConfig:
]
@dataclass
class OpticalFlowConfig:
enabled: bool = False
norm_squared_threshold: float = 0.3
gaussian_kernel: tuple[int, int] = (5, 5)
brightness_range: tuple[int, int] = (20, 235)
def load_optical_flow_config(path: Path) -> OpticalFlowConfig:
with open(path) as f:
data = yaml.safe_load(f)
data["gaussian_kernel"] = tuple(data["gaussian_kernel"])
data["brightness_range"] = tuple(data["brightness_range"])
return OpticalFlowConfig(**data)
def load_config(path: Path) -> AppConfig:
with open(path) as f:
data = yaml.safe_load(f)