269 lines
9.3 KiB
Python
269 lines
9.3 KiB
Python
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
from PySide6.QtCore import QTimer
|
||
|
|
from PySide6.QtWidgets import (
|
||
|
|
QButtonGroup,
|
||
|
|
QGroupBox,
|
||
|
|
QHBoxLayout,
|
||
|
|
QLabel,
|
||
|
|
QMainWindow,
|
||
|
|
QPushButton,
|
||
|
|
QRadioButton,
|
||
|
|
QVBoxLayout,
|
||
|
|
QWidget,
|
||
|
|
)
|
||
|
|
|
||
|
|
from .clip_selector import ClipSelector
|
||
|
|
from .config import DEFAULTS, QUESTIONS, Config
|
||
|
|
from .mask_canvas import MaskCanvas
|
||
|
|
from .video_loader import load_frames
|
||
|
|
|
||
|
|
|
||
|
|
class Annotator(QMainWindow):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
data_dir: Path,
|
||
|
|
out_dir: Path,
|
||
|
|
clip: str = None,
|
||
|
|
target_time: str = None,
|
||
|
|
daily: bool = False,
|
||
|
|
extras: bool = False,
|
||
|
|
skip_existing_day: bool = False,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
self.out_dir = Path(out_dir)
|
||
|
|
self.extras = extras
|
||
|
|
|
||
|
|
self.selector = ClipSelector(
|
||
|
|
data_dir=Path(data_dir),
|
||
|
|
out_dir=self.out_dir,
|
||
|
|
target_time=target_time,
|
||
|
|
daily=daily,
|
||
|
|
skip_existing_day=skip_existing_day,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.setWindowTitle("River Annotator")
|
||
|
|
self._load_clip(specific=clip)
|
||
|
|
self._init_ui()
|
||
|
|
self._init_timer()
|
||
|
|
|
||
|
|
# ── clip loading ───────────────────────────────────────────────
|
||
|
|
def _load_clip(self, specific: str = None, next_day: bool = False):
|
||
|
|
self.filename = self.selector.next(specific=specific, next_day=next_day)
|
||
|
|
self.frames, self.fps, self.dh, self.dw, self.h, self.w = load_frames(
|
||
|
|
self.filename, Config.MAX_FRAMES
|
||
|
|
)
|
||
|
|
self._pending_answers = self._read_saved_answers()
|
||
|
|
|
||
|
|
def _read_saved_mask(self):
|
||
|
|
mask_path = self.out_dir / self.filename.stem / "mask.png"
|
||
|
|
if not mask_path.exists():
|
||
|
|
return None
|
||
|
|
mask_full = np.array(Image.open(mask_path).convert("L"))
|
||
|
|
return cv2.resize(
|
||
|
|
(mask_full > 127).astype(np.uint8),
|
||
|
|
(self.dw, self.dh),
|
||
|
|
interpolation=cv2.INTER_NEAREST,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _read_saved_answers(self):
|
||
|
|
meta_path = self.out_dir / self.filename.stem / "metadata.json"
|
||
|
|
if not meta_path.exists():
|
||
|
|
return None
|
||
|
|
with open(meta_path) as f:
|
||
|
|
return json.load(f)
|
||
|
|
|
||
|
|
# ── UI setup ───────────────────────────────────────────────────
|
||
|
|
def _init_ui(self):
|
||
|
|
self.mc = MaskCanvas(self.frames, self.dh, self.dw)
|
||
|
|
self.mc.set_title(self.filename.name)
|
||
|
|
self.mc.reset(self._read_saved_mask())
|
||
|
|
|
||
|
|
self.q_widgets = {}
|
||
|
|
question_panel = self._build_question_panel()
|
||
|
|
|
||
|
|
btn_save = QPushButton("Save")
|
||
|
|
btn_next = QPushButton("Next")
|
||
|
|
btn_skip = QPushButton("Skip")
|
||
|
|
btn_clear = QPushButton("Clear")
|
||
|
|
btn_undo = QPushButton("Undo")
|
||
|
|
btn_reload = QPushButton("Reload Saved")
|
||
|
|
|
||
|
|
row1 = QHBoxLayout()
|
||
|
|
for b in [btn_save, btn_next, btn_skip]:
|
||
|
|
row1.addWidget(b)
|
||
|
|
|
||
|
|
row2 = QHBoxLayout()
|
||
|
|
for b in [btn_clear, self.mc.btn_erase, btn_undo, btn_reload]:
|
||
|
|
row2.addWidget(b)
|
||
|
|
row2.addWidget(QLabel("Brush"))
|
||
|
|
row2.addWidget(self.mc.brush_slider)
|
||
|
|
|
||
|
|
left = QVBoxLayout()
|
||
|
|
left.addWidget(self.mc.canvas)
|
||
|
|
left.addLayout(row1)
|
||
|
|
left.addLayout(row2)
|
||
|
|
|
||
|
|
left_widget = QWidget()
|
||
|
|
left_widget.setLayout(left)
|
||
|
|
right_widget = QWidget()
|
||
|
|
right_widget.setLayout(question_panel)
|
||
|
|
|
||
|
|
main = QHBoxLayout()
|
||
|
|
main.addWidget(left_widget, 3)
|
||
|
|
main.addWidget(right_widget, 2)
|
||
|
|
|
||
|
|
container = QWidget()
|
||
|
|
container.setLayout(main)
|
||
|
|
self.setCentralWidget(container)
|
||
|
|
|
||
|
|
btn_save.clicked.connect(self.save)
|
||
|
|
btn_next.clicked.connect(self.next_clip)
|
||
|
|
btn_skip.clicked.connect(self.skip_clip)
|
||
|
|
btn_clear.clicked.connect(self.mc.clear)
|
||
|
|
btn_undo.clicked.connect(self.mc.undo)
|
||
|
|
btn_reload.clicked.connect(self.reload_saved)
|
||
|
|
|
||
|
|
if self._pending_answers:
|
||
|
|
self._set_answers(self._pending_answers)
|
||
|
|
self._pending_answers = None
|
||
|
|
|
||
|
|
def _build_question_panel(self) -> QVBoxLayout:
|
||
|
|
vbox = QVBoxLayout()
|
||
|
|
for section, qs in QUESTIONS:
|
||
|
|
group = QGroupBox(section)
|
||
|
|
gvbox = QVBoxLayout()
|
||
|
|
for key, label, options in qs:
|
||
|
|
gvbox.addWidget(QLabel(label))
|
||
|
|
btn_group = QButtonGroup(self)
|
||
|
|
row = QHBoxLayout()
|
||
|
|
buttons = []
|
||
|
|
default_value = DEFAULTS.get(key)
|
||
|
|
for opt in options:
|
||
|
|
btn = QRadioButton(opt)
|
||
|
|
btn_group.addButton(btn)
|
||
|
|
row.addWidget(btn)
|
||
|
|
buttons.append(btn)
|
||
|
|
if default_value == opt:
|
||
|
|
btn.setChecked(True)
|
||
|
|
if default_value is None and buttons:
|
||
|
|
buttons[-1].setChecked(True)
|
||
|
|
self.q_widgets[key] = (btn_group, buttons, options)
|
||
|
|
gvbox.addLayout(row)
|
||
|
|
group.setLayout(gvbox)
|
||
|
|
vbox.addWidget(group)
|
||
|
|
return vbox
|
||
|
|
|
||
|
|
def _set_answers(self, answers: dict):
|
||
|
|
for key, value in answers.items():
|
||
|
|
if key not in self.q_widgets:
|
||
|
|
continue
|
||
|
|
_, buttons, options = self.q_widgets[key]
|
||
|
|
for i, btn in enumerate(buttons):
|
||
|
|
btn.setChecked(options[i] == value)
|
||
|
|
|
||
|
|
def _init_timer(self):
|
||
|
|
self.frame_i = 0
|
||
|
|
self.timer = QTimer()
|
||
|
|
self.timer.timeout.connect(self._tick)
|
||
|
|
self.timer.start(int(1000 / self.fps))
|
||
|
|
|
||
|
|
def _tick(self):
|
||
|
|
self.frame_i = (self.frame_i + 1) % len(self.frames)
|
||
|
|
self.mc.set_frame(self.frames[self.frame_i])
|
||
|
|
|
||
|
|
# ── answers ────────────────────────────────────────────────────
|
||
|
|
def get_answers(self) -> dict:
|
||
|
|
out = {}
|
||
|
|
for key, (_, buttons, options) in self.q_widgets.items():
|
||
|
|
for i, btn in enumerate(buttons):
|
||
|
|
if btn.isChecked():
|
||
|
|
out[key] = options[i]
|
||
|
|
return out
|
||
|
|
|
||
|
|
# ── save helpers ───────────────────────────────────────────────
|
||
|
|
def _make_overlay(self, frame, alpha=0.4):
|
||
|
|
overlay = frame.copy()
|
||
|
|
green = np.zeros_like(frame)
|
||
|
|
green[..., 1] = 255
|
||
|
|
m = self.mc.mask.astype(bool)
|
||
|
|
overlay[m] = (1 - alpha) * overlay[m] + alpha * green[m]
|
||
|
|
return overlay.astype(np.uint8)
|
||
|
|
|
||
|
|
def _save_gif(self, frames, out_path, scale=1.0):
|
||
|
|
h, w = frames[0].shape[:2]
|
||
|
|
nh, nw = max(1, int(h * scale)), max(1, int(w * scale))
|
||
|
|
pil_frames = [Image.fromarray(cv2.resize(f, (nw, nh))) for f in frames]
|
||
|
|
pil_frames[0].save(
|
||
|
|
out_path,
|
||
|
|
save_all=True,
|
||
|
|
append_images=pil_frames[1:],
|
||
|
|
duration=int(1000 / self.fps),
|
||
|
|
loop=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
# ── actions ────────────────────────────────────────────────────
|
||
|
|
def save(self):
|
||
|
|
out = self.out_dir / self.filename.stem
|
||
|
|
out.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
mask_full = cv2.resize(
|
||
|
|
self.mc.mask.astype(np.uint8),
|
||
|
|
(self.w, self.h),
|
||
|
|
interpolation=cv2.INTER_NEAREST,
|
||
|
|
)
|
||
|
|
Image.fromarray(mask_full * 255).save(out / "mask.png")
|
||
|
|
|
||
|
|
with open(out / "metadata.json", "w") as f:
|
||
|
|
json.dump(self.get_answers(), f, indent=2)
|
||
|
|
|
||
|
|
mid = len(self.frames) // 2
|
||
|
|
frame = self.frames[mid]
|
||
|
|
Image.fromarray(frame).save(out / "frame.png")
|
||
|
|
Image.fromarray(self._make_overlay(frame)).save(out / "overlay.png")
|
||
|
|
|
||
|
|
if self.extras:
|
||
|
|
Image.fromarray((self.mc.mask * 255).astype(np.uint8)).save(out / "mask_vis.png")
|
||
|
|
overlay_frames = [self._make_overlay(f) for f in self.frames]
|
||
|
|
self._save_gif(self.frames, out / "video_original_hires.gif", scale=1.0)
|
||
|
|
self._save_gif(self.frames, out / "video_original_lowres.gif", scale=0.5)
|
||
|
|
self._save_gif(overlay_frames, out / "video_overlay_hires.gif", scale=1.0)
|
||
|
|
self._save_gif(overlay_frames, out / "video_overlay_lowres.gif", scale=0.5)
|
||
|
|
|
||
|
|
print("Saved:", out)
|
||
|
|
|
||
|
|
def reload_saved(self):
|
||
|
|
mask = self._read_saved_mask()
|
||
|
|
if mask is None:
|
||
|
|
return
|
||
|
|
self.mc.reset(mask)
|
||
|
|
answers = self._read_saved_answers()
|
||
|
|
if answers:
|
||
|
|
self._set_answers(answers)
|
||
|
|
|
||
|
|
def _advance_clip(self, next_day: bool):
|
||
|
|
self._load_clip(next_day=next_day)
|
||
|
|
self.frame_i = 0
|
||
|
|
self.mc.load_clip(
|
||
|
|
self.frames,
|
||
|
|
self.dh,
|
||
|
|
self.dw,
|
||
|
|
mask=self._read_saved_mask(),
|
||
|
|
title=self.filename.name,
|
||
|
|
)
|
||
|
|
if self._pending_answers:
|
||
|
|
self._set_answers(self._pending_answers)
|
||
|
|
self._pending_answers = None
|
||
|
|
|
||
|
|
def next_clip(self):
|
||
|
|
self.save()
|
||
|
|
self._advance_clip(next_day=self.selector.daily)
|
||
|
|
|
||
|
|
def skip_clip(self):
|
||
|
|
self._advance_clip(next_day=self.selector.daily)
|