Skip to content

Commit

Permalink
Add tests for TriangulateSession
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Oct 13, 2023
1 parent b53c103 commit 9771ff5
Showing 1 changed file with 350 additions and 6 deletions.
356 changes: 350 additions & 6 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import sys
import time
from pathlib import Path, PurePath
from typing import List
from typing import Dict, List
import numpy as np

import pytest

Expand All @@ -17,9 +18,11 @@
RemoveVideo,
ReplaceVideo,
SaveProjectAs,
TriangulateSession,
get_new_version_filename,
)
from sleap.instance import Instance, LabeledFrame
from sleap.io.cameras import Camcorder
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
Expand All @@ -28,11 +31,11 @@
from sleap.io.video import Video
from sleap.util import get_package_file

# These imports cause trouble when running `pytest.main()` from within the file
# Comment out to debug tests file via VSCode's "Debug Python File"
from tests.info.test_h5 import extract_meta_hdf5
from tests.io.test_formats import read_nix_meta
from tests.io.test_video import assert_video_params
# # These imports cause trouble when running `pytest.main()` from within the file
# # Comment out to debug tests file via VSCode's "Debug Python File"
# from tests.info.test_h5 import extract_meta_hdf5
# from tests.io.test_formats import read_nix_meta
# from tests.io.test_video import assert_video_params


