Skip to content

Commit

Permalink
Merge pull request #3 from TravisWheelerLab/develop
Browse files Browse the repository at this point in the history
Version 0.0.5
  • Loading branch information
isaacrobinson2000 authored Jun 6, 2023
2 parents b1e93c6 + 0c29a3c commit 5c44a2d
Show file tree
Hide file tree
Showing 34 changed files with 659 additions and 150 deletions.
5 changes: 3 additions & 2 deletions diplomat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
"""

__version__ = "0.0.4"
__version__ = "0.0.5"
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
CLI_RUN = False

Expand Down Expand Up @@ -61,12 +61,13 @@ def _load_frontends():
if(hasattr(frontend, "__doc__")):
mod.__doc__ = frontend.__doc__

for (name, func) in asdict(res).items():
for (name, func) in res:
if(not name.startswith("_")):
func = replace_function_name_and_module(func, name, mod.__name__)
setattr(mod, name, func)
mod.__all__.append(name)

return frontends, loaded_funcs


_FRONTENDS, _LOADED_FRONTENDS = _load_frontends()
2 changes: 1 addition & 1 deletion diplomat/_cli_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_dynamic_cli_tree() -> dict:

for frontend_name, funcs in diplomat._LOADED_FRONTENDS.items():
frontend_commands = {
name: func for name, func in asdict(funcs).items() if(not name.startswith("_"))
name: func for name, func in funcs if(not name.startswith("_"))
}

doc_str = getattr(getattr(diplomat, frontend_name), "__doc__", None)
Expand Down
50 changes: 38 additions & 12 deletions diplomat/core_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
from types import ModuleType
from diplomat.utils.tweak_ui import UIImportError
from diplomat.frontends import DIPLOMATContract, DIPLOMATCommands


class ArgumentError(CLIError):
Expand Down Expand Up @@ -47,14 +48,21 @@ def _get_casted_args(tc_func, extra_args, error_on_miss=True):
return new_args


def _find_frontend(config: os.PathLike, **kwargs: typing.Any) -> typing.Tuple[str, ModuleType]:
def _find_frontend(
contracts: Union[DIPLOMATContract, List[DIPLOMATContract]],
config: os.PathLike,
**kwargs: typing.Any
) -> typing.Tuple[str, ModuleType]:
from diplomat import _LOADED_FRONTENDS

contracts = [contracts] if(isinstance(contracts, DIPLOMATContract)) else contracts

for name, funcs in _LOADED_FRONTENDS.items():
if(funcs._verifier(
if(all(funcs.verify(
contract=c,
config=config,
**kwargs
)):
) for c in contracts)):
print(f"Frontend '{name}' selected.")
return (name, funcs)

Expand Down Expand Up @@ -178,6 +186,7 @@ def track(
from diplomat import CLI_RUN

selected_frontend_name, selected_frontend = _find_frontend(
contracts=[DIPLOMATCommands.analyze_videos, DIPLOMATCommands.analyze_videos],
config=config,
videos=videos,
frame_stores=frame_stores,
Expand Down Expand Up @@ -324,7 +333,12 @@ def annotate(
from diplomat import CLI_RUN

# Iterate the frontends, looking for one that actually matches our request...
selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args)
selected_frontend_name, selected_frontend = _find_frontend(
contracts=DIPLOMATCommands.label_videos,
config=config,
videos=videos,
**extra_args
)

if(help_extra):
_display_help(selected_frontend_name, "video labeling", "diplomat annotate", selected_frontend.label_videos, CLI_RUN)
Expand All @@ -350,18 +364,25 @@ def tweak(
**extra_args
):
"""
Make modifications to DIPLOMAT produced tracking results created for a video using a limited version supervised labeling UI. Allows for touching
up and fixing any minor issues that may arise after tracking and saving results.
Make modifications to DIPLOMAT produced tracking results created for a video using a limited version supervised
labeling UI. Allows for touching up and fixing any minor issues that may arise after tracking and saving results.
:param config: The path to the configuration file for the project. The format of this argument will depend on the frontend.
:param config: The path to the configuration file for the project. The format of this argument will depend on the
frontend.
:param videos: A single path or list of paths to video files to tweak the tracks of.
:param help_extra: Boolean, if set to true print extra settings for the automatically selected frontend instead of showing the UI.
:param extra_args: Any additional arguments (if the CLI, flags starting with '--') are passed to the automatically selected frontend.
To see valid values, run tweak with extra_help flag set to true.
:param help_extra: Boolean, if set to true print extra settings for the automatically selected frontend instead of
showing the UI.
:param extra_args: Any additional arguments (if the CLI, flags starting with '--') are passed to the automatically
selected frontend. To see valid values, run tweak with extra_help flag set to true.
"""
from diplomat import CLI_RUN

selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args)
selected_frontend_name, selected_frontend = _find_frontend(
contracts=DIPLOMATCommands.tweak_videos,
config=config,
videos=videos,
**extra_args
)

if(help_extra):
_display_help(selected_frontend_name, "label tweaking", "diplomat tweak", selected_frontend.tweak_videos, CLI_RUN)
Expand Down Expand Up @@ -404,7 +425,12 @@ def convert(
"""
from diplomat import CLI_RUN

selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args)
selected_frontend_name, selected_frontend = _find_frontend(
contracts=DIPLOMATCommands.convert_results,
config=config,
videos=videos,
**extra_args
)

if(help_extra):
_display_help(
Expand Down
2 changes: 1 addition & 1 deletion diplomat/frontend_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def list_loaded_frontends():
print("Description:")
print(f"\t{frontend_docs[name]}")
print("Supported Functions:")
for k, v in asdict(funcs).items():
for k, v in funcs:
if(k.startswith("_")):
continue
print(f"\t{k}")
Expand Down
91 changes: 81 additions & 10 deletions diplomat/frontends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from collections import OrderedDict
from diplomat.processing.type_casters import StrictCallable, PathLike, Union, List, Dict, Any, Optional, TypeCaster, NoneType
import typing


class Select(Union):
def __eq__(self, other: TypeCaster):
if(isinstance(other, Union)):
Expand Down Expand Up @@ -62,29 +62,100 @@ def to_type_hint(self) -> typing.Type:
)


@dataclass(frozen=False)
class DIPLOMATBaselineCommands:
@dataclass(frozen=True)
class DIPLOMATContract:
"""
Represents a 'contract'
"""
method_name: str
method_type: StrictCallable


class CommandManager(type):

__no_type_check__ = False
def __new__(cls, *args, **kwargs):
obj = super().__new__(cls, *args, **kwargs)

annotations = typing.get_type_hints(obj)

for name, annot in annotations.items():
if(name in obj.__dict__):
raise TypeError(f"Command annotation '{name}' has default value, which is not allowed.")

return obj

def __getattr__(self, item):
annot = typing.get_type_hints(self)[item]
return DIPLOMATContract(item, annot)


def required(typecaster: TypeCaster) -> TypeCaster:
typecaster._required = True
return typecaster


class DIPLOMATCommands(metaclass=CommandManager):
"""
The baseline set of functions each DIPLOMAT backend must implement. Backends can add additional commands
by extending this base class...
by passing the methods to this classes constructor.
"""
_verifier: VerifierFunction
_verifier: required(VerifierFunction)
analyze_videos: AnalyzeVideosFunction(NoneType)
analyze_frames: AnalyzeFramesFunction(NoneType)
label_videos: LabelVideosFunction(NoneType)
tweak_videos: LabelVideosFunction(NoneType)
convert_results: ConvertResultsFunction(NoneType)

def __post_init__(self):
def __init__(self, **kwargs):
missing = object()
self._commands = OrderedDict()

annotations = typing.get_type_hints(type(self))

for name, value in asdict(self).items():
annot = annotations.get(name, None)
for name, annot in annotations.items():
value = kwargs.get(name, missing)

if(value is missing):
if(getattr(annot, "_required", False)):
raise ValueError(f"Command '{name}' is required, but was not provided.")
continue
if(annot is None or (not isinstance(annot, TypeCaster))):
raise TypeError("DIPLOMAT Command Struct can only contain typecaster types.")

