import io import json from pathlib import Path import cv2 import numpy as np from PIL import Image from PySide6.QtCore import Qt, QTimer from PySide6.QtWidgets import ( QApplication, QButtonGroup, QGroupBox, QHBoxLayout, QLabel, QMainWindow, QMessageBox, QPushButton, QRadioButton, QVBoxLayout, QWidget, ) from .clip_selector import ClipSelector from .compute_optical_flow import compute_optical_flow_mask from .config import AppConfig, load_optical_flow_config from .filesystem import fsjoin, fsname, fsstem from .mask_canvas import MaskCanvas from .video_loader import load_frames class Annotator(QMainWindow): def __init__( self, config: AppConfig, clip: str = None, extras: bool = False, skip_annotated: bool = True, fs=None, ): super().__init__() self.cfg = config self.fs = fs self.out_dir = config.out_dir self.extras = extras self.of_cfg = load_optical_flow_config(Path(config.optical_flow_config_file)) self.selector = ClipSelector( data_dir=config.data_dir, out_dir=self.out_dir, clips_file=Path(config.clips_file), mask_filename=config.filenames.mask, zip_extension=config.filenames.zip_extension, skip_annotated=skip_annotated, fs=fs, ) self.history: list[str] = [] self.history_pos: int = -1 self.setWindowTitle("Clip Annotator") self._load_clip(specific=clip) self._history_push() self._init_ui() self._init_timer() # ── filesystem helpers ───────────────────────────────────────── def _out_path(self, *parts: str) -> str: return fsjoin(self.out_dir, fsstem(self.filename), *parts) def _fs_exists(self, path: str) -> bool: if self.fs is None: return Path(path).exists() return self.fs.exists(path) def _fs_makedirs(self, path: str): if self.fs is None: Path(path).mkdir(parents=True, exist_ok=True) else: self.fs.makedirs(path, exist_ok=True) def _pil_open(self, path: str) -> Image.Image: if self.fs is None: return Image.open(path) with self.fs.open(path, "rb") as f: return Image.open(io.BytesIO(f.read())) def _pil_save(self, img: Image.Image, path: str): if self.fs is None: img.save(path) else: ext = str(path).rsplit(".", 1)[-1].lower() fmt = "JPEG" if ext in ("jpg", "jpeg") else ext.upper() buf = io.BytesIO() img.save(buf, format=fmt) self.fs.pipe(path, buf.getvalue()) def _json_read(self, path: str): if self.fs is None: with open(path) as f: return json.load(f) with self.fs.open(path, "r") as f: return json.load(f) def _json_write(self, data, path: str): if self.fs is None: with open(path, "w") as f: json.dump(data, f, indent=2) else: self.fs.pipe(path, json.dumps(data, indent=2).encode()) # ── clip loading ─────────────────────────────────────────────── def _load_clip(self, specific: str = None, path: str = None): if path is not None: self.filename = path else: self.filename = self.selector.next(specific=specific) self.frames, self.fps, self.dh, self.dw, self.h, self.w = load_frames( self.filename, self.cfg.max_frames, self.cfg.display_max, self.cfg.fps_fallback, self.cfg.filenames.video_in_zip, self.cfg.filenames.video_tmp_suffix, fs=self.fs, ) self._pending_answers = self._read_saved_answers() def _history_push(self): del self.history[self.history_pos + 1 :] self.history.append(self.filename) self.history_pos = len(self.history) - 1 def _read_saved_mask(self): mask_path = self._out_path(self.cfg.filenames.mask) if not self._fs_exists(mask_path): return None mask_full = np.array(self._pil_open(mask_path).convert("L")) return cv2.resize( (mask_full > 127).astype(np.uint8), (self.dw, self.dh), interpolation=cv2.INTER_NEAREST, ) def _read_saved_answers(self): meta_path = self._out_path(self.cfg.filenames.metadata) if not self._fs_exists(meta_path): return None return self._json_read(meta_path) # ── UI setup ─────────────────────────────────────────────────── def _init_ui(self): self.mc = MaskCanvas(self.frames, self.dh, self.dw) self.mc.set_title(fsname(self.filename)) self.mc.reset(self._read_saved_mask()) self.q_widgets = {} question_panel = self._build_question_panel() self.btn_prev = QPushButton("Previous") self.btn_prev.setEnabled(False) self.btn_next = QPushButton("Next") btn_skip = QPushButton("Skip") btn_clear = QPushButton("Clear") btn_undo = QPushButton("Undo") btn_undo10 = QPushButton("Undo×10") btn_redo = QPushButton("Redo") btn_load_prev_mask = QPushButton("Load Prev Mask") btn_auto_segment = QPushButton("Auto Segment") btn_auto_segment.setEnabled(self.of_cfg.enabled) row1 = QHBoxLayout() for b in [ self.btn_prev, self.btn_next, btn_skip, btn_load_prev_mask, btn_auto_segment, ]: row1.addWidget(b) row_tools = QHBoxLayout() for b in [ self.mc.btn_brush, self.mc.btn_polygon, self.mc.btn_fill, self.mc.btn_del_shape, self.mc.btn_cancel_poly, ]: row_tools.addWidget(b) row2 = QHBoxLayout() for b in [ btn_clear, self.mc.btn_erase, btn_undo, btn_undo10, btn_redo, self.mc.btn_mask, ]: row2.addWidget(b) row3 = QHBoxLayout() row3.addWidget(QLabel("Brush size")) row3.addWidget(self.mc.brush_slider) row3.addWidget(self.mc.brush_reset) row4 = QHBoxLayout() row4.addWidget(QLabel("Mask Alpha")) row4.addWidget(self.mc.alpha_slider) row4.addWidget(self.mc.alpha_reset) vert_panel = QHBoxLayout() vert_panel.setContentsMargins(0, 0, 4, 0) for label_text, slider, reset_btn in [ ("B", self.mc.brightness_slider, self.mc.brightness_reset), ("C", self.mc.contrast_slider, self.mc.contrast_reset), ("G", self.mc.gamma_slider, self.mc.gamma_reset), ]: col = QVBoxLayout() lbl = QLabel(label_text) lbl.setAlignment(Qt.AlignmentFlag.AlignHCenter) col.addWidget(lbl) col.addWidget(slider, 1) col.addWidget(reset_btn) vert_panel.addLayout(col) canvas_row = QHBoxLayout() canvas_row.addLayout(vert_panel) canvas_row.addWidget(self.mc.canvas, 1) left = QVBoxLayout() left.addLayout(canvas_row) left.addLayout(row1) left.addLayout(row_tools) left.addLayout(row2) left.addLayout(row3) left.addLayout(row4) left_widget = QWidget() left_widget.setLayout(left) right_widget = QWidget() right_widget.setLayout(question_panel) main = QHBoxLayout() right_widget.setMaximumWidth(160) main.addWidget(left_widget, 1) main.addWidget(right_widget, 0) container = QWidget() container.setLayout(main) self.setCentralWidget(container) self.btn_prev.clicked.connect(self.prev_clip) self.btn_next.clicked.connect(self.next_clip) btn_skip.clicked.connect(self.skip_clip) btn_clear.clicked.connect(self.mc.clear) btn_undo.clicked.connect(self.mc.undo) btn_undo10.clicked.connect(self.mc.undo10) btn_redo.clicked.connect(self.mc.redo) btn_load_prev_mask.clicked.connect(self.load_prev_mask) btn_auto_segment.clicked.connect(self.run_optical_flow) if self._pending_answers: self._set_answers(self._pending_answers) self._pending_answers = None def _build_question_panel(self) -> QVBoxLayout: vbox = QVBoxLayout() for section, qs in self.cfg.get_questions(): group = QGroupBox(section) gvbox = QVBoxLayout() for key, label, options, default in qs: gvbox.addWidget(QLabel(label)) btn_group = QButtonGroup(self) col = QVBoxLayout() buttons = [] for opt in options: btn = QRadioButton(opt) btn_group.addButton(btn) col.addWidget(btn) buttons.append(btn) if default == opt: btn.setChecked(True) if default is None and buttons: buttons[-1].setChecked(True) self.q_widgets[key] = (btn_group, buttons, options) gvbox.addLayout(col) group.setLayout(gvbox) vbox.addWidget(group) return vbox def _set_answers(self, answers: dict): for key, value in answers.items(): if key not in self.q_widgets: continue _, buttons, options = self.q_widgets[key] for i, btn in enumerate(buttons): btn.setChecked(options[i] == value) def _init_timer(self): self.frame_i = 0 self.timer = QTimer() self.timer.timeout.connect(self._tick) self.timer.start(int(1000 / self.fps)) def _tick(self): self.frame_i = (self.frame_i + 1) % len(self.frames) self.mc.set_frame(self.frames[self.frame_i]) # ── answers ──────────────────────────────────────────────────── def get_answers(self) -> dict: out = {} for key, (_, buttons, options) in self.q_widgets.items(): for i, btn in enumerate(buttons): if btn.isChecked(): out[key] = options[i] return out # ── save helpers ─────────────────────────────────────────────── def _make_overlay(self, frame, alpha=0.4): overlay = frame.copy() green = np.zeros_like(frame) green[..., 1] = 255 m = self.mc.mask.astype(bool) overlay[m] = (1 - alpha) * overlay[m] + alpha * green[m] return overlay.astype(np.uint8) def _save_gif(self, frames, out_path: str, scale=1.0): h, w = frames[0].shape[:2] nh, nw = max(1, int(h * scale)), max(1, int(w * scale)) pil_frames = [Image.fromarray(cv2.resize(f, (nw, nh))) for f in frames] if self.fs is None: pil_frames[0].save( out_path, save_all=True, append_images=pil_frames[1:], duration=int(1000 / self.fps), loop=0, ) else: buf = io.BytesIO() pil_frames[0].save( buf, format="GIF", save_all=True, append_images=pil_frames[1:], duration=int(1000 / self.fps), loop=0, ) self.fs.pipe(out_path, buf.getvalue()) # ── actions ──────────────────────────────────────────────────── def _save_locked(self): self.btn_next.setEnabled(False) self.btn_prev.setEnabled(False) QApplication.processEvents() try: self.save() finally: self.btn_next.setEnabled(True) self.btn_prev.setEnabled(self.history_pos > 0) def save(self): out = fsjoin(self.out_dir, fsstem(self.filename)) self._fs_makedirs(out) mask_full = cv2.resize( self.mc.mask.astype(np.uint8), (self.w, self.h), interpolation=cv2.INTER_NEAREST, ) fn = self.cfg.filenames self._pil_save(Image.fromarray(mask_full * 255), fsjoin(out, fn.mask)) self._json_write(self.get_answers(), fsjoin(out, fn.metadata)) mid = len(self.frames) // 2 frame = self.frames[mid] self._pil_save(Image.fromarray(frame), fsjoin(out, fn.frame)) self._pil_save( Image.fromarray(self._make_overlay(frame)), fsjoin(out, fn.overlay) ) if self.extras: self._pil_save( Image.fromarray((self.mc.mask * 255).astype(np.uint8)), fsjoin(out, fn.mask_vis), ) overlay_frames = [self._make_overlay(f) for f in self.frames] self._save_gif(self.frames, fsjoin(out, fn.gif_original_hires), scale=1.0) self._save_gif(self.frames, fsjoin(out, fn.gif_original_lowres), scale=0.5) self._save_gif(overlay_frames, fsjoin(out, fn.gif_overlay_hires), scale=1.0) self._save_gif( overlay_frames, fsjoin(out, fn.gif_overlay_lowres), scale=0.5 ) print("Saved:", out) def _switch_ui_to_clip(self): self.frame_i = 0 self.mc.load_clip( self.frames, self.dh, self.dw, mask=self._read_saved_mask(), title=fsname(self.filename), ) if self._pending_answers: self._set_answers(self._pending_answers) self._pending_answers = None self.btn_prev.setEnabled(self.history_pos > 0) def _advance_clip(self): if self.history_pos < len(self.history) - 1: self.history_pos += 1 self._load_clip(path=self.history[self.history_pos]) self._switch_ui_to_clip() return try: self._load_clip() except RuntimeError: msg = QMessageBox(self) msg.setWindowTitle("All done!") msg.setText("You have reached the end of all clips.") msg.setStandardButtons(QMessageBox.StandardButton.Ok) msg.exec() QApplication.instance().quit() return self._history_push() self._switch_ui_to_clip() def prev_clip(self): if self.history_pos <= 0: return self._save_locked() self.history_pos -= 1 self._load_clip(path=self.history[self.history_pos]) self._switch_ui_to_clip() def next_clip(self): mask_path = self._out_path(self.cfg.filenames.mask) if self._fs_exists(mask_path): msg = QMessageBox(self) msg.setWindowTitle("Existing annotation found") msg.setText( f"'{fsstem(self.filename)}' already has a saved annotation.\n" "Replace it with your current work, or keep the existing save?" ) btn_replace = msg.addButton( "Replace & Continue", QMessageBox.ButtonRole.AcceptRole ) btn_keep = msg.addButton( "Keep Existing & Continue", QMessageBox.ButtonRole.AcceptRole ) msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole) msg.setDefaultButton(btn_replace) msg.exec() clicked = msg.clickedButton() if clicked == btn_replace: self._save_locked() self._advance_clip() elif clicked == btn_keep: self._advance_clip() # Cancel: do nothing else: self._save_locked() self._advance_clip() def skip_clip(self): self._advance_clip() def load_prev_mask(self): try: idx = self.selector.clips.index(self.filename) except ValueError: return if idx == 0: QMessageBox.information( self, "No previous clip", "This is the first clip in the list." ) return prev_clip = self.selector.clips[idx - 1] mask_path = fsjoin(self.out_dir, fsstem(prev_clip), self.cfg.filenames.mask) if not self._fs_exists(mask_path): QMessageBox.information( self, "No mask found", f"No saved mask found for '{fsstem(prev_clip)}'.", ) return mask_full = np.array(self._pil_open(mask_path).convert("L")) mask = cv2.resize( (mask_full > 127).astype(np.uint8), (self.dw, self.dh), interpolation=cv2.INTER_NEAREST, ) self.mc.set_mask(mask) def run_optical_flow(self): mask = compute_optical_flow_mask( self.frames, self.fps, self.of_cfg.norm_squared_threshold, self.of_cfg.gaussian_kernel, self.of_cfg.brightness_range, ) self.mc.set_mask(mask)