def test_delete_user_dialog(centered_pair_predictions):
Expand Down Expand Up @@ -952,3 +955,344 @@ def test_AddSession(
assert len(labels.sessions) == 2
assert context.state["session"] is session
assert labels.sessions[1] is not session


def test_triangulate_session_get_all_views_at_frame(
multiview_min_session_labels: Labels,
):
labels = multiview_min_session_labels
session = labels.sessions[0]
lf = labels.labeled_frames[0]
frame_idx = lf.frame_idx

# Test with no cams_to_include, expect views from all linked cameras
views = TriangulateSession.get_all_views_at_frame(session, frame_idx)
assert len(views) == len(session.linked_cameras)
for cam in session.linked_cameras:
assert views[cam].frame_idx == frame_idx
assert views[cam].video == session[cam]

# Test with cams_to_include, expect views from only those cameras
cams_to_include = session.linked_cameras[0:2]
views = TriangulateSession.get_all_views_at_frame(
session, frame_idx, cams_to_include=cams_to_include
)
assert len(views) == len(cams_to_include)
for cam in cams_to_include:
assert views[cam].frame_idx == frame_idx
assert views[cam].video == session[cam]


def test_triangulate_session_get_instances_across_views(
multiview_min_session_labels: Labels,
):

labels = multiview_min_session_labels
session = labels.sessions[0]

# Test get_instances_across_views
lf: LabeledFrame = labels[0]
track = labels.tracks[0]
instances: Dict[
Camcorder, Instance
] = TriangulateSession.get_instances_across_views(
session=session, frame_idx=lf.frame_idx, track=track
)
assert len(instances) == len(session.videos)
for vid in session.videos:
cam = session[vid]
inst = instances[cam]
assert inst.frame_idx == lf.frame_idx
assert inst.track == track
assert inst.video == vid

# Try with excluding cam views
lf: LabeledFrame = labels[2]
track = labels.tracks[1]
cams_to_include = session.linked_cameras[:4]
videos_to_include: Dict[
Camcorder, Video
] = session.get_videos_from_selected_cameras(cams_to_include=cams_to_include)
assert len(cams_to_include) == 4
assert len(videos_to_include) == len(cams_to_include)
instances: Dict[
Camcorder, Instance
] = TriangulateSession.get_instances_across_views(
session=session,
frame_idx=lf.frame_idx,
track=track,
cams_to_include=cams_to_include,
)
assert len(instances) == len(
videos_to_include
) # May not be true if no instances at that frame
for cam, vid in videos_to_include.items():
inst = instances[cam]
assert inst.frame_idx == lf.frame_idx
assert inst.track == track
assert inst.video == vid

# Try with only a single view
cams_to_include = [session.linked_cameras[0]]
with pytest.raises(ValueError):
instances = TriangulateSession.get_instances_across_views(
session=session,
frame_idx=lf.frame_idx,
cams_to_include=cams_to_include,
track=track,
require_multiple_views=True,
)

# Try with multiple views, but not enough instances
track = labels.tracks[1]
cams_to_include = session.linked_cameras[4:6]
with pytest.raises(ValueError):
instances = TriangulateSession.get_instances_across_views(
session=session,
frame_idx=lf.frame_idx,
cams_to_include=cams_to_include,
track=track,
require_multiple_views=True,
)


def test_triangulate_session_get_and_verify_enough_instances(
multiview_min_session_labels: Labels,
caplog,
):
labels = multiview_min_session_labels
session = labels.sessions[0]
lf = labels.labeled_frames[0]
track = labels.tracks[1]

# Test with no cams_to_include, expect views from all linked cameras
instances = TriangulateSession.get_and_verify_enough_instances(
session=session, frame_idx=lf.frame_idx, track=track
)
assert len(instances) == 6 # Some views don't have an instance at this track
for cam in session.linked_cameras:
if cam.name in ["side", "sideL"]: # The views that don't have an instance
continue
assert instances[cam].frame_idx == lf.frame_idx
assert instances[cam].track == track
assert instances[cam].video == session[cam]

# Test with cams_to_include, expect views from only those cameras
cams_to_include = session.linked_cameras[-2:]
instances = TriangulateSession.get_and_verify_enough_instances(
session=session,
frame_idx=lf.frame_idx,
cams_to_include=cams_to_include,
track=track,
)
assert len(instances) == len(cams_to_include)
for cam in cams_to_include:
assert instances[cam].frame_idx == lf.frame_idx
assert instances[cam].track == track
assert instances[cam].video == session[cam]

# Test with not enough instances, expect views from only those cameras
cams_to_include = session.linked_cameras[0:2]
instances = TriangulateSession.get_and_verify_enough_instances(
session=session, frame_idx=lf.frame_idx, cams_to_include=cams_to_include
)
assert isinstance(instances, bool)
assert not instances
messages = "".join([rec.message for rec in caplog.records])
assert "One or less instances found for frame" in messages


def test_triangulate_session_verify_enough_views(
multiview_min_session_labels: Labels, caplog
):
labels = multiview_min_session_labels
session = labels.sessions[0]

# Test with enough views
enough_views = TriangulateSession.verify_enough_views(
session=session, show_dialog=False
)
assert enough_views
messages = "".join([rec.message for rec in caplog.records])
assert len(messages) == 0
caplog.clear()

# Test with not enough views
cams_to_include = [session.linked_cameras[0]]
enough_views = TriangulateSession.verify_enough_views(
session=session, cams_to_include=cams_to_include, show_dialog=False
)
assert not enough_views
messages = "".join([rec.message for rec in caplog.records])
assert "One or less cameras available." in messages


def test_triangulate_session_verify_views_and_instances(
multiview_min_session_labels: Labels,
):
labels = multiview_min_session_labels
session = labels.sessions[0]

# Test with enough views and instances
lf = labels.labeled_frames[0]
instance = lf.instances[0]

context = CommandContext.from_labels(labels)
params = {
"video": session.videos[0],
"session": session,
"frame_idx": lf.frame_idx,
"instance": instance,
"show_dialog": False,
}
enough_views = TriangulateSession.verify_views_and_instances(context, params)
assert enough_views
assert "instances" in params

# Test with not enough views
cams_to_include = [session.linked_cameras[0]]
params = {
"video": session.videos[0],
"session": session,
"frame_idx": lf.frame_idx,
"instance": instance,
"cams_to_include": cams_to_include,
"show_dialog": False,
}
enough_views = TriangulateSession.verify_views_and_instances(context, params)
assert not enough_views
assert "instances" not in params


def test_triangulate_session_calculate_reprojected_points(
multiview_min_session_labels: Labels,
):
"""Test `TriangulateSession.calculate_reprojected_points`."""

session = multiview_min_session_labels.sessions[0]
lf: LabeledFrame = multiview_min_session_labels[0]
track = multiview_min_session_labels.tracks[0]
instances: Dict[
Camcorder, Instance
] = TriangulateSession.get_instances_across_views(
session=session, frame_idx=lf.frame_idx, track=track
)
instances_and_coords = TriangulateSession.calculate_reprojected_points(
session=session, instances=instances
)

# Check that we get the same number of instances as input
assert len(instances) == len(list(instances_and_coords))

# Check that each instance has the same number of points
for inst, inst_coords in instances_and_coords:
assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2)


