Made project river-agnostic
This commit is contained in:
49
src/clip_annotator/compute_optical_flow.py
Normal file
49
src/clip_annotator/compute_optical_flow.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_optical_flow_mask(
|
||||
frames: list[np.ndarray],
|
||||
fps: float,
|
||||
norm_squared_threshold: float,
|
||||
gaussian_kernel: tuple[int, int],
|
||||
brightness_range: tuple[int, int],
|
||||
) -> np.ndarray:
|
||||
"""Return a binary mask (uint8, values 0/1) from optical flow + brightness."""
|
||||
if len(frames) < 2:
|
||||
return np.zeros(frames[0].shape[:2], dtype=np.uint8)
|
||||
|
||||
frames_arr = np.stack(frames).astype(np.float64)
|
||||
frames_sub_mean = frames_arr - np.mean(frames_arr, axis=0)
|
||||
mn, mx = frames_sub_mean.min(), frames_sub_mean.max()
|
||||
if mx > mn:
|
||||
standardized = ((frames_sub_mean - mn) / (mx - mn) * 255).astype(np.uint8)
|
||||
else:
|
||||
standardized = np.zeros_like(frames_arr, dtype=np.uint8)
|
||||
|
||||
N = len(standardized)
|
||||
gray = np.stack([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in standardized])
|
||||
|
||||
flow_data = np.zeros((N - 1,) + gray.shape[1:] + (2,))
|
||||
for i in range(N - 1):
|
||||
flow_data[i] = fps * cv2.optflow.calcOpticalFlowSparseToDense(
|
||||
gray[i], gray[i + 1]
|
||||
)
|
||||
|
||||
optical_flow = np.median(flow_data, axis=0)
|
||||
|
||||
flow_norm_sq = np.sum(optical_flow**2, axis=-1)
|
||||
max_norm = np.max(flow_norm_sq)
|
||||
if max_norm > 0:
|
||||
flow_mask = flow_norm_sq >= max_norm * norm_squared_threshold**2
|
||||
else:
|
||||
flow_mask = np.zeros(flow_norm_sq.shape, dtype=bool)
|
||||
|
||||
reference_frame = frames[len(frames) // 2]
|
||||
smoothed = cv2.GaussianBlur(reference_frame, gaussian_kernel, 0)
|
||||
gray_ref = cv2.cvtColor(smoothed, cv2.COLOR_RGB2GRAY)
|
||||
brightness_mask = (gray_ref > brightness_range[0]) & (
|
||||
gray_ref < brightness_range[1]
|
||||
)
|
||||
|
||||
return np.logical_and(brightness_mask, flow_mask).astype(np.uint8)
|
||||
Reference in New Issue
Block a user