Made project river-agnostic
This commit is contained in:
1
src/clip_annotator/__init__.py
Normal file
1
src/clip_annotator/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
76
src/clip_annotator/annotation_script.py
Normal file
76
src/clip_annotator/annotation_script.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from matplotlib import use
|
||||
|
||||
|
||||
use("QtAgg")
|
||||
|
||||
from PySide6.QtWidgets import QApplication, QMessageBox
|
||||
|
||||
from .annotator import Annotator
|
||||
from .config import load_config
|
||||
from .filesystem import make_fs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default="config/config.yaml",
|
||||
help="Path to config YAML file (default: config/config.yaml)",
|
||||
)
|
||||
parser.add_argument("--data", default=None, help="Override data_dir from config")
|
||||
parser.add_argument("--out", default=None, help="Override out_dir from config")
|
||||
parser.add_argument("--clips", default=None, help="Override clips_file from config")
|
||||
parser.add_argument(
|
||||
"--clip", default=None, help="Stem name of a specific clip to load"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extras",
|
||||
action="store_true",
|
||||
help="Also save GIFs, frame PNG, overlay PNG, and mask_vis PNG alongside the mask.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-skip",
|
||||
action="store_true",
|
||||
help="Show already-annotated clips instead of skipping them.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
args = parse_args()
|
||||
|
||||
cfg = load_config(Path(args.config))
|
||||
if args.data:
|
||||
cfg.data_dir = args.data
|
||||
if args.out:
|
||||
cfg.out_dir = args.out
|
||||
if args.clips:
|
||||
cfg.clips_file = args.clips
|
||||
|
||||
fs = make_fs(cfg.storage)
|
||||
|
||||
app = QApplication([])
|
||||
try:
|
||||
win = Annotator(
|
||||
cfg,
|
||||
clip=args.clip,
|
||||
extras=args.extras,
|
||||
skip_annotated=not args.no_skip,
|
||||
fs=fs,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
QMessageBox.information(None, "No clips", str(e))
|
||||
sys.exit(0)
|
||||
win.show()
|
||||
app.exec()
|
||||
496
src/clip_annotator/annotator.py
Normal file
496
src/clip_annotator/annotator.py
Normal file
@@ -0,0 +1,496 @@
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PySide6.QtCore import Qt, QTimer
|
||||
from PySide6.QtWidgets import (
|
||||
QApplication,
|
||||
QButtonGroup,
|
||||
QGroupBox,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QMainWindow,
|
||||
QMessageBox,
|
||||
QPushButton,
|
||||
QRadioButton,
|
||||
QVBoxLayout,
|
||||
QWidget,
|
||||
)
|
||||
|
||||
from .clip_selector import ClipSelector
|
||||
from .compute_optical_flow import compute_optical_flow_mask
|
||||
from .config import AppConfig, load_optical_flow_config
|
||||
from .filesystem import fsjoin, fsname, fsstem
|
||||
from .mask_canvas import MaskCanvas
|
||||
from .video_loader import load_frames
|
||||
|
||||
|
||||
class Annotator(QMainWindow):
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig,
|
||||
clip: str = None,
|
||||
extras: bool = False,
|
||||
skip_annotated: bool = True,
|
||||
fs=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = config
|
||||
self.fs = fs
|
||||
self.out_dir = config.out_dir
|
||||
self.extras = extras
|
||||
self.of_cfg = load_optical_flow_config(Path(config.optical_flow_config_file))
|
||||
|
||||
self.selector = ClipSelector(
|
||||
data_dir=config.data_dir,
|
||||
out_dir=self.out_dir,
|
||||
clips_file=Path(config.clips_file),
|
||||
mask_filename=config.filenames.mask,
|
||||
zip_extension=config.filenames.zip_extension,
|
||||
skip_annotated=skip_annotated,
|
||||
fs=fs,
|
||||
)
|
||||
|
||||
self.history: list[str] = []
|
||||
self.history_pos: int = -1
|
||||
|
||||
self.setWindowTitle("Clip Annotator")
|
||||
self._load_clip(specific=clip)
|
||||
self._history_push()
|
||||
self._init_ui()
|
||||
self._init_timer()
|
||||
|
||||
# ── filesystem helpers ─────────────────────────────────────────
|
||||
def _out_path(self, *parts: str) -> str:
|
||||
return fsjoin(self.out_dir, fsstem(self.filename), *parts)
|
||||
|
||||
def _fs_exists(self, path: str) -> bool:
|
||||
if self.fs is None:
|
||||
return Path(path).exists()
|
||||
return self.fs.exists(path)
|
||||
|
||||
def _fs_makedirs(self, path: str):
|
||||
if self.fs is None:
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
self.fs.makedirs(path, exist_ok=True)
|
||||
|
||||
def _pil_open(self, path: str) -> Image.Image:
|
||||
if self.fs is None:
|
||||
return Image.open(path)
|
||||
with self.fs.open(path, "rb") as f:
|
||||
return Image.open(io.BytesIO(f.read()))
|
||||
|
||||
def _pil_save(self, img: Image.Image, path: str):
|
||||
if self.fs is None:
|
||||
img.save(path)
|
||||
else:
|
||||
ext = str(path).rsplit(".", 1)[-1].lower()
|
||||
fmt = "JPEG" if ext in ("jpg", "jpeg") else ext.upper()
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format=fmt)
|
||||
self.fs.pipe(path, buf.getvalue())
|
||||
|
||||
def _json_read(self, path: str):
|
||||
if self.fs is None:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
with self.fs.open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _json_write(self, data, path: str):
|
||||
if self.fs is None:
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
else:
|
||||
self.fs.pipe(path, json.dumps(data, indent=2).encode())
|
||||
|
||||
# ── clip loading ───────────────────────────────────────────────
|
||||
def _load_clip(self, specific: str = None, path: str = None):
|
||||
if path is not None:
|
||||
self.filename = path
|
||||
else:
|
||||
self.filename = self.selector.next(specific=specific)
|
||||
self.frames, self.fps, self.dh, self.dw, self.h, self.w = load_frames(
|
||||
self.filename,
|
||||
self.cfg.max_frames,
|
||||
self.cfg.display_max,
|
||||
self.cfg.fps_fallback,
|
||||
self.cfg.filenames.video_in_zip,
|
||||
self.cfg.filenames.video_tmp_suffix,
|
||||
fs=self.fs,
|
||||
)
|
||||
self._pending_answers = self._read_saved_answers()
|
||||
|
||||
def _history_push(self):
|
||||
del self.history[self.history_pos + 1 :]
|
||||
self.history.append(self.filename)
|
||||
self.history_pos = len(self.history) - 1
|
||||
|
||||
def _read_saved_mask(self):
|
||||
mask_path = self._out_path(self.cfg.filenames.mask)
|
||||
if not self._fs_exists(mask_path):
|
||||
return None
|
||||
mask_full = np.array(self._pil_open(mask_path).convert("L"))
|
||||
return cv2.resize(
|
||||
(mask_full > 127).astype(np.uint8),
|
||||
(self.dw, self.dh),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
def _read_saved_answers(self):
|
||||
meta_path = self._out_path(self.cfg.filenames.metadata)
|
||||
if not self._fs_exists(meta_path):
|
||||
return None
|
||||
return self._json_read(meta_path)
|
||||
|
||||
# ── UI setup ───────────────────────────────────────────────────
|
||||
def _init_ui(self):
|
||||
self.mc = MaskCanvas(self.frames, self.dh, self.dw)
|
||||
self.mc.set_title(fsname(self.filename))
|
||||
self.mc.reset(self._read_saved_mask())
|
||||
|
||||
self.q_widgets = {}
|
||||
question_panel = self._build_question_panel()
|
||||
|
||||
self.btn_prev = QPushButton("Previous")
|
||||
self.btn_prev.setEnabled(False)
|
||||
btn_next = QPushButton("Next")
|
||||
btn_skip = QPushButton("Skip")
|
||||
btn_clear = QPushButton("Clear")
|
||||
btn_undo = QPushButton("Undo")
|
||||
btn_undo10 = QPushButton("Undo×10")
|
||||
btn_redo = QPushButton("Redo")
|
||||
btn_load_prev_mask = QPushButton("Load Prev Mask")
|
||||
btn_auto_segment = QPushButton("Auto Segment")
|
||||
btn_auto_segment.setEnabled(self.of_cfg.enabled)
|
||||
|
||||
row1 = QHBoxLayout()
|
||||
for b in [
|
||||
self.btn_prev,
|
||||
btn_next,
|
||||
btn_skip,
|
||||
btn_load_prev_mask,
|
||||
btn_auto_segment,
|
||||
]:
|
||||
row1.addWidget(b)
|
||||
|
||||
row_tools = QHBoxLayout()
|
||||
for b in [
|
||||
self.mc.btn_brush,
|
||||
self.mc.btn_polygon,
|
||||
self.mc.btn_fill,
|
||||
self.mc.btn_del_shape,
|
||||
self.mc.btn_cancel_poly,
|
||||
]:
|
||||
row_tools.addWidget(b)
|
||||
|
||||
row2 = QHBoxLayout()
|
||||
for b in [
|
||||
btn_clear,
|
||||
self.mc.btn_erase,
|
||||
btn_undo,
|
||||
btn_undo10,
|
||||
btn_redo,
|
||||
self.mc.btn_mask,
|
||||
]:
|
||||
row2.addWidget(b)
|
||||
|
||||
row3 = QHBoxLayout()
|
||||
row3.addWidget(QLabel("Brush size"))
|
||||
row3.addWidget(self.mc.brush_slider)
|
||||
row3.addWidget(self.mc.brush_reset)
|
||||
|
||||
row4 = QHBoxLayout()
|
||||
row4.addWidget(QLabel("Mask Alpha"))
|
||||
row4.addWidget(self.mc.alpha_slider)
|
||||
row4.addWidget(self.mc.alpha_reset)
|
||||
|
||||
vert_panel = QHBoxLayout()
|
||||
vert_panel.setContentsMargins(0, 0, 4, 0)
|
||||
for label_text, slider, reset_btn in [
|
||||
("Brightness", self.mc.brightness_slider, self.mc.brightness_reset),
|
||||
("Contrast", self.mc.contrast_slider, self.mc.contrast_reset),
|
||||
("Gamma", self.mc.gamma_slider, self.mc.gamma_reset),
|
||||
]:
|
||||
col = QVBoxLayout()
|
||||
lbl = QLabel(label_text)
|
||||
lbl.setAlignment(Qt.AlignmentFlag.AlignHCenter)
|
||||
col.addWidget(lbl)
|
||||
col.addWidget(slider, 1)
|
||||
col.addWidget(reset_btn)
|
||||
vert_panel.addLayout(col)
|
||||
|
||||
canvas_row = QHBoxLayout()
|
||||
canvas_row.addLayout(vert_panel)
|
||||
canvas_row.addWidget(self.mc.canvas, 1)
|
||||
|
||||
left = QVBoxLayout()
|
||||
left.addLayout(canvas_row)
|
||||
left.addLayout(row1)
|
||||
left.addLayout(row_tools)
|
||||
left.addLayout(row2)
|
||||
left.addLayout(row3)
|
||||
left.addLayout(row4)
|
||||
|
||||
left_widget = QWidget()
|
||||
left_widget.setLayout(left)
|
||||
right_widget = QWidget()
|
||||
right_widget.setLayout(question_panel)
|
||||
|
||||
main = QHBoxLayout()
|
||||
main.addWidget(left_widget, 3)
|
||||
main.addWidget(right_widget, 1)
|
||||
|
||||
container = QWidget()
|
||||
container.setLayout(main)
|
||||
self.setCentralWidget(container)
|
||||
|
||||
self.btn_prev.clicked.connect(self.prev_clip)
|
||||
btn_next.clicked.connect(self.next_clip)
|
||||
btn_skip.clicked.connect(self.skip_clip)
|
||||
btn_clear.clicked.connect(self.mc.clear)
|
||||
btn_undo.clicked.connect(self.mc.undo)
|
||||
btn_undo10.clicked.connect(self.mc.undo10)
|
||||
btn_redo.clicked.connect(self.mc.redo)
|
||||
btn_load_prev_mask.clicked.connect(self.load_prev_mask)
|
||||
btn_auto_segment.clicked.connect(self.run_optical_flow)
|
||||
|
||||
if self._pending_answers:
|
||||
self._set_answers(self._pending_answers)
|
||||
self._pending_answers = None
|
||||
|
||||
def _build_question_panel(self) -> QVBoxLayout:
|
||||
vbox = QVBoxLayout()
|
||||
for section, qs in self.cfg.get_questions():
|
||||
group = QGroupBox(section)
|
||||
gvbox = QVBoxLayout()
|
||||
for key, label, options, default in qs:
|
||||
gvbox.addWidget(QLabel(label))
|
||||
btn_group = QButtonGroup(self)
|
||||
row = QHBoxLayout()
|
||||
buttons = []
|
||||
for opt in options:
|
||||
btn = QRadioButton(opt)
|
||||
btn_group.addButton(btn)
|
||||
row.addWidget(btn)
|
||||
buttons.append(btn)
|
||||
if default == opt:
|
||||
btn.setChecked(True)
|
||||
if default is None and buttons:
|
||||
buttons[-1].setChecked(True)
|
||||
self.q_widgets[key] = (btn_group, buttons, options)
|
||||
gvbox.addLayout(row)
|
||||
group.setLayout(gvbox)
|
||||
vbox.addWidget(group)
|
||||
return vbox
|
||||
|
||||
def _set_answers(self, answers: dict):
|
||||
for key, value in answers.items():
|
||||
if key not in self.q_widgets:
|
||||
continue
|
||||
_, buttons, options = self.q_widgets[key]
|
||||
for i, btn in enumerate(buttons):
|
||||
btn.setChecked(options[i] == value)
|
||||
|
||||
def _init_timer(self):
|
||||
self.frame_i = 0
|
||||
self.timer = QTimer()
|
||||
self.timer.timeout.connect(self._tick)
|
||||
self.timer.start(int(1000 / self.fps))
|
||||
|
||||
def _tick(self):
|
||||
self.frame_i = (self.frame_i + 1) % len(self.frames)
|
||||
self.mc.set_frame(self.frames[self.frame_i])
|
||||
|
||||
# ── answers ────────────────────────────────────────────────────
|
||||
def get_answers(self) -> dict:
|
||||
out = {}
|
||||
for key, (_, buttons, options) in self.q_widgets.items():
|
||||
for i, btn in enumerate(buttons):
|
||||
if btn.isChecked():
|
||||
out[key] = options[i]
|
||||
return out
|
||||
|
||||
# ── save helpers ───────────────────────────────────────────────
|
||||
def _make_overlay(self, frame, alpha=0.4):
|
||||
overlay = frame.copy()
|
||||
green = np.zeros_like(frame)
|
||||
green[..., 1] = 255
|
||||
m = self.mc.mask.astype(bool)
|
||||
overlay[m] = (1 - alpha) * overlay[m] + alpha * green[m]
|
||||
return overlay.astype(np.uint8)
|
||||
|
||||
def _save_gif(self, frames, out_path: str, scale=1.0):
|
||||
h, w = frames[0].shape[:2]
|
||||
nh, nw = max(1, int(h * scale)), max(1, int(w * scale))
|
||||
pil_frames = [Image.fromarray(cv2.resize(f, (nw, nh))) for f in frames]
|
||||
if self.fs is None:
|
||||
pil_frames[0].save(
|
||||
out_path,
|
||||
save_all=True,
|
||||
append_images=pil_frames[1:],
|
||||
duration=int(1000 / self.fps),
|
||||
loop=0,
|
||||
)
|
||||
else:
|
||||
buf = io.BytesIO()
|
||||
pil_frames[0].save(
|
||||
buf,
|
||||
format="GIF",
|
||||
save_all=True,
|
||||
append_images=pil_frames[1:],
|
||||
duration=int(1000 / self.fps),
|
||||
loop=0,
|
||||
)
|
||||
self.fs.pipe(out_path, buf.getvalue())
|
||||
|
||||
# ── actions ────────────────────────────────────────────────────
|
||||
def save(self):
|
||||
out = fsjoin(self.out_dir, fsstem(self.filename))
|
||||
self._fs_makedirs(out)
|
||||
|
||||
mask_full = cv2.resize(
|
||||
self.mc.mask.astype(np.uint8),
|
||||
(self.w, self.h),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
fn = self.cfg.filenames
|
||||
self._pil_save(Image.fromarray(mask_full * 255), fsjoin(out, fn.mask))
|
||||
self._json_write(self.get_answers(), fsjoin(out, fn.metadata))
|
||||
|
||||
mid = len(self.frames) // 2
|
||||
frame = self.frames[mid]
|
||||
self._pil_save(Image.fromarray(frame), fsjoin(out, fn.frame))
|
||||
self._pil_save(
|
||||
Image.fromarray(self._make_overlay(frame)), fsjoin(out, fn.overlay)
|
||||
)
|
||||
|
||||
if self.extras:
|
||||
self._pil_save(
|
||||
Image.fromarray((self.mc.mask * 255).astype(np.uint8)),
|
||||
fsjoin(out, fn.mask_vis),
|
||||
)
|
||||
overlay_frames = [self._make_overlay(f) for f in self.frames]
|
||||
self._save_gif(self.frames, fsjoin(out, fn.gif_original_hires), scale=1.0)
|
||||
self._save_gif(self.frames, fsjoin(out, fn.gif_original_lowres), scale=0.5)
|
||||
self._save_gif(overlay_frames, fsjoin(out, fn.gif_overlay_hires), scale=1.0)
|
||||
self._save_gif(
|
||||
overlay_frames, fsjoin(out, fn.gif_overlay_lowres), scale=0.5
|
||||
)
|
||||
|
||||
print("Saved:", out)
|
||||
|
||||
def _switch_ui_to_clip(self):
|
||||
self.frame_i = 0
|
||||
self.mc.load_clip(
|
||||
self.frames,
|
||||
self.dh,
|
||||
self.dw,
|
||||
mask=self._read_saved_mask(),
|
||||
title=fsname(self.filename),
|
||||
)
|
||||
if self._pending_answers:
|
||||
self._set_answers(self._pending_answers)
|
||||
self._pending_answers = None
|
||||
self.btn_prev.setEnabled(self.history_pos > 0)
|
||||
|
||||
def _advance_clip(self):
|
||||
if self.history_pos < len(self.history) - 1:
|
||||
self.history_pos += 1
|
||||
self._load_clip(path=self.history[self.history_pos])
|
||||
self._switch_ui_to_clip()
|
||||
return
|
||||
try:
|
||||
self._load_clip()
|
||||
except RuntimeError:
|
||||
msg = QMessageBox(self)
|
||||
msg.setWindowTitle("All done!")
|
||||
msg.setText("You have reached the end of all clips.")
|
||||
msg.setStandardButtons(QMessageBox.StandardButton.Ok)
|
||||
msg.exec()
|
||||
QApplication.instance().quit()
|
||||
return
|
||||
self._history_push()
|
||||
self._switch_ui_to_clip()
|
||||
|
||||
def prev_clip(self):
|
||||
if self.history_pos <= 0:
|
||||
return
|
||||
self.save()
|
||||
self.history_pos -= 1
|
||||
self._load_clip(path=self.history[self.history_pos])
|
||||
self._switch_ui_to_clip()
|
||||
|
||||
def next_clip(self):
|
||||
mask_path = self._out_path(self.cfg.filenames.mask)
|
||||
if self._fs_exists(mask_path):
|
||||
msg = QMessageBox(self)
|
||||
msg.setWindowTitle("Existing annotation found")
|
||||
msg.setText(
|
||||
f"'{fsstem(self.filename)}' already has a saved annotation.\n"
|
||||
"Replace it with your current work, or keep the existing save?"
|
||||
)
|
||||
btn_replace = msg.addButton(
|
||||
"Replace & Continue", QMessageBox.ButtonRole.AcceptRole
|
||||
)
|
||||
btn_keep = msg.addButton(
|
||||
"Keep Existing & Continue", QMessageBox.ButtonRole.AcceptRole
|
||||
)
|
||||
msg.addButton("Cancel", QMessageBox.ButtonRole.RejectRole)
|
||||
msg.setDefaultButton(btn_replace)
|
||||
msg.exec()
|
||||
clicked = msg.clickedButton()
|
||||
if clicked == btn_replace:
|
||||
self.save()
|
||||
self._advance_clip()
|
||||
elif clicked == btn_keep:
|
||||
self._advance_clip()
|
||||
# Cancel: do nothing
|
||||
else:
|
||||
self.save()
|
||||
self._advance_clip()
|
||||
|
||||
def skip_clip(self):
|
||||
self._advance_clip()
|
||||
|
||||
def load_prev_mask(self):
|
||||
try:
|
||||
idx = self.selector.clips.index(self.filename)
|
||||
except ValueError:
|
||||
return
|
||||
if idx == 0:
|
||||
QMessageBox.information(
|
||||
self, "No previous clip", "This is the first clip in the list."
|
||||
)
|
||||
return
|
||||
prev_clip = self.selector.clips[idx - 1]
|
||||
mask_path = fsjoin(self.out_dir, fsstem(prev_clip), self.cfg.filenames.mask)
|
||||
if not self._fs_exists(mask_path):
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"No mask found",
|
||||
f"No saved mask found for '{fsstem(prev_clip)}'.",
|
||||
)
|
||||
return
|
||||
mask_full = np.array(self._pil_open(mask_path).convert("L"))
|
||||
mask = cv2.resize(
|
||||
(mask_full > 127).astype(np.uint8),
|
||||
(self.dw, self.dh),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
self.mc.set_mask(mask)
|
||||
|
||||
def run_optical_flow(self):
|
||||
mask = compute_optical_flow_mask(
|
||||
self.frames,
|
||||
self.fps,
|
||||
self.of_cfg.norm_squared_threshold,
|
||||
self.of_cfg.gaussian_kernel,
|
||||
self.of_cfg.brightness_range,
|
||||
)
|
||||
self.mc.set_mask(mask)
|
||||
75
src/clip_annotator/clip_selector.py
Normal file
75
src/clip_annotator/clip_selector.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from pathlib import Path
|
||||
|
||||
from .filesystem import fsjoin, fsstem
|
||||
|
||||
|
||||
class ClipSelector:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir,
|
||||
out_dir,
|
||||
clips_file: Path,
|
||||
mask_filename: str = "mask.png",
|
||||
zip_extension: str = ".zip",
|
||||
skip_annotated: bool = True,
|
||||
fs=None,
|
||||
):
|
||||
self.data_dir = str(data_dir)
|
||||
self.out_dir = str(out_dir)
|
||||
self.mask_filename = mask_filename
|
||||
self.zip_extension = zip_extension
|
||||
self.skip_annotated = skip_annotated
|
||||
self.fs = fs
|
||||
self.clips = self._load_clips(clips_file)
|
||||
self.index = 0
|
||||
|
||||
def _load_clips(self, clips_file: Path) -> list:
|
||||
lines = clips_file.read_text().splitlines()
|
||||
return [
|
||||
fsjoin(self.data_dir, name.strip())
|
||||
for name in lines
|
||||
if name.strip() and not name.strip().startswith("#")
|
||||
]
|
||||
|
||||
def is_annotated(self, path) -> bool:
|
||||
mask_path = fsjoin(self.out_dir, fsstem(path), self.mask_filename)
|
||||
if self.fs is None:
|
||||
return Path(mask_path).exists()
|
||||
return self.fs.exists(mask_path)
|
||||
|
||||
def next(self, specific: str = None) -> str:
|
||||
if specific:
|
||||
return self._resolve_specific(specific)
|
||||
return self._pick_next()
|
||||
|
||||
def _resolve_specific(self, specific: str) -> str:
|
||||
if self.fs is None:
|
||||
data_dir = Path(self.data_dir)
|
||||
matches = list(data_dir.glob(f"{specific}{self.zip_extension}"))
|
||||
if not matches:
|
||||
p = data_dir / specific
|
||||
matches = [p] if p.exists() else []
|
||||
if not matches:
|
||||
raise FileNotFoundError(
|
||||
f"Clip '{specific}' not found in {self.data_dir}"
|
||||
)
|
||||
return str(matches[0])
|
||||
else:
|
||||
pattern = fsjoin(self.data_dir, f"{specific}{self.zip_extension}")
|
||||
matches = self.fs.glob(pattern)
|
||||
if not matches:
|
||||
p = fsjoin(self.data_dir, specific)
|
||||
matches = [p] if self.fs.exists(p) else []
|
||||
if not matches:
|
||||
raise FileNotFoundError(
|
||||
f"Clip '{specific}' not found in {self.data_dir}"
|
||||
)
|
||||
return matches[0]
|
||||
|
||||
def _pick_next(self) -> str:
|
||||
while self.index < len(self.clips):
|
||||
clip = self.clips[self.index]
|
||||
self.index += 1
|
||||
if not self.skip_annotated or not self.is_annotated(clip):
|
||||
return clip
|
||||
raise RuntimeError("No remaining clips to annotate")
|
||||
49
src/clip_annotator/compute_optical_flow.py
Normal file
49
src/clip_annotator/compute_optical_flow.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_optical_flow_mask(
|
||||
frames: list[np.ndarray],
|
||||
fps: float,
|
||||
norm_squared_threshold: float,
|
||||
gaussian_kernel: tuple[int, int],
|
||||
brightness_range: tuple[int, int],
|
||||
) -> np.ndarray:
|
||||
"""Return a binary mask (uint8, values 0/1) from optical flow + brightness."""
|
||||
if len(frames) < 2:
|
||||
return np.zeros(frames[0].shape[:2], dtype=np.uint8)
|
||||
|
||||
frames_arr = np.stack(frames).astype(np.float64)
|
||||
frames_sub_mean = frames_arr - np.mean(frames_arr, axis=0)
|
||||
mn, mx = frames_sub_mean.min(), frames_sub_mean.max()
|
||||
if mx > mn:
|
||||
standardized = ((frames_sub_mean - mn) / (mx - mn) * 255).astype(np.uint8)
|
||||
else:
|
||||
standardized = np.zeros_like(frames_arr, dtype=np.uint8)
|
||||
|
||||
N = len(standardized)
|
||||
gray = np.stack([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in standardized])
|
||||
|
||||
flow_data = np.zeros((N - 1,) + gray.shape[1:] + (2,))
|
||||
for i in range(N - 1):
|
||||
flow_data[i] = fps * cv2.optflow.calcOpticalFlowSparseToDense(
|
||||
gray[i], gray[i + 1]
|
||||
)
|
||||
|
||||
optical_flow = np.median(flow_data, axis=0)
|
||||
|
||||
flow_norm_sq = np.sum(optical_flow**2, axis=-1)
|
||||
max_norm = np.max(flow_norm_sq)
|
||||
if max_norm > 0:
|
||||
flow_mask = flow_norm_sq >= max_norm * norm_squared_threshold**2
|
||||
else:
|
||||
flow_mask = np.zeros(flow_norm_sq.shape, dtype=bool)
|
||||
|
||||
reference_frame = frames[len(frames) // 2]
|
||||
smoothed = cv2.GaussianBlur(reference_frame, gaussian_kernel, 0)
|
||||
gray_ref = cv2.cvtColor(smoothed, cv2.COLOR_RGB2GRAY)
|
||||
brightness_mask = (gray_ref > brightness_range[0]) & (
|
||||
gray_ref < brightness_range[1]
|
||||
)
|
||||
|
||||
return np.logical_and(brightness_mask, flow_mask).astype(np.uint8)
|
||||
94
src/clip_annotator/config.py
Normal file
94
src/clip_annotator/config.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilenameConfig:
|
||||
video_in_zip: str = "left.mp4"
|
||||
video_tmp_suffix: str = ".mp4"
|
||||
zip_extension: str = ".zip"
|
||||
mask: str = "mask.png"
|
||||
metadata: str = "metadata.json"
|
||||
frame: str = "frame.png"
|
||||
overlay: str = "overlay.png"
|
||||
mask_vis: str = "mask_vis.png"
|
||||
gif_original_hires: str = "video_original_hires.gif"
|
||||
gif_original_lowres: str = "video_original_lowres.gif"
|
||||
gif_overlay_hires: str = "video_overlay_hires.gif"
|
||||
gif_overlay_lowres: str = "video_overlay_lowres.gif"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
storage: str
|
||||
data_dir: str
|
||||
out_dir: str
|
||||
optical_flow_config_file: str
|
||||
questions_config_file: str
|
||||
display_max: int = 480
|
||||
fps_fallback: int = 25
|
||||
max_frames: int = 100
|
||||
clips_file: str = "config/clips.txt"
|
||||
filenames: FilenameConfig = field(default_factory=FilenameConfig)
|
||||
questions: list = field(default_factory=list, init=False)
|
||||
|
||||
def get_questions(self):
|
||||
return [
|
||||
(
|
||||
s["section"],
|
||||
[
|
||||
(
|
||||
item["key"],
|
||||
item["label"],
|
||||
[str(o) for o in item["options"]],
|
||||
str(item["default"])
|
||||
if item.get("default") is not None
|
||||
else None,
|
||||
)
|
||||
for item in s["items"]
|
||||
],
|
||||
)
|
||||
for s in self.questions
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpticalFlowConfig:
|
||||
enabled: bool = False
|
||||
norm_squared_threshold: float = 0.3
|
||||
gaussian_kernel: tuple[int, int] = (5, 5)
|
||||
brightness_range: tuple[int, int] = (20, 235)
|
||||
|
||||
|
||||
def load_optical_flow_config(path: Path) -> OpticalFlowConfig:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
data["gaussian_kernel"] = tuple(data["gaussian_kernel"])
|
||||
data["brightness_range"] = tuple(data["brightness_range"])
|
||||
return OpticalFlowConfig(**data)
|
||||
|
||||
|
||||
def load_questions_config(path: Path) -> list:
|
||||
with open(path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def load_config(path: Path) -> AppConfig:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
for required in (
|
||||
"storage",
|
||||
"data_dir",
|
||||
"out_dir",
|
||||
"optical_flow_config_file",
|
||||
"questions_config_file",
|
||||
):
|
||||
if not data.get(required):
|
||||
raise ValueError(f"{path}: missing required field '{required}'.")
|
||||
fn_data = data.pop("filenames", {})
|
||||
cfg = AppConfig(**data)
|
||||
cfg.filenames = FilenameConfig(**fn_data)
|
||||
cfg.questions = load_questions_config(Path(cfg.questions_config_file))
|
||||
return cfg
|
||||
35
src/clip_annotator/filesystem.py
Normal file
35
src/clip_annotator/filesystem.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
|
||||
|
||||
_DEFAULT_ENDPOINT = "https://os.zhdk.cloud.switch.ch"
|
||||
|
||||
|
||||
def make_fs(storage: str):
|
||||
"""Return an S3FileSystem for storage='s3', or None for local."""
|
||||
if storage != "s3":
|
||||
return None
|
||||
import s3fs
|
||||
|
||||
return s3fs.S3FileSystem(
|
||||
key=os.environ["S3_ACCESS_KEY"],
|
||||
secret=os.environ["S3_SECRET_ACCESS_KEY"],
|
||||
client_kwargs={
|
||||
"endpoint_url": os.environ.get("S3_ENDPOINT_URL", _DEFAULT_ENDPOINT)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def fsjoin(base, *parts: str) -> str:
|
||||
"""Join path segments with forward slashes (works for both local and S3)."""
|
||||
return "/".join([str(base).rstrip("/"), *[str(p).strip("/") for p in parts if p]])
|
||||
|
||||
|
||||
def fsstem(path) -> str:
|
||||
"""Filename stem (no extension) for local Path or S3 string."""
|
||||
name = str(path).replace("\\", "/").split("/")[-1]
|
||||
return name.rsplit(".", 1)[0] if "." in name else name
|
||||
|
||||
|
||||
def fsname(path) -> str:
|
||||
"""Filename component (with extension) for local Path or S3 string."""
|
||||
return str(path).replace("\\", "/").split("/")[-1]
|
||||
452
src/clip_annotator/mask_canvas.py
Normal file
452
src/clip_annotator/mask_canvas.py
Normal file
@@ -0,0 +1,452 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.patches import Circle
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtWidgets import QPushButton, QSlider
|
||||
|
||||
|
||||
class MaskCanvas:
|
||||
"""Matplotlib canvas with brush/polygon mask drawing, undo/redo, and erase."""
|
||||
|
||||
_BRUSH_DEFAULT = 5
|
||||
_ALPHA_DEFAULT = 15
|
||||
_BRIGHTNESS_DEFAULT = 0
|
||||
_CONTRAST_DEFAULT = 0
|
||||
_GAMMA_DEFAULT = 100
|
||||
_CLOSE_THRESHOLD = 15 # image-pixel distance to first vertex that closes a polygon
|
||||
|
||||
def __init__(self, frames, dh: int, dw: int):
|
||||
self.dh = dh
|
||||
self.dw = dw
|
||||
|
||||
self.mask = np.zeros((dh, dw), dtype=np.uint8)
|
||||
self.history: list[np.ndarray] = []
|
||||
self.redo_stack: list[np.ndarray] = []
|
||||
self.erase_mode = False
|
||||
self.drawing = False
|
||||
self.mask_visible = True
|
||||
self._current_frame = frames[0]
|
||||
|
||||
self.tool_mode = "brush"
|
||||
self._shapes: list[list[tuple]] = []
|
||||
self._current_poly: list[tuple] = []
|
||||
self._poly_artists: list = []
|
||||
self._mouse_pos: tuple | None = None
|
||||
|
||||
self._build_figure(frames)
|
||||
self._build_controls()
|
||||
self._connect_events()
|
||||
|
||||
def _build_figure(self, frames):
|
||||
self.fig = Figure(figsize=(self.dw / 80, self.dh / 80))
|
||||
self.canvas = FigureCanvas(self.fig)
|
||||
self.ax = self.fig.add_subplot(111)
|
||||
self.ax.axis("off")
|
||||
self.img_artist = self.ax.imshow(frames[0])
|
||||
self.mask_artist = self.ax.imshow(np.zeros((self.dh, self.dw, 4)))
|
||||
self.title_text = self.ax.set_title("", fontsize=10, pad=4)
|
||||
self.brush_circle = Circle(
|
||||
(0, 0), radius=5, fill=False, color="white", linewidth=1.5, visible=False
|
||||
)
|
||||
self.ax.add_patch(self.brush_circle)
|
||||
self.ax.autoscale(False) # prevent polygon plot() calls from expanding the view
|
||||
|
||||
def _build_controls(self):
|
||||
self.btn_erase = QPushButton("Eraser")
|
||||
self.btn_mask = QPushButton("Hide Mask")
|
||||
|
||||
self.btn_brush = QPushButton("Brush")
|
||||
self.btn_brush.setStyleSheet("background-color: #4488ff; color: white;")
|
||||
self.btn_polygon = QPushButton("Polygon")
|
||||
self.btn_fill = QPushButton("Fill")
|
||||
self.btn_fill.setEnabled(False)
|
||||
self.btn_del_shape = QPushButton("Del Shape")
|
||||
self.btn_del_shape.setEnabled(False)
|
||||
self.btn_cancel_poly = QPushButton("Cancel Current Poly")
|
||||
|
||||
self.brush_slider = QSlider(Qt.Horizontal)
|
||||
self.brush_slider.setRange(2, 50)
|
||||
self.brush_slider.setValue(self._BRUSH_DEFAULT)
|
||||
self.brush_reset = QPushButton("↺")
|
||||
self.brush_reset.setFixedWidth(28)
|
||||
|
||||
self.alpha_slider = QSlider(Qt.Horizontal)
|
||||
self.alpha_slider.setRange(0, 100)
|
||||
self.alpha_slider.setValue(self._ALPHA_DEFAULT)
|
||||
self.alpha_reset = QPushButton("↺")
|
||||
self.alpha_reset.setFixedWidth(28)
|
||||
|
||||
self.brightness_slider = QSlider(Qt.Vertical)
|
||||
self.brightness_slider.setRange(-100, 100)
|
||||
self.brightness_slider.setValue(self._BRIGHTNESS_DEFAULT)
|
||||
self.brightness_reset = QPushButton("↺")
|
||||
self.brightness_reset.setFixedWidth(28)
|
||||
|
||||
self.contrast_slider = QSlider(Qt.Vertical)
|
||||
self.contrast_slider.setRange(-100, 100)
|
||||
self.contrast_slider.setValue(self._CONTRAST_DEFAULT)
|
||||
self.contrast_reset = QPushButton("↺")
|
||||
self.contrast_reset.setFixedWidth(28)
|
||||
|
||||
self.gamma_slider = QSlider(Qt.Vertical)
|
||||
self.gamma_slider.setRange(10, 300)
|
||||
self.gamma_slider.setValue(self._GAMMA_DEFAULT)
|
||||
self.gamma_reset = QPushButton("↺")
|
||||
self.gamma_reset.setFixedWidth(28)
|
||||
|
||||
def _connect_events(self):
|
||||
self.canvas.mpl_connect("button_press_event", self._on_press)
|
||||
self.canvas.mpl_connect("motion_notify_event", self._on_move)
|
||||
self.canvas.mpl_connect("button_release_event", self._on_release)
|
||||
self.canvas.mpl_connect("axes_leave_event", self._on_axes_leave)
|
||||
self.btn_erase.clicked.connect(self.toggle_erase)
|
||||
self.btn_mask.clicked.connect(self.toggle_mask)
|
||||
self.btn_brush.clicked.connect(lambda: self.set_tool_mode("brush"))
|
||||
self.btn_polygon.clicked.connect(lambda: self.set_tool_mode("polygon"))
|
||||
self.btn_fill.clicked.connect(lambda: self.set_tool_mode("fill"))
|
||||
self.btn_del_shape.clicked.connect(self.delete_last_shape)
|
||||
self.btn_cancel_poly.clicked.connect(self.cancel_polygon)
|
||||
self.alpha_slider.valueChanged.connect(self.redraw)
|
||||
self.brightness_slider.valueChanged.connect(self._refresh_frame)
|
||||
self.contrast_slider.valueChanged.connect(self._refresh_frame)
|
||||
self.gamma_slider.valueChanged.connect(self._refresh_frame)
|
||||
self.brush_reset.clicked.connect(
|
||||
lambda: self.brush_slider.setValue(self._BRUSH_DEFAULT)
|
||||
)
|
||||
self.alpha_reset.clicked.connect(
|
||||
lambda: self.alpha_slider.setValue(self._ALPHA_DEFAULT)
|
||||
)
|
||||
self.brightness_reset.clicked.connect(
|
||||
lambda: self.brightness_slider.setValue(self._BRIGHTNESS_DEFAULT)
|
||||
)
|
||||
self.contrast_reset.clicked.connect(
|
||||
lambda: self.contrast_slider.setValue(self._CONTRAST_DEFAULT)
|
||||
)
|
||||
self.gamma_reset.clicked.connect(
|
||||
lambda: self.gamma_slider.setValue(self._GAMMA_DEFAULT)
|
||||
)
|
||||
|
||||
# ── clip transition ────────────────────────────────────────────
|
||||
def load_clip(self, frames, dh: int, dw: int, mask=None, title: str = ""):
|
||||
self.dh = dh
|
||||
self.dw = dw
|
||||
self.mask = mask if mask is not None else np.zeros((dh, dw), dtype=np.uint8)
|
||||
self.history = []
|
||||
self.redo_stack = []
|
||||
self._current_frame = frames[0]
|
||||
self._clear_poly_state()
|
||||
self.img_artist.set_data(self._apply_image_adjustments(frames[0]))
|
||||
self.ax.set_xlim(-0.5, dw - 0.5)
|
||||
self.ax.set_ylim(dh - 0.5, -0.5)
|
||||
self.set_title(title)
|
||||
self.redraw()
|
||||
|
||||
def _clear_poly_state(self):
|
||||
self._shapes = []
|
||||
self._current_poly = []
|
||||
self._mouse_pos = None
|
||||
for a in self._poly_artists:
|
||||
a.remove()
|
||||
self._poly_artists = []
|
||||
self._update_poly_buttons()
|
||||
|
||||
# ── frame / title ──────────────────────────────────────────────
|
||||
def set_frame(self, frame):
|
||||
self._current_frame = frame
|
||||
self.img_artist.set_data(self._apply_image_adjustments(frame))
|
||||
self.canvas.draw_idle()
|
||||
|
||||
# ── image adjustments ──────────────────────────────────────────
|
||||
def _apply_image_adjustments(self, frame):
|
||||
img = frame.astype(np.float32)
|
||||
img += self.brightness_slider.value()
|
||||
c = self.contrast_slider.value() / 100.0
|
||||
img = (1.0 + c) * (img - 128.0) + 128.0
|
||||
np.clip(img, 0, 255, out=img)
|
||||
g = self.gamma_slider.value() / 100.0
|
||||
img = (img / 255.0) ** (1.0 / g) * 255.0
|
||||
return np.clip(img, 0, 255).astype(np.uint8)
|
||||
|
||||
def _refresh_frame(self):
|
||||
if self._current_frame is not None:
|
||||
self.img_artist.set_data(self._apply_image_adjustments(self._current_frame))
|
||||
self.canvas.draw_idle()
|
||||
|
||||
def set_title(self, text: str):
|
||||
self.title_text.set_text(text)
|
||||
|
||||
# ── mask ops ───────────────────────────────────────────────────
|
||||
def reset(self, mask=None):
|
||||
self.mask = (
|
||||
mask if mask is not None else np.zeros((self.dh, self.dw), dtype=np.uint8)
|
||||
)
|
||||
self.history = []
|
||||
self.redo_stack = []
|
||||
self.redraw()
|
||||
|
||||
def set_mask(self, mask):
|
||||
"""Replace the mask and push the previous state onto the undo stack."""
|
||||
self.history.append(self.mask.copy())
|
||||
self.redo_stack.clear()
|
||||
self.mask = mask
|
||||
self.redraw()
|
||||
|
||||
def redraw(self):
|
||||
if self.mask_visible:
|
||||
alpha = self.alpha_slider.value() / 100.0
|
||||
rgba = np.zeros((self.dh, self.dw, 4))
|
||||
rgba[..., 1] = self.mask * 0.7
|
||||
rgba[..., 3] = self.mask * alpha
|
||||
else:
|
||||
rgba = np.zeros((self.dh, self.dw, 4))
|
||||
self.mask_artist.set_data(rgba)
|
||||
self.canvas.draw_idle()
|
||||
|
||||
def clear(self):
|
||||
self.mask[:] = 0
|
||||
self.redraw()
|
||||
|
||||
def undo(self):
|
||||
if self.history:
|
||||
self.redo_stack.append(self.mask.copy())
|
||||
self.mask = self.history.pop()
|
||||
self.redraw()
|
||||
|
||||
def undo10(self):
|
||||
for _ in range(10):
|
||||
if not self.history:
|
||||
break
|
||||
self.redo_stack.append(self.mask.copy())
|
||||
self.mask = self.history.pop()
|
||||
self.redraw()
|
||||
|
||||
def redo(self):
|
||||
if self.redo_stack:
|
||||
self.history.append(self.mask.copy())
|
||||
self.mask = self.redo_stack.pop()
|
||||
self.redraw()
|
||||
|
||||
def toggle_erase(self):
|
||||
self.erase_mode = not self.erase_mode
|
||||
if self.erase_mode:
|
||||
self.btn_erase.setText("Eraser ON")
|
||||
self.btn_erase.setStyleSheet("background-color: orange; color: black;")
|
||||
else:
|
||||
self.btn_erase.setText("Eraser")
|
||||
self.btn_erase.setStyleSheet("")
|
||||
|
||||
def toggle_mask(self):
|
||||
self.mask_visible = not self.mask_visible
|
||||
if self.mask_visible:
|
||||
self.btn_mask.setText("Hide Mask")
|
||||
self.btn_mask.setStyleSheet("")
|
||||
else:
|
||||
self.btn_mask.setText("Show Mask")
|
||||
self.btn_mask.setStyleSheet("background-color: red; color: white;")
|
||||
self.redraw()
|
||||
|
||||
def stamp(self, x, y):
|
||||
if x is None or y is None:
|
||||
return
|
||||
self.history.append(self.mask.copy())
|
||||
self.redo_stack.clear()
|
||||
r = self.brush_slider.value()
|
||||
ix, iy = int(x), int(y)
|
||||
y0, y1 = max(0, iy - r), min(self.dh, iy + r + 1)
|
||||
x0, x1 = max(0, ix - r), min(self.dw, ix + r + 1)
|
||||
Y, X = np.ogrid[y0:y1, x0:x1]
|
||||
circle = (X - ix) ** 2 + (Y - iy) ** 2 <= r**2
|
||||
self.mask[y0:y1, x0:x1][circle] = 0 if self.erase_mode else 1
|
||||
self.redraw()
|
||||
|
||||
# ── tool mode ──────────────────────────────────────────────────
|
||||
def set_tool_mode(self, mode: str):
|
||||
self.tool_mode = mode
|
||||
active = "background-color: #4488ff; color: white;"
|
||||
self.btn_brush.setStyleSheet(active if mode == "brush" else "")
|
||||
self.btn_polygon.setStyleSheet(active if mode == "polygon" else "")
|
||||
self.btn_fill.setStyleSheet(active if mode == "fill" else "")
|
||||
if mode != "brush":
|
||||
self.brush_circle.set_visible(False)
|
||||
self.canvas.draw_idle()
|
||||
|
||||
# ── polygon ops ────────────────────────────────────────────────
|
||||
def _near_first(self, x: float, y: float) -> bool:
|
||||
if not self._current_poly:
|
||||
return False
|
||||
fx, fy = self._current_poly[0]
|
||||
return (x - fx) ** 2 + (y - fy) ** 2 <= self._CLOSE_THRESHOLD**2
|
||||
|
||||
def _update_poly_buttons(self):
|
||||
has = bool(self._shapes)
|
||||
self.btn_fill.setEnabled(has)
|
||||
self.btn_del_shape.setEnabled(has)
|
||||
|
||||
def _draw_polygon_overlay(self, mouse_pos=None):
|
||||
for a in self._poly_artists:
|
||||
a.remove()
|
||||
self._poly_artists.clear()
|
||||
|
||||
# Completed shapes — thick closed outline
|
||||
for shape in self._shapes:
|
||||
xs = [p[0] for p in shape] + [shape[0][0]]
|
||||
ys = [p[1] for p in shape] + [shape[0][1]]
|
||||
(line,) = self.ax.plot(xs, ys, color="cyan", linewidth=3, zorder=5)
|
||||
(dots,) = self.ax.plot(
|
||||
[p[0] for p in shape],
|
||||
[p[1] for p in shape],
|
||||
"o",
|
||||
color="cyan",
|
||||
markersize=4,
|
||||
zorder=6,
|
||||
)
|
||||
self._poly_artists.extend([line, dots])
|
||||
|
||||
# In-progress polygon
|
||||
if self._current_poly:
|
||||
xs = [p[0] for p in self._current_poly]
|
||||
ys = [p[1] for p in self._current_poly]
|
||||
|
||||
if len(self._current_poly) > 1:
|
||||
(edge,) = self.ax.plot(xs, ys, color="yellow", linewidth=1.5, zorder=5)
|
||||
self._poly_artists.append(edge)
|
||||
|
||||
(verts,) = self.ax.plot(xs, ys, "o", color="yellow", markersize=5, zorder=6)
|
||||
# Red dot on first vertex as close-target indicator
|
||||
(first,) = self.ax.plot(
|
||||
[xs[0]], [ys[0]], "o", color="red", markersize=8, zorder=7
|
||||
)
|
||||
self._poly_artists.extend([verts, first])
|
||||
|
||||
# Rubber-band line from last vertex to cursor
|
||||
if mouse_pos:
|
||||
mx, my = mouse_pos
|
||||
near = len(self._current_poly) >= 3 and self._near_first(mx, my)
|
||||
clr = "lime" if near else "yellow"
|
||||
(rband,) = self.ax.plot(
|
||||
[xs[-1], mx], [ys[-1], my], "--", color=clr, linewidth=1, zorder=5
|
||||
)
|
||||
self._poly_artists.append(rband)
|
||||
|
||||
self.canvas.draw_idle()
|
||||
|
||||
def cancel_polygon(self):
|
||||
self._current_poly = []
|
||||
self._draw_polygon_overlay(mouse_pos=self._mouse_pos)
|
||||
|
||||
def delete_last_shape(self):
|
||||
if self._shapes:
|
||||
self._shapes.pop()
|
||||
self._update_poly_buttons()
|
||||
self._draw_polygon_overlay(mouse_pos=self._mouse_pos)
|
||||
|
||||
def _fill_shape_at(self, x: float, y: float):
|
||||
if not self._shapes:
|
||||
return
|
||||
|
||||
polys = [
|
||||
np.array(
|
||||
[(int(round(px)), int(round(py))) for px, py in shape], dtype=np.int32
|
||||
)
|
||||
for shape in self._shapes
|
||||
]
|
||||
|
||||
# Find all shapes that contain the click point
|
||||
containing = []
|
||||
for i, poly in enumerate(polys):
|
||||
poly_f32 = poly.reshape(-1, 1, 2).astype(np.float32)
|
||||
if cv2.pointPolygonTest(poly_f32, (x, y), False) >= 0:
|
||||
containing.append((i, poly))
|
||||
|
||||
if not containing:
|
||||
return # click was outside all shapes
|
||||
|
||||
# Pick the innermost (smallest area) shape that contains the click
|
||||
containing.sort(key=lambda t: cv2.contourArea(t[1]))
|
||||
target_idx, target_poly = containing[0]
|
||||
|
||||
self.history.append(self.mask.copy())
|
||||
self.redo_stack.clear()
|
||||
|
||||
temp = np.zeros((self.dh, self.dw), dtype=np.uint8)
|
||||
cv2.fillPoly(temp, [target_poly], 1)
|
||||
|
||||
# Punch holes for any shapes completely inside the target
|
||||
target_f32 = target_poly.reshape(-1, 1, 2).astype(np.float32)
|
||||
for i, poly in enumerate(polys):
|
||||
if i == target_idx:
|
||||
continue
|
||||
cx = float(np.mean(poly[:, 0]))
|
||||
cy = float(np.mean(poly[:, 1]))
|
||||
if cv2.pointPolygonTest(target_f32, (cx, cy), False) > 0:
|
||||
cv2.fillPoly(temp, [poly], 0)
|
||||
|
||||
self.mask |= temp
|
||||
self.redraw()
|
||||
|
||||
# ── brush preview ──────────────────────────────────────────────
|
||||
def _update_brush_preview(self, e):
|
||||
if e.inaxes == self.ax and e.xdata is not None:
|
||||
self.brush_circle.center = (e.xdata, e.ydata)
|
||||
self.brush_circle.set_radius(self.brush_slider.value())
|
||||
self.brush_circle.set_visible(True)
|
||||
else:
|
||||
self.brush_circle.set_visible(False)
|
||||
self.canvas.draw_idle()
|
||||
|
||||
def _on_axes_leave(self, _):
|
||||
self.brush_circle.set_visible(False)
|
||||
if self.tool_mode == "polygon":
|
||||
self._mouse_pos = None
|
||||
self._draw_polygon_overlay()
|
||||
else:
|
||||
self.canvas.draw_idle()
|
||||
|
||||
# ── mouse events ───────────────────────────────────────────────
|
||||
def _on_press(self, e):
|
||||
if e.xdata is None:
|
||||
return
|
||||
if self.tool_mode == "brush":
|
||||
self.drawing = True
|
||||
self.stamp(e.xdata, e.ydata)
|
||||
elif self.tool_mode == "polygon":
|
||||
self._handle_polygon_click(e)
|
||||
elif self.tool_mode == "fill" and e.button == 1:
|
||||
self._fill_shape_at(e.xdata, e.ydata)
|
||||
|
||||
def _handle_polygon_click(self, e):
|
||||
if e.button == 3: # right-click: remove last vertex
|
||||
if self._current_poly:
|
||||
self._current_poly.pop()
|
||||
self._draw_polygon_overlay(mouse_pos=self._mouse_pos)
|
||||
return
|
||||
if e.button != 1:
|
||||
return
|
||||
x, y = e.xdata, e.ydata
|
||||
if len(self._current_poly) >= 3 and self._near_first(x, y):
|
||||
self._shapes.append(list(self._current_poly))
|
||||
self._current_poly = []
|
||||
self._update_poly_buttons()
|
||||
self._draw_polygon_overlay(mouse_pos=self._mouse_pos)
|
||||
else:
|
||||
self._current_poly.append((x, y))
|
||||
self._draw_polygon_overlay(mouse_pos=self._mouse_pos)
|
||||
|
||||
def _on_move(self, e):
|
||||
if self.tool_mode == "brush":
|
||||
self._update_brush_preview(e)
|
||||
if self.drawing:
|
||||
self.stamp(e.xdata, e.ydata)
|
||||
elif self.tool_mode == "polygon":
|
||||
self.brush_circle.set_visible(False)
|
||||
if e.inaxes == self.ax and e.xdata is not None:
|
||||
self._mouse_pos = (e.xdata, e.ydata)
|
||||
self._draw_polygon_overlay(mouse_pos=self._mouse_pos)
|
||||
else:
|
||||
self._mouse_pos = None
|
||||
self._draw_polygon_overlay()
|
||||
|
||||
def _on_release(self, _):
|
||||
self.drawing = False
|
||||
56
src/clip_annotator/video_loader.py
Normal file
56
src/clip_annotator/video_loader.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
def load_frames(
|
||||
zip_path,
|
||||
max_frames: int,
|
||||
display_max: int,
|
||||
fps_fallback: int,
|
||||
video_in_zip: str = "left.mp4",
|
||||
video_tmp_suffix: str = ".mp4",
|
||||
fs=None,
|
||||
):
|
||||
if fs is None:
|
||||
video_bytes = zipfile.ZipFile(zip_path).read(video_in_zip)
|
||||
else:
|
||||
with fs.open(str(zip_path), "rb") as f:
|
||||
video_bytes = zipfile.ZipFile(io.BytesIO(f.read())).read(video_in_zip)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=video_tmp_suffix, delete=False) as f:
|
||||
f.write(video_bytes)
|
||||
tmp_path = f.name
|
||||
|
||||
cap = cv2.VideoCapture(tmp_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or fps_fallback
|
||||
|
||||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
step = max(1, total // max_frames)
|
||||
|
||||
frames = []
|
||||
i = 0
|
||||
while len(frames) < max_frames:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
break
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
i += step
|
||||
|
||||
cap.release()
|
||||
os.unlink(tmp_path)
|
||||
|
||||
if not frames:
|
||||
raise RuntimeError(f"No frames found in {zip_path}")
|
||||
|
||||
h, w = frames[0].shape[:2]
|
||||
scale = display_max / max(h, w)
|
||||
dh, dw = int(h * scale), int(w * scale)
|
||||
|
||||
frames = [cv2.resize(f, (dw, dh)) for f in frames]
|
||||
|
||||
return frames, fps, dh, dw, h, w
|
||||
Reference in New Issue
Block a user