Skip to content

Commit

Permalink
Fix issue introduced by batching video for face detection
Browse files Browse the repository at this point in the history
  • Loading branch information
prouast committed Jul 20, 2024
1 parent b79b4bc commit 62de2df
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 56 deletions.
42 changes: 21 additions & 21 deletions tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def test_enforce_temporal_consistency():
[[.5, .5, .75, .75 ], [.125, .5, .375, .75]],
[[.125, .625, .375, .875], [.625, .5, .875, .75]]])
info = np.array(
[[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]],
[[1, 1, 0, 0, .2 ], [1, 1, 1, 1, .99]],
[[2, 1, 1, 1, .99], [2, 1, 1, 1, .99]],
[[3, 1, 1, 1, .99], [3, 1, 1, 1, .99]],
[[4, 1, 1, 1, .99], [4, 1, 1, 1, .99]]])
[[[0, 1, 1, .99], [0, 1, 1, .99]],
[[1, 1, 0, .2 ], [1, 1, 1, .99]],
[[2, 1, 1, .99], [2, 1, 1, .99]],
[[3, 1, 1, .99], [3, 1, 1, .99]],
[[4, 1, 1, .99], [4, 1, 1, .99]]])
boxes_out, info_out = enforce_temporal_consistency(
boxes=boxes, info=info, inputs_shape=(5, 8, 8, 3))
boxes=boxes, info=info, n_frames=5)
np.testing.assert_equal(
boxes_out,
np.array(
Expand All @@ -98,11 +98,11 @@ def test_enforce_temporal_consistency():
np.testing.assert_equal(
info_out,
np.array(
[[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]],
[[1, 1, 1, 1, .99], [1, 1, 0, 0, .2 ]],
[[2, 1, 1, 1, .99], [2, 1, 1, 1, .99]],
[[3, 1, 1, 1, .99], [3, 1, 1, 1, .99]],
[[4, 1, 1, 1, .99], [4, 1, 1, 1, .99]]]))
[[[0, 1, 1, .99], [0, 1, 1, .99]],
[[1, 1, 1, .99], [1, 1, 0, .2 ]],
[[2, 1, 1, .99], [2, 1, 1, .99]],
[[3, 1, 1, .99], [3, 1, 1, .99]],
[[4, 1, 1, .99], [4, 1, 1, .99]]]))

def test_interpolate_unscanned_frames():
# Example with 2 moving faces, 3 time steps, no detection for face 1 in time step 2, faces swapped in time step 4
Expand All @@ -111,11 +111,11 @@ def test_interpolate_unscanned_frames():
[[.25, .5, .5, .75], [.125, .25, .375, .5 ]],
[[.375, .5, .625, .75], [.125, .375, .375, .625]]])
info = np.array(
[[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]],
[[1, 1, 1, 1, .99], [1, 1, 0, 0, .2 ]],
[[2, 1, 1, 1, .99], [2, 1, 1, 1, .99]]])
[[[0, 1, 1, .99], [0, 1, 1, .99]],
[[2, 1, 1, .99], [2, 1, 0, .2 ]],
[[4, 1, 1, .99], [4, 1, 1, .99]]])
boxes_out, info_out = interpolate_unscanned_frames(
boxes=boxes, info=info, scan_every=2, inputs_shape=(5, 8, 8, 3))
boxes=boxes, info=info, n_frames=5)
np.testing.assert_equal(
boxes_out,
np.array(
Expand All @@ -127,11 +127,11 @@ def test_interpolate_unscanned_frames():
np.testing.assert_equal(
info_out,
np.array(
[[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]],
[[1, 0, 0, 1, 0 ], [1, 0, 0, 1, 0 ]], # Imperfection of the implementation
[[2, 1, 1, 1, .99], [2, 1, 0, 0, .2 ]],
[[3, 0, 0, 1, 0 ], [3, 0, 0, 0, 0 ]],
[[4, 1, 1, 1, .99], [4, 1, 1, 1, .99]]]))
[[[0, 1, 1, .99], [0, 1, 1, .99]],
[[1, 0, 0, 0 ], [1, 0, 0, 0 ]], # Imperfection of the implementation
[[2, 1, 1, .99], [2, 1, 0, .2 ]],
[[3, 0, 0, 0 ], [3, 0, 0, 0 ]],
[[4, 1, 1, .99], [4, 1, 1, .99]]]))

