Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow max tracking args for Kalman filter #1986

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cef0da6
and note where `target_instance_count` is initialized
eberrigan Oct 8, 2024
dea7369
`target_instance_count` is not available in the GUI but `max_tracks` is
eberrigan Oct 8, 2024
43c0f0a
add note where `target_instance_count` is initialized
eberrigan Oct 8, 2024
60a40ab
add note since neither `target_instance_count` nor `pre_cull_to_targe…
eberrigan Oct 8, 2024
06ab653
accept either max_tracks or target_instance_count for compatibility w…
eberrigan Oct 8, 2024
4c70d27
TypeError: track() got an unexpected keyword argument 'img_hw' since …
eberrigan Oct 8, 2024
474d0d6
useful print statements
eberrigan Oct 8, 2024
1eafd5f
black
eberrigan Oct 8, 2024
1d6ed7c
np.bool is deprecated
eberrigan Dec 16, 2024
a3e64a9
debug
eberrigan Dec 17, 2024
24a442d
add params for testing kalman filter
eberrigan Dec 17, 2024
96a08cc
remove params because this function isn't used
eberrigan Dec 17, 2024
bd364cd
debugging
eberrigan Dec 17, 2024
8722813
test kalman filter tracking
eberrigan Dec 17, 2024
0fcce5c
add documentation
eberrigan Dec 18, 2024
7322b3c
kalman filter needs node indices, simple tracking and similarity anyt…
eberrigan Dec 18, 2024
266d8ae
add tests for every combination related to kalman args
eberrigan Dec 18, 2024
00cbdaf
add example to documentation
eberrigan Dec 18, 2024
6482f13
delete debug scripts
eberrigan Dec 18, 2024
cbf5ca8
delete print statements
eberrigan Dec 18, 2024
c29921b
black
eberrigan Dec 18, 2024
fa5a82b
add test for connect single breaks
eberrigan Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ optional arguments:
--tracking.kf_node_indices TRACKING.KF_NODE_INDICES
For Kalman filter: Indices of nodes to track. (default: )
--tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT
For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0)
For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) Kalman filters require TRACKING.KF_NODE_INDICES, TRACKING.MAX_TRACKING and TRACKING.MAX_TRACKS or TRACKING.TARGET_INSTANCE_COUNT, TRACKING.TRACKER to be simple or simplemaxtracks, and TRACKING.SIMILARITY to not be normalized_instance.
```

