Add optical flow Auto Segment button
Annotators can now press Auto Segment to replace the current mask with an automatic river segmentation based on dense optical flow magnitude and frame brightness. The result is pushed onto the undo stack, so it can be refined or reverted like any other mask operation. Parameters (norm_squared_threshold, gaussian_kernel, brightness_range) live in a separate config/optical_flow_config.yaml; the button is only enabled when optical_flow_config_file is set in config.yaml. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -20,7 +20,8 @@ from PySide6.QtWidgets import (
|
||||
)
|
||||
|
||||
from .clip_selector import ClipSelector
|
||||
from .config import AppConfig
|
||||
from .compute_optical_flow import compute_optical_flow_mask
|
||||
from .config import AppConfig, load_optical_flow_config
|
||||
from .mask_canvas import MaskCanvas
|
||||
from .video_loader import load_frames
|
||||
|
||||
@@ -38,6 +39,11 @@ class Annotator(QMainWindow):
|
||||
self.cfg = config
|
||||
self.out_dir = Path(config.out_dir)
|
||||
self.extras = extras
|
||||
self.of_cfg = (
|
||||
load_optical_flow_config(Path(config.optical_flow_config_file))
|
||||
if config.optical_flow_config_file
|
||||
else None
|
||||
)
|
||||
|
||||
self.selector = ClipSelector(
|
||||
data_dir=Path(config.data_dir),
|
||||
@@ -114,9 +120,17 @@ class Annotator(QMainWindow):
|
||||
btn_undo10 = QPushButton("Undo×10")
|
||||
btn_redo = QPushButton("Redo")
|
||||
btn_load_prev_mask = QPushButton("Load Prev Mask")
|
||||
btn_auto_segment = QPushButton("Auto Segment")
|
||||
btn_auto_segment.setEnabled(self.of_cfg is not None and self.of_cfg.enabled)
|
||||
|
||||
row1 = QHBoxLayout()
|
||||
for b in [self.btn_prev, btn_next, btn_skip, btn_load_prev_mask]:
|
||||
for b in [
|
||||
self.btn_prev,
|
||||
btn_next,
|
||||
btn_skip,
|
||||
btn_load_prev_mask,
|
||||
btn_auto_segment,
|
||||
]:
|
||||
row1.addWidget(b)
|
||||
|
||||
row_tools = QHBoxLayout()
|
||||
@@ -198,6 +212,7 @@ class Annotator(QMainWindow):
|
||||
btn_undo10.clicked.connect(self.mc.undo10)
|
||||
btn_redo.clicked.connect(self.mc.redo)
|
||||
btn_load_prev_mask.clicked.connect(self.load_prev_mask)
|
||||
btn_auto_segment.clicked.connect(self.run_optical_flow)
|
||||
|
||||
if self._pending_answers:
|
||||
self._set_answers(self._pending_answers)
|
||||
@@ -406,3 +421,13 @@ class Annotator(QMainWindow):
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
self.mc.set_mask(mask)
|
||||
|
||||
def run_optical_flow(self):
|
||||
mask = compute_optical_flow_mask(
|
||||
self.frames,
|
||||
self.fps,
|
||||
self.of_cfg.norm_squared_threshold,
|
||||
self.of_cfg.gaussian_kernel,
|
||||
self.of_cfg.brightness_range,
|
||||
)
|
||||
self.mc.set_mask(mask)
|
||||
|
||||
49
src/river_annotation_tool/compute_optical_flow.py
Normal file
49
src/river_annotation_tool/compute_optical_flow.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_optical_flow_mask(
|
||||
frames: list[np.ndarray],
|
||||
fps: float,
|
||||
norm_squared_threshold: float,
|
||||
gaussian_kernel: tuple[int, int],
|
||||
brightness_range: tuple[int, int],
|
||||
) -> np.ndarray:
|
||||
"""Return a binary mask (uint8, values 0/1) from optical flow + brightness."""
|
||||
if len(frames) < 2:
|
||||
return np.zeros(frames[0].shape[:2], dtype=np.uint8)
|
||||
|
||||
frames_arr = np.stack(frames).astype(np.float64)
|
||||
frames_sub_mean = frames_arr - np.mean(frames_arr, axis=0)
|
||||
mn, mx = frames_sub_mean.min(), frames_sub_mean.max()
|
||||
if mx > mn:
|
||||
standardized = ((frames_sub_mean - mn) / (mx - mn) * 255).astype(np.uint8)
|
||||
else:
|
||||
standardized = np.zeros_like(frames_arr, dtype=np.uint8)
|
||||
|
||||
N = len(standardized)
|
||||
gray = np.stack([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in standardized])
|
||||
|
||||
flow_data = np.zeros((N - 1,) + gray.shape[1:] + (2,))
|
||||
for i in range(N - 1):
|
||||
flow_data[i] = fps * cv2.optflow.calcOpticalFlowSparseToDense(
|
||||
gray[i], gray[i + 1]
|
||||
)
|
||||
|
||||
optical_flow = np.median(flow_data, axis=0)
|
||||
|
||||
flow_norm_sq = np.sum(optical_flow**2, axis=-1)
|
||||
max_norm = np.max(flow_norm_sq)
|
||||
if max_norm > 0:
|
||||
flow_mask = flow_norm_sq >= max_norm * norm_squared_threshold**2
|
||||
else:
|
||||
flow_mask = np.zeros(flow_norm_sq.shape, dtype=bool)
|
||||
|
||||
reference_frame = frames[len(frames) // 2]
|
||||
smoothed = cv2.GaussianBlur(reference_frame, gaussian_kernel, 0)
|
||||
gray_ref = cv2.cvtColor(smoothed, cv2.COLOR_RGB2GRAY)
|
||||
brightness_mask = (gray_ref > brightness_range[0]) & (
|
||||
gray_ref < brightness_range[1]
|
||||
)
|
||||
|
||||
return np.logical_and(brightness_mask, flow_mask).astype(np.uint8)
|
||||
@@ -28,6 +28,7 @@ class AppConfig:
|
||||
data_dir: str = "data/clips"
|
||||
out_dir: str = "data/annotation_results"
|
||||
clips_file: str = "config/clips.txt"
|
||||
optical_flow_config_file: str = ""
|
||||
questions: list = field(default_factory=list)
|
||||
filenames: FilenameConfig = field(default_factory=FilenameConfig)
|
||||
|
||||
@@ -51,6 +52,22 @@ class AppConfig:
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpticalFlowConfig:
|
||||
enabled: bool = False
|
||||
norm_squared_threshold: float = 0.3
|
||||
gaussian_kernel: tuple[int, int] = (5, 5)
|
||||
brightness_range: tuple[int, int] = (20, 235)
|
||||
|
||||
|
||||
def load_optical_flow_config(path: Path) -> OpticalFlowConfig:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
data["gaussian_kernel"] = tuple(data["gaussian_kernel"])
|
||||
data["brightness_range"] = tuple(data["brightness_range"])
|
||||
return OpticalFlowConfig(**data)
|
||||
|
||||
|
||||
def load_config(path: Path) -> AppConfig:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
Reference in New Issue
Block a user