Skip to content

Commit

Permalink
Merge pull request #29 from DeepLabCut/niels/sa_detectors
Browse files Browse the repository at this point in the history
SuperAnimal Model Updates
  • Loading branch information
n-poulsen authored Oct 11, 2024
2 parents b7d9dde + c4237cb commit 5a22726
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 11 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ model_dir.mkdir()
download_huggingface_model("superanimal_quadruped", model_dir)
```

PyTorch models available for a given dataset (compatible with DeepLabCut>=3.0) can be
listed using the `dlclibrary.get_available_detectors` and
`dlclibrary.get_available_models` methods. Example use:

```python
>>> import dlclibrary
>>> dlclibrary.get_available_detectors("superanimal_bird")
['fasterrcnn_mobilenet_v3_large_fpn', 'ssdlite']

>>> dlclibrary.get_available_models("superanimal_bird")
['resnet_50']
```


## How to add a new model?

Pick a good model_name. Follow the (novel) naming convention (modeltype_species), e.g. ```superanimal_topviewmouse```.
Expand Down
2 changes: 2 additions & 0 deletions dlclibrary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from dlclibrary.dlcmodelzoo.modelzoo_download import (
download_huggingface_model,
get_available_detectors,
get_available_models,
parse_available_supermodels,
)
from dlclibrary.version import __version__, VERSION
68 changes: 57 additions & 11 deletions dlclibrary/dlcmodelzoo/modelzoo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path

from huggingface_hub import hf_hub_download
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedBase

# just expand this list when adding new models:
Expand All @@ -27,12 +28,9 @@
"mouse_pupil_vclose",
"horse_sideview",
"full_macaque",
"superanimal_topviewmouse_dlcrnet",
"superanimal_quadruped_dlcrnet",
"superanimal_topviewmouse_hrnetw32",
"superanimal_quadruped_hrnetw32",
"superanimal_topviewmouse", # DeepLabCut 2.X backwards compatibility
"superanimal_quadruped", # DeepLabCut 2.X backwards compatibility
"superanimal_bird",
"superanimal_quadruped",
"superanimal_topviewmouse",
]


Expand All @@ -43,20 +41,66 @@ def _get_dlclibrary_path():
return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0]


def _load_model_names():
def _load_pytorch_models() -> dict[str, dict[str, dict[str, str]]]:
"""Load URLs and commit hashes for available models."""
urls = Path(_get_dlclibrary_path()) / "dlcmodelzoo" / "modelzoo_urls_pytorch.yaml"
with open(urls) as file:
data = YAML(pure=True).load(file)

return data


def _load_pytorch_dataset_models(dataset: str) -> dict[str, dict[str, str]]:
"""Load URLs and commit hashes for available models."""
from ruamel.yaml import YAML
models = _load_pytorch_models()
if not dataset in models:
raise ValueError(
f"Could not find any models for {dataset}. Models are available for "
f"{list(models.keys())}"
)

return models[dataset]


def _load_model_names():
"""Load URLs and commit hashes for available models."""
fn = os.path.join(_get_dlclibrary_path(), "dlcmodelzoo", "modelzoo_urls.yaml")
with open(fn) as file:
return YAML().load(file)
model_names = YAML().load(file)

# add PyTorch models
for dataset, model_types in _load_pytorch_models().items():
for model_type, models in model_types.items():
for model, url in models.items():
model_names[f"{dataset}_{model}"] = url

return model_names


def parse_available_supermodels():
libpath = _get_dlclibrary_path()
json_path = os.path.join(libpath, "dlcmodelzoo", "superanimal_models.json")
with open(json_path) as file:
return json.load(file)
super_animal_models = json.load(file)
return super_animal_models


def get_available_detectors(dataset: str) -> list[str]:
""" Only for PyTorch models.
Returns:
The detectors available for the dataset.
"""
return list(_load_pytorch_dataset_models(dataset)["detectors"].keys())


def get_available_models(dataset: str) -> list[str]:
""" Only for PyTorch models.
Returns:
The pose models available for the dataset.
"""
return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys())


def _handle_downloaded_file(
Expand Down Expand Up @@ -103,7 +147,9 @@ def download_huggingface_model(
"""
net_urls = _load_model_names()
if model_name not in net_urls:
raise ValueError(f"`modelname` should be one of: {', '.join(net_urls)}.")
raise ValueError(
f"`modelname={model_name}` should be one of: {', '.join(net_urls)}."
)

print("Loading....", model_name)
urls = net_urls[model_name]
Expand Down
24 changes: 24 additions & 0 deletions dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# DeepLabCut 3.0: SuperAnimal detectors and pose model URLS

superanimal_bird:
detectors:
fasterrcnn_mobilenet_v3_large_fpn: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_fasterrcnn_mobilenet_v3_large_fpn.pt
ssdlite: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_ssdlite.pt
pose_models:
resnet_50: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_resnet_50.pt

superanimal_topviewmouse:
detectors:
fasterrcnn_mobilenet_v3_large_fpn: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_fasterrcnn_mobilenet_v3_large_fpn.pt
fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_fasterrcnn_resnet50_fpn_v2.pt
pose_models:
hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_hrnet_w32.pt
resnet_50: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/superanimal_topviewmouse_resnet_50.pt

superanimal_quadruped:
detectors:
fasterrcnn_mobilenet_v3_large_fpn: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_fasterrcnn_mobilenet_v3_large_fpn.pt
fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_fasterrcnn_resnet50_fpn_v2.pt
pose_models:
hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_hrnet_w32.pt
resnet_50: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/superanimal_quadruped_resnet_50.pt
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"dlclibrary",
[
"dlclibrary/dlcmodelzoo/modelzoo_urls.yaml",
"dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml",
"dlclibrary/dlcmodelzoo/superanimal_models.json",
"dlclibrary/dlcmodelzoo/superanimal_configs/superquadruped.yaml",
"dlclibrary/dlcmodelzoo/superanimal_configs/supertopview.yaml",
Expand Down
47 changes: 47 additions & 0 deletions tests/test_pytorch_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# DeepLabCut Toolbox (deeplabcut.org)
# © A. & M.W. Mathis Labs
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
import os
import pytest

import dlclibrary
import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo


@pytest.mark.parametrize(
"data",
[
("superanimal_bird", ["ssdlite"]),
("superanimal_topviewmouse", ["fasterrcnn_resnet50_fpn_v2"]),
("superanimal_quadruped", ["fasterrcnn_resnet50_fpn_v2"]),
]
)
def test_get_super_animal_detectors(data: tuple[str, list[str]]):
dataset, expected_detectors = data
detectors = modelzoo.get_available_detectors(dataset)
assert len(detectors) >= len(expected_detectors)
for det in expected_detectors:
assert det in detectors


@pytest.mark.parametrize(
"data",
[
("superanimal_bird", ["resnet_50"]),
("superanimal_topviewmouse", ["hrnet_w32"]),
("superanimal_quadruped", ["hrnet_w32"]),
]
)
def test_get_super_animal_pose_models(data: tuple[str, list[str]]):
dataset, expected_pose_models = data
pose_models = modelzoo.get_available_models(dataset)
assert len(pose_models) >= len(expected_pose_models)
for pose_model in expected_pose_models:
assert pose_model in pose_models

0 comments on commit 5a22726

Please sign in to comment.