#### Examples:
Expand Down Expand Up @@ -285,6 +285,12 @@ sleap-track --gpu 1 ...
sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4"
```

**9. Use Kalman tracker (not recommended since flow is preferred):**

```none
sleap-track -m "models/my_model" --tracking.similarity instance --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 --tracking.kf_init_frame_count 10 --tracking.kf_node_indices 0,1 -o "output_predictions.slp" "input_video.mp4"
```

## Dataset files

(sleap-convert)=
Expand Down
12 changes: 9 additions & 3 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,9 +1129,11 @@ def export_model(
info["predicted_tensors"] = tensors

full_model = tf.function(
lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False)
if unrag_outputs
else model(x)
lambda x: (
sleap.nn.data.utils.unrag_example(model(x), numpy=False)
if unrag_outputs
else model(x)
)
)

full_model = full_model.get_concrete_function(
Expand Down Expand Up @@ -5717,3 +5719,7 @@ def main(args: Optional[list] = None):
"To retrack on predictions, must specify tracker. "
"Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion sleap/nn/tracker/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def remove_second_bests_from_cost_matrix(
cost matrix with invalid matches set to specified invalid value.
"""

valid_match_mask = np.full_like(cost_matrix, True, dtype=np.bool)
valid_match_mask = np.full_like(cost_matrix, True, dtype=bool)

rows, columns = cost_matrix.shape

Expand Down
54 changes: 46 additions & 8 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ class Tracker(BaseTracker):
max_tracking: bool = False # To enable maximum tracking.

cleaner: Optional[Callable] = None # TODO: deprecate
target_instance_count: int = 0
target_instance_count: int = 0 # TODO: deprecate
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Implement proper deprecation warning for target_instance_count.

Since target_instance_count is marked for deprecation but still being used, implement a proper deprecation warning to notify users.

Add this warning in the __init__ method:

def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    if self.target_instance_count:
        import warnings
        warnings.warn(
            "target_instance_count is deprecated and will be removed in a future version. "
            "Use max_tracks instead.",
            DeprecationWarning,
            stacklevel=2
        )

pre_cull_function: Optional[Callable] = None
post_connect_single_breaks: bool = False
robust_best_instance: float = 1.0
Expand Down Expand Up @@ -824,8 +824,15 @@ def final_pass(self, frames: List[LabeledFrame]):
# "tracking."
# )
self.cleaner.run(frames)
elif self.target_instance_count and self.post_connect_single_breaks:
elif (
self.target_instance_count or self.max_tracks
) and self.post_connect_single_breaks:
if not self.target_instance_count:
# If target_instance_count is not set, use max_tracks instead
# target_instance_count not available in the GUI
self.target_instance_count = self.max_tracks
connect_single_track_breaks(frames, self.target_instance_count)
print("Connecting single track breaks.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logging instead of print statements for better practice.

The print statement for status messages can be replaced with the logging module to provide configurable logging levels and outputs.

Apply this diff:

+import logging
...
+logger = logging.getLogger(__name__)
...
-print("Connecting single track breaks.")
+logger.info("Connecting single track breaks.")

Committable suggestion was skipped due to low confidence.


def get_name(self):
tracker_name = self.candidate_maker.__class__.__name__
Expand All @@ -850,7 +857,7 @@ def make_tracker_by_name(
of_max_levels: int = 3,
save_shifted_instances: bool = False,
# Pre-tracking options to cull instances
target_instance_count: int = 0,
target_instance_count: int = 0, # TODO: deprecate target_instance_count
pre_cull_to_target: bool = False,
pre_cull_iou_threshold: Optional[float] = None,
# Post-tracking options to connect broken tracks
Expand Down Expand Up @@ -921,6 +928,7 @@ def make_tracker_by_name(

pre_cull_function = None
if target_instance_count and pre_cull_to_target:
# Right now this is not accessible from the GUI

def pre_cull_function(inst_list):
cull_frame_instances(
Expand All @@ -940,11 +948,34 @@ def pre_cull_function(inst_list):
pre_cull_function=pre_cull_function,
max_tracking=max_tracking,
max_tracks=max_tracks,
target_instance_count=target_instance_count,
target_instance_count=target_instance_count, # TODO: deprecate target_instance_count
post_connect_single_breaks=post_connect_single_breaks,
)

if target_instance_count and kf_init_frame_count:
# Kalman filter requires deprecated target_instance_count
if (max_tracks or target_instance_count) and kf_init_frame_count:
if not kf_node_indices:
raise ValueError(
"Kalman filter requires node indices for instance tracking."
)

if tracker == "flow" or tracker == "flowmaxtracks":
# Tracking with Kalman filter requires initial tracker object to be simple
raise ValueError(
"Kalman filter requires simple tracker for initial tracking."
)

if similarity == "normalized_instance":
# Kalman filter doesnot support normalized_instance_similarity
raise ValueError(
"Kalman filter does not support normalized_instance_similarity."
)

if not target_instance_count:
# If target_instance_count is not set, use max_tracks instead
# target_instance_count not available in the GUI
target_instance_count = max_tracks

kalman_obj = KalmanTracker.make_tracker(
init_tracker=tracker_obj,
init_frame_count=kf_init_frame_count,
Expand All @@ -954,8 +985,10 @@ def pre_cull_function(inst_list):
)

return kalman_obj
elif kf_init_frame_count and not target_instance_count:
raise ValueError("Kalman filter requires target instance count.")
elif kf_init_frame_count and not (max_tracks or target_instance_count):
raise ValueError(
"Kalman filter requires max tracks or target instance count."
)
else:
return tracker_obj

Expand Down Expand Up @@ -1369,6 +1402,10 @@ def cull_function(inst_list):
if init_tracker.pre_cull_function is None:
init_tracker.pre_cull_function = cull_function

print(
f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters."
)

return cls(
init_tracker=init_tracker,
kalman_tracker=kalman_tracker,
Expand All @@ -1386,6 +1423,7 @@ def track(
untracked_instances: List[InstanceType],
img: Optional[np.ndarray] = None,
t: int = None,
**kwargs,
) -> List[InstanceType]:
"""Tracks individual frame, using Kalman filters if possible."""

Expand Down Expand Up @@ -1420,7 +1458,7 @@ def track(
# Initialize the Kalman filters
self.kalman_tracker.init_filters(self.init_set.instances)

# print(f"Kalman filters initialized (frame {t})")
print(f"Kalman filters initialized (frame {t})")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logging instead of print statements for initialization messages.

Replace the print statement with the logging module to improve flexibility and control over log outputs.

Apply this diff:

+import logging
...
+logger = logging.getLogger(__name__)
...
-print(f"Kalman filters initialized (frame {t})")
+logger.info(f"Kalman filters initialized (frame {t})")

Committable suggestion was skipped due to low confidence.


# Clear the data used to init filters, so that if the filters
# stop tracking and we need to re-init, we won't re-use the
Expand Down
194 changes: 193 additions & 1 deletion tests/nn/test_tracking_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,205 @@
import operator
import os
import time

import pytest
import sleap
from sleap.nn.inference import main as inference_cli
import sleap.nn.tracker.components
from sleap.io.dataset import Labels, LabeledFrame


similarity_args = [
"instance",
"normalized_instance",
"object_keypoint",
"centroid",
"iou",
]
match_args = ["hungarian", "greedy"]


@pytest.mark.parametrize(
"tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"]
)
@pytest.mark.parametrize("similarity", similarity_args)
@pytest.mark.parametrize("match", match_args)
def test_kalman_tracker(
tmpdir, centered_pair_predictions_slp_path, tracker_name, similarity, match
):

if tracker_name == "flow" or tracker_name == "flowmaxtracks":
# Expecting ValueError for "flow" or "flowmaxtracks" due to Kalman filter requiring a simple tracker
with pytest.raises(
ValueError,
match="Kalman filter requires simple tracker for initial tracking.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
else:
# For simple or simplemaxtracks, continue with other tests
# Check for ValueError when similarity is "normalized_instance"
if similarity == "normalized_instance":
with pytest.raises(
ValueError,
match="Kalman filter does not support normalized_instance_similarity.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
return

# Check for ValueError when kf_node_indices is None which is the default
with pytest.raises(
ValueError,
match="Kalman filter requires node indices for instance tracking.",
):
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

# Test for missing max_tracks and target_instance_count with kf_init_frame_count
with pytest.raises(
ValueError,
match="Kalman filter requires max tracks or target instance count.",
):
cli = (
f"--tracking.tracker {tracker_name} "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
f"-o {tmpdir}/{tracker_name}.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

# Test with target_instance_count and without max_tracks
cli = (
f"--tracking.tracker {tracker_name} "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
f"-o {tmpdir}/{tracker_name}_target_instance_count.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(f"{tmpdir}/{tracker_name}_target_instance_count.slp")
assert len(labels.tracks) == 2

# Test with target_instance_count and with max_tracks
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp"
)
assert len(labels.tracks) == 2

# Test with "--tracking.pre_cull_iou_threshold", "0.8"
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.pre_cull_iou_threshold 0.8 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))

labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp"
)
assert len(labels.tracks) == 2

# Test with "--tracking.pre_cull_to_target", "1"
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.pre_cull_to_target 1 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp"
)
assert len(labels.tracks) == 2

# Test with 'tracking.post_connect_single_breaks': 0
cli = (
f"--tracking.tracker {tracker_name} "
"--tracking.max_tracking 1 --tracking.max_tracks 2 "
f"--tracking.similarity {similarity} "
f"--tracking.match {match} "
"--tracking.track_window 5 "
"--tracking.kf_init_frame_count 10 "
"--tracking.kf_node_indices 0,1 "
"--tracking.target_instance_count 2 "
"--tracking.post_connect_single_breaks 0 "
f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp "
f"{centered_pair_predictions_slp_path}"
)
inference_cli(cli.split(" "))
labels = sleap.load_file(
f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp"
)
assert len(labels.tracks) == 2


def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path):
cli = (
"--tracking.tracker simple "
Expand Down
Loading