Annotators can now press Auto Segment to replace the current mask with an automatic river segmentation based on dense optical flow magnitude and frame brightness. The result is pushed onto the undo stack, so it can be refined or reverted like any other mask operation. Parameters (norm_squared_threshold, gaussian_kernel, brightness_range) live in a separate config/optical_flow_config.yaml; the button is only enabled when optical_flow_config_file is set in config.yaml. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
434 lines
15 KiB
Python
434 lines
15 KiB
Python
import json
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from PIL import Image
|
||
from PySide6.QtCore import Qt, QTimer
|
||
from PySide6.QtWidgets import (
|
||
QApplication,
|
||
QButtonGroup,
|
||
QGroupBox,
|
||
QHBoxLayout,
|
||
QLabel,
|
||
QMainWindow,
|
||
QMessageBox,
|
||
QPushButton,
|
||
QRadioButton,
|
||
QVBoxLayout,
|
||
QWidget,
|
||
)
|
||
|
||
from .clip_selector import ClipSelector
|
||
from .compute_optical_flow import compute_optical_flow_mask
|
||
from .config import AppConfig, load_optical_flow_config
|
||
from .mask_canvas import MaskCanvas
|
||
from .video_loader import load_frames
|
||
|
||
|
||
class Annotator(QMainWindow):
|
||
def __init__(
|
||
self,
|
||
config: AppConfig,
|
||
clip: str = None,
|
||
extras: bool = False,
|
||
skip_annotated: bool = True,
|
||
):
|
||
super().__init__()
|
||
|
||
self.cfg = config
|
||
self.out_dir = Path(config.out_dir)
|
||
self.extras = extras
|
||
self.of_cfg = (
|
||
load_optical_flow_config(Path(config.optical_flow_config_file))
|
||
if config.optical_flow_config_file
|
||
else None
|
||
)
|
||
|
||
self.selector = ClipSelector(
|
||
data_dir=Path(config.data_dir),
|
||
out_dir=self.out_dir,
|
||
clips_file=Path(config.clips_file),
|
||
mask_filename=config.filenames.mask,
|
||
zip_extension=config.filenames.zip_extension,
|
||
skip_annotated=skip_annotated,
|
||
)
|
||
|
||
self.history: list[Path] = []
|
||
self.history_pos: int = -1
|
||
|
||
self.setWindowTitle("River Annotator")
|
||
self._load_clip(specific=clip)
|
||
self._history_push()
|
||
self._init_ui()
|
||
self._init_timer()
|
||
|
||
# ── clip loading ───────────────────────────────────────────────
|
||
def _load_clip(self, specific: str = None, path: Path = None):
|
||
if path is not None:
|
||
self.filename = path
|
||
else:
|
||
self.filename = self.selector.next(specific=specific)
|
||
self.frames, self.fps, self.dh, self.dw, self.h, self.w = load_frames(
|
||
self.filename,
|
||
self.cfg.max_frames,
|
||
self.cfg.display_max,
|
||
self.cfg.fps_fallback,
|
||
self.cfg.filenames.video_in_zip,
|
||
self.cfg.filenames.video_tmp_suffix,
|
||
)
|
||
self._pending_answers = self._read_saved_answers()
|
||
|
||
def _history_push(self):
|
||
del self.history[self.history_pos + 1 :]
|
||
self.history.append(self.filename)
|
||
self.history_pos = len(self.history) - 1
|
||
|
||
def _read_saved_mask(self):
|
||
mask_path = self.out_dir / self.filename.stem / self.cfg.filenames.mask
|
||
if not mask_path.exists():
|
||
return None
|
||
mask_full = np.array(Image.open(mask_path).convert("L"))
|
||
return cv2.resize(
|
||
(mask_full > 127).astype(np.uint8),
|
||
(self.dw, self.dh),
|
||
interpolation=cv2.INTER_NEAREST,
|
||
)
|
||
|
||
def _read_saved_answers(self):
|
||
meta_path = self.out_dir / self.filename.stem / self.cfg.filenames.metadata
|
||
if not meta_path.exists():
|
||
return None
|
||
with open(meta_path) as f:
|
||
return json.load(f)
|
||
|
||
# ── UI setup ───────────────────────────────────────────────────
|
||
def _init_ui(self):
|
||
self.mc = MaskCanvas(self.frames, self.dh, self.dw)
|
||
self.mc.set_title(self.filename.name)
|
||
self.mc.reset(self._read_saved_mask())
|
||
|
||
self.q_widgets = {}
|
||
question_panel = self._build_question_panel()
|
||
|
||
self.btn_prev = QPushButton("Previous")
|
||
self.btn_prev.setEnabled(False)
|
||
btn_next = QPushButton("Next")
|
||
btn_skip = QPushButton("Skip")
|
||
btn_clear = QPushButton("Clear")
|
||
btn_undo = QPushButton("Undo")
|
||
btn_undo10 = QPushButton("Undo×10")
|
||
btn_redo = QPushButton("Redo")
|
||
btn_load_prev_mask = QPushButton("Load Prev Mask")
|
||
btn_auto_segment = QPushButton("Auto Segment")
|
||
btn_auto_segment.setEnabled(self.of_cfg is not None and self.of_cfg.enabled)
|
||
|
||
row1 = QHBoxLayout()
|
||
for b in [
|
||
self.btn_prev,
|
||
btn_next,
|
||
btn_skip,
|
||
btn_load_prev_mask,
|
||
btn_auto_segment,
|
||
]:
|
||
row1.addWidget(b)
|
||
|
||
row_tools = QHBoxLayout()
|
||
for b in [
|
||
self.mc.btn_brush,
|
||
self.mc.btn_polygon,
|
||
self.mc.btn_fill,
|
||
self.mc.btn_del_shape,
|
||
self.mc.btn_cancel_poly,
|
||
]:
|
||
row_tools.addWidget(b)
|
||
|
||
row2 = QHBoxLayout()
|
||
for b in [
|
||
btn_clear,
|
||
self.mc.btn_erase,
|
||
btn_undo,
|
||
btn_undo10,
|
||
btn_redo,
|
||
self.mc.btn_mask,
|
||
]:
|
||
row2.addWidget(b)
|
||
|
||
row3 = QHBoxLayout()
|
||
row3.addWidget(QLabel("Brush size"))
|
||
row3.addWidget(self.mc.brush_slider)
|
||
row3.addWidget(self.mc.brush_reset)
|
||
|
||
row4 = QHBoxLayout()
|
||
row4.addWidget(QLabel("Mask Alpha"))
|
||
row4.addWidget(self.mc.alpha_slider)
|
||
row4.addWidget(self.mc.alpha_reset)
|
||
|
||
vert_panel = QHBoxLayout()
|
||
vert_panel.setContentsMargins(0, 0, 4, 0)
|
||
for label_text, slider, reset_btn in [
|
||
("Brightness", self.mc.brightness_slider, self.mc.brightness_reset),
|
||
("Contrast", self.mc.contrast_slider, self.mc.contrast_reset),
|
||
("Gamma", self.mc.gamma_slider, self.mc.gamma_reset),
|
||
]:
|
||
col = QVBoxLayout()
|
||
lbl = QLabel(label_text)
|
||
lbl.setAlignment(Qt.AlignmentFlag.AlignHCenter)
|
||
col.addWidget(lbl)
|
||
col.addWidget(slider, 1)
|
||
col.addWidget(reset_btn)
|
||
vert_panel.addLayout(col)
|
||
|
||
canvas_row = QHBoxLayout()
|
||
canvas_row.addLayout(vert_panel)
|
||
canvas_row.addWidget(self.mc.canvas, 1)
|
||
|
||
left = QVBoxLayout()
|
||
left.addLayout(canvas_row)
|
||
left.addLayout(row1)
|
||
left.addLayout(row_tools)
|
||
left.addLayout(row2)
|
||
left.addLayout(row3)
|
||
left.addLayout(row4)
|
||
|
||
left_widget = QWidget()
|
||
left_widget.setLayout(left)
|
||
right_widget = QWidget()
|
||
right_widget.setLayout(question_panel)
|
||
|
||
main = QHBoxLayout()
|
||
main.addWidget(left_widget, 3)
|
||
main.addWidget(right_widget, 1)
|
||
|
||
container = QWidget()
|
||
container.setLayout(main)
|
||
self.setCentralWidget(container)
|
||
|
||
self.btn_prev.clicked.connect(self.prev_clip)
|
||
btn_next.clicked.connect(self.next_clip)
|
||
btn_skip.clicked.connect(self.skip_clip)
|
||
btn_clear.clicked.connect(self.mc.clear)
|
||
btn_undo.clicked.connect(self.mc.undo)
|
||
btn_undo10.clicked.connect(self.mc.undo10)
|
||
btn_redo.clicked.connect(self.mc.redo)
|
||
btn_load_prev_mask.clicked.connect(self.load_prev_mask)
|
||
btn_auto_segment.clicked.connect(self.run_optical_flow)
|
||
|
||
if self._pending_answers:
|
||
self._set_answers(self._pending_answers)
|
||
self._pending_answers = None
|
||
|
||
def _build_question_panel(self) -> QVBoxLayout:
|
||
vbox = QVBoxLayout()
|
||
for section, qs in self.cfg.get_questions():
|
||
group = QGroupBox(section)
|
||
gvbox = QVBoxLayout()
|
||
for key, label, options, default in qs:
|
||
gvbox.addWidget(QLabel(label))
|
||
btn_group = QButtonGroup(self)
|
||
row = QHBoxLayout()
|
||
buttons = []
|
||
for opt in options:
|
||
btn = QRadioButton(opt)
|
||
btn_group.addButton(btn)
|
||
row.addWidget(btn)
|
||
buttons.append(btn)
|
||
if default == opt:
|
||
btn.setChecked(True)
|
||
if default is None and buttons:
|
||
buttons[-1].setChecked(True)
|
||
self.q_widgets[key] = (btn_group, buttons, options)
|
||
gvbox.addLayout(row)
|
||
group.setLayout(gvbox)
|
||
vbox.addWidget(group)
|
||
return vbox
|
||
|
||
def _set_answers(self, answers: dict):
|
||
for key, value in answers.items():
|
||
if key not in self.q_widgets:
|
||
continue
|
||
_, buttons, options = self.q_widgets[key]
|
||
for i, btn in enumerate(buttons):
|
||
btn.setChecked(options[i] == value)
|
||
|
||
def _init_timer(self):
|
||
self.frame_i = 0
|
||
self.timer = QTimer()
|
||
self.timer.timeout.connect(self._tick)
|
||
self.timer.start(int(1000 / self.fps))
|
||
|
||
def _tick(self):
|
||
self.frame_i = (self.frame_i + 1) % len(self.frames)
|
||
self.mc.set_frame(self.frames[self.frame_i])
|
||
|
||
# ── answers ────────────────────────────────────────────────────
|
||
def get_answers(self) -> dict:
|
||
out = {}
|
||
for key, (_, buttons, options) in self.q_widgets.items():
|
||
for i, btn in enumerate(buttons):
|
||
if btn.isChecked():
|
||
out[key] = options[i]
|
||
return out
|
||
|
||
# ── save helpers ───────────────────────────────────────────────
|
||
def _make_overlay(self, frame, alpha=0.4):
|
||
overlay = frame.copy()
|
||
green = np.zeros_like(frame)
|
||
green[..., 1] = 255
|
||
m = self.mc.mask.astype(bool)
|
||
overlay[m] = (1 - alpha) * overlay[m] + alpha * green[m]
|
||
return overlay.astype(np.uint8)
|
||
|
||
def _save_gif(self, frames, out_path, scale=1.0):
|
||
h, w = frames[0].shape[:2]
|
||
nh, nw = max(1, int(h * scale)), max(1, int(w * scale))
|
||
pil_frames = [Image.fromarray(cv2.resize(f, (nw, nh))) for f in frames]
|
||
pil_frames[0].save(
|
||
out_path,
|
||
save_all=True,
|
||
append_images=pil_frames[1:],
|
||
duration=int(1000 / self.fps),
|
||
loop=0,
|
||
)
|
||
|
||
# ── actions ────────────────────────────────────────────────────
|
||
def save(self):
|
||
out = self.out_dir / self.filename.stem
|
||
out.mkdir(parents=True, exist_ok=True)
|
||
|
||
mask_full = cv2.resize(
|
||
self.mc.mask.astype(np.uint8),
|
||
(self.w, self.h),
|
||
interpolation=cv2.INTER_NEAREST,
|
||
)
|
||
fn = self.cfg.filenames
|
||
Image.fromarray(mask_full * 255).save(out / fn.mask)
|
||
|
||
with open(out / fn.metadata, "w") as f:
|
||
json.dump(self.get_answers(), f, indent=2)
|
||
|
||
mid = len(self.frames) // 2
|
||
frame = self.frames[mid]
|
||
Image.fromarray(frame).save(out / fn.frame)
|
||
Image.fromarray(self._make_overlay(frame)).save(out / fn.overlay)
|
||
|
||
if self.extras:
|
||
Image.fromarray((self.mc.mask * 255).astype(np.uint8)).save(
|
||
out / fn.mask_vis
|
||
)
|
||
overlay_frames = [self._make_overlay(f) for f in self.frames]
|
||
self._save_gif(self.frames, out / fn.gif_original_hires, scale=1.0)
|
||
self._save_gif(self.frames, out / fn.gif_original_lowres, scale=0.5)
|
||
self._save_gif(overlay_frames, out / fn.gif_overlay_hires, scale=1.0)
|
||
self._save_gif(overlay_frames, out / fn.gif_overlay_lowres, scale=0.5)
|
||
|
||
print("Saved:", out)
|
||
|
||
def _switch_ui_to_clip(self):
|
||
self.frame_i = 0
|
||
self.mc.load_clip(
|
||
self.frames,
|
||
self.dh,
|
||
self.dw,
|
||
mask=self._read_saved_mask(),
|
||
title=self.filename.name,
|
||
)
|
||
if self._pending_answers:
|
||
self._set_answers(self._pending_answers)
|
||
self._pending_answers = None
|
||
self.btn_prev.setEnabled(self.history_pos > 0)
|
||
|
||
def _advance_clip(self):
|
||
if self.history_pos < len(self.history) - 1:
|
||
self.history_pos += 1
|
||
self._load_clip(path=self.history[self.history_pos])
|
||
self._switch_ui_to_clip()
|
||
return
|
||
try:
|
||
self._load_clip()
|
||
except RuntimeError:
|
||
msg = QMessageBox(self)
|
||
msg.setWindowTitle("All done!")
|
||
msg.setText("You have reached the end of all clips.")
|
||
msg.setStandardButtons(QMessageBox.StandardButton.Ok)
|
||
msg.exec()
|
||
QApplication.instance().quit()
|
||
return
|
||
self._history_push()
|
||
self._switch_ui_to_clip()
|
||
|
||
def prev_clip(self):
|
||
if self.history_pos <= 0:
|
||
return
|
||
self.save()
|
||
self.history_pos -= 1
|
||
self._load_clip(path=self.history[self.history_pos])
|
||
self._switch_ui_to_clip()
|
||
|
||
def next_clip(self):
|
||
mask_path = self.out_dir / self.filename.stem / self.cfg.filenames.mask
|
||
if mask_path.exists():
|
||
msg = QMessageBox(self)
|
||
msg.setWindowTitle("Existing annotation found")
|
||
msg.setText(
|
||
f"'{self.filename.stem}' already has a saved annotation.\n"
|
||
"Replace it with your current work, or keep the existing save?"
|
||
)
|
||
btn_replace = msg.addButton(
|
||
"Replace & Continue", QMessageBox.ButtonRole.AcceptRole
|
||
)
|
||
btn_keep = msg.addButton(
|
||
"Keep Existing & Continue", QMessageBox.ButtonRole.AcceptRole
|
||
)
|
||
msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
|
||
msg.setDefaultButton(btn_replace)
|
||
msg.exec()
|
||
clicked = msg.clickedButton()
|
||
if clicked == btn_replace:
|
||
self.save()
|
||
self._advance_clip()
|
||
elif clicked == btn_keep:
|
||
self._advance_clip()
|
||
# Cancel: do nothing
|
||
else:
|
||
self.save()
|
||
self._advance_clip()
|
||
|
||
def skip_clip(self):
|
||
self._advance_clip()
|
||
|
||
def load_prev_mask(self):
|
||
try:
|
||
idx = self.selector.clips.index(self.filename)
|
||
except ValueError:
|
||
return
|
||
if idx == 0:
|
||
QMessageBox.information(
|
||
self, "No previous clip", "This is the first clip in the list."
|
||
)
|
||
return
|
||
prev_clip = self.selector.clips[idx - 1]
|
||
mask_path = self.out_dir / prev_clip.stem / self.cfg.filenames.mask
|
||
if not mask_path.exists():
|
||
QMessageBox.information(
|
||
self, "No mask found", f"No saved mask found for '{prev_clip.stem}'."
|
||
)
|
||
return
|
||
mask_full = np.array(Image.open(mask_path).convert("L"))
|
||
mask = cv2.resize(
|
||
(mask_full > 127).astype(np.uint8),
|
||
(self.dw, self.dh),
|
||
interpolation=cv2.INTER_NEAREST,
|
||
)
|
||
self.mc.set_mask(mask)
|
||
|
||
def run_optical_flow(self):
|
||
mask = compute_optical_flow_mask(
|
||
self.frames,
|
||
self.fps,
|
||
self.of_cfg.norm_squared_threshold,
|
||
self.of_cfg.gaussian_kernel,
|
||
self.of_cfg.brightness_range,
|
||
)
|
||
self.mc.set_mask(mask)
|