setattr(self, name, annot(value))
self._commands[name] = annot(value)

for name, value in kwargs.items():
if(name not in annotations):
self._commands[name] = value

def __iter__(self):
return iter(self._commands.items())

def __getattr__(self, item: str):
return self._commands.get(item)

def verify(self, contract: DIPLOMATContract, config: Union[List[PathLike], PathLike], **kwargs: Any) -> bool:
"""
Verify this backend can handle the provided command type, config file, and arguments.
:param contract: The contract for the command. Includes the name of the method and the type of the method,
which will typically be a strict callable.
:param config: The configuration file, checks if the backend can handle this configuration file.
:param kwargs: Any additional arguments to pass to the backends verifier.
:return: A boolean, True if the backend can handle the provided command and arguments, otherwise False.
"""
if(contract.method_name in self._commands):
func = self._commands[contract.method_name]
try:
contract.method_type(func)
except Exception:
return False

return self._verifier(config, **kwargs)

return False


class DIPLOMATFrontend(ABC):
Expand All @@ -93,7 +164,7 @@ class DIPLOMATFrontend(ABC):
"""
@classmethod
@abstractmethod
def init(cls) -> typing.Optional[DIPLOMATBaselineCommands]:
def init(cls) -> typing.Optional[DIPLOMATCommands]:
"""
Attempt to initialize the frontend, returning a list of api functions. If the backend can't initialize due to missing imports/requirements,
this function should return None.
Expand Down
29 changes: 29 additions & 0 deletions diplomat/frontends/csv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional
from diplomat.frontends import DIPLOMATFrontend, DIPLOMATCommands


class DEEPLABCUTFrontend(DIPLOMATFrontend):
"""
The CSV frontend for DIPLOMAT. Contains functions for running some DIPLOMAT operations on csv trajectory files.
Supports video creation, and tweak UI commands.
"""
@classmethod
def init(cls) -> Optional[DIPLOMATCommands]:
try:
from diplomat.frontends.csv._verify_func import _verify
from diplomat.frontends.csv.label_videos import label_videos
from diplomat.frontends.csv.tweak_results import tweak_videos
except ImportError:
return None

return DIPLOMATCommands(
_verifier=_verify,
label_videos=label_videos,
tweak_videos=tweak_videos
)

@classmethod
def get_package_name(cls) -> str:
return "csv"


17 changes: 17 additions & 0 deletions diplomat/frontends/csv/_verify_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from diplomat.processing.type_casters import typecaster_function, Union, List, PathLike
from .csv_utils import _fix_paths, _header_check


@typecaster_function
def _verify(
config: Union[List[PathLike], PathLike],
**kwargs
) -> bool:
if("videos" not in kwargs):
return False

try:
config, videos = _fix_paths(config, kwargs["videos"])
return all(_header_check(c) for c in config)
except (IOError, ValueError):
return False
33 changes: 33 additions & 0 deletions diplomat/frontends/csv/csv_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@


def _header_check(csv):
with open(csv, "r") as csv_handle:
first_lines = [csv_handle.readline().strip("\n").split(",") for i in range(3)]

header_cols = len(first_lines[0])

if(not all(header_cols == len(line) for line in first_lines)):
return False

last_header_line = first_lines[-1]
last_line_exp = ["x", "y", "likelihood"] * (len(last_header_line) // 3)

if(last_header_line != last_line_exp):
return False

return True


def _fix_paths(csvs, videos):
csvs = csvs if(isinstance(csvs, (tuple, list))) else [csvs]
videos = videos if(isinstance(videos, (tuple, list))) else [videos]

if(len(csvs) == 1):
csvs = csvs * len(videos)
if(len(videos) == 1):
videos = videos * len(csvs)

if(len(videos) != len(csvs)):
raise ValueError("Number of videos and csv files passes don't match!")

return csvs, videos
Loading

0 comments on commit 5c44a2d

Please sign in to comment.