Files
clip-annotator/src/river_annotation_tool/annotator.py
asreva 67c9a1152c Add optical flow Auto Segment button
Annotators can now press Auto Segment to replace the current mask with
an automatic river segmentation based on dense optical flow magnitude
and frame brightness. The result is pushed onto the undo stack, so it
can be refined or reverted like any other mask operation.

Parameters (norm_squared_threshold, gaussian_kernel, brightness_range)
live in a separate config/optical_flow_config.yaml; the button is only
enabled when optical_flow_config_file is set in config.yaml.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 15:13:10 +02:00

434 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from PySide6.QtCore import Qt, QTimer
from PySide6.QtWidgets import (
QApplication,
QButtonGroup,
QGroupBox,
QHBoxLayout,
QLabel,
QMainWindow,
QMessageBox,
QPushButton,
QRadioButton,
QVBoxLayout,
QWidget,
)
from .clip_selector import ClipSelector
from .compute_optical_flow import compute_optical_flow_mask
from .config import AppConfig, load_optical_flow_config
from .mask_canvas import MaskCanvas
from .video_loader import load_frames
class Annotator(QMainWindow):
def __init__(
self,
config: AppConfig,
clip: str = None,
extras: bool = False,
skip_annotated: bool = True,
):
super().__init__()
self.cfg = config
self.out_dir = Path(config.out_dir)
self.extras = extras
self.of_cfg = (
load_optical_flow_config(Path(config.optical_flow_config_file))
if config.optical_flow_config_file
else None
)
self.selector = ClipSelector(
data_dir=Path(config.data_dir),
out_dir=self.out_dir,
clips_file=Path(config.clips_file),
mask_filename=config.filenames.mask,
zip_extension=config.filenames.zip_extension,
skip_annotated=skip_annotated,
)
self.history: list[Path] = []
self.history_pos: int = -1
self.setWindowTitle("River Annotator")
self._load_clip(specific=clip)
self._history_push()
self._init_ui()
self._init_timer()
# ── clip loading ───────────────────────────────────────────────
def _load_clip(self, specific: str = None, path: Path = None):
if path is not None:
self.filename = path
else:
self.filename = self.selector.next(specific=specific)
self.frames, self.fps, self.dh, self.dw, self.h, self.w = load_frames(
self.filename,
self.cfg.max_frames,
self.cfg.display_max,
self.cfg.fps_fallback,
self.cfg.filenames.video_in_zip,
self.cfg.filenames.video_tmp_suffix,
)
self._pending_answers = self._read_saved_answers()
def _history_push(self):
del self.history[self.history_pos + 1 :]
self.history.append(self.filename)
self.history_pos = len(self.history) - 1
def _read_saved_mask(self):
mask_path = self.out_dir / self.filename.stem / self.cfg.filenames.mask
if not mask_path.exists():
return None
mask_full = np.array(Image.open(mask_path).convert("L"))
return cv2.resize(
(mask_full > 127).astype(np.uint8),
(self.dw, self.dh),
interpolation=cv2.INTER_NEAREST,
)
def _read_saved_answers(self):
meta_path = self.out_dir / self.filename.stem / self.cfg.filenames.metadata
if not meta_path.exists():
return None
with open(meta_path) as f:
return json.load(f)
# ── UI setup ───────────────────────────────────────────────────
def _init_ui(self):
self.mc = MaskCanvas(self.frames, self.dh, self.dw)
self.mc.set_title(self.filename.name)
self.mc.reset(self._read_saved_mask())
self.q_widgets = {}
question_panel = self._build_question_panel()
self.btn_prev = QPushButton("Previous")
self.btn_prev.setEnabled(False)
btn_next = QPushButton("Next")
btn_skip = QPushButton("Skip")
btn_clear = QPushButton("Clear")
btn_undo = QPushButton("Undo")
btn_undo10 = QPushButton("Undo×10")
btn_redo = QPushButton("Redo")
btn_load_prev_mask = QPushButton("Load Prev Mask")
btn_auto_segment = QPushButton("Auto Segment")
btn_auto_segment.setEnabled(self.of_cfg is not None and self.of_cfg.enabled)
row1 = QHBoxLayout()
for b in [
self.btn_prev,
btn_next,
btn_skip,
btn_load_prev_mask,
btn_auto_segment,
]:
row1.addWidget(b)
row_tools = QHBoxLayout()
for b in [
self.mc.btn_brush,
self.mc.btn_polygon,
self.mc.btn_fill,
self.mc.btn_del_shape,
self.mc.btn_cancel_poly,
]:
row_tools.addWidget(b)
row2 = QHBoxLayout()
for b in [
btn_clear,
self.mc.btn_erase,
btn_undo,
btn_undo10,
btn_redo,
self.mc.btn_mask,
]:
row2.addWidget(b)
row3 = QHBoxLayout()
row3.addWidget(QLabel("Brush size"))
row3.addWidget(self.mc.brush_slider)
row3.addWidget(self.mc.brush_reset)
row4 = QHBoxLayout()
row4.addWidget(QLabel("Mask Alpha"))
row4.addWidget(self.mc.alpha_slider)
row4.addWidget(self.mc.alpha_reset)
vert_panel = QHBoxLayout()
vert_panel.setContentsMargins(0, 0, 4, 0)
for label_text, slider, reset_btn in [
("Brightness", self.mc.brightness_slider, self.mc.brightness_reset),
("Contrast", self.mc.contrast_slider, self.mc.contrast_reset),
("Gamma", self.mc.gamma_slider, self.mc.gamma_reset),
]:
col = QVBoxLayout()
lbl = QLabel(label_text)
lbl.setAlignment(Qt.AlignmentFlag.AlignHCenter)
col.addWidget(lbl)
col.addWidget(slider, 1)
col.addWidget(reset_btn)
vert_panel.addLayout(col)
canvas_row = QHBoxLayout()
canvas_row.addLayout(vert_panel)
canvas_row.addWidget(self.mc.canvas, 1)
left = QVBoxLayout()
left.addLayout(canvas_row)
left.addLayout(row1)
left.addLayout(row_tools)
left.addLayout(row2)
left.addLayout(row3)
left.addLayout(row4)
left_widget = QWidget()
left_widget.setLayout(left)
right_widget = QWidget()
right_widget.setLayout(question_panel)
main = QHBoxLayout()
main.addWidget(left_widget, 3)
main.addWidget(right_widget, 1)
container = QWidget()
container.setLayout(main)
self.setCentralWidget(container)
self.btn_prev.clicked.connect(self.prev_clip)
btn_next.clicked.connect(self.next_clip)
btn_skip.clicked.connect(self.skip_clip)
btn_clear.clicked.connect(self.mc.clear)
btn_undo.clicked.connect(self.mc.undo)
btn_undo10.clicked.connect(self.mc.undo10)
btn_redo.clicked.connect(self.mc.redo)
btn_load_prev_mask.clicked.connect(self.load_prev_mask)
btn_auto_segment.clicked.connect(self.run_optical_flow)
if self._pending_answers:
self._set_answers(self._pending_answers)
self._pending_answers = None
def _build_question_panel(self) -> QVBoxLayout:
vbox = QVBoxLayout()
for section, qs in self.cfg.get_questions():
group = QGroupBox(section)
gvbox = QVBoxLayout()
for key, label, options, default in qs:
gvbox.addWidget(QLabel(label))
btn_group = QButtonGroup(self)
row = QHBoxLayout()
buttons = []
for opt in options:
btn = QRadioButton(opt)
btn_group.addButton(btn)
row.addWidget(btn)
buttons.append(btn)
if default == opt:
btn.setChecked(True)
if default is None and buttons:
buttons[-1].setChecked(True)
self.q_widgets[key] = (btn_group, buttons, options)
gvbox.addLayout(row)
group.setLayout(gvbox)
vbox.addWidget(group)
return vbox
def _set_answers(self, answers: dict):
for key, value in answers.items():
if key not in self.q_widgets:
continue
_, buttons, options = self.q_widgets[key]
for i, btn in enumerate(buttons):
btn.setChecked(options[i] == value)
def _init_timer(self):
self.frame_i = 0
self.timer = QTimer()
self.timer.timeout.connect(self._tick)
self.timer.start(int(1000 / self.fps))
def _tick(self):
self.frame_i = (self.frame_i + 1) % len(self.frames)
self.mc.set_frame(self.frames[self.frame_i])
# ── answers ────────────────────────────────────────────────────
def get_answers(self) -> dict:
out = {}
for key, (_, buttons, options) in self.q_widgets.items():
for i, btn in enumerate(buttons):
if btn.isChecked():
out[key] = options[i]
return out
# ── save helpers ───────────────────────────────────────────────
def _make_overlay(self, frame, alpha=0.4):
overlay = frame.copy()
green = np.zeros_like(frame)
green[..., 1] = 255
m = self.mc.mask.astype(bool)
overlay[m] = (1 - alpha) * overlay[m] + alpha * green[m]
return overlay.astype(np.uint8)
def _save_gif(self, frames, out_path, scale=1.0):
h, w = frames[0].shape[:2]
nh, nw = max(1, int(h * scale)), max(1, int(w * scale))
pil_frames = [Image.fromarray(cv2.resize(f, (nw, nh))) for f in frames]
pil_frames[0].save(
out_path,
save_all=True,
append_images=pil_frames[1:],
duration=int(1000 / self.fps),
loop=0,
)
# ── actions ────────────────────────────────────────────────────
def save(self):
out = self.out_dir / self.filename.stem
out.mkdir(parents=True, exist_ok=True)
mask_full = cv2.resize(
self.mc.mask.astype(np.uint8),
(self.w, self.h),
interpolation=cv2.INTER_NEAREST,
)
fn = self.cfg.filenames
Image.fromarray(mask_full * 255).save(out / fn.mask)
with open(out / fn.metadata, "w") as f:
json.dump(self.get_answers(), f, indent=2)
mid = len(self.frames) // 2
frame = self.frames[mid]
Image.fromarray(frame).save(out / fn.frame)
Image.fromarray(self._make_overlay(frame)).save(out / fn.overlay)
if self.extras:
Image.fromarray((self.mc.mask * 255).astype(np.uint8)).save(
out / fn.mask_vis
)
overlay_frames = [self._make_overlay(f) for f in self.frames]
self._save_gif(self.frames, out / fn.gif_original_hires, scale=1.0)
self._save_gif(self.frames, out / fn.gif_original_lowres, scale=0.5)
self._save_gif(overlay_frames, out / fn.gif_overlay_hires, scale=1.0)
self._save_gif(overlay_frames, out / fn.gif_overlay_lowres, scale=0.5)
print("Saved:", out)
def _switch_ui_to_clip(self):
self.frame_i = 0
self.mc.load_clip(
self.frames,
self.dh,
self.dw,
mask=self._read_saved_mask(),
title=self.filename.name,
)
if self._pending_answers:
self._set_answers(self._pending_answers)
self._pending_answers = None
self.btn_prev.setEnabled(self.history_pos > 0)
def _advance_clip(self):
if self.history_pos < len(self.history) - 1:
self.history_pos += 1
self._load_clip(path=self.history[self.history_pos])
self._switch_ui_to_clip()
return
try:
self._load_clip()
except RuntimeError:
msg = QMessageBox(self)
msg.setWindowTitle("All done!")
msg.setText("You have reached the end of all clips.")
msg.setStandardButtons(QMessageBox.StandardButton.Ok)
msg.exec()
QApplication.instance().quit()
return
self._history_push()
self._switch_ui_to_clip()
def prev_clip(self):
if self.history_pos <= 0:
return
self.save()
self.history_pos -= 1
self._load_clip(path=self.history[self.history_pos])
self._switch_ui_to_clip()
def next_clip(self):
mask_path = self.out_dir / self.filename.stem / self.cfg.filenames.mask
if mask_path.exists():
msg = QMessageBox(self)
msg.setWindowTitle("Existing annotation found")
msg.setText(
f"'{self.filename.stem}' already has a saved annotation.\n"
"Replace it with your current work, or keep the existing save?"
)
btn_replace = msg.addButton(
"Replace & Continue", QMessageBox.ButtonRole.AcceptRole
)
btn_keep = msg.addButton(
"Keep Existing & Continue", QMessageBox.ButtonRole.AcceptRole
)
msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
msg.setDefaultButton(btn_replace)
msg.exec()
clicked = msg.clickedButton()
if clicked == btn_replace:
self.save()
self._advance_clip()
elif clicked == btn_keep:
self._advance_clip()
# Cancel: do nothing
else:
self.save()
self._advance_clip()
def skip_clip(self):
self._advance_clip()
def load_prev_mask(self):
try:
idx = self.selector.clips.index(self.filename)
except ValueError:
return
if idx == 0:
QMessageBox.information(
self, "No previous clip", "This is the first clip in the list."
)
return
prev_clip = self.selector.clips[idx - 1]
mask_path = self.out_dir / prev_clip.stem / self.cfg.filenames.mask
if not mask_path.exists():
QMessageBox.information(
self, "No mask found", f"No saved mask found for '{prev_clip.stem}'."
)
return
mask_full = np.array(Image.open(mask_path).convert("L"))
mask = cv2.resize(
(mask_full > 127).astype(np.uint8),
(self.dw, self.dh),
interpolation=cv2.INTER_NEAREST,
)
self.mc.set_mask(mask)
def run_optical_flow(self):
mask = compute_optical_flow_mask(
self.frames,
self.fps,
self.of_cfg.norm_squared_threshold,
self.of_cfg.gaussian_kernel,
self.of_cfg.brightness_range,
)
self.mc.set_mask(mask)