diff --git a/tests/test_ssd.py b/tests/test_ssd.py index 3b8e38b..bccfc47 100644 --- a/tests/test_ssd.py +++ b/tests/test_ssd.py @@ -80,13 +80,13 @@ def test_enforce_temporal_consistency(): [[.5, .5, .75, .75 ], [.125, .5, .375, .75]], [[.125, .625, .375, .875], [.625, .5, .875, .75]]]) info = np.array( - [[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]], - [[1, 1, 0, 0, .2 ], [1, 1, 1, 1, .99]], - [[2, 1, 1, 1, .99], [2, 1, 1, 1, .99]], - [[3, 1, 1, 1, .99], [3, 1, 1, 1, .99]], - [[4, 1, 1, 1, .99], [4, 1, 1, 1, .99]]]) + [[[0, 1, 1, .99], [0, 1, 1, .99]], + [[1, 1, 0, .2 ], [1, 1, 1, .99]], + [[2, 1, 1, .99], [2, 1, 1, .99]], + [[3, 1, 1, .99], [3, 1, 1, .99]], + [[4, 1, 1, .99], [4, 1, 1, .99]]]) boxes_out, info_out = enforce_temporal_consistency( - boxes=boxes, info=info, inputs_shape=(5, 8, 8, 3)) + boxes=boxes, info=info, n_frames=5) np.testing.assert_equal( boxes_out, np.array( @@ -98,11 +98,11 @@ def test_enforce_temporal_consistency(): np.testing.assert_equal( info_out, np.array( - [[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]], - [[1, 1, 1, 1, .99], [1, 1, 0, 0, .2 ]], - [[2, 1, 1, 1, .99], [2, 1, 1, 1, .99]], - [[3, 1, 1, 1, .99], [3, 1, 1, 1, .99]], - [[4, 1, 1, 1, .99], [4, 1, 1, 1, .99]]])) + [[[0, 1, 1, .99], [0, 1, 1, .99]], + [[1, 1, 1, .99], [1, 1, 0, .2 ]], + [[2, 1, 1, .99], [2, 1, 1, .99]], + [[3, 1, 1, .99], [3, 1, 1, .99]], + [[4, 1, 1, .99], [4, 1, 1, .99]]])) def test_interpolate_unscanned_frames(): # Example with 2 moving faces, 3 time steps, no detection for face 1 in time step 2, faces swapped in time step 4 @@ -111,11 +111,11 @@ def test_interpolate_unscanned_frames(): [[.25, .5, .5, .75], [.125, .25, .375, .5 ]], [[.375, .5, .625, .75], [.125, .375, .375, .625]]]) info = np.array( - [[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]], - [[1, 1, 1, 1, .99], [1, 1, 0, 0, .2 ]], - [[2, 1, 1, 1, .99], [2, 1, 1, 1, .99]]]) + [[[0, 1, 1, .99], [0, 1, 1, .99]], + [[2, 1, 1, .99], [2, 1, 0, .2 ]], + [[4, 1, 1, .99], [4, 1, 1, .99]]]) boxes_out, info_out = interpolate_unscanned_frames( - boxes=boxes, info=info, scan_every=2, inputs_shape=(5, 8, 8, 3)) + boxes=boxes, info=info, n_frames=5) np.testing.assert_equal( boxes_out, np.array( @@ -127,11 +127,11 @@ def test_interpolate_unscanned_frames(): np.testing.assert_equal( info_out, np.array( - [[[0, 1, 1, 1, .99], [0, 1, 1, 1, .99]], - [[1, 0, 0, 1, 0 ], [1, 0, 0, 1, 0 ]], # Imperfection of the implementation - [[2, 1, 1, 1, .99], [2, 1, 0, 0, .2 ]], - [[3, 0, 0, 1, 0 ], [3, 0, 0, 0, 0 ]], - [[4, 1, 1, 1, .99], [4, 1, 1, 1, .99]]])) + [[[0, 1, 1, .99], [0, 1, 1, .99]], + [[1, 0, 0, 0 ], [1, 0, 0, 0 ]], # Imperfection of the implementation + [[2, 1, 1, .99], [2, 1, 0, .2 ]], + [[3, 0, 0, 0 ], [3, 0, 0, 0 ]], + [[4, 1, 1, .99], [4, 1, 1, .99]]])) @pytest.mark.parametrize("file", [True, False]) def test_FaceDetector(request, file): @@ -151,7 +151,7 @@ def test_FaceDetector(request, file): inputs_shape=test_video_ndarray.shape, fps=test_video_fps) assert boxes.shape == (360, 1, 4) - assert info.shape == (360, 1, 5) + assert info.shape == (360, 1, 4) np.testing.assert_allclose(boxes[0,0], [0.32223, 0.118318, 0.572684, 0.696835], atol=0.01) diff --git a/tests/test_utils.py b/tests/test_utils.py index 75e0aed..ea62b0c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -90,12 +90,12 @@ def test_probe_video_inputs_wrong_type(): def test_parse_video_inputs(request, file, roi, target_size, target_fps): if file: test_video_path = request.getfixturevalue('test_video_path') - parsed, fps_in, video_shape_in, ds_factor = parse_video_inputs( + parsed, fps_in, video_shape_in, ds_factor, idxs = parse_video_inputs( test_video_path, roi=roi, target_size=target_size, target_fps=target_fps) else: test_video_ndarray = request.getfixturevalue('test_video_ndarray') test_video_fps = request.getfixturevalue('test_video_fps') - parsed, fps_in, video_shape_in, ds_factor = parse_video_inputs( + parsed, fps_in, video_shape_in, ds_factor, idxs = parse_video_inputs( test_video_ndarray, fps=test_video_fps, roi=roi, target_size=target_size, target_fps=target_fps) assert parsed.shape == (360 if target_fps is None else 360 // 2, @@ -105,6 +105,7 @@ def test_parse_video_inputs(request, file, roi, target_size, target_fps): assert fps_in == 30 assert video_shape_in == (360, 480, 768, 3) assert ds_factor == 1 if target_fps is None else 2 + assert idxs == list(range(360)) if target_fps is None else list(range(0, 360, 2)) def test_parse_video_inputs_no_file(): with pytest.raises(Exception): diff --git a/vitallens/methods/simple_rppg_method.py b/vitallens/methods/simple_rppg_method.py index 43e2fea..d384d6c 100644 --- a/vitallens/methods/simple_rppg_method.py +++ b/vitallens/methods/simple_rppg_method.py @@ -76,7 +76,7 @@ def __call__( u_roi = merge_faces(faces) faces = faces - [u_roi[0], u_roi[1], u_roi[0], u_roi[1]] # Parse the inputs - frames_ds, fps, inputs_shape, ds_factor = parse_video_inputs( + frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs( video=frames, fps=fps, target_size=None, roi=u_roi, target_fps=override_fps_target if override_fps_target is not None else self.fps_target) assert inputs_shape[0] == faces.shape[0], "Need same number of frames as face detections" diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index d60ea1d..10d6732 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -79,7 +79,7 @@ def __call__( (faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))): logging.warn("Large face movement detected") # Parse the inputs - frames_ds, fps, inputs_shape, ds_factor = parse_video_inputs( + frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs( video=frames, fps=fps, target_size=self.input_size, roi=roi, target_fps=override_fps_target if override_fps_target is not None else self.fps_target, library='prpy', scale_algorithm='bilinear') diff --git a/vitallens/ssd.py b/vitallens/ssd.py index 302e55d..bea3d29 100644 --- a/vitallens/ssd.py +++ b/vitallens/ssd.py @@ -99,21 +99,21 @@ def nms( def enforce_temporal_consistency( boxes: np.ndarray, info: np.ndarray, - inputs_shape: tuple + n_frames: int ) -> Tuple[np.ndarray, np.ndarray]: """Enforce temporal consistency by sorting faces along second axis to minimize spatial distances. Args: boxes: Detected boxes in point form [0, 1], shape (n_frames, n_faces, 4) - info: Detection info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5) - inputs_shape: Shape of the inputs. + info: Detection info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4) + n_frames: Number of frames in the original input. Returns: boxes: Processed boxes in point form [0, 1], shape (n_frames, n_faces, 4) - info: Processed info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5) + info: Processed info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4) """ # Make sure that enough frames are present - if inputs_shape[0] == 1: - logging.warning("Ignoring enforce_consistency since n_frames=={}".format(inputs_shape[0])) + if n_frames == 1: + logging.warning("Ignoring enforce_consistency since n_frames=={}".format(n_frames)) return boxes, info # Determine the maximum number of detections in any frame max_det_faces = int(np.max(np.sum(info[...,2], axis=-1))) @@ -145,7 +145,7 @@ def distance_minimizing_idxs(boxes, info, max_det_faces): boxes = np.take_along_axis(boxes, idxs, axis=1) info = np.take_along_axis(info, idxs, axis=1) # Sort second axis by total confidence - order = np.argsort(np.sum(info[...,4], axis=0))[::-1][np.newaxis] + order = np.argsort(np.sum(info[...,3], axis=0))[::-1][np.newaxis] boxes = np.squeeze(np.take(boxes, order, axis=1), axis=1) info = np.squeeze(np.take(info, order, axis=1), axis=1) return boxes, info @@ -153,28 +153,24 @@ def distance_minimizing_idxs(boxes, info, max_det_faces): def interpolate_unscanned_frames( boxes: np.ndarray, info: np.ndarray, - scan_every: int, - inputs_shape: tuple + n_frames: int ) -> Tuple[np.ndarray, np.ndarray]: """Interpolate values for frames that were not scanned. Args: boxes: Detected boxes in point form [0, 1], shape (n_frames, n_faces, 4) info: Detection info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5) - scan_every: Scalar indicating that only every xth frame was scanned - inputs_shape: Shape of the inputs. + n_frames: Number of frames in the original input. Returns: boxes: Processed boxes in point form [0, 1], shape (orig_n_frames, n_faces, 4) - info: Processed info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (orig_n_frames, n_faces, 5) + info: Processed info: idx, scanned, scan_found_face, confidence. Shape (orig_n_frames, n_faces, 4) """ - # Transform info to original frame idxs - info[:,:,0] *= scan_every _, n_faces, _ = info.shape # Add rows corresponding to unscanned frames - add_idxs = list(set.difference(set(range(inputs_shape[0])), set(info[:,0,0].astype(np.int32).tolist()))) + add_idxs = list(set.difference(set(range(n_frames)), set(info[:,0,0].astype(np.int32).tolist()))) idxs = np.repeat(np.asarray(add_idxs)[:,np.newaxis], n_faces, axis=1)[...,np.newaxis] # Info - add_info = np.r_['2', idxs, np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32)] + add_info = np.r_['2', idxs, np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32), np.zeros_like(idxs, np.int32)] info = np.concatenate([info, add_info]) # Boxes add_boxes = np.full([len(add_idxs), n_faces, 4], np.nan) @@ -185,11 +181,6 @@ def interpolate_unscanned_frames( info = np.take(info, sort_idxs, axis=0) # Interpolation boxes = np.apply_along_axis(interpolate_vals, 0, boxes) - # Set interp_valid: 1 if interpolation happened between two valid dets, otherwise 0 - int_valid = np.reshape(np.tile(np.reshape(info[info[:,:,1]>0][:,2], (-1, n_faces)),(1, scan_every)),(-1, n_faces)) - if inputs_shape[0] - int_valid.shape[0] < 0: - int_valid = int_valid[:(inputs_shape[0]-int_valid.shape[0])] - info[:,:,3] = int_valid return boxes, info class FaceDetector: @@ -243,7 +234,8 @@ def __call__( results = [self.scan_batch(inputs=inputs, batch=i, n_batches=n_batches, start=int(s), end=int(s+l), fps=fps) for i, (s, l) in enumerate(offsets_lengths)] boxes = np.concatenate([r[0] for r in results], axis=0) classes = np.concatenate([r[1] for r in results], axis=0) - scan_every = results[0][2] + scan_idxs = np.concatenate([r[2] for r in results], axis=0) + scan_every = int(np.max(np.diff(scan_idxs))) n_frames_scan = boxes.shape[0] # Non-max suppression idxs, num_valid = nms(boxes=boxes, @@ -263,17 +255,16 @@ def __call__( idxs = np.repeat(np.arange(n_frames_scan, dtype=np.int32)[:,np.newaxis], max_valid, axis=1)[...,np.newaxis] scanned = np.ones((n_frames_scan, max_valid, 1), dtype=np.int32) scan_found_face = np.where(classes[...,1:2] < self.score_threshold, np.zeros([n_frames_scan, max_valid, 1], dtype=np.int32), scanned) - info = np.r_['2', idxs, scanned, scan_found_face, scan_found_face, classes[...,1:2]] + info = np.r_['2', scan_idxs[...,np.newaxis,np.newaxis], scanned, scan_found_face, classes[...,1:2]] # Enforce temporal consistency - boxes, info = enforce_temporal_consistency(boxes=boxes, info=info, inputs_shape=inputs_shape) + boxes, info = enforce_temporal_consistency(boxes=boxes, info=info, n_frames=n_frames) # Interpolate unscanned frames if necessary if scan_every > 1: # Set unsuccessful detections to nan nan = info[:,:,2] == 0 boxes[nan] = np.nan # Interpolate - boxes, info = interpolate_unscanned_frames( - boxes, info, scan_every, inputs_shape) + boxes, info = interpolate_unscanned_frames(boxes=boxes, info=info, n_frames=n_frames) # Return return boxes, info def scan_batch( @@ -298,11 +289,11 @@ def scan_batch( Returns: boxes: Scanned boxes in flat point form (n_frames, n_boxes, 4) classes: Detection scores for boxes (n_frames, n_boxes, 2) - scan_every: Scalar indicating that only every xth frame was scanned + idxs: Indices of the scanned frames from the original video """ logging.debug("Batch {}/{}...".format(batch, n_batches)) # Parse the inputs - inputs, fps, _, scan_every = parse_video_inputs( + inputs, fps, _, _, idxs = parse_video_inputs( video=inputs, fps=fps, target_size=INPUT_SIZE, target_fps=self.fs, library='prpy', scale_algorithm='bilinear', trim=(start, end)) # Forward pass @@ -311,5 +302,5 @@ def scan_batch( boxes = onnx_outputs[..., -4:] classes = onnx_outputs[..., 0:2] # Return - return boxes, classes, scan_every + return boxes, classes, idxs \ No newline at end of file diff --git a/vitallens/utils.py b/vitallens/utils.py index d387292..99699fe 100644 --- a/vitallens/utils.py +++ b/vitallens/utils.py @@ -111,6 +111,7 @@ def parse_video_inputs( fps_in: Frame rate of original inputs shape_in: Shape of original inputs in form (n, h, w, c) ds_factor: Temporal downsampling factor applied + idxs: The frame indices returned from original video """ # Check if input is array or file name if isinstance(video, str): @@ -125,7 +126,10 @@ def parse_video_inputs( video, ds_factor = read_video_from_path( path=video, target_fps=target_fps, crop=roi, scale=target_size, trim=trim, pix_fmt='rgb24', dim_deltas=(1,1,1), scale_algorithm=scale_algorithm) - return video, fps, (n, h, w, 3), ds_factor + start_idx = max(0, trim[0]) if trim is not None else 0 + end_idx = min(n, trim[1]) if trim is not None else n + idxs = list(range(start_idx, end_idx, ds_factor)) + return video, fps, (n, h, w, 3), ds_factor, idxs except Exception as e: raise ValueError("Problem reading video from {}: {}".format(video, e)) else: @@ -146,7 +150,8 @@ def parse_video_inputs( video = crop_slice_resize( inputs=video, target_size=target_size, roi=roi, target_idxs=target_idxs, preserve_aspect_ratio=False, library=library, scale_algorithm=scale_algorithm) - return video, fps, video_shape_in, ds_factor + if target_idxs is None: target_idxs = list(range(video_shape_in[0])) + return video, fps, video_shape_in, ds_factor, target_idxs else: raise ValueError("Invalid video {}, type {}".format(video, type(video)))