From b842660eed1d559f9af4bb20ef1869f7da7e7488 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Fri, 2 Jun 2023 16:33:29 -0400 Subject: [PATCH 1/7] Make sleap functions that don't need the model not load it. --- diplomat/__init__.py | 3 +- diplomat/_cli_runner.py | 2 +- diplomat/core_ops.py | 50 +++++++--- diplomat/frontend_ops.py | 2 +- diplomat/frontends/__init__.py | 91 +++++++++++++++++-- diplomat/frontends/deeplabcut/__init__.py | 6 +- diplomat/frontends/sleap/__init__.py | 6 +- diplomat/frontends/sleap/_verify_func.py | 7 +- .../frontends/sleap/convert_results_sleap.py | 9 +- .../frontends/sleap/label_videos_sleap.py | 8 +- .../frontends/sleap/predict_frames_sleap.py | 13 ++- diplomat/frontends/sleap/run_utils.py | 12 +++ diplomat/frontends/sleap/sleap_providers.py | 11 ++- .../frontends/sleap/tweak_results_sleap.py | 7 +- diplomat/utils/lazy_import.py | 4 +- docs/source/ext/plugin_docgen.py | 6 +- pyproject.toml | 4 +- 17 files changed, 174 insertions(+), 67 deletions(-) diff --git a/diplomat/__init__.py b/diplomat/__init__.py index 51232c1..a69cae9 100644 --- a/diplomat/__init__.py +++ b/diplomat/__init__.py @@ -61,7 +61,7 @@ def _load_frontends(): if(hasattr(frontend, "__doc__")): mod.__doc__ = frontend.__doc__ - for (name, func) in asdict(res).items(): + for (name, func) in res: if(not name.startswith("_")): func = replace_function_name_and_module(func, name, mod.__name__) setattr(mod, name, func) @@ -69,4 +69,5 @@ def _load_frontends(): return frontends, loaded_funcs + _FRONTENDS, _LOADED_FRONTENDS = _load_frontends() diff --git a/diplomat/_cli_runner.py b/diplomat/_cli_runner.py index 819599b..81ea3d8 100644 --- a/diplomat/_cli_runner.py +++ b/diplomat/_cli_runner.py @@ -42,7 +42,7 @@ def get_dynamic_cli_tree() -> dict: for frontend_name, funcs in diplomat._LOADED_FRONTENDS.items(): frontend_commands = { - name: func for name, func in asdict(funcs).items() if(not name.startswith("_")) + name: func for name, func in funcs if(not name.startswith("_")) } doc_str = getattr(getattr(diplomat, frontend_name), "__doc__", None) diff --git a/diplomat/core_ops.py b/diplomat/core_ops.py index 7c6ae22..286859c 100644 --- a/diplomat/core_ops.py +++ b/diplomat/core_ops.py @@ -7,6 +7,7 @@ import typing from types import ModuleType from diplomat.utils.tweak_ui import UIImportError +from diplomat.frontends import DIPLOMATContract, DIPLOMATCommands class ArgumentError(CLIError): @@ -47,14 +48,21 @@ def _get_casted_args(tc_func, extra_args, error_on_miss=True): return new_args -def _find_frontend(config: os.PathLike, **kwargs: typing.Any) -> typing.Tuple[str, ModuleType]: +def _find_frontend( + contracts: Union[DIPLOMATContract, List[DIPLOMATContract]], + config: os.PathLike, + **kwargs: typing.Any +) -> typing.Tuple[str, ModuleType]: from diplomat import _LOADED_FRONTENDS + contracts = [contracts] if(isinstance(contracts, DIPLOMATContract)) else contracts + for name, funcs in _LOADED_FRONTENDS.items(): - if(funcs._verifier( + if(all(funcs.verify( + contract=c, config=config, **kwargs - )): + ) for c in contracts)): print(f"Frontend '{name}' selected.") return (name, funcs) @@ -178,6 +186,7 @@ def track( from diplomat import CLI_RUN selected_frontend_name, selected_frontend = _find_frontend( + contracts=[DIPLOMATCommands.analyze_videos, DIPLOMATCommands.analyze_videos], config=config, videos=videos, frame_stores=frame_stores, @@ -324,7 +333,12 @@ def annotate( from diplomat import CLI_RUN # Iterate the frontends, looking for one that actually matches our request... - selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args) + selected_frontend_name, selected_frontend = _find_frontend( + contracts=DIPLOMATCommands.label_videos, + config=config, + videos=videos, + **extra_args + ) if(help_extra): _display_help(selected_frontend_name, "video labeling", "diplomat annotate", selected_frontend.label_videos, CLI_RUN) @@ -350,18 +364,25 @@ def tweak( **extra_args ): """ - Make modifications to DIPLOMAT produced tracking results created for a video using a limited version supervised labeling UI. Allows for touching - up and fixing any minor issues that may arise after tracking and saving results. + Make modifications to DIPLOMAT produced tracking results created for a video using a limited version supervised + labeling UI. Allows for touching up and fixing any minor issues that may arise after tracking and saving results. - :param config: The path to the configuration file for the project. The format of this argument will depend on the frontend. + :param config: The path to the configuration file for the project. The format of this argument will depend on the + frontend. :param videos: A single path or list of paths to video files to tweak the tracks of. - :param help_extra: Boolean, if set to true print extra settings for the automatically selected frontend instead of showing the UI. - :param extra_args: Any additional arguments (if the CLI, flags starting with '--') are passed to the automatically selected frontend. - To see valid values, run tweak with extra_help flag set to true. + :param help_extra: Boolean, if set to true print extra settings for the automatically selected frontend instead of + showing the UI. + :param extra_args: Any additional arguments (if the CLI, flags starting with '--') are passed to the automatically + selected frontend. To see valid values, run tweak with extra_help flag set to true. """ from diplomat import CLI_RUN - selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args) + selected_frontend_name, selected_frontend = _find_frontend( + contracts=DIPLOMATCommands.tweak_videos, + config=config, + videos=videos, + **extra_args + ) if(help_extra): _display_help(selected_frontend_name, "label tweaking", "diplomat tweak", selected_frontend.tweak_videos, CLI_RUN) @@ -404,7 +425,12 @@ def convert( """ from diplomat import CLI_RUN - selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args) + selected_frontend_name, selected_frontend = _find_frontend( + contracts=DIPLOMATCommands.convert_results, + config=config, + videos=videos, + **extra_args + ) if(help_extra): _display_help( diff --git a/diplomat/frontend_ops.py b/diplomat/frontend_ops.py index 4d7365d..07aac0d 100644 --- a/diplomat/frontend_ops.py +++ b/diplomat/frontend_ops.py @@ -36,7 +36,7 @@ def list_loaded_frontends(): print("Description:") print(f"\t{frontend_docs[name]}") print("Supported Functions:") - for k, v in asdict(funcs).items(): + for k, v in funcs: if(k.startswith("_")): continue print(f"\t{k}") diff --git a/diplomat/frontends/__init__.py b/diplomat/frontends/__init__.py index f17e417..2f3ce2a 100644 --- a/diplomat/frontends/__init__.py +++ b/diplomat/frontends/__init__.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, asdict +from collections import OrderedDict from diplomat.processing.type_casters import StrictCallable, PathLike, Union, List, Dict, Any, Optional, TypeCaster, NoneType import typing - class Select(Union): def __eq__(self, other: TypeCaster): if(isinstance(other, Union)): @@ -62,29 +62,100 @@ def to_type_hint(self) -> typing.Type: ) -@dataclass(frozen=False) -class DIPLOMATBaselineCommands: +@dataclass(frozen=True) +class DIPLOMATContract: + """ + Represents a 'contract' + """ + method_name: str + method_type: StrictCallable + + +class CommandManager(type): + + __no_type_check__ = False + def __new__(cls, *args, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + + annotations = typing.get_type_hints(obj) + + for name, annot in annotations.items(): + if(name in obj.__dict__): + raise TypeError(f"Command annotation '{name}' has default value, which is not allowed.") + + return obj + + def __getattr__(self, item): + annot = typing.get_type_hints(self)[item] + return DIPLOMATContract(item, annot) + + +def required(typecaster: TypeCaster) -> TypeCaster: + typecaster._required = True + return typecaster + + +class DIPLOMATCommands(metaclass=CommandManager): """ The baseline set of functions each DIPLOMAT backend must implement. Backends can add additional commands - by extending this base class... + by passing the methods to this classes constructor. """ - _verifier: VerifierFunction + _verifier: required(VerifierFunction) analyze_videos: AnalyzeVideosFunction(NoneType) analyze_frames: AnalyzeFramesFunction(NoneType) label_videos: LabelVideosFunction(NoneType) tweak_videos: LabelVideosFunction(NoneType) convert_results: ConvertResultsFunction(NoneType) - def __post_init__(self): + def __init__(self, **kwargs): + missing = object() + self._commands = OrderedDict() + annotations = typing.get_type_hints(type(self)) - for name, value in asdict(self).items(): - annot = annotations.get(name, None) + for name, annot in annotations.items(): + value = kwargs.get(name, missing) + if(value is missing): + if(getattr(annot, "_required", False)): + raise ValueError(f"Command '{name}' is required, but was not provided.") + continue if(annot is None or (not isinstance(annot, TypeCaster))): raise TypeError("DIPLOMAT Command Struct can only contain typecaster types.") - setattr(self, name, annot(value)) + self._commands[name] = annot(value) + + for name, value in kwargs.items(): + if(name not in annotations): + self._commands[name] = value + + def __iter__(self): + return iter(self._commands.items()) + + def __getattr__(self, item: str): + return self._commands.get(item) + + def verify(self, contract: DIPLOMATContract, config: Union[List[PathLike], PathLike], **kwargs: Any) -> bool: + """ + Verify this backend can handle the provided command type, config file, and arguments. + + :param contract: The contract for the command. Includes the name of the method and the type of the method, + which will typically be a strict callable. + :param config: The configuration file, checks if the backend can handle this configuration file. + :param kwargs: Any additional arguments to pass to the backends verifier. + + :return: A boolean, True if the backend can handle the provided command and arguments, otherwise False. + """ + if(contract.method_name in self._commands): + func = self._commands[contract.method_name] + try: + contract.method_type(func) + except Exception: + return False + + return self._verifier(config, **kwargs) + + return False class DIPLOMATFrontend(ABC): @@ -93,7 +164,7 @@ class DIPLOMATFrontend(ABC): """ @classmethod @abstractmethod - def init(cls) -> typing.Optional[DIPLOMATBaselineCommands]: + def init(cls) -> typing.Optional[DIPLOMATCommands]: """ Attempt to initialize the frontend, returning a list of api functions. If the backend can't initialize due to missing imports/requirements, this function should return None. diff --git a/diplomat/frontends/deeplabcut/__init__.py b/diplomat/frontends/deeplabcut/__init__.py index aab8b76..230238a 100644 --- a/diplomat/frontends/deeplabcut/__init__.py +++ b/diplomat/frontends/deeplabcut/__init__.py @@ -1,5 +1,5 @@ from typing import Optional -from diplomat.frontends import DIPLOMATFrontend, DIPLOMATBaselineCommands +from diplomat.frontends import DIPLOMATFrontend, DIPLOMATCommands class DEEPLABCUTFrontend(DIPLOMATFrontend): @@ -7,7 +7,7 @@ class DEEPLABCUTFrontend(DIPLOMATFrontend): The DEEPLABCUT frontend for DIPLOMAT. Contains functions for running DIPLOMAT on DEEPLABCUT projects. """ @classmethod - def init(cls) -> Optional[DIPLOMATBaselineCommands]: + def init(cls) -> Optional[DIPLOMATCommands]: try: from diplomat.frontends.deeplabcut._verify_func import _verify_dlc_like from diplomat.frontends.deeplabcut.predict_videos_dlc import analyze_videos @@ -18,7 +18,7 @@ def init(cls) -> Optional[DIPLOMATBaselineCommands]: except ImportError: return None - return DIPLOMATBaselineCommands( + return DIPLOMATCommands( _verifier=_verify_dlc_like, analyze_videos=analyze_videos, analyze_frames=analyze_frames, diff --git a/diplomat/frontends/sleap/__init__.py b/diplomat/frontends/sleap/__init__.py index 2ad160a..f785b2b 100644 --- a/diplomat/frontends/sleap/__init__.py +++ b/diplomat/frontends/sleap/__init__.py @@ -1,5 +1,5 @@ from typing import Optional -from diplomat.frontends import DIPLOMATFrontend, DIPLOMATBaselineCommands +from diplomat.frontends import DIPLOMATFrontend, DIPLOMATCommands class SLEAPFrontend(DIPLOMATFrontend): @@ -7,7 +7,7 @@ class SLEAPFrontend(DIPLOMATFrontend): The SLEAP frontend for DIPLOMAT. Contains functions for running DIPLOMAT on SLEAP projects. """ @classmethod - def init(cls) -> Optional[DIPLOMATBaselineCommands]: + def init(cls) -> Optional[DIPLOMATCommands]: try: from diplomat.frontends.sleap._verify_func import _verify_sleap_like from diplomat.frontends.sleap.predict_videos_sleap import analyze_videos @@ -18,7 +18,7 @@ def init(cls) -> Optional[DIPLOMATBaselineCommands]: except ImportError: return None - return DIPLOMATBaselineCommands( + return DIPLOMATCommands( _verifier=_verify_sleap_like, analyze_videos=analyze_videos, analyze_frames=analyze_frames, diff --git a/diplomat/frontends/sleap/_verify_func.py b/diplomat/frontends/sleap/_verify_func.py index bdada2b..45a2641 100644 --- a/diplomat/frontends/sleap/_verify_func.py +++ b/diplomat/frontends/sleap/_verify_func.py @@ -1,7 +1,8 @@ import sleap from diplomat.processing.type_casters import Union, List, PathLike, typecaster_function -from .run_utils import _paths_to_str +from .run_utils import _paths_to_str, _load_config + @typecaster_function def _verify_sleap_like( @@ -11,7 +12,7 @@ def _verify_sleap_like( try: # Config for sleap is always a sleap model, so try to load it... config = _paths_to_str(config) - __ = sleap.load_model(config) + _load_config(config) return True except: - return False \ No newline at end of file + return False diff --git a/diplomat/frontends/sleap/convert_results_sleap.py b/diplomat/frontends/sleap/convert_results_sleap.py index 74363f6..26b1e29 100644 --- a/diplomat/frontends/sleap/convert_results_sleap.py +++ b/diplomat/frontends/sleap/convert_results_sleap.py @@ -7,7 +7,8 @@ from .run_utils import ( _paths_to_str, - _to_diplomat_poses + _to_diplomat_poses, + _load_config ) @@ -24,10 +25,8 @@ def convert_results( zip file. :param videos: Paths to the sleap label files, or .slp files, to convert to csv files, NOT the video files. """ - model = sleap.load_model(_paths_to_str(config)) - - if (model is None): - raise ValueError("Model passed was invalid!") + # Load config just to verify it's valid... + _load_config(_paths_to_str(config)) label_paths = _paths_to_str(videos) label_paths = [label_paths] if(isinstance(label_paths, str)) else label_paths diff --git a/diplomat/frontends/sleap/label_videos_sleap.py b/diplomat/frontends/sleap/label_videos_sleap.py index efe82cb..ed994c7 100644 --- a/diplomat/frontends/sleap/label_videos_sleap.py +++ b/diplomat/frontends/sleap/label_videos_sleap.py @@ -13,7 +13,8 @@ from .visual_settings import FULL_VISUAL_SETTINGS from .run_utils import ( _paths_to_str, - _to_diplomat_poses + _to_diplomat_poses, + _load_config ) @@ -47,10 +48,7 @@ def label_videos( {extra_cli_args} """ - model = sleap.load_model(_paths_to_str(config)) - - if(model is None): - raise ValueError("Model passed was invalid!") + _load_config(_paths_to_str(config)) videos = _paths_to_str(videos) videos = [videos] if(isinstance(videos, str)) else videos diff --git a/diplomat/frontends/sleap/predict_frames_sleap.py b/diplomat/frontends/sleap/predict_frames_sleap.py index 6f1446e..d9b406d 100644 --- a/diplomat/frontends/sleap/predict_frames_sleap.py +++ b/diplomat/frontends/sleap/predict_frames_sleap.py @@ -9,8 +9,9 @@ from diplomat.utils.video_info import is_video from diplomat.utils import frame_store_fmt -from .run_utils import _get_default_value, _paths_to_str, _get_video_metadata, _get_predictor_settings, PoseLabels, Timer, _attach_run_info -from .sleap_providers import PredictorExtractor, SleapMetadata +from .run_utils import _get_default_value, _paths_to_str, _get_video_metadata, _get_predictor_settings, PoseLabels, \ + Timer, _attach_run_info, _load_config +from .sleap_providers import sleap_metadata_from_config, SleapMetadata from .visual_settings import VISUAL_SETTINGS import sleap @@ -57,11 +58,9 @@ def analyze_frames( batch_size = _get_default_value(sleap.load_model, "batch_size", 4) if (batch_size is None) else batch_size num_outputs = 1 if (num_outputs is None) else num_outputs - print("Loading Model...") - model = sleap.load_model(_paths_to_str(config), batch_size=batch_size) - # Get the model extractor... - mdl_extractor = PredictorExtractor(model, refinement_kernel_size) - mdl_metadata = mdl_extractor.get_metadata() + print("Loading Config...") + config = _load_config(_paths_to_str(config))[0] + mdl_metadata = sleap_metadata_from_config(config.data) predictor_cls = get_predictor("SegmentedFramePassEngine" if (predictor is None) else predictor) print(f"Using predictor: '{predictor_cls.get_name()}'") diff --git a/diplomat/frontends/sleap/run_utils.py b/diplomat/frontends/sleap/run_utils.py index 260c479..06eb4b5 100644 --- a/diplomat/frontends/sleap/run_utils.py +++ b/diplomat/frontends/sleap/run_utils.py @@ -48,6 +48,18 @@ def _paths_to_str(paths): return str(paths) +def _load_config(paths): + try: + paths = [paths] if(isinstance(paths, str)) else paths + + if(len(paths) < 1): + raise ValueError(f"No configuration files passed to open!") + + return [sleap.load_config(p) for p in paths] + except IOError as e: + raise type(e)(f"Unable to load provided sleap config: '{repr(e)}'") + + def _get_default_value(func, attr, fallback): param = signature(func).parameters.get(attr, None) return fallback if(param is None) else param.default diff --git a/diplomat/frontends/sleap/sleap_providers.py b/diplomat/frontends/sleap/sleap_providers.py index 4729836..a5732e8 100644 --- a/diplomat/frontends/sleap/sleap_providers.py +++ b/diplomat/frontends/sleap/sleap_providers.py @@ -2,6 +2,7 @@ from typing import Optional, Dict, Union, Iterator, Set, Type, List, Tuple import sleap.nn.data.resizing +from sleap.nn.config import DataConfig as SleapDataConfig from typing_extensions import TypedDict import numpy as np import tensorflow as tf @@ -20,10 +21,10 @@ class SleapMetadata(TypedDict): orig_skeleton: SleapSkeleton -def _extract_metadata(predictor: SleapPredictor) -> SleapMetadata: - skel_list = predictor.data_config.labels.skeletons +def sleap_metadata_from_config(config: SleapDataConfig) -> SleapMetadata: + skel_list = config.labels.skeletons - if(len(skel_list) < 1): + if (len(skel_list) < 1): raise ValueError("No part information for this SLEAP project, can't run diplomat!") skeleton1 = skel_list[0] @@ -31,7 +32,7 @@ def _extract_metadata(predictor: SleapPredictor) -> SleapMetadata: return SleapMetadata( bp_names=skeleton1.node_names, - skeleton=edge_name_list if(len(edge_name_list) > 0) else None, + skeleton=edge_name_list if (len(edge_name_list) > 0) else None, orig_skeleton=skeleton1 ) @@ -47,7 +48,7 @@ def __init__(self, model: SleapPredictor): self.__p = model def get_metadata(self) -> SleapMetadata: - return _extract_metadata(self.__p) + return sleap_metadata_from_config(self.__p.data_config) @abstractmethod def extract(self, data: Union[Provider]) -> Tuple[tf.Tensor, Optional[tf.Tensor], float]: diff --git a/diplomat/frontends/sleap/tweak_results_sleap.py b/diplomat/frontends/sleap/tweak_results_sleap.py index b806032..83ee724 100644 --- a/diplomat/frontends/sleap/tweak_results_sleap.py +++ b/diplomat/frontends/sleap/tweak_results_sleap.py @@ -13,7 +13,7 @@ _paths_to_str, _get_video_metadata, _to_diplomat_poses, - PoseLabels, + PoseLabels, _load_config, ) from .sleap_providers import SleapMetadata @@ -34,10 +34,7 @@ def tweak_videos( {extra_cli_args} """ - model = sleap.load_model(_paths_to_str(config)) - - if(model is None): - raise ValueError("Model passed was invalid!") + _load_config(_paths_to_str(config)) label_paths = _paths_to_str(videos) label_paths = [label_paths] if(isinstance(label_paths, str)) else label_paths diff --git a/diplomat/utils/lazy_import.py b/diplomat/utils/lazy_import.py index 7725d82..36ba140 100644 --- a/diplomat/utils/lazy_import.py +++ b/diplomat/utils/lazy_import.py @@ -71,7 +71,9 @@ def verify_existence_of(name: str): raise ValueError("Can only check top-level modules without attempting to import them.") try: - find_spec(name) + spec = find_spec(name) + if(spec is None): + raise ImportError(f"Unable to find package '{name}'.") except Exception as e: raise ImportError(str(e)) diff --git a/docs/source/ext/plugin_docgen.py b/docs/source/ext/plugin_docgen.py index a924798..edf4a69 100644 --- a/docs/source/ext/plugin_docgen.py +++ b/docs/source/ext/plugin_docgen.py @@ -27,7 +27,7 @@ import diplomat.predictors.fpe.frame_passes as frame_passes from diplomat.predictors.fpe.frame_pass import FramePass - from diplomat.frontends import DIPLOMATBaselineCommands + from diplomat.frontends import DIPLOMATCommands def load_plugins_with_mocks(module, clazz): from diplomat.utils.pluginloader import load_plugin_classes @@ -195,7 +195,7 @@ def get_frame_pass_rst(plugin: Type[FramePass]) -> str: settings = format_settings(plugin.get_config_options()) ) -def get_frontend_rst(name: str, methods: DIPLOMATBaselineCommands): +def get_frontend_rst(name: str, methods: DIPLOMATCommands): from dataclasses import asdict module_name = "diplomat." + name @@ -207,7 +207,7 @@ def get_frontend_rst(name: str, methods: DIPLOMATBaselineCommands): module_name_eqs = "=" * len(module_name), desc = clean_doc_str(doc), function_list = "\n".join( - f" ~{func.__module__}.{func.__name__}" for name, func in asdict(methods).items() if(not func.__name__.startswith("_")) + f" ~{func.__module__}.{func.__name__}" for name, func in methods if(not func.__name__.startswith("_")) ) ) diff --git a/pyproject.toml b/pyproject.toml index 5958bb6..fbeb179 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ diplomat = "diplomat._cli_runner:main" [project.optional-dependencies] # Latest version has build issues... gui = ["wxpython<4.2.0"] -dlc = ["deeplabcut[tf]"] +dlc = ["deeplabcut"] sleap = ["sleap"] -all = ["wxpython<4.2.0", "deeplabcut[tf]", "sleap"] +all = ["wxpython<4.2.0", "deeplabcut", "sleap"] From 9ae1aa01ac1ae21407c339be29e392fbc14f2081 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 5 Jun 2023 01:43:26 -0400 Subject: [PATCH 2/7] Make sleap imports lazy and silent. --- diplomat/frontends/sleap/_verify_func.py | 2 - .../frontends/sleap/convert_results_sleap.py | 2 +- .../frontends/sleap/label_videos_sleap.py | 2 +- .../frontends/sleap/predict_frames_sleap.py | 17 +++++-- .../frontends/sleap/predict_videos_sleap.py | 3 +- diplomat/frontends/sleap/run_utils.py | 2 +- diplomat/frontends/sleap/sleap_importer.py | 14 ++++++ diplomat/frontends/sleap/sleap_providers.py | 46 +++++++++++-------- .../frontends/sleap/tweak_results_sleap.py | 2 +- diplomat/frontends/sleap/visual_settings.py | 20 ++++---- diplomat/utils/lazy_import.py | 8 +++- 11 files changed, 79 insertions(+), 39 deletions(-) create mode 100644 diplomat/frontends/sleap/sleap_importer.py diff --git a/diplomat/frontends/sleap/_verify_func.py b/diplomat/frontends/sleap/_verify_func.py index 45a2641..8d8200d 100644 --- a/diplomat/frontends/sleap/_verify_func.py +++ b/diplomat/frontends/sleap/_verify_func.py @@ -1,5 +1,3 @@ -import sleap - from diplomat.processing.type_casters import Union, List, PathLike, typecaster_function from .run_utils import _paths_to_str, _load_config diff --git a/diplomat/frontends/sleap/convert_results_sleap.py b/diplomat/frontends/sleap/convert_results_sleap.py index 26b1e29..108c5db 100644 --- a/diplomat/frontends/sleap/convert_results_sleap.py +++ b/diplomat/frontends/sleap/convert_results_sleap.py @@ -1,6 +1,6 @@ from pathlib import Path -import sleap +from .sleap_importer import sleap import diplomat.processing.type_casters as tc from diplomat.utils.track_formats import to_diplomat_table, save_diplomat_table diff --git a/diplomat/frontends/sleap/label_videos_sleap.py b/diplomat/frontends/sleap/label_videos_sleap.py index ed994c7..daaac0a 100644 --- a/diplomat/frontends/sleap/label_videos_sleap.py +++ b/diplomat/frontends/sleap/label_videos_sleap.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import TypeVar, Type, Tuple import cv2 -import sleap +from .sleap_importer import sleap import diplomat.processing.type_casters as tc from diplomat.utils.cli_tools import extra_cli_args diff --git a/diplomat/frontends/sleap/predict_frames_sleap.py b/diplomat/frontends/sleap/predict_frames_sleap.py index d9b406d..7bc01fb 100644 --- a/diplomat/frontends/sleap/predict_frames_sleap.py +++ b/diplomat/frontends/sleap/predict_frames_sleap.py @@ -9,13 +9,21 @@ from diplomat.utils.video_info import is_video from diplomat.utils import frame_store_fmt -from .run_utils import _get_default_value, _paths_to_str, _get_video_metadata, _get_predictor_settings, PoseLabels, \ - Timer, _attach_run_info, _load_config +from .run_utils import ( + _get_default_value, + _paths_to_str, + _get_video_metadata, + _get_predictor_settings, + PoseLabels, + Timer, + _attach_run_info, + _load_config +) + +from .sleap_importer import sleap from .sleap_providers import sleap_metadata_from_config, SleapMetadata from .visual_settings import VISUAL_SETTINGS -import sleap - @dataclass class _DummyVideo: @@ -55,6 +63,7 @@ def analyze_frames( {extra_cli_args} """ + import sleap batch_size = _get_default_value(sleap.load_model, "batch_size", 4) if (batch_size is None) else batch_size num_outputs = 1 if (num_outputs is None) else num_outputs diff --git a/diplomat/frontends/sleap/predict_videos_sleap.py b/diplomat/frontends/sleap/predict_videos_sleap.py index b3e7061..f33a6be 100644 --- a/diplomat/frontends/sleap/predict_videos_sleap.py +++ b/diplomat/frontends/sleap/predict_videos_sleap.py @@ -5,7 +5,6 @@ from diplomat.utils.cli_tools import extra_cli_args, Flag from diplomat.processing.progress_bar import TQDMProgressBar from diplomat.processing import get_predictor, Config, Predictor -import sleap from .sleap_providers import PredictorExtractor from .visual_settings import VISUAL_SETTINGS @@ -58,6 +57,7 @@ def analyze_videos( """ _setup_gpus(use_cpu, gpu_index) + import sleap batch_size = _get_default_value(sleap.load_model, "batch_size", 4) if(batch_size is None) else batch_size num_outputs = 1 if(num_outputs is None) else num_outputs @@ -103,6 +103,7 @@ def _analyze_single_video( predictor_settings: Optional[dict], output_suffix: str ): + import sleap video_path = Path(video_path).resolve() video = sleap.load_video(str(video_path)) output_path = video_path.parent / (video_path.name + f".diplomat_{predictor_cls.get_name()}{output_suffix}.slp") diff --git a/diplomat/frontends/sleap/run_utils.py b/diplomat/frontends/sleap/run_utils.py index 06eb4b5..7544800 100644 --- a/diplomat/frontends/sleap/run_utils.py +++ b/diplomat/frontends/sleap/run_utils.py @@ -5,7 +5,7 @@ from inspect import signature from pathlib import Path from typing import Optional, Type, List, Tuple, Iterable, Dict -import sleap +from .sleap_importer import sleap import numpy as np from diplomat.processing import Predictor, Config, Pose from diplomat.utils.shapes import shape_iterator diff --git a/diplomat/frontends/sleap/sleap_importer.py b/diplomat/frontends/sleap/sleap_importer.py new file mode 100644 index 0000000..7fe1def --- /dev/null +++ b/diplomat/frontends/sleap/sleap_importer.py @@ -0,0 +1,14 @@ +from diplomat.utils.lazy_import import LazyImporter, verify_existence_of + +verify_existence_of("sleap") + + +sleap = LazyImporter("sleap") + +SleapDataConfig = LazyImporter("sleap.nn.config.DataConfig") +SleapVideo = LazyImporter("sleap.Video") +Provider = LazyImporter("sleap.nn.data.pipelines.Provider") +SleapVideoReader = LazyImporter("sleap.nn.data.providers.VideoReader") +SleapPredictor = LazyImporter("sleap.nn.inference.Predictor") +SleapInferenceLayer = LazyImporter("sleap.nn.inference.InferenceLayer") +SleapSkeleton = LazyImporter("sleap.skeleton.Skeleton") \ No newline at end of file diff --git a/diplomat/frontends/sleap/sleap_providers.py b/diplomat/frontends/sleap/sleap_providers.py index a5732e8..87f386e 100644 --- a/diplomat/frontends/sleap/sleap_providers.py +++ b/diplomat/frontends/sleap/sleap_providers.py @@ -1,17 +1,20 @@ from abc import ABC, abstractmethod from typing import Optional, Dict, Union, Iterator, Set, Type, List, Tuple -import sleap.nn.data.resizing -from sleap.nn.config import DataConfig as SleapDataConfig from typing_extensions import TypedDict import numpy as np import tensorflow as tf -from sleap import Video as SleapVideo -from sleap.nn.data.pipelines import Provider -from sleap.nn.data.providers import VideoReader as SleapVideoReader -from sleap.nn.inference import Predictor as SleapPredictor -from sleap.nn.inference import InferenceLayer as SleapInferenceLayer -from sleap.skeleton import Skeleton as SleapSkeleton + +from .sleap_importer import ( + SleapDataConfig, + SleapVideo, + Provider, + SleapVideoReader, + SleapPredictor, + SleapInferenceLayer, + SleapSkeleton +) + from diplomat.processing import TrackingData @@ -62,10 +65,12 @@ def _normalize_conf_map(conf_map: tf.Tensor) -> tf.Tensor: class BottomUpModelExtractor(SleapModelExtractor): - from sleap.nn.inference import BottomUpPredictor, BottomUpMultiClassPredictor - supported_models: Optional[Set[SleapPredictor]] = {BottomUpPredictor, BottomUpMultiClassPredictor} + @property + def supported_models(self) -> Set[SleapPredictor]: + from sleap.nn.inference import BottomUpPredictor, BottomUpMultiClassPredictor + return {BottomUpPredictor, BottomUpMultiClassPredictor} - def __init__(self, model: Union[BottomUpPredictor, BottomUpMultiClassPredictor]): + def __init__(self, model: SleapPredictor): super().__init__(model) self._predictor = model @@ -99,14 +104,17 @@ def _extract_model_outputs(inf_layer: SleapInferenceLayer, images: tf.Tensor) -> class TopDownModelExtractor(SleapModelExtractor): - from sleap.nn.inference import TopDownPredictor, TopDownMultiClassPredictor - supported_models: Optional[Set[SleapPredictor]] = {TopDownPredictor, TopDownMultiClassPredictor} + @property + def supported_models(self) -> Set[SleapPredictor]: + from sleap.nn.inference import TopDownPredictor, TopDownMultiClassPredictor + return {TopDownPredictor, TopDownMultiClassPredictor} - def __init__(self, model: Union[TopDownPredictor, TopDownMultiClassPredictor]): + def __init__(self, model: SleapPredictor): super().__init__(model) self._predictor = model # TODO: Eventually fix top down support to actually work. - raise NotImplementedError("SLEAP's top down model is currently not supported. Please train using a different model type to use DIPLOMAT.") + raise NotImplementedError("SLEAP's top down model is currently not supported. Please train using a " + "different model type to use DIPLOMAT.") @staticmethod def _merge_tiles( @@ -162,10 +170,12 @@ def extract(self, data: Union[Dict, np.ndarray]) -> Tuple[tf.Tensor, Optional[tf class SingleInstanceModelExtractor(SleapModelExtractor): - from sleap.nn.inference import SingleInstancePredictor - supported_models: Optional[Set[SleapPredictor]] = {SingleInstancePredictor} + @property + def supported_models(self) -> Set[SleapPredictor]: + from sleap.nn.inference import SingleInstancePredictor + return {SingleInstancePredictor} - def __init__(self, model: Union[SingleInstancePredictor]): + def __init__(self, model: SleapPredictor): super().__init__(model) self._predictor = model diff --git a/diplomat/frontends/sleap/tweak_results_sleap.py b/diplomat/frontends/sleap/tweak_results_sleap.py index 83ee724..350229f 100644 --- a/diplomat/frontends/sleap/tweak_results_sleap.py +++ b/diplomat/frontends/sleap/tweak_results_sleap.py @@ -1,6 +1,6 @@ from pathlib import Path -import sleap +from .sleap_importer import sleap import diplomat.processing.type_casters as tc from diplomat.utils.cli_tools import extra_cli_args diff --git a/diplomat/frontends/sleap/visual_settings.py b/diplomat/frontends/sleap/visual_settings.py index 34c1b9c..bfd73a8 100644 --- a/diplomat/frontends/sleap/visual_settings.py +++ b/diplomat/frontends/sleap/visual_settings.py @@ -5,27 +5,29 @@ import diplomat.processing.type_casters as tc import matplotlib.colors as mpl_colors from diplomat.utils.colormaps import to_colormap -from sleap.prefs import prefs -import numpy as np import cv2 + def cv2_fourcc_string(val) -> int: return int(cv2.VideoWriter_fourcc(*val)) Skeleton = tc.Union[tc.List[tc.Tuple[str, str]], tc.Dict[str, tc.List[str]], tc.List[str], tc.NoneType, bool] -_marker_size_fallback = prefs["marker size"] if(isinstance(prefs["marker size"], int)) else 4 VISUAL_SETTINGS: ConfigSpec = { "pcutoff": (0.1, tc.RangedFloat(0, 1), "The probability to cutoff results below."), - "dotsize": (_marker_size_fallback, int, "The size of the dots."), + "dotsize": (4, int, "The size of the dots."), "alphavalue": (0.7, tc.RangedFloat(0, 1), "The alpha value of the dots."), - "colormap": (None, to_colormap, "The colormap to use for tracked points in the video. Can be a matplotlib colormap or a list of matplotlib colors."), - "shape_list": (None, tc.Optional(tc.List(str)), "A list of shape names, shapes to use for drawing each individual's dots."), + "colormap": (None, to_colormap, "The colormap to use for tracked points in the video. Can be a matplotlib " + "colormap or a list of matplotlib colors."), + "shape_list": (None, tc.Optional(tc.List(str)), "A list of shape names, shapes to use for drawing each " + "individual's dots."), "line_thickness": (1, int, "Thickness of lines drawn."), - "skeleton": (None, Skeleton, "The skeleton to use for this this run of DIPLOMAT. Defaults to None, which uses the skeleton associated with the " - "sleap project. Can be a list of strings, a list of tuples of strings, a dictionary of strings to strings or lists" - "of strings, or True/False (True connects all parts, False connects no parts, disabling the skeleton).") + "skeleton": (None, Skeleton, "The skeleton to use for this this run of DIPLOMAT. Defaults to None, which uses " + "the skeleton associated with the sleap project. Can be a list of strings, a list of " + "tuples of strings, a dictionary of strings to strings or lists of strings, or " + "True/False (True connects all parts, False connects no parts, disabling the " + "skeleton).") } diff --git a/diplomat/utils/lazy_import.py b/diplomat/utils/lazy_import.py index 36ba140..8775f61 100644 --- a/diplomat/utils/lazy_import.py +++ b/diplomat/utils/lazy_import.py @@ -112,4 +112,10 @@ def __call__(self, *args, **kwargs): if(self._mod is self.NOTHING): self._mod = self._imp(self._name, self._pkg) - return self._mod(*args, **kwargs) \ No newline at end of file + return self._mod(*args, **kwargs) + + def __str__(self) -> str: + return repr(self) + + def __repr__(self) -> str: + return self._name From deb97a057b13385458ebe0f993f2d71ff08e3d40 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 5 Jun 2023 02:01:54 -0400 Subject: [PATCH 3/7] Further lazy importing... --- diplomat/frontends/deeplabcut/dlc_importer.py | 1 + diplomat/frontends/deeplabcut/predict_videos_dlc.py | 3 +-- diplomat/frontends/sleap/run_utils.py | 4 +++- diplomat/frontends/sleap/sleap_importer.py | 1 + diplomat/frontends/sleap/sleap_providers.py | 3 ++- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/diplomat/frontends/deeplabcut/dlc_importer.py b/diplomat/frontends/deeplabcut/dlc_importer.py index ad2064b..e50e35e 100644 --- a/diplomat/frontends/deeplabcut/dlc_importer.py +++ b/diplomat/frontends/deeplabcut/dlc_importer.py @@ -6,6 +6,7 @@ verify_existence_of("deeplabcut") deeplabcut = LazyImporter("deeplabcut") +tf = LazyImporter("tensorflow") predict = LazyImporter("deeplabcut.pose_estimation_tensorflow.core.predict") checkcropping = LazyImporter("deeplabcut.pose_estimation_tensorflow.predict_videos.checkcropping") load_config = LazyImporter("deeplabcut.pose_estimation_tensorflow.config.load_config") diff --git a/diplomat/frontends/deeplabcut/predict_videos_dlc.py b/diplomat/frontends/deeplabcut/predict_videos_dlc.py index 5e61979..8704d43 100644 --- a/diplomat/frontends/deeplabcut/predict_videos_dlc.py +++ b/diplomat/frontends/deeplabcut/predict_videos_dlc.py @@ -12,8 +12,7 @@ from diplomat.utils.shapes import shape_iterator # DLC Imports -from .dlc_importer import predict, checkcropping, load_config, auxiliaryfunctions -import tensorflow as tf +from .dlc_importer import predict, checkcropping, load_config, auxiliaryfunctions, tf from tqdm import tqdm import pandas as pd diff --git a/diplomat/frontends/sleap/run_utils.py b/diplomat/frontends/sleap/run_utils.py index 7544800..2acd924 100644 --- a/diplomat/frontends/sleap/run_utils.py +++ b/diplomat/frontends/sleap/run_utils.py @@ -1,4 +1,4 @@ -import tensorflow as tf +from .sleap_importer import tf import platform import time from datetime import datetime @@ -16,6 +16,7 @@ def _frame_iter( skeleton: sleap.Skeleton, track_to_idx: Dict[sleap.Track, int] ) -> Iterable[sleap.PredictedInstance]: + import sleap for inst in frame.instances: if((inst.track is not None) and isinstance(inst, sleap.PredictedInstance) and (inst.skeleton == skeleton)): yield track_to_idx[inst.track], inst @@ -266,6 +267,7 @@ def _attach_run_info( command: List[str] ) -> sleap.Labels: import diplomat + import sleap labels.provenance.update({ "sleap_version": sleap.__version__, diff --git a/diplomat/frontends/sleap/sleap_importer.py b/diplomat/frontends/sleap/sleap_importer.py index 7fe1def..06d09e1 100644 --- a/diplomat/frontends/sleap/sleap_importer.py +++ b/diplomat/frontends/sleap/sleap_importer.py @@ -4,6 +4,7 @@ sleap = LazyImporter("sleap") +tf = LazyImporter("tensorflow") SleapDataConfig = LazyImporter("sleap.nn.config.DataConfig") SleapVideo = LazyImporter("sleap.Video") diff --git a/diplomat/frontends/sleap/sleap_providers.py b/diplomat/frontends/sleap/sleap_providers.py index 87f386e..0cf5b97 100644 --- a/diplomat/frontends/sleap/sleap_providers.py +++ b/diplomat/frontends/sleap/sleap_providers.py @@ -3,9 +3,9 @@ from typing_extensions import TypedDict import numpy as np -import tensorflow as tf from .sleap_importer import ( + tf, SleapDataConfig, SleapVideo, Provider, @@ -258,6 +258,7 @@ def _create_integral_offsets(cls, probs: tf.Tensor, stride: float, kernel_size: return tf.math.divide_no_nan(results[:, :, :, :, :2], results[:, :, :, :, 2:]) def extract(self, data: Union[Provider, SleapVideo]) -> Iterator[TrackingData]: + from sleap import Video as SleapVideo if(isinstance(data, SleapVideo)): data = SleapVideoReader(data) From 0014fbe4ba93bc89f663e359c8643b0ded6728b6 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 5 Jun 2023 15:20:48 -0400 Subject: [PATCH 4/7] A bunch of refactoring + new csv frontend. --- diplomat/frontends/csv/__init__.py | 29 ++++ diplomat/frontends/csv/_verify_func.py | 17 +++ diplomat/frontends/csv/csv_utils.py | 33 +++++ diplomat/frontends/csv/label_videos.py | 138 ++++++++++++++++++ diplomat/frontends/csv/tweak_results.py | 86 +++++++++++ .../frontends/sleap/label_videos_sleap.py | 36 +---- diplomat/utils/track_formats.py | 27 +++- diplomat/utils/tweak_ui.py | 2 +- diplomat/utils/video_io.py | 51 +++++++ 9 files changed, 382 insertions(+), 37 deletions(-) create mode 100644 diplomat/frontends/csv/__init__.py create mode 100644 diplomat/frontends/csv/_verify_func.py create mode 100644 diplomat/frontends/csv/csv_utils.py create mode 100644 diplomat/frontends/csv/label_videos.py create mode 100644 diplomat/frontends/csv/tweak_results.py create mode 100644 diplomat/utils/video_io.py diff --git a/diplomat/frontends/csv/__init__.py b/diplomat/frontends/csv/__init__.py new file mode 100644 index 0000000..ed0864a --- /dev/null +++ b/diplomat/frontends/csv/__init__.py @@ -0,0 +1,29 @@ +from typing import Optional +from diplomat.frontends import DIPLOMATFrontend, DIPLOMATCommands + + +class DEEPLABCUTFrontend(DIPLOMATFrontend): + """ + The CSV frontend for DIPLOMAT. Contains functions for running some DIPLOMAT operations on csv trajectory files. + Supports video creation, and tweak UI commands. + """ + @classmethod + def init(cls) -> Optional[DIPLOMATCommands]: + try: + from diplomat.frontends.csv._verify_func import _verify + from diplomat.frontends.csv.label_videos import label_videos + from diplomat.frontends.csv.tweak_results import tweak_videos + except ImportError: + return None + + return DIPLOMATCommands( + _verifier=_verify, + label_videos=label_videos, + tweak_videos=tweak_videos + ) + + @classmethod + def get_package_name(cls) -> str: + return "csv" + + diff --git a/diplomat/frontends/csv/_verify_func.py b/diplomat/frontends/csv/_verify_func.py new file mode 100644 index 0000000..ea29baf --- /dev/null +++ b/diplomat/frontends/csv/_verify_func.py @@ -0,0 +1,17 @@ +from diplomat.processing.type_casters import typecaster_function, Union, List, PathLike +from .csv_utils import _fix_paths, _header_check + + +@typecaster_function +def _verify( + config: Union[List[PathLike], PathLike], + **kwargs +) -> bool: + if("videos" not in kwargs): + return False + + try: + config, videos = _fix_paths(config, kwargs["videos"]) + return all(_header_check(c) for c in config) + except (IOError, ValueError): + return False diff --git a/diplomat/frontends/csv/csv_utils.py b/diplomat/frontends/csv/csv_utils.py new file mode 100644 index 0000000..8e33286 --- /dev/null +++ b/diplomat/frontends/csv/csv_utils.py @@ -0,0 +1,33 @@ + + +def _header_check(csv): + with open(csv, "r") as csv_handle: + first_lines = [csv_handle.readline().strip("\n").split(",") for i in range(3)] + + header_cols = len(first_lines[0]) + + if(not all(header_cols == len(line) for line in first_lines)): + return False + + last_header_line = first_lines[-1] + last_line_exp = ["x", "y", "likelihood"] * (len(last_header_line) // 3) + + if(last_header_line != last_line_exp): + return False + + return True + + +def _fix_paths(csvs, videos): + csvs = csvs if(isinstance(csvs, (tuple, list))) else [csvs] + videos = videos if(isinstance(videos, (tuple, list))) else [videos] + + if(len(csvs) == 1): + csvs = csvs * len(videos) + if(len(videos) == 1): + videos = videos * len(csvs) + + if(len(videos) != len(csvs)): + raise ValueError("Number of videos and csv files passes don't match!") + + return csvs, videos diff --git a/diplomat/frontends/csv/label_videos.py b/diplomat/frontends/csv/label_videos.py new file mode 100644 index 0000000..04804c8 --- /dev/null +++ b/diplomat/frontends/csv/label_videos.py @@ -0,0 +1,138 @@ +from pathlib import Path +from typing import Tuple + +from diplomat.utils.cli_tools import extra_cli_args +from diplomat.frontends.sleap.visual_settings import FULL_VISUAL_SETTINGS +import diplomat.processing.type_casters as tc +from diplomat.utils.track_formats import load_diplomat_table, to_diplomat_pose +from diplomat.utils.video_io import ContextVideoWriter, ContextVideoCapture +from diplomat.utils.shapes import CV2DotShapeDrawer, shape_iterator +from diplomat.processing import Config, TQDMProgressBar +from diplomat.utils.colormaps import iter_colormap + +import cv2 + +from .csv_utils import _fix_paths + + +@extra_cli_args(FULL_VISUAL_SETTINGS, auto_cast=False) +@tc.typecaster_function +def label_videos( + config: tc.Union[tc.List[tc.PathLike], tc.PathLike], + videos: tc.Union[tc.List[tc.PathLike], tc.PathLike], + body_parts_to_plot: tc.Optional[tc.List[str]] = None, + video_extension: str = "mp4", + **kwargs +): + """ + Labeled videos with arbitrary csv files in diplomat's csv format. + + :param config: The path (or list of paths) to the csv file(s) to label the videos with. + :param videos: Paths to video file(s) corresponding to the provided csv files. + :param body_parts_to_plot: A set or list of body part names to label, or None, indicating to label all parts. + :param video_extension: The file extension to use on the created labeled video, excluding the dot. + Defaults to 'mp4'. + :param kwargs: The following additional arguments are supported: + + {extra_cli_args} + """ + config, videos = _fix_paths(config, videos) + visual_settings = Config(kwargs, FULL_VISUAL_SETTINGS) + + for c, v in zip(config, videos): + _label_videos_single(str(c), str(v), body_parts_to_plot, video_extension, visual_settings) + + +class EverythingSet: + def __contains__(self, item): + return True + + +def _to_cv2_color(color: Tuple[float, float, float, float]) -> Tuple[int, int, int, int]: + r, g, b, a = [min(255, max(0, int(val * 256))) for val in color] + return (b, g, r, a) + + +def _label_videos_single( + csv: str, + video: str, + body_parts_to_plot: tc.Optional[tc.List[str]], + video_extension: str, + visual_settings: Config +): + pose_data = load_diplomat_table(csv) + poses, bp_names, num_outputs = to_diplomat_pose(pose_data) + video_extension = video_extension if(video_extension.startswith(".")) else f".{video_extension}" + video_path = Path(video) + + # Create the output path... + output_path = video_path.parent / (video_path.stem + "_labeled" + video_extension) + + body_parts_to_plot = EverythingSet() if(body_parts_to_plot is None) else set(body_parts_to_plot) + upscale = 1 if(visual_settings.upscale_factor is None) else visual_settings.upscale_factor + + with ContextVideoCapture(video) as in_video: + out_w, out_h = tuple( + int(dim * upscale) for dim in [ + in_video.get(cv2.CAP_PROP_FRAME_WIDTH), + in_video.get(cv2.CAP_PROP_FRAME_HEIGHT) + ] + ) + + with ContextVideoWriter( + str(output_path), + visual_settings.output_codec, + in_video.get(cv2.CAP_PROP_FPS), + (out_w, out_h) + ) as writer: + with TQDMProgressBar(total=poses.get_frame_count()) as p: + for f_i in range(poses.get_frame_count()): + retval, frame = in_video.read() + + if(not retval): + continue + + frame = frame[..., ::-1] + + if (visual_settings.upscale_factor is not None): + frame = cv2.resize( + frame, + (out_w, out_h), + interpolation=cv2.INTER_NEAREST + ) + + overlay = frame.copy() + + colors = iter_colormap(visual_settings.colormap, poses.get_bodypart_count()) + shapes = shape_iterator(visual_settings.shape_list, num_outputs) + + part_iter = zip( + [name for name in bp_names for _ in range(num_outputs)], + poses.get_x_at(f_i, slice(None)), + poses.get_y_at(f_i, slice(None)), + poses.get_prob_at(f_i, slice(None)), + colors, + shapes + ) + + for (name, x, y, prob, color, shape) in part_iter: + if (x != x or y != y): + continue + + if (name not in body_parts_to_plot): + continue + + shape_drawer = CV2DotShapeDrawer( + overlay, + _to_cv2_color(tuple(color[:3]) + (1,)), + -1 if (prob > visual_settings.pcutoff) else visual_settings.line_thickness, + cv2.LINE_AA if (visual_settings.antialiasing) else None + )[shape] + + if (prob > visual_settings.pcutoff or visual_settings.draw_hidden_tracks): + shape_drawer(int(x * upscale), int(y * upscale), int(visual_settings.dotsize * upscale)) + + writer.write(cv2.addWeighted( + overlay, visual_settings.alphavalue, frame, 1 - visual_settings.alphavalue, 0 + )) + p.update() diff --git a/diplomat/frontends/csv/tweak_results.py b/diplomat/frontends/csv/tweak_results.py new file mode 100644 index 0000000..9b51acb --- /dev/null +++ b/diplomat/frontends/csv/tweak_results.py @@ -0,0 +1,86 @@ +import cv2 +from diplomat.frontends.sleap.visual_settings import VISUAL_SETTINGS +from diplomat.processing import Config, Pose +from diplomat.utils.cli_tools import extra_cli_args +from diplomat.utils.tweak_ui import TweakUI +import diplomat.processing.type_casters as tc +from diplomat.utils.track_formats import load_diplomat_table, to_diplomat_pose, save_diplomat_table, to_diplomat_table +from diplomat.utils.video_io import ContextVideoCapture +from diplomat.utils.shapes import shape_iterator + +from .csv_utils import _fix_paths + + +@extra_cli_args(VISUAL_SETTINGS, auto_cast=False) +@tc.typecaster_function +def tweak_videos( + config: tc.PathLike, + videos: tc.Union[tc.List[tc.PathLike], tc.PathLike], + **kwargs +): + """ + Make minor modifications and tweaks to arbitrary csv files using DIPLOMAT's supervised UI. + + :param config: The path (or list of paths) to the csv file(s) to edit. + :param videos: Paths to video file(s) corresponding to the provided csv files. + :param kwargs: The following additional arguments are supported: + + {extra_cli_args} + """ + config, videos = _fix_paths(config, videos) + visual_cfg = Config(kwargs, VISUAL_SETTINGS) + + for c, v in zip(config, videos): + _tweak_video_single(str(c), str(v), visual_cfg) + + +def _get_video_meta( + video: str, + num_frames: int, + visual_settings: Config, + output_file: str, + num_outputs: int +): + with ContextVideoCapture(str(video)) as vid_cap: + fps = vid_cap.get(cv2.CAP_PROP_FPS) + w, h = vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH), vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + + return { + "fps": fps, + "duration": num_frames / fps, + "size": (h, w), + "output-file-path": str(output_file), + "orig-video-path": video, + "cropping-offset": None, + "dotsize": visual_settings.dotsize, + "colormap": visual_settings.colormap, + "shape_list": shape_iterator(visual_settings.shape_list, num_outputs), + "alphavalue": visual_settings.alphavalue, + "pcutoff": visual_settings.pcutoff, + "line_thickness": visual_settings.get("line_thickness", 1), + "skeleton": visual_settings.skeleton + } + + +def _tweak_video_single( + csv: str, + video: str, + visual_cfg: Config +): + print(f"Making modifications to: '{csv}' (video: '{video}')") + pose_table = load_diplomat_table(csv) + poses, bp_names, num_outputs = to_diplomat_pose(pose_table) + + ui_manager = TweakUI() + + def on_end(save: bool, p: Pose): + if(save): + print("Saving results...") + save_diplomat_table(to_diplomat_table(num_outputs, bp_names, p), csv) + print("Results saved!") + else: + print("Operation canceled...") + + all_names = [name if(i == 0) else f"{name}{i}" for name in bp_names for i in range(num_outputs)] + video_meta = _get_video_meta(video, poses.get_frame_count(), visual_cfg, csv, num_outputs) + ui_manager.tweak(None, video, poses, all_names, video_meta, num_outputs, None, on_end) diff --git a/diplomat/frontends/sleap/label_videos_sleap.py b/diplomat/frontends/sleap/label_videos_sleap.py index daaac0a..ccecc3c 100644 --- a/diplomat/frontends/sleap/label_videos_sleap.py +++ b/diplomat/frontends/sleap/label_videos_sleap.py @@ -1,6 +1,5 @@ -import functools from pathlib import Path -from typing import TypeVar, Type, Tuple +from typing import Tuple import cv2 from .sleap_importer import sleap @@ -8,6 +7,7 @@ from diplomat.utils.cli_tools import extra_cli_args from diplomat.processing import Config, TQDMProgressBar from diplomat.utils.colormaps import iter_colormap +from diplomat.utils.video_io import ContextVideoWriter from diplomat.utils.shapes import shape_iterator, CV2DotShapeDrawer from .visual_settings import FULL_VISUAL_SETTINGS @@ -59,38 +59,6 @@ def label_videos( _label_video_single(video, visual_settings, body_parts_to_plot, video_extension) -T = TypeVar("T") - - -@functools.lru_cache(None) -def _create_manager(clazz: Type[T]) -> Type[T]: - class cv2_context_manager(clazz): - def __enter__(self): - if(not self.isOpened()): - self.release() - raise IOError("Unable to open video capture...") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.release() - - def read(self): - if(not self.isOpened()): - raise IOError("Video capture is not open.") - return super().read() - - def write(self, frame): - if (not self.isOpened()): - raise IOError("Video writer is not open.") - return super().write(frame) - - return cv2_context_manager - - -ContextVideoWriter = _create_manager(cv2.VideoWriter) -ContextVideoCapture = _create_manager(cv2.VideoCapture) - - def _label_video_single( label_path: str, visual_settings: Config, diff --git a/diplomat/utils/track_formats.py b/diplomat/utils/track_formats.py index 70cbe5b..85f05ee 100644 --- a/diplomat/utils/track_formats.py +++ b/diplomat/utils/track_formats.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Tuple import pandas as pd from io import BufferedWriter, BufferedReader from diplomat.processing import Pose @@ -52,4 +52,27 @@ def load_diplomat_table(path_or_buf: Union[str, BufferedReader]) -> pd.DataFrame :return: A pd.DataFrame, being the diplomat-pandas pose table stored at the given location. """ - return pd.read_csv(path_or_buf, index_col=False, header=[0, 1, 2]) + return pd.read_csv(path_or_buf, index_col=None, header=[0, 1, 2]) + + +def to_diplomat_pose(table: pd.DataFrame) -> Tuple[Pose, List[str], int]: + """ + Convert a diplomat pandas table to a diplomat Pose object. + + :param table: A diplomat pandas table, typically as loaded from a CSV. + + :return: A tuple containing a diplomat Pose object, a list of string giving the body part names in order, + and an integer giving the number of bodies, our outputs. + """ + num_outputs = len(table.columns.unique(0)) + names = table.columns.unique(1) + + poses_enc = table.to_numpy() + poses_enc = poses_enc.reshape( + (table.shape[0], num_outputs, len(names), 3) + ).transpose((0, 2, 1, 3)).reshape(table.shape) + + poses = Pose.empty_pose(table.shape[0], len(names) * num_outputs) + poses.get_all()[:] = poses_enc + + return (poses, names, num_outputs) diff --git a/diplomat/utils/tweak_ui.py b/diplomat/utils/tweak_ui.py index d9d6445..1a833a4 100644 --- a/diplomat/utils/tweak_ui.py +++ b/diplomat/utils/tweak_ui.py @@ -243,7 +243,7 @@ def __init__(self): def tweak( self, parent, - video_path: os.PathLike, + video_path: Union[os.PathLike, str], poses: Pose, bodypart_names: List[str], video_metadata: Dict[str, Any], diff --git a/diplomat/utils/video_io.py b/diplomat/utils/video_io.py new file mode 100644 index 0000000..39fb8c1 --- /dev/null +++ b/diplomat/utils/video_io.py @@ -0,0 +1,51 @@ +import functools +import cv2 +from typing import TypeVar, Type + + +T = TypeVar("T") + + +@functools.lru_cache(None) +def _create_cv2_manager(clazz: Type[T]) -> Type[T]: + """ + Create a context manager for a CV2 io writing class. Requires the class implements release for closing a + file resource. + + :param clazz: The cv2 class to subclass, adding support for python context managers to the class. + + :return: A new class, with support for with statements. + """ + class cv2_context_manager: + def __init__(self, *args, **kwargs): + self._inst = clazz(*args, **kwargs) + def __enter__(self): + if(not self.isOpened()): + self.release() + raise IOError("Unable to open video capture...") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + def read(self): + if(not self.isOpened()): + raise IOError("Video capture is not open.") + return self._inst.read() + + def write(self, frame): + if (not self.isOpened()): + raise IOError("Video writer is not open.") + return self._inst.write(frame) + + def __getattr__(self, item: str): + return getattr(self._inst, item) + + return cv2_context_manager + + +""" An implementation of cv2.VideoWriter with support for context managers. """ +ContextVideoWriter = _create_cv2_manager(cv2.VideoWriter) + +""" An implementation of cv2.VideoWriter with support for context managers. """ +ContextVideoCapture = _create_cv2_manager(cv2.VideoCapture) From 6ea71c3a41066830165e139b92bfc5e58b0fb6ff Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Mon, 5 Jun 2023 19:33:03 -0400 Subject: [PATCH 5/7] Fix csv video labeling to not invert colors. --- diplomat/frontends/csv/label_videos.py | 2 -- diplomat/frontends/sleap/sleap_providers.py | 18 ++++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/diplomat/frontends/csv/label_videos.py b/diplomat/frontends/csv/label_videos.py index 04804c8..a9ea88a 100644 --- a/diplomat/frontends/csv/label_videos.py +++ b/diplomat/frontends/csv/label_videos.py @@ -92,8 +92,6 @@ def _label_videos_single( if(not retval): continue - frame = frame[..., ::-1] - if (visual_settings.upscale_factor is not None): frame = cv2.resize( frame, diff --git a/diplomat/frontends/sleap/sleap_providers.py b/diplomat/frontends/sleap/sleap_providers.py index 0cf5b97..6c89632 100644 --- a/diplomat/frontends/sleap/sleap_providers.py +++ b/diplomat/frontends/sleap/sleap_providers.py @@ -44,7 +44,9 @@ class SleapModelExtractor(ABC): """ Takes a SLEAP Predictor, and modifies it so that it outputs TrackingData instead of SLEAP predictions. """ - supported_models: Optional[Set[Type[SleapPredictor]]] = None + @classmethod + def supported_models(cls) -> Set[SleapPredictor]: + return set() @abstractmethod def __init__(self, model: SleapPredictor): @@ -65,8 +67,8 @@ def _normalize_conf_map(conf_map: tf.Tensor) -> tf.Tensor: class BottomUpModelExtractor(SleapModelExtractor): - @property - def supported_models(self) -> Set[SleapPredictor]: + @classmethod + def supported_models(cls) -> Set[SleapPredictor]: from sleap.nn.inference import BottomUpPredictor, BottomUpMultiClassPredictor return {BottomUpPredictor, BottomUpMultiClassPredictor} @@ -104,8 +106,8 @@ def _extract_model_outputs(inf_layer: SleapInferenceLayer, images: tf.Tensor) -> class TopDownModelExtractor(SleapModelExtractor): - @property - def supported_models(self) -> Set[SleapPredictor]: + @classmethod + def supported_models(cls) -> Set[SleapPredictor]: from sleap.nn.inference import TopDownPredictor, TopDownMultiClassPredictor return {TopDownPredictor, TopDownMultiClassPredictor} @@ -170,8 +172,8 @@ def extract(self, data: Union[Dict, np.ndarray]) -> Tuple[tf.Tensor, Optional[tf class SingleInstanceModelExtractor(SleapModelExtractor): - @property - def supported_models(self) -> Set[SleapPredictor]: + @classmethod + def supported_models(cls) -> Set[SleapPredictor]: from sleap.nn.inference import SingleInstancePredictor return {SingleInstancePredictor} @@ -206,7 +208,7 @@ def __init__(self, predictor: SleapPredictor, refinement_kernel_size: int): self._refinement_kernel_size = refinement_kernel_size for model_extractor in EXTRACTORS: - if(type(predictor) in model_extractor.supported_models): + if(type(predictor) in model_extractor.supported_models()): self._model_extractor = model_extractor(self._predictor) break else: From 41b0b18995842b12ee0d753444dc8e41dae1a1a4 Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Tue, 6 Jun 2023 18:01:25 -0400 Subject: [PATCH 6/7] Fix frame bug fixed for errant cases. Some default config values adjusted. Multiplier added to std finding pass. --- diplomat/predictors/fpe/frame_passes/cluster_frames.py | 2 +- diplomat/predictors/fpe/frame_passes/fix_frame.py | 9 +++++++++ diplomat/predictors/fpe/frame_passes/optimize_std.py | 6 +++++- diplomat/predictors/sfpe/segmented_frame_pass_engine.py | 2 +- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/diplomat/predictors/fpe/frame_passes/cluster_frames.py b/diplomat/predictors/fpe/frame_passes/cluster_frames.py index e2b8ed6..7de5d30 100644 --- a/diplomat/predictors/fpe/frame_passes/cluster_frames.py +++ b/diplomat/predictors/fpe/frame_passes/cluster_frames.py @@ -296,7 +296,7 @@ def get_config_options(cls) -> ConfigSpec: return { "minimum_cluster_size": ( - 0.25, float, "The minimum size a cluster is allowed to be (As compared to average of all clusters)." + 0.10, float, "The minimum size a cluster is allowed to be (As compared to average of all clusters)." "If the cluster is smaller, it get thrown out and a forest is resolved using the rest of" "the data." ), diff --git a/diplomat/predictors/fpe/frame_passes/fix_frame.py b/diplomat/predictors/fpe/frame_passes/fix_frame.py index 8b5e8fc..e0d66e2 100644 --- a/diplomat/predictors/fpe/frame_passes/fix_frame.py +++ b/diplomat/predictors/fpe/frame_passes/fix_frame.py @@ -191,6 +191,15 @@ def create_fix_frame( # Copy over data to start, ignoring skeleton... for bp_i in range(fb_data.num_bodyparts): fixed_frame[bp_i] = fb_data.frames[frame_idx][bp_i].copy() + + __, __, prob, __, __ = fixed_frame[bp_i].src_data.unpack() + if(prob is None): + # Fallback fix frame: We just create a single cell with 0 probability, forcing viterbi to use entry + # states... + src_data = SparseTrackingData().pack([0], [0], [0], [0], [0]) + fixed_frame[bp_i].src_data = src_data + fb_data.frames[frame_idx][bp_i].src_data = src_data + fixed_frame[bp_i].disable_occluded = True if(skeleton is not None): diff --git a/diplomat/predictors/fpe/frame_passes/optimize_std.py b/diplomat/predictors/fpe/frame_passes/optimize_std.py index 8f88241..1fd4d49 100644 --- a/diplomat/predictors/fpe/frame_passes/optimize_std.py +++ b/diplomat/predictors/fpe/frame_passes/optimize_std.py @@ -49,7 +49,7 @@ def run_pass( result = super().run_pass(fb_data, prog_bar, in_place, reset_bar) - approx_std = self._histogram.get_quantile(0.5)[2] / self.MAGIC_CONST # Median... + approx_std = (self._histogram.get_quantile(0.5)[2] / self.MAGIC_CONST) * self.config.std_multiplier result.metadata.optimal_std = (*self._histogram.get_bin_for_value(approx_std)[:2], approx_std) if(self.config.DEBUG): @@ -117,5 +117,9 @@ def get_config_options(cls) -> ConfigSpec: "A decimal, the offset of the first bin used in the histogram for computing " "the mode, in pixels. Defaults to 1." ), + "std_multiplier": ( + 3, tc.RangedFloat(0, np.inf), + "A positive float, the computed standard deviation is multiplied by this value before " + ), "DEBUG": (False, bool, "Set to True to print the optimal standard deviation found...") } diff --git a/diplomat/predictors/sfpe/segmented_frame_pass_engine.py b/diplomat/predictors/sfpe/segmented_frame_pass_engine.py index a245fed..b4d2cb6 100644 --- a/diplomat/predictors/sfpe/segmented_frame_pass_engine.py +++ b/diplomat/predictors/sfpe/segmented_frame_pass_engine.py @@ -1294,7 +1294,7 @@ def get_settings(cls) -> ConfigSpec: "Whether or not to allow frame passes to utilize multithreading. Defaults to True." ), "segment_size": ( - 400, + 200, type_casters.RangedInteger(10, np.inf), "The size of the segments in frames to break the video into for tracking." ), From 0c29a3c2a84a4046a056844eb60148e653e7ef0c Mon Sep 17 00:00:00 2001 From: Isaac Robinson Date: Tue, 6 Jun 2023 18:03:19 -0400 Subject: [PATCH 7/7] Version bump. --- diplomat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diplomat/__init__.py b/diplomat/__init__.py index a69cae9..e5f9038 100644 --- a/diplomat/__init__.py +++ b/diplomat/__init__.py @@ -2,7 +2,7 @@ A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software. """ -__version__ = "0.0.4" +__version__ = "0.0.5" # Can be used by functions to determine if diplomat was invoked through it's CLI interface. CLI_RUN = False