def test_triangulate_session_get_instances_matrices(
multiview_min_session_labels: Labels,
):
"""Test `TriangulateSession.get_instance_matrices`."""
labels = multiview_min_session_labels
session = labels.sessions[0]
lf: LabeledFrame = labels[0]
track = labels.tracks[0]
instances: Dict[
Camcorder, Instance
] = TriangulateSession.get_instances_across_views(
session=session, frame_idx=lf.frame_idx, track=track
)
instances_matrices = TriangulateSession.get_instances_matrices(
instances_ordered=instances.values()
)

# Verify shape
n_views = len(instances)
n_frames = 1
n_tracks = 1
n_nodes = len(labels.skeleton)
assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2)


def test_triangulate_session_update_instances(multiview_min_session_labels: Labels):
"""Test `RecordingSession.update_instances`."""

# Test update_instances
session = multiview_min_session_labels.sessions[0]
lf: LabeledFrame = multiview_min_session_labels[0]
track = multiview_min_session_labels.tracks[0]
instances: Dict[
Camcorder, Instance
] = TriangulateSession.get_instances_across_views(
session=session, frame_idx=lf.frame_idx, track=track
)
instances_and_coordinates = TriangulateSession.calculate_reprojected_points(
session=session, instances=instances
)
for inst, inst_coords in instances_and_coordinates:
assert inst_coords.shape == (1, len(inst.skeleton), 2) # Tracks, Nodes, 2
# Assert coord are different from original
assert not np.array_equal(inst_coords, inst.points_array)

# Just run for code coverage testing, do not test output here (race condition)
# (see "functional core, imperative shell" pattern)
TriangulateSession.update_instances(session=session, instances=instances)


def test_triangulate_session_do_action(multiview_min_session_labels: Labels):
"""Test `TriangulateSession.do_action`."""

labels = multiview_min_session_labels
session = labels.sessions[0]

# Test with enough views and instances
lf = labels.labeled_frames[0]
instance = lf.instances[0]

context = CommandContext.from_labels(labels)
params = {
"video": session.videos[0],
"session": session,
"frame_idx": lf.frame_idx,
"instance": instance,
"ask_again": True,
}
TriangulateSession.do_action(context, params)

# Test with not enough views
cams_to_include = [session.linked_cameras[0]]
params = {
"video": session.videos[0],
"session": session,
"frame_idx": lf.frame_idx,
"instance": instance,
"cams_to_include": cams_to_include,
"ask_again": True,
}
TriangulateSession.do_action(context, params)


def test_triangulate_session(multiview_min_session_labels: Labels):
"""Test `TriangulateSession`, if"""

labels = multiview_min_session_labels
session = labels.sessions[0]
video = session.videos[0]
lf = labels.labeled_frames[0]
instance = lf.instances[0]
context = CommandContext.from_labels(labels)

# Test with enough views and instances so we don't get any GUI pop-ups
context.triangulateSession(
frame_idx=lf.frame_idx,
video=video,
instance=instance,
session=session,
)

# Test with using state to gather params
context.state["session"] = session
context.state["video"] = video
context.state["instance"] = instance
context.state["frame_idx"] = lf.frame_idx
context.triangulateSession()

0 comments on commit 9771ff5

Please sign in to comment.