Skip to content

Commit

Permalink
cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Sep 20, 2024
1 parent 842386a commit 6ff22bc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 22 deletions.
Binary file not shown.
6 changes: 5 additions & 1 deletion simba/data_processors/cuda/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def is_inside_polygon(x: np.ndarray, y: np.ndarray) -> np.ndarray:
the polygon defined by the vertices in `y`. The result is an array where each element indicates whether
the corresponding point is inside the polygon.
.. image:: _static/img/simba.data_processors.cuda.geometry.is_inside_polygon.webp
:width: 450
:align: center
.. csv-table::
:header: EXPECTED RUNTIMES
:file: ../../../docs/tables/is_inside_polygon.csv
Expand All @@ -144,7 +148,7 @@ def is_inside_polygon(x: np.ndarray, y: np.ndarray) -> np.ndarray:
:header-rows: 1
.. seealso::
For numba CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.framewise_inside_polygon_roi`
For jitted CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.framewise_inside_polygon_roi`
:param np.ndarray x: An array of shape (N, 2) where each row represents a point in 2D space. The points are checked against the polygon.
:param np.ndarray y: An array of shape (M, 2) where each row represents a vertex of the polygon in 2D space.
Expand Down
27 changes: 6 additions & 21 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,19 +1492,9 @@ def _read_data_file_helper(file_path: str,
if clf_names != None:
for clf_name in clf_names:
if not clf_name in df.columns:
raise ColumnNotFoundError(
column_name=clf_name,
file_name=file_path,
source=TrainModelMixin._read_data_file_helper.__name__,
)
elif (
len(set(df[clf_name].unique()) - {0, 1}) > 0
and raise_bool_clf_error
):
raise InvalidInputError(
msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.",
source=TrainModelMixin._read_data_file_helper.__name__,
)
raise MissingColumnsError(msg=f'The SimBA project specifies a classifier named "{clf_name}" that could not be found in your dataset for file {file_path}. Make sure that your project_config.ini is created correctly.', source=TrainModelMixin._read_data_file_helper.__name__)
elif (len(set(df[clf_name].unique()) - {0, 1}) > 0 and raise_bool_clf_error):
raise InvalidInputError(msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field column contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.", source=TrainModelMixin._read_data_file_helper.__name__)
timer.stop_timer()
print(f"Reading complete {vid_name} (elapsed time: {timer.elapsed_time_str}s)...")

Expand Down Expand Up @@ -1599,14 +1589,9 @@ def _read_data_file_helper_futures(
if clf_names != None:
for clf_name in clf_names:
if not clf_name in df.columns:
raise ColumnNotFoundError(column_name=clf_name, file_name=file_path)
elif (
len(set(df[clf_name].unique()) - {0, 1}) > 0
and raise_bool_clf_error
):
raise InvalidInputError(
msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}."
)
raise MissingColumnsError(msg=f'The SimBA project specifies a classifier named "{clf_name}" that could not be found in your dataset for file {file_path}. Make sure that your project_config.ini is created correctly.')
elif (len(set(df[clf_name].unique()) - {0, 1}) > 0 and raise_bool_clf_error):
raise InvalidInputError(msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.")
timer.stop_timer()
return df, vid_name, timer.elapsed_time_str, frm_numbers

Expand Down

0 comments on commit 6ff22bc

Please sign in to comment.