Files
clip-annotator/src/clip_annotator/annotator.py

507 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import io
import json
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from PySide6.QtCore import Qt, QTimer
from PySide6.QtWidgets import (
QApplication,
QButtonGroup,
QGroupBox,
QHBoxLayout,
QLabel,
QMainWindow,
QMessageBox,
QPushButton,
QRadioButton,
QVBoxLayout,
QWidget,
)
from .clip_selector import ClipSelector
from .compute_optical_flow import compute_optical_flow_mask
from .config import AppConfig, load_optical_flow_config
from .filesystem import fsjoin, fsname, fsstem
from .mask_canvas import MaskCanvas
from .video_loader import load_frames
class Annotator(QMainWindow):
def __init__(
self,
config: AppConfig,
clip: str = None,
extras: bool = False,
skip_annotated: bool = True,
fs=None,
):
super().__init__()
self.cfg = config
self.fs = fs
self.out_dir = config.out_dir
self.extras = extras
self.of_cfg = load_optical_flow_config(Path(config.optical_flow_config_file))
self.selector = ClipSelector(
data_dir=config.data_dir,
out_dir=self.out_dir,
clips_file=Path(config.clips_file),
mask_filename=config.filenames.mask,
zip_extension=config.filenames.zip_extension,
skip_annotated=skip_annotated,
fs=fs,
)
self.history: list[str] = []
self.history_pos: int = -1
self.setWindowTitle("Clip Annotator")
self._load_clip(specific=clip)
self._history_push()
self._init_ui()
self._init_timer()
# ── filesystem helpers ─────────────────────────────────────────
def _out_path(self, *parts: str) -> str:
return fsjoin(self.out_dir, fsstem(self.filename), *parts)
def _fs_exists(self, path: str) -> bool:
if self.fs is None:
return Path(path).exists()
return self.fs.exists(path)
def _fs_makedirs(self, path: str):
if self.fs is None:
Path(path).mkdir(parents=True, exist_ok=True)
else:
self.fs.makedirs(path, exist_ok=True)
def _pil_open(self, path: str) -> Image.Image:
if self.fs is None:
return Image.open(path)
with self.fs.open(path, "rb") as f:
return Image.open(io.BytesIO(f.read()))
def _pil_save(self, img: Image.Image, path: str):
if self.fs is None:
img.save(path)
else:
ext = str(path).rsplit(".", 1)[-1].lower()
fmt = "JPEG" if ext in ("jpg", "jpeg") else ext.upper()
buf = io.BytesIO()
img.save(buf, format=fmt)
self.fs.pipe(path, buf.getvalue())
def _json_read(self, path: str):
if self.fs is None:
with open(path) as f:
return json.load(f)
with self.fs.open(path, "r") as f:
return json.load(f)
def _json_write(self, data, path: str):
if self.fs is None:
with open(path, "w") as f:
json.dump(data, f, indent=2)
else:
self.fs.pipe(path, json.dumps(data, indent=2).encode())
# ── clip loading ───────────────────────────────────────────────
def _load_clip(self, specific: str = None, path: str = None):
if path is not None:
self.filename = path
else:
self.filename = self.selector.next(specific=specific)
self.frames, self.fps, self.dh, self.dw, self.h, self.w = load_frames(
self.filename,
self.cfg.max_frames,
self.cfg.display_max,
self.cfg.fps_fallback,
self.cfg.filenames.video_in_zip,
self.cfg.filenames.video_tmp_suffix,
fs=self.fs,
)
self._pending_answers = self._read_saved_answers()
def _history_push(self):
del self.history[self.history_pos + 1 :]
self.history.append(self.filename)
self.history_pos = len(self.history) - 1
def _read_saved_mask(self):
mask_path = self._out_path(self.cfg.filenames.mask)
if not self._fs_exists(mask_path):
return None
mask_full = np.array(self._pil_open(mask_path).convert("L"))
return cv2.resize(
(mask_full > 127).astype(np.uint8),
(self.dw, self.dh),
interpolation=cv2.INTER_NEAREST,
)
def _read_saved_answers(self):
meta_path = self._out_path(self.cfg.filenames.metadata)
if not self._fs_exists(meta_path):
return None
return self._json_read(meta_path)
# ── UI setup ───────────────────────────────────────────────────
def _init_ui(self):
self.mc = MaskCanvas(self.frames, self.dh, self.dw)
self.mc.set_title(fsname(self.filename))
self.mc.reset(self._read_saved_mask())
self.q_widgets = {}
question_panel = self._build_question_panel()
self.btn_prev = QPushButton("Previous")
self.btn_prev.setEnabled(False)
self.btn_next = QPushButton("Next")
btn_skip = QPushButton("Skip")
btn_clear = QPushButton("Clear")
btn_undo = QPushButton("Undo")
btn_undo10 = QPushButton("Undo×10")
btn_redo = QPushButton("Redo")
btn_load_prev_mask = QPushButton("Load Prev Mask")
btn_auto_segment = QPushButton("Auto Segment")
btn_auto_segment.setEnabled(self.of_cfg.enabled)
row1 = QHBoxLayout()
for b in [
self.btn_prev,
self.btn_next,
btn_skip,
btn_load_prev_mask,
btn_auto_segment,
]:
row1.addWidget(b)
row_tools = QHBoxLayout()
for b in [
self.mc.btn_brush,
self.mc.btn_polygon,
self.mc.btn_fill,
self.mc.btn_del_shape,
self.mc.btn_cancel_poly,
]:
row_tools.addWidget(b)
row2 = QHBoxLayout()
for b in [
btn_clear,
self.mc.btn_erase,
btn_undo,
btn_undo10,
btn_redo,
self.mc.btn_mask,
]:
row2.addWidget(b)
row3 = QHBoxLayout()
row3.addWidget(QLabel("Brush size"))
row3.addWidget(self.mc.brush_slider)
row3.addWidget(self.mc.brush_reset)
row4 = QHBoxLayout()
row4.addWidget(QLabel("Mask Alpha"))
row4.addWidget(self.mc.alpha_slider)
row4.addWidget(self.mc.alpha_reset)
vert_panel = QHBoxLayout()
vert_panel.setContentsMargins(0, 0, 4, 0)
for label_text, slider, reset_btn in [
("Brightness", self.mc.brightness_slider, self.mc.brightness_reset),
("Contrast", self.mc.contrast_slider, self.mc.contrast_reset),
("Gamma", self.mc.gamma_slider, self.mc.gamma_reset),
]:
col = QVBoxLayout()
lbl = QLabel(label_text)
lbl.setAlignment(Qt.AlignmentFlag.AlignHCenter)
col.addWidget(lbl)
col.addWidget(slider, 1)
col.addWidget(reset_btn)
vert_panel.addLayout(col)
canvas_row = QHBoxLayout()
canvas_row.addLayout(vert_panel)
canvas_row.addWidget(self.mc.canvas, 1)
left = QVBoxLayout()
left.addLayout(canvas_row)
left.addLayout(row1)
left.addLayout(row_tools)
left.addLayout(row2)
left.addLayout(row3)
left.addLayout(row4)
left_widget = QWidget()
left_widget.setLayout(left)
right_widget = QWidget()
right_widget.setLayout(question_panel)
main = QHBoxLayout()
main.addWidget(left_widget, 3)
main.addWidget(right_widget, 1)
container = QWidget()
container.setLayout(main)
self.setCentralWidget(container)
self.btn_prev.clicked.connect(self.prev_clip)
self.btn_next.clicked.connect(self.next_clip)
btn_skip.clicked.connect(self.skip_clip)
btn_clear.clicked.connect(self.mc.clear)
btn_undo.clicked.connect(self.mc.undo)
btn_undo10.clicked.connect(self.mc.undo10)
btn_redo.clicked.connect(self.mc.redo)
btn_load_prev_mask.clicked.connect(self.load_prev_mask)
btn_auto_segment.clicked.connect(self.run_optical_flow)
if self._pending_answers:
self._set_answers(self._pending_answers)
self._pending_answers = None
def _build_question_panel(self) -> QVBoxLayout:
vbox = QVBoxLayout()
for section, qs in self.cfg.get_questions():
group = QGroupBox(section)
gvbox = QVBoxLayout()
for key, label, options, default in qs:
gvbox.addWidget(QLabel(label))
btn_group = QButtonGroup(self)
row = QHBoxLayout()
buttons = []
for opt in options:
btn = QRadioButton(opt)
btn_group.addButton(btn)
row.addWidget(btn)
buttons.append(btn)
if default == opt:
btn.setChecked(True)
if default is None and buttons:
buttons[-1].setChecked(True)
self.q_widgets[key] = (btn_group, buttons, options)
gvbox.addLayout(row)
group.setLayout(gvbox)
vbox.addWidget(group)
return vbox
def _set_answers(self, answers: dict):
for key, value in answers.items():
if key not in self.q_widgets:
continue
_, buttons, options = self.q_widgets[key]
for i, btn in enumerate(buttons):
btn.setChecked(options[i] == value)
def _init_timer(self):
self.frame_i = 0
self.timer = QTimer()
self.timer.timeout.connect(self._tick)
self.timer.start(int(1000 / self.fps))
def _tick(self):
self.frame_i = (self.frame_i + 1) % len(self.frames)
self.mc.set_frame(self.frames[self.frame_i])
# ── answers ────────────────────────────────────────────────────
def get_answers(self) -> dict:
out = {}
for key, (_, buttons, options) in self.q_widgets.items():
for i, btn in enumerate(buttons):
if btn.isChecked():
out[key] = options[i]
return out
# ── save helpers ───────────────────────────────────────────────
def _make_overlay(self, frame, alpha=0.4):
overlay = frame.copy()
green = np.zeros_like(frame)
green[..., 1] = 255
m = self.mc.mask.astype(bool)
overlay[m] = (1 - alpha) * overlay[m] + alpha * green[m]
return overlay.astype(np.uint8)
def _save_gif(self, frames, out_path: str, scale=1.0):
h, w = frames[0].shape[:2]
nh, nw = max(1, int(h * scale)), max(1, int(w * scale))
pil_frames = [Image.fromarray(cv2.resize(f, (nw, nh))) for f in frames]
if self.fs is None:
pil_frames[0].save(
out_path,
save_all=True,
append_images=pil_frames[1:],
duration=int(1000 / self.fps),
loop=0,
)
else:
buf = io.BytesIO()
pil_frames[0].save(
buf,
format="GIF",
save_all=True,
append_images=pil_frames[1:],
duration=int(1000 / self.fps),
loop=0,
)
self.fs.pipe(out_path, buf.getvalue())
# ── actions ────────────────────────────────────────────────────
def _save_locked(self):
self.btn_next.setEnabled(False)
self.btn_prev.setEnabled(False)
QApplication.processEvents()
try:
self.save()
finally:
self.btn_next.setEnabled(True)
self.btn_prev.setEnabled(self.history_pos > 0)
def save(self):
out = fsjoin(self.out_dir, fsstem(self.filename))
self._fs_makedirs(out)
mask_full = cv2.resize(
self.mc.mask.astype(np.uint8),
(self.w, self.h),
interpolation=cv2.INTER_NEAREST,
)
fn = self.cfg.filenames
self._pil_save(Image.fromarray(mask_full * 255), fsjoin(out, fn.mask))
self._json_write(self.get_answers(), fsjoin(out, fn.metadata))
mid = len(self.frames) // 2
frame = self.frames[mid]
self._pil_save(Image.fromarray(frame), fsjoin(out, fn.frame))
self._pil_save(
Image.fromarray(self._make_overlay(frame)), fsjoin(out, fn.overlay)
)
if self.extras:
self._pil_save(
Image.fromarray((self.mc.mask * 255).astype(np.uint8)),
fsjoin(out, fn.mask_vis),
)
overlay_frames = [self._make_overlay(f) for f in self.frames]
self._save_gif(self.frames, fsjoin(out, fn.gif_original_hires), scale=1.0)
self._save_gif(self.frames, fsjoin(out, fn.gif_original_lowres), scale=0.5)
self._save_gif(overlay_frames, fsjoin(out, fn.gif_overlay_hires), scale=1.0)
self._save_gif(
overlay_frames, fsjoin(out, fn.gif_overlay_lowres), scale=0.5
)
print("Saved:", out)
def _switch_ui_to_clip(self):
self.frame_i = 0
self.mc.load_clip(
self.frames,
self.dh,
self.dw,
mask=self._read_saved_mask(),
title=fsname(self.filename),
)
if self._pending_answers:
self._set_answers(self._pending_answers)
self._pending_answers = None
self.btn_prev.setEnabled(self.history_pos > 0)
def _advance_clip(self):
if self.history_pos < len(self.history) - 1:
self.history_pos += 1
self._load_clip(path=self.history[self.history_pos])
self._switch_ui_to_clip()
return
try:
self._load_clip()
except RuntimeError:
msg = QMessageBox(self)
msg.setWindowTitle("All done!")
msg.setText("You have reached the end of all clips.")
msg.setStandardButtons(QMessageBox.StandardButton.Ok)
msg.exec()
QApplication.instance().quit()
return
self._history_push()
self._switch_ui_to_clip()
def prev_clip(self):
if self.history_pos <= 0:
return
self._save_locked()
self.history_pos -= 1
self._load_clip(path=self.history[self.history_pos])
self._switch_ui_to_clip()
def next_clip(self):
mask_path = self._out_path(self.cfg.filenames.mask)
if self._fs_exists(mask_path):
msg = QMessageBox(self)
msg.setWindowTitle("Existing annotation found")
msg.setText(
f"'{fsstem(self.filename)}' already has a saved annotation.\n"
"Replace it with your current work, or keep the existing save?"
)
btn_replace = msg.addButton(
"Replace & Continue", QMessageBox.ButtonRole.AcceptRole
)
btn_keep = msg.addButton(
"Keep Existing & Continue", QMessageBox.ButtonRole.AcceptRole
)
msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
msg.setDefaultButton(btn_replace)
msg.exec()
clicked = msg.clickedButton()
if clicked == btn_replace:
self._save_locked()
self._advance_clip()
elif clicked == btn_keep:
self._advance_clip()
# Cancel: do nothing
else:
self._save_locked()
self._advance_clip()
def skip_clip(self):
self._advance_clip()
def load_prev_mask(self):
try:
idx = self.selector.clips.index(self.filename)
except ValueError:
return
if idx == 0:
QMessageBox.information(
self, "No previous clip", "This is the first clip in the list."
)
return
prev_clip = self.selector.clips[idx - 1]
mask_path = fsjoin(self.out_dir, fsstem(prev_clip), self.cfg.filenames.mask)
if not self._fs_exists(mask_path):
QMessageBox.information(
self,
"No mask found",
f"No saved mask found for '{fsstem(prev_clip)}'.",
)
return
mask_full = np.array(self._pil_open(mask_path).convert("L"))
mask = cv2.resize(
(mask_full > 127).astype(np.uint8),
(self.dw, self.dh),
interpolation=cv2.INTER_NEAREST,
)
self.mc.set_mask(mask)
def run_optical_flow(self):
mask = compute_optical_flow_mask(
self.frames,
self.fps,
self.of_cfg.norm_squared_threshold,
self.of_cfg.gaussian_kernel,
self.of_cfg.brightness_range,
)
self.mc.set_mask(mask)