@pytest.mark.parametrize("file", [True, False])
def test_FaceDetector(request, file):
Expand All @@ -151,7 +151,7 @@ def test_FaceDetector(request, file):
inputs_shape=test_video_ndarray.shape,
fps=test_video_fps)
assert boxes.shape == (360, 1, 4)
assert info.shape == (360, 1, 5)
assert info.shape == (360, 1, 4)
np.testing.assert_allclose(boxes[0,0],
[0.32223, 0.118318, 0.572684, 0.696835],
atol=0.01)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def test_probe_video_inputs_wrong_type():
def test_parse_video_inputs(request, file, roi, target_size, target_fps):
if file:
test_video_path = request.getfixturevalue('test_video_path')
parsed, fps_in, video_shape_in, ds_factor = parse_video_inputs(
parsed, fps_in, video_shape_in, ds_factor, idxs = parse_video_inputs(
test_video_path, roi=roi, target_size=target_size, target_fps=target_fps)
else:
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
parsed, fps_in, video_shape_in, ds_factor = parse_video_inputs(
parsed, fps_in, video_shape_in, ds_factor, idxs = parse_video_inputs(
test_video_ndarray, fps=test_video_fps, roi=roi, target_size=target_size,
target_fps=target_fps)
assert parsed.shape == (360 if target_fps is None else 360 // 2,
Expand All @@ -105,6 +105,7 @@ def test_parse_video_inputs(request, file, roi, target_size, target_fps):
assert fps_in == 30
assert video_shape_in == (360, 480, 768, 3)
assert ds_factor == 1 if target_fps is None else 2
assert idxs == list(range(360)) if target_fps is None else list(range(0, 360, 2))

def test_parse_video_inputs_no_file():
with pytest.raises(Exception):
Expand Down
2 changes: 1 addition & 1 deletion vitallens/methods/simple_rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __call__(
u_roi = merge_faces(faces)
faces = faces - [u_roi[0], u_roi[1], u_roi[0], u_roi[1]]
# Parse the inputs
frames_ds, fps, inputs_shape, ds_factor = parse_video_inputs(
frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs(
video=frames, fps=fps, target_size=None, roi=u_roi,
target_fps=override_fps_target if override_fps_target is not None else self.fps_target)
assert inputs_shape[0] == faces.shape[0], "Need same number of frames as face detections"
Expand Down
2 changes: 1 addition & 1 deletion vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __call__(
(faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))):
logging.warn("Large face movement detected")
# Parse the inputs
frames_ds, fps, inputs_shape, ds_factor = parse_video_inputs(
frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs(
video=frames, fps=fps, target_size=self.input_size, roi=roi,
target_fps=override_fps_target if override_fps_target is not None else self.fps_target,
library='prpy', scale_algorithm='bilinear')
Expand Down
49 changes: 20 additions & 29 deletions vitallens/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ def nms(
def enforce_temporal_consistency(
boxes: np.ndarray,
info: np.ndarray,
inputs_shape: tuple
n_frames: int
) -> Tuple[np.ndarray, np.ndarray]:
"""Enforce temporal consistency by sorting faces along second axis to minimize spatial distances.
Args:
boxes: Detected boxes in point form [0, 1], shape (n_frames, n_faces, 4)
info: Detection info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5)
inputs_shape: Shape of the inputs.
info: Detection info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4)
n_frames: Number of frames in the original input.
Returns:
boxes: Processed boxes in point form [0, 1], shape (n_frames, n_faces, 4)
info: Processed info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5)
info: Processed info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4)
"""
# Make sure that enough frames are present
if inputs_shape[0] == 1:
logging.warning("Ignoring enforce_consistency since n_frames=={}".format(inputs_shape[0]))
if n_frames == 1:
logging.warning("Ignoring enforce_consistency since n_frames=={}".format(n_frames))
return boxes, info
# Determine the maximum number of detections in any frame
max_det_faces = int(np.max(np.sum(info[...,2], axis=-1)))
Expand Down Expand Up @@ -145,36 +145,32 @@ def distance_minimizing_idxs(boxes, info, max_det_faces):
boxes = np.take_along_axis(boxes, idxs, axis=1)
info = np.take_along_axis(info, idxs, axis=1)
# Sort second axis by total confidence
order = np.argsort(np.sum(info[...,4], axis=0))[::-1][np.newaxis]
order = np.argsort(np.sum(info[...,3], axis=0))[::-1][np.newaxis]
boxes = np.squeeze(np.take(boxes, order, axis=1), axis=1)
info = np.squeeze(np.take(info, order, axis=1), axis=1)
return boxes, info

def interpolate_unscanned_frames(
boxes: np.ndarray,
info: np.ndarray,
scan_every: int,
inputs_shape: tuple
n_frames: int
) -> Tuple[np.ndarray, np.ndarray]:
"""Interpolate values for frames that were not scanned.
Args:
boxes: Detected boxes in point form [0, 1], shape (n_frames, n_faces, 4)
info: Detection info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5)
scan_every: Scalar indicating that only every xth frame was scanned
inputs_shape: Shape of the inputs.
n_frames: Number of frames in the original input.
Returns:
boxes: Processed boxes in point form [0, 1], shape (orig_n_frames, n_faces, 4)
info: Processed info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (orig_n_frames, n_faces, 5)
info: Processed info: idx, scanned, scan_found_face, confidence. Shape (orig_n_frames, n_faces, 4)
"""
# Transform info to original frame idxs
info[:,:,0] *= scan_every
_, n_faces, _ = info.shape
# Add rows corresponding to unscanned frames
add_idxs = list(set.difference(set(range(inputs_shape[0])), set(info[:,0,0].astype(np.int32).tolist())))
add_idxs = list(set.difference(set(range(n_frames)), set(info[:,0,0].astype(np.int32).tolist())))
idxs = np.repeat(np.asarray(add_idxs)[:,np.newaxis], n_faces, axis=1)[...,np.newaxis]
# Info
add_info = np.r_['2', idxs, np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32)]
add_info = np.r_['2', idxs, np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32)]
info = np.concatenate([info, add_info])
# Boxes
add_boxes = np.full([len(add_idxs), n_faces, 4], np.nan)
Expand All @@ -185,11 +181,6 @@ def interpolate_unscanned_frames(
info = np.take(info, sort_idxs, axis=0)
# Interpolation
boxes = np.apply_along_axis(interpolate_vals, 0, boxes)
# Set interp_valid: 1 if interpolation happened between two valid dets, otherwise 0
int_valid = np.reshape(np.tile(np.reshape(info[info[:,:,1]>0][:,2], (-1, n_faces)),(1, scan_every)),(-1, n_faces))
if inputs_shape[0] - int_valid.shape[0] < 0:
int_valid = int_valid[:(inputs_shape[0]-int_valid.shape[0])]
info[:,:,3] = int_valid
return boxes, info

class FaceDetector:
Expand Down Expand Up @@ -243,7 +234,8 @@ def __call__(
results = [self.scan_batch(inputs=inputs, batch=i, n_batches=n_batches, start=int(s), end=int(s+l), fps=fps) for i, (s, l) in enumerate(offsets_lengths)]
boxes = np.concatenate([r[0] for r in results], axis=0)
classes = np.concatenate([r[1] for r in results], axis=0)
scan_every = results[0][2]
scan_idxs = np.concatenate([r[2] for r in results], axis=0)
scan_every = int(np.max(np.diff(scan_idxs)))
n_frames_scan = boxes.shape[0]
# Non-max suppression
idxs, num_valid = nms(boxes=boxes,
Expand All @@ -263,17 +255,16 @@ def __call__(
idxs = np.repeat(np.arange(n_frames_scan, dtype=np.int32)[:,np.newaxis], max_valid, axis=1)[...,np.newaxis]
scanned = np.ones((n_frames_scan, max_valid, 1), dtype=np.int32)
scan_found_face = np.where(classes[...,1:2] < self.score_threshold, np.zeros([n_frames_scan, max_valid, 1], dtype=np.int32), scanned)
info = np.r_['2', idxs, scanned, scan_found_face, scan_found_face, classes[...,1:2]]
info = np.r_['2', scan_idxs[...,np.newaxis,np.newaxis], scanned, scan_found_face, classes[...,1:2]]
# Enforce temporal consistency
boxes, info = enforce_temporal_consistency(boxes=boxes, info=info, inputs_shape=inputs_shape)
boxes, info = enforce_temporal_consistency(boxes=boxes, info=info, n_frames=n_frames)
# Interpolate unscanned frames if necessary
if scan_every > 1:
# Set unsuccessful detections to nan
nan = info[:,:,2] == 0
boxes[nan] = np.nan
# Interpolate
boxes, info = interpolate_unscanned_frames(
boxes, info, scan_every, inputs_shape)
boxes, info = interpolate_unscanned_frames(boxes=boxes, info=info, n_frames=n_frames)
# Return
return boxes, info
def scan_batch(
Expand All @@ -298,11 +289,11 @@ def scan_batch(
Returns:
boxes: Scanned boxes in flat point form (n_frames, n_boxes, 4)
classes: Detection scores for boxes (n_frames, n_boxes, 2)
scan_every: Scalar indicating that only every xth frame was scanned
idxs: Indices of the scanned frames from the original video
"""
logging.debug("Batch {}/{}...".format(batch, n_batches))
# Parse the inputs
inputs, fps, _, scan_every = parse_video_inputs(
inputs, fps, _, _, idxs = parse_video_inputs(
video=inputs, fps=fps, target_size=INPUT_SIZE, target_fps=self.fs,
library='prpy', scale_algorithm='bilinear', trim=(start, end))
# Forward pass
Expand All @@ -311,5 +302,5 @@ def scan_batch(
boxes = onnx_outputs[..., -4:]
classes = onnx_outputs[..., 0:2]
# Return
return boxes, classes, scan_every
return boxes, classes, idxs

9 changes: 7 additions & 2 deletions vitallens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def parse_video_inputs(
fps_in: Frame rate of original inputs
shape_in: Shape of original inputs in form (n, h, w, c)
ds_factor: Temporal downsampling factor applied
idxs: The frame indices returned from original video
"""
# Check if input is array or file name
if isinstance(video, str):
Expand All @@ -125,7 +126,10 @@ def parse_video_inputs(
video, ds_factor = read_video_from_path(
path=video, target_fps=target_fps, crop=roi, scale=target_size, trim=trim,
pix_fmt='rgb24', dim_deltas=(1,1,1), scale_algorithm=scale_algorithm)
return video, fps, (n, h, w, 3), ds_factor
start_idx = max(0, trim[0]) if trim is not None else 0
end_idx = min(n, trim[1]) if trim is not None else n
idxs = list(range(start_idx, end_idx, ds_factor))
return video, fps, (n, h, w, 3), ds_factor, idxs
except Exception as e:
raise ValueError("Problem reading video from {}: {}".format(video, e))
else:
Expand All @@ -146,7 +150,8 @@ def parse_video_inputs(
video = crop_slice_resize(
inputs=video, target_size=target_size, roi=roi, target_idxs=target_idxs,
preserve_aspect_ratio=False, library=library, scale_algorithm=scale_algorithm)
return video, fps, video_shape_in, ds_factor
if target_idxs is None: target_idxs = list(range(video_shape_in[0]))
return video, fps, video_shape_in, ds_factor, target_idxs
else:
raise ValueError("Invalid video {}, type {}".format(video, type(video)))

Expand Down

0 comments on commit 62de2df

Please sign in to comment.