From 46f5ff894795a983bf8f1c981890f2ea25267bad Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 9 Aug 2024 14:10:38 +0000 Subject: [PATCH 01/11] added aggregator class, unfinished --- examples/example_spine_statistics.py | 17 +++++ panoptica/__init__.py | 1 + panoptica/instance_evaluator.py | 18 ++--- panoptica/panoptica_aggregator.py | 63 ++++++++++++++++ panoptica/panoptica_evaluator.py | 104 +++++++++++++-------------- 5 files changed, 133 insertions(+), 70 deletions(-) create mode 100644 examples/example_spine_statistics.py create mode 100644 panoptica/panoptica_aggregator.py diff --git a/examples/example_spine_statistics.py b/examples/example_spine_statistics.py new file mode 100644 index 0000000..78c1ff8 --- /dev/null +++ b/examples/example_spine_statistics.py @@ -0,0 +1,17 @@ +from auxiliary.nifti.io import read_nifti +from auxiliary.turbopath import turbopath + +from panoptica import Panoptica_Evaluator, Panoptica_Aggregator + +directory = turbopath(__file__).parent + +reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") + +evaluator = Panoptica_Aggregator( + Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance"), +) + + +if __name__ == "__main__": + results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 2f02659..04a89d6 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,6 +3,7 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching +from panoptica.panoptica_aggregator import Panoptica_Aggregator from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult from panoptica.utils.processing_pair import ( diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 303abf0..9b784e7 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -28,9 +28,7 @@ def evaluate_matched_instance( >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ if decision_metric is not None: - assert decision_metric.name in [ - v.name for v in eval_metrics - ], "decision metric not contained in eval_metrics" + assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -42,21 +40,13 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [ - (reference_arr, prediction_arr, ref_idx, eval_metrics) - for ref_idx in ref_matched_labels - ] + instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap( - _evaluate_instance, instance_pairs - ) + metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) for metric_dict in metric_dicts: if decision_metric is None or ( - decision_threshold is not None - and decision_metric.score_beats_threshold( - metric_dict[decision_metric], decision_threshold - ) + decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) ): for k, v in metric_dict.items(): score_dict[k].append(v) diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py new file mode 100644 index 0000000..c297bb9 --- /dev/null +++ b/panoptica/panoptica_aggregator.py @@ -0,0 +1,63 @@ +import numpy as np +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import PanopticaResult +from dataclasses import dataclass + + +@dataclass +class NamedPanopticaResultGroup: + name: str + group2result: dict[str, PanopticaResult] + + +# Mean over instances +# mean over subjects +# give below/above percentile of metric (the names) +# make plot with metric dots +# make auc curve as plot +class Panoptica_Aggregator: + """Aggregator that calls evaluations and saves the resulting metrics per sample. Can be used to create statistics, ...""" + + def __init__(self, panoptica_evaluator: Panoptica_Evaluator): + self._panoptica_evaluator = panoptica_evaluator + self._group2named_results: dict[str, list[NamedPanopticaResultGroup]] = {} + self._n_samples = 0 + + def evaluate( + self, + prediction_arr: np.ndarray, + reference_arr: np.ndarray, + subject_name: str | None = None, + verbose: bool | None = None, + ): + """Evaluates one case + + Args: + prediction_arr (np.ndarray): Prediction array + reference_arr (np.ndarray): reference array + subject_name (str | None, optional): Unique name of the sample. If none, will give it a name based on count. Defaults to None. + verbose (bool | None, optional): Verbose. Defaults to None. + """ + if subject_name is None: + subject_name = f"Sample_{self._n_samples}" + + res = self._panoptica_evaluator.evaluate( + prediction_arr, + reference_arr, + result_all=True, + verbose=verbose, + ) + for k, v in res.items(): + if k not in self._group2named_results: + self._group2named_results[k] = [] + result_obj, _ = v + self._group2named_results[k].append(NamedPanopticaResultGroup(subject_name, result_obj)) + + self._n_samples += 1 + + def save_results(): + # save to excel + pass + + def load_results(): + pass diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index f0522f8..e0e7f04 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -1,5 +1,4 @@ from time import perf_counter -from typing import Type from panoptica.instance_approximator import InstanceApproximator from panoptica.instance_evaluator import evaluate_matched_instance @@ -70,13 +69,9 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -107,9 +102,7 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance( - processing_pair, self.__expected_input.value - ), f"input not of expected type {self.__expected_input}" + assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -129,53 +122,56 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.prediction_arr, raise_error=True - ) - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.reference_arr, raise_error=True - ) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): - assert isinstance(label_group, LabelGroup) - - prediction_arr_grouped = label_group(processing_pair.prediction_arr) - reference_arr_grouped = label_group(processing_pair.reference_arr) - - single_instance_mode = label_group.single_instance - processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore - decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance( - processing_pair, MatchedInstancePair - ): - processing_pair_grouped = MatchedInstancePair( - prediction_arr=processing_pair_grouped.prediction_arr, - reference_arr=processing_pair_grouped.reference_arr, - ) - decision_threshold = 0.0 - - result_grouped[group_name] = panoptic_evaluate( - processing_pair=processing_pair_grouped, - edge_case_handler=self.__edge_case_handler, - instance_approximator=self.__instance_approximator, - instance_matcher=self.__instance_matcher, - instance_metrics=self.__eval_metrics, - global_metrics=self.__global_metrics, - decision_metric=self.__decision_metric, - decision_threshold=decision_threshold, - result_all=result_all, - log_times=self.__log_times, - verbose=True if verbose is None else verbose, - verbose_calc=self.__verbose if verbose is None else verbose, - ) + result_grouped[group_name] = self._evaluate_group(group_name, label_group, processing_pair, result_all, verbose)[1:] return result_grouped + def _evaluate_group( + self, + group_name: str, + label_group: LabelGroup, + processing_pair, + result_all: bool = True, + verbose: bool | None = None, + ): + assert isinstance(label_group, LabelGroup) + + prediction_arr_grouped = label_group(processing_pair.prediction_arr) + reference_arr_grouped = label_group(processing_pair.reference_arr) + + single_instance_mode = label_group.single_instance + processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore + decision_threshold = self.__decision_threshold + if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + processing_pair_grouped = MatchedInstancePair( + prediction_arr=processing_pair_grouped.prediction_arr, + reference_arr=processing_pair_grouped.reference_arr, + ) + decision_threshold = 0.0 + + result, debug_data = panoptic_evaluate( + processing_pair=processing_pair_grouped, + edge_case_handler=self.__edge_case_handler, + instance_approximator=self.__instance_approximator, + instance_matcher=self.__instance_matcher, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, + decision_metric=self.__decision_metric, + decision_threshold=decision_threshold, + result_all=result_all, + log_times=self.__log_times, + verbose=True if verbose is None else verbose, + verbose_calc=self.__verbose if verbose is None else verbose, + ) + return group_name, result, debug_data + def panoptic_evaluate( - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -232,9 +228,7 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -255,9 +249,7 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, From 7c4a2da0c649c25c515e33b427b75f3b5cf11169 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:13:30 +0000 Subject: [PATCH 02/11] updated gitignore to ignore png, tsv, csv, niftis, and jsons --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index e955bf1..27873e8 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,8 @@ cython_debug/ .DS_Store .vscode +*.png +*.tsv +*.csv +*.nii.gz +*.json \ No newline at end of file From bb7be7c2c6c39cd5d9849c918c9f1ed5ebe21e27 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:15:00 +0000 Subject: [PATCH 03/11] fixed global metrics using instance edge case handling --- panoptica/panoptica_result.py | 51 +++++++++++++---------------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index b1dd671..1b7b60d 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -252,25 +252,19 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[m] = Evaluation_List_Metric( - m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result - ) + self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) # even if not available, set the global vars default_value = None was_calculated = False if m in self._global_metrics and arrays_present: - default_value = self._calc_global_bin_metric( - m, pred_binary, ref_binary, do_binarize=False - ) + default_value = self._calc_global_bin_metric(m, pred_binary, ref_binary, do_binarize=False) was_calculated = True self._add_metric( f"global_bin_{m.name.lower()}", MetricType.GLOBAL, - lambda x: MetricCouldNotBeComputedException( - f"Global Metric {m} not set" - ), + lambda x: MetricCouldNotBeComputedException(f"Global Metric {m} not set"), long_name="Global Binary " + m.value.long_name, default_value=default_value, was_calculated=was_calculated, @@ -285,12 +279,6 @@ def _calc_global_bin_metric( ): if metric not in self._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") - if self.tp == 0: - is_edgecase, result = self._edge_case_handler.handle_zero_tp( - metric, self.tp, self.num_pred_instances, self.num_ref_instances - ) - if is_edgecase: - return result if do_binarize: pred_binary = prediction_arr.copy() @@ -301,6 +289,13 @@ def _calc_global_bin_metric( pred_binary = prediction_arr ref_binary = reference_arr + prediction_empty = pred_binary.sum() == 0 + reference_empty = ref_binary.sum() == 0 + if prediction_empty or reference_empty: + is_edgecase, result = self._edge_case_handler.handle_zero_tp(metric, 0, int(prediction_empty), int(reference_empty)) + if is_edgecase: + return result + return metric( reference_arr=ref_binary, prediction_arr=pred_binary, @@ -383,19 +378,17 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return { - k: getattr(self, v.id) - for k, v in self._evaluation_metrics.items() - if (v._error == False and v._was_calculated) - } + return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + + @property + def evaluation_metrics(self): + return self._evaluation_metrics def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException( - f"{metric} could not be found, have you set it in eval_metrics during evaluation?" - ) + raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -411,9 +404,7 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException( - f"could not find metric with name {metric_name}" - ) + raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") def __getattribute__(self, __name: str) -> Any: attr = None @@ -428,9 +419,7 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) + raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -556,9 +545,7 @@ def function_template(res: PanopticaResult): if metric not in res._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") if res.tp == 0: - is_edgecase, result = res._edge_case_handler.handle_zero_tp( - metric, res.tp, res.num_pred_instances, res.num_ref_instances - ) + is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances) if is_edgecase: return result pred_binary = res._prediction_arr.copy() From e891ae551bae817bb5e182e4956e60af5d667977 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:16:52 +0000 Subject: [PATCH 04/11] fixed issue where instance approximation needs bigger uint dtype --- panoptica/instance_approximator.py | 41 ++++++++++-------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index f8e061f..78f0f7e 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -42,9 +42,7 @@ class InstanceApproximator(SupportsConfig, metaclass=ABCMeta): """ @abstractmethod - def _approximate_instances( - self, semantic_pair: SemanticPair, **kwargs - ) -> UnmatchedInstancePair | MatchedInstancePair: + def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> UnmatchedInstancePair | MatchedInstancePair: """ Abstract method to be implemented by subclasses for instance approximation. @@ -58,9 +56,7 @@ def _approximate_instances( pass def _yaml_repr(cls, node) -> dict: - raise NotImplementedError( - f"Tried to get yaml representation of abstract class {cls.__name__}" - ) + raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") return {} def approximate_instances( @@ -81,19 +77,11 @@ def approximate_instances( """ # Check validity pred_labels, ref_labels = semantic_pair._pred_labels, semantic_pair._ref_labels - pred_label_range = ( - (np.min(pred_labels), np.max(pred_labels)) - if len(pred_labels) > 0 - else (0, 0) - ) - ref_label_range = ( - (np.min(ref_labels), np.max(ref_labels)) if len(ref_labels) > 0 else (0, 0) - ) + pred_label_range = (np.min(pred_labels), np.max(pred_labels)) if len(pred_labels) > 0 else (0, 0) + ref_label_range = (np.min(ref_labels), np.max(ref_labels)) if len(ref_labels) > 0 else (0, 0) # min_value = min(np.min(pred_label_range[0]), np.min(ref_label_range[0])) - assert ( - min_value >= 0 - ), "There are negative values in the semantic maps. This is not allowed!" + assert min_value >= 0, "There are negative values in the semantic maps. This is not allowed!" # Set dtype to smalles fitting uint max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1])) dtype = _get_smallest_fitting_uint(max_value) @@ -133,9 +121,7 @@ def __init__(self, cca_backend: CCABackend | None = None) -> None: """ self.cca_backend = cca_backend - def _approximate_instances( - self, semantic_pair: SemanticPair, **kwargs - ) -> UnmatchedInstancePair: + def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> UnmatchedInstancePair: """ Approximate instances using the connected components algorithm. @@ -148,9 +134,7 @@ def _approximate_instances( """ cca_backend = self.cca_backend if cca_backend is None: - cca_backend = ( - CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy - ) + cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy assert cca_backend is not None empty_prediction = len(semantic_pair._pred_labels) == 0 @@ -161,13 +145,14 @@ def _approximate_instances( else (semantic_pair._prediction_arr, 0) ) reference_arr, n_reference_instance = ( - _connected_components(semantic_pair._reference_arr, cca_backend) - if not empty_reference - else (semantic_pair._reference_arr, 0) + _connected_components(semantic_pair._reference_arr, cca_backend) if not empty_reference else (semantic_pair._reference_arr, 0) ) + + dtype = _get_smallest_fitting_uint(max(prediction_arr.max(), reference_arr.max())) + return UnmatchedInstancePair( - prediction_arr=prediction_arr, - reference_arr=reference_arr, + prediction_arr=prediction_arr.astype(dtype), + reference_arr=reference_arr.astype(dtype), n_prediction_instance=n_prediction_instance, n_reference_instance=n_reference_instance, ) From d57d89b3756b9c634f29e9afaf9d90dcfc30d5cf Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:18:02 +0000 Subject: [PATCH 05/11] fixed issue where instance approximation needs bigger uint dtype --- panoptica/_functionals.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 5344d20..e3c81bc 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -33,11 +33,7 @@ def _calc_overlapping_labels( # instance_pairs = [(reference_arr, prediction_arr, i, j) for i, j in overlapping_indices] # (ref, pred) - return [ - (int(i % (max_ref)), int(i // (max_ref))) - for i in np.unique(overlap_arr) - if i > max_ref - ] + return [(int(i % (max_ref)), int(i // (max_ref))) for i in np.unique(overlap_arr) if i > max_ref] def _calc_matching_metric_of_overlapping_labels( @@ -67,13 +63,8 @@ def _calc_matching_metric_of_overlapping_labels( with Pool() as pool: mm_values = pool.starmap(matching_metric.value, instance_pairs) - mm_pairs = [ - (i, (instance_pairs[idx][2], instance_pairs[idx][3])) - for idx, i in enumerate(mm_values) - ] - mm_pairs = sorted( - mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing - ) + mm_pairs = [(i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values)] + mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing) return mm_pairs @@ -133,7 +124,7 @@ def _connected_components( else: raise NotImplementedError(cca_backend) - return cc_arr.astype(array.dtype), n_instances + return cc_arr, n_instances def _get_paired_crop( From 3084a58f2ce721e01436281be82bf6026a932af7 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:18:36 +0000 Subject: [PATCH 06/11] removed comment --- panoptica/instance_evaluator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 6ad16da..cdb3e77 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -40,6 +40,7 @@ def evaluate_matched_instance( with Pool() as pool: metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) + # TODO if instance matcher already gives matching metric, adapt here! for metric_dict in metric_dicts: if decision_metric is None or ( decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) From 1ed2fa05b78e0e6946fce62f647a3d8581ea0acf Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:19:28 +0000 Subject: [PATCH 07/11] fixed EdgeCaseResult NAN not pickable bug, refined aggregator to first working prototype that can also make figures. --- examples/example_spine_semantic.py | 12 +- examples/example_spine_statistics.py | 59 +++- panoptica/__init__.py | 1 + .../configs/panoptica_evaluator_BRATS.yaml | 30 ++ .../configs/panoptica_evaluator_ISLES.yaml | 30 ++ panoptica/instance_matcher.py | 1 + panoptica/panoptica_aggregator.py | 175 +++++++++-- panoptica/panoptica_evaluator.py | 42 ++- panoptica/panoptica_statistics.py | 286 ++++++++++++++++++ panoptica/utils/__init__.py | 1 + panoptica/utils/edge_case_handling.py | 58 ++-- panoptica/utils/parallel_processing.py | 36 +++ panoptica/utils/processing_pair.py | 84 ++--- unit_tests/test_metrics.py | 52 ++++ unit_tests/test_panoptic_evaluator.py | 24 ++ 15 files changed, 749 insertions(+), 142 deletions(-) create mode 100644 panoptica/configs/panoptica_evaluator_BRATS.yaml create mode 100644 panoptica/configs/panoptica_evaluator_ISLES.yaml create mode 100644 panoptica/panoptica_statistics.py create mode 100644 panoptica/utils/parallel_processing.py diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index e2f45b1..888f618 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -26,9 +26,7 @@ def main(): with cProfile.Profile() as pr: - result, intermediate_steps_data = evaluator.evaluate( - prediction_mask, reference_mask - )["ungrouped"] + result, intermediate_steps_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] # To print the results, just call print print(result) @@ -37,12 +35,8 @@ def main(): intermediate_steps_data.original_prediction_arr # Input prediction array, untouched intermediate_steps_data.original_reference_arr # Input reference array, untouched - intermediate_steps_data.prediction_arr( - InputType.MATCHED_INSTANCE - ) # Prediction array after instances have been matched - intermediate_steps_data.reference_arr( - InputType.MATCHED_INSTANCE - ) # Reference array after instances have been matched + intermediate_steps_data.prediction_arr(InputType.MATCHED_INSTANCE) # Prediction array after instances have been matched + intermediate_steps_data.reference_arr(InputType.MATCHED_INSTANCE) # Reference array after instances have been matched pr.dump_stats(directory + "/semantic_example.log") return result, intermediate_steps_data diff --git a/examples/example_spine_statistics.py b/examples/example_spine_statistics.py index 78c1ff8..c7ca6aa 100644 --- a/examples/example_spine_statistics.py +++ b/examples/example_spine_statistics.py @@ -1,7 +1,16 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import Panoptica_Evaluator, Panoptica_Aggregator +from panoptica import Panoptica_Evaluator, Panoptica_Aggregator, make_curve_over_setups +from pathlib import Path +from panoptica.utils import NonDaemonicPool +from joblib import delayed, Parallel +from concurrent.futures import ProcessPoolExecutor, as_completed +from tqdm import tqdm +from multiprocessing import set_start_method + + +# set_start_method("fork") directory = turbopath(__file__).parent @@ -10,8 +19,54 @@ evaluator = Panoptica_Aggregator( Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance"), + Path(__file__).parent.joinpath("spine_example.tsv"), ) if __name__ == "__main__": - results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) + parallel_opt = "None" # none, pool, joblib, future + # + parallel_opt = parallel_opt.lower() + + if parallel_opt == "pool": + args = [ + (prediction_mask, reference_mask, "sample1"), + (prediction_mask, reference_mask, "sample2"), + (prediction_mask, reference_mask, "sample3"), + (prediction_mask, reference_mask, "sample4"), + ] + with NonDaemonicPool() as pool: + pool.starmap(evaluator.evaluate, args) + elif parallel_opt == "none": + for i in range(4): + results = evaluator.evaluate(prediction_mask, reference_mask, f"sample{i}") + elif parallel_opt == "joblib": + Parallel(n_jobs=4, backend="threading")(delayed(evaluator.evaluate)(prediction_mask, reference_mask) for i in range(4)) + elif parallel_opt == "future": + with ProcessPoolExecutor() as executor: + futures = {executor.submit(evaluator.evaluate, prediction_mask, reference_mask) for i in range(4)} + for future in tqdm(as_completed(futures), total=len(futures), desc="Panoptica Evaluation"): + result = future.result() + if result is not None: + print("Done") + + panoptic_statistic = evaluator.make_statistic() + panoptic_statistic.print_summary() + + fig = panoptic_statistic.get_summary_figure("sq_dsc", horizontal=True) + out_figure = str(Path(__file__).parent.joinpath("example_sq_dsc_figure.png")) + fig.write_image(out_figure) + + fig2 = make_curve_over_setups( + { + "t0.5": panoptic_statistic, + "bad": panoptic_statistic, + "good classifier": panoptic_statistic, + 2.0: panoptic_statistic, + }, + groups=None, + metric="pq", + ) + + out_figure = str(Path(__file__).parent.joinpath("example_multiple_statistics.png")) + fig2.savefig(out_figure) diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 04a89d6..41648f0 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,6 +3,7 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching +from panoptica.panoptica_statistics import Panoptica_Statistic, make_curve_over_setups from panoptica.panoptica_aggregator import Panoptica_Aggregator from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult diff --git a/panoptica/configs/panoptica_evaluator_BRATS.yaml b/panoptica/configs/panoptica_evaluator_BRATS.yaml new file mode 100644 index 0000000..64573f8 --- /dev/null +++ b/panoptica/configs/panoptica_evaluator_BRATS.yaml @@ -0,0 +1,30 @@ +!Panoptica_Evaluator +decision_metric: null +decision_threshold: null +edge_case_handler: !EdgeCaseHandler + empty_list_std: !EdgeCaseResult NAN + listmetric_zeroTP_handling: + !Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF, + empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult INF} + !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, + empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult NAN} +instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD] +global_metrics: [!Metric DSC, !Metric RVD, !Metric IOU] +expected_input: !InputType SEMANTIC +instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null} +instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, + matching_threshold: 0.5} +log_times: true +segmentation_class_groups: null +verbose: false diff --git a/panoptica/configs/panoptica_evaluator_ISLES.yaml b/panoptica/configs/panoptica_evaluator_ISLES.yaml new file mode 100644 index 0000000..64573f8 --- /dev/null +++ b/panoptica/configs/panoptica_evaluator_ISLES.yaml @@ -0,0 +1,30 @@ +!Panoptica_Evaluator +decision_metric: null +decision_threshold: null +edge_case_handler: !EdgeCaseHandler + empty_list_std: !EdgeCaseResult NAN + listmetric_zeroTP_handling: + !Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF, + empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult INF} + !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, + empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult NAN} +instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD] +global_metrics: [!Metric DSC, !Metric RVD, !Metric IOU] +expected_input: !InputType SEMANTIC +instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null} +instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, + matching_threshold: 0.5} +log_times: true +segmentation_class_groups: null +verbose: false diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 5bedd9e..25534c0 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -212,6 +212,7 @@ def _match_instances( and not self._allow_many_to_one ): continue # -> doesnt make speed difference + # TODO always go in here, but add the matching score to the pair (so evaluation over multiple thresholds becomes easy) if self._matching_metric.score_beats_threshold( matching_score, self._matching_threshold ): diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index c297bb9..269b770 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -1,34 +1,90 @@ import numpy as np +from panoptica.panoptica_statistics import Panoptica_Statistic from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult -from dataclasses import dataclass +from pathlib import Path +from multiprocessing import Lock, set_start_method +import csv +import os +import atexit +set_start_method("fork") +filelock = Lock() +inevalfilelock = Lock() -@dataclass -class NamedPanopticaResultGroup: - name: str - group2result: dict[str, PanopticaResult] - -# Mean over instances -# mean over subjects -# give below/above percentile of metric (the names) -# make plot with metric dots -# make auc curve as plot +# class Panoptica_Aggregator: + # internal_list_lock = Lock() + # """Aggregator that calls evaluations and saves the resulting metrics per sample. Can be used to create statistics, ...""" - def __init__(self, panoptica_evaluator: Panoptica_Evaluator): - self._panoptica_evaluator = panoptica_evaluator - self._group2named_results: dict[str, list[NamedPanopticaResultGroup]] = {} - self._n_samples = 0 + def __init__( + self, + panoptica_evaluator: Panoptica_Evaluator, + output_file: Path, + continue_file: bool = True, + ): + """ + Args: + panoptica_evaluator (Panoptica_Evaluator): The Panoptica_Evaluator used for the pipeline. + output_file (Path | None, optional): If given, will stream the sample results into this file. If the file is existent, will append results if not already there. Defaults to None. + """ + self.__panoptica_evaluator = panoptica_evaluator + self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names + self.__output_file = None + self.__output_buffer_file = None + self.__evaluation_metrics = panoptica_evaluator.resulting_metric_keys + + # uses tsv + assert output_file.parent.exists(), f"Directory {str(output_file.parent)} does not exist" + + out_file_path = str(output_file) + if not out_file_path.endswith(".tsv"): + out_file_path += ".tsv" + + out_buffer_file: Path = Path(out_file_path).parent.joinpath("panoptica_aggregator_tmp.tsv") + self.__output_buffer_file = out_buffer_file + + Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) + self.__output_file = out_file_path + + header = ["subject_name"] + [f"{g}-{m}" for g in self.__class_group_names for m in self.__evaluation_metrics] + header_hash = hash("+".join(header)) + + if not output_file.exists(): + # write header + _write_content(output_file, [header]) + else: + header_list = _read_first_row(output_file) + # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file + assert header_hash == hash("+".join(header_list)), "Hash of header not the same! You are using a different setup!" + + if out_buffer_file.exists(): + os.remove(out_buffer_file) + open(out_buffer_file, "a").close() + + if continue_file: + with inevalfilelock: + with filelock: + id_list = _load_first_column_entries(self.__output_file) + _write_content(self.__output_buffer_file, [[s] for s in id_list]) + + atexit.register(self.__exist_handler) + + def __exist_handler(self): + os.remove(self.__output_buffer_file) + + def make_statistic(self) -> Panoptica_Statistic: + with filelock: + obj = Panoptica_Statistic.from_file(self.__output_file) + return obj def evaluate( self, prediction_arr: np.ndarray, reference_arr: np.ndarray, - subject_name: str | None = None, - verbose: bool | None = None, + subject_name: str, ): """Evaluates one case @@ -36,28 +92,83 @@ def evaluate( prediction_arr (np.ndarray): Prediction array reference_arr (np.ndarray): reference array subject_name (str | None, optional): Unique name of the sample. If none, will give it a name based on count. Defaults to None. + skip_already_existent (bool): If true, will skip subjects which were already evaluated instead of crashing. Defaults to False. verbose (bool | None, optional): Verbose. Defaults to None. """ - if subject_name is None: - subject_name = f"Sample_{self._n_samples}" + # Read tmp file to see which sample names are blocked + with inevalfilelock: + id_list = _load_first_column_entries(self.__output_buffer_file) - res = self._panoptica_evaluator.evaluate( + if subject_name in id_list: + print( + f"Subject '{subject_name}' evaluated or in process {self.__output_file}, do not add duplicates to your evaluation!", + flush=True, + ) + return + _write_content(self.__output_buffer_file, [[subject_name]]) + + # Run Evaluation (allowed in parallel) + res = self.__panoptica_evaluator.evaluate( prediction_arr, reference_arr, result_all=True, - verbose=verbose, + verbose=False, + log_times=False, ) - for k, v in res.items(): - if k not in self._group2named_results: - self._group2named_results[k] = [] - result_obj, _ = v - self._group2named_results[k].append(NamedPanopticaResultGroup(subject_name, result_obj)) - self._n_samples += 1 + # Add to file + self._save_one_subject(subject_name, res) + + def _save_one_subject(self, subject_name, result_grouped): + with filelock: + # + content = [subject_name] + for groupname in self.__class_group_names: + result: PanopticaResult = result_grouped[groupname][0] + result_dict = result.to_dict() + del result + + for e in self.__evaluation_metrics: + mvalue = result_dict[e] if e in result_dict else "" + content.append(mvalue) + _write_content(self.__output_file, [content]) + print(f"Saved entry {subject_name} into {str(self.__output_file)}") + + +def _read_first_row(file: str): + # NOT THREAD SAFE BY ITSELF! + with open(str(file), "r", encoding="utf8", newline="") as tsvfile: + rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") + + rows = [row for row in rd] + if len(rows) == 0: + row = [] + else: + row = rows[0] + + return row + + +def _load_first_column_entries(file: str): + # NOT THREAD SAFE BY ITSELF! + with open(str(file), "r", encoding="utf8", newline="") as tsvfile: + rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") + + rows = [row for row in rd] + if len(rows) == 0: + id_list = [] + else: + id_list = list([row[0] for row in rows]) + + n_id = len(id_list) + assert n_id == len(list(set(id_list))), "file has duplicate entries!" + + return id_list - def save_results(): - # save to excel - pass - def load_results(): - pass +def _write_content(file: str, content: list[list[str]]): + # NOT THREAD SAFE BY ITSELF! + with open(str(file), "a", encoding="utf8", newline="") as tsvfile: + writer = csv.writer(tsvfile, delimiter="\t", lineterminator="\n") + for c in content: + writer.writerow(c) diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index ccc245e..d2fe8cf 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -21,6 +21,8 @@ from panoptica.utils.config import SupportsConfig from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup +NO_GROUP_KEY = "ungrouped" + class Panoptica_Evaluator(SupportsConfig): @@ -67,6 +69,7 @@ def __init__( self.__global_metrics = global_metrics self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold + self.__resulting_metric_keys = None self.__segmentation_class_groups = segmentation_class_groups @@ -100,6 +103,7 @@ def evaluate( prediction_arr: np.ndarray, reference_arr: np.ndarray, result_all: bool = True, + log_times: bool | None = None, verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) @@ -107,7 +111,7 @@ def evaluate( if self.__segmentation_class_groups is None: return { - "ungrouped": panoptic_evaluate( + NO_GROUP_KEY: panoptic_evaluate( input_pair=processing_pair, edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, @@ -117,7 +121,7 @@ def evaluate( decision_metric=self.__decision_metric, decision_threshold=self.__decision_threshold, result_all=result_all, - log_times=self.__log_times, + log_times=self.__log_times if log_times is None else log_times, verbose=True if verbose is None else verbose, verbose_calc=self.__verbose if verbose is None else verbose, ) @@ -128,9 +132,38 @@ def evaluate( result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): - result_grouped[group_name] = self._evaluate_group(group_name, label_group, processing_pair, result_all, verbose)[1:] + result_grouped[group_name] = self._evaluate_group( + group_name, + label_group, + processing_pair, + result_all, + log_times=log_times, + verbose=verbose, + )[1:] return result_grouped + @property + def segmentation_class_groups_names(self) -> list[str]: + if self.__segmentation_class_groups is None: + return [NO_GROUP_KEY] + return self.__segmentation_class_groups.keys() + + @property + def resulting_metric_keys(self) -> list[str]: + if self.__resulting_metric_keys is None: + dummy_input = MatchedInstancePair(np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8)) + _, res, _ = self._evaluate_group( + group_name="", + label_group=LabelGroup(1, single_instance=False), + processing_pair=dummy_input, + result_all=True, + log_times=False, + verbose=False, + ) + self.__resulting_metric_keys = list(res.to_dict().keys()) + return self.__resulting_metric_keys + # panoptic_evaluate + def _evaluate_group( self, group_name: str, @@ -138,6 +171,7 @@ def _evaluate_group( processing_pair, result_all: bool = True, verbose: bool | None = None, + log_times: bool | None = None, ): assert isinstance(label_group, LabelGroup) @@ -164,7 +198,7 @@ def _evaluate_group( decision_metric=self.__decision_metric, decision_threshold=decision_threshold, result_all=result_all, - log_times=self.__log_times, + log_times=self.__log_times if log_times is None else log_times, verbose=True if verbose is None else verbose, verbose_calc=self.__verbose if verbose is None else verbose, ) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py new file mode 100644 index 0000000..b21dd52 --- /dev/null +++ b/panoptica/panoptica_statistics.py @@ -0,0 +1,286 @@ +import csv +import numpy as np +from pathlib import Path +import numpy as np + +try: + import pandas as pd + import matplotlib.pyplot as plt + import plotly.express as px + import plotly.graph_objects as go +except Exception as e: + print(e) + print("OPTIONAL PACKAGE MISSING") + assert False + + +class Panoptica_Statistic: + + def __init__( + self, + subj_names: list[str], + value_dict: dict[str, dict[str, list[float]]], + ) -> None: + self.__subj_names = subj_names + self.__value_dict = value_dict + + self.__groupnames = list(value_dict.keys()) + self.__metricnames = list(value_dict[self.__groupnames[0]].keys()) + + @property + def groupnames(self): + return self.__groupnames + + @property + def metricnames(self): + return self.__metricnames + + @classmethod + def from_file(cls, file: str): + # check integrity of header and so on + with open(str(file), "r", encoding="utf8", newline="") as tsvfile: + rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") + + rows = [row for row in rd] + + header = rows[0] + assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + + keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) + metric_names = [] + for k in keys_in_order: + if k[1] not in metric_names: + metric_names.append(k[1]) + group_names = list(set([k[0] for k in keys_in_order])) + + print(f"Found {len(rows)-1} entries") + print(f"Found metrics: {metric_names}") + print(f"Found groups: {group_names}") + + # initialize collection + subj_names = [] + # list of floats in order fo subject_names + # from group to metric to list of values + value_dict: dict[str, dict[str, list[float]]] = {} + + # now load entries + for r in rows[1:]: + sn = r[0] # subject_name + subj_names.append(sn) + + for idx, value in enumerate(r[1:]): + group_name, metric_name = keys_in_order[idx] + if group_name not in value_dict: + value_dict[group_name] = {m: [] for m in metric_names} + + value_dict[group_name][metric_name].append(float(value)) + + return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) + + def _assertgroup(self, group): + assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + + def _assertmetric(self, metric): + assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + + def _assertsubject(self, subjectname): + assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + + def get(self, group, metric) -> list[float]: + """Returns the list of values for given group and metric + + Args: + group (_type_): _description_ + metric (_type_): _description_ + + Returns: + list[float]: _description_ + """ + self._assertgroup(group) + self._assertmetric(metric) + + assert ( + group in self.__value_dict and metric in self.__value_dict[group] + ), f"Values not found for group {group} and metric {metric} evem though they should!" + return self.__value_dict[group][metric] + + def get_one_subject(self, subjectname: str): + """Gets the values for ONE subject for each group and metric + + Args: + subjectname (str): _description_ + + Returns: + _type_: _description_ + """ + self._assertsubject(subjectname) + sidx = self.__subj_names.index(subjectname) + return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + + def get_across_groups(self, metric): + """Given metric, gives list of all values (even across groups!) Treat with care! + + Args: + metric (_type_): _description_ + + Returns: + _type_: _description_ + """ + values = [] + for g in self.__groupnames: + values.append(self.get(g, metric)) + return values + + def get_summary_dict(self): + return {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + + def get_summary(self, group, metric): + # TODO maybe more here? range, stuff like that + return self.avg_std(group, metric) + + def avg_std(self, group, metric) -> tuple[float, float]: + values = self.get(group, metric) + avg = float(np.average(values)) + std = float(np.std(values)) + return (avg, std) + + def print_summary(self): + summary = self.get_summary_dict() + print() + for g in self.__groupnames: + print(f"Group {g}:") + for m in self.__metricnames: + avg, std = summary[g][m] + print(m, ":", avg, "+-", std) + print() + + def get_summary_figure( + self, + metric: str, + horizontal: bool = True, + # title overwrite? + ): + """Returns a figure object that shows the given metric for each group and its std + + Args: + metric (str): _description_ + horizontal (bool, optional): _description_. Defaults to True. + + Returns: + _type_: _description_ + """ + orientation = "h" if horizontal else "v" + data_plot = {g: np.asarray(self.get(g, metric)) for g in self.__groupnames} + return plot_box( + data=data_plot, + orientation=orientation, + score=metric, + ) + + # groupwise or in total + # Mean over instances + # mean over subjects + # give below/above percentile of metric (the names) + # make auc curve as plot + + +def make_curve_over_setups( + statistics_dict: dict[str | int | float, Panoptica_Statistic], + metric: str, + groups: list[str] | str | None = None, +): + if groups is None: + groups = list(statistics_dict.values())[0].groupnames + # + if isinstance(groups, str): + groups = [groups] + # + for setupname, stat in statistics_dict.items(): + assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + + setupnames = list(statistics_dict.keys()) + convert_x_to_digit = True + for s in setupnames: + if not str(s).isdigit(): + convert_x_to_digit = False + break + + # If X (setupnames) are digits only, plot as digits + if convert_x_to_digit: + X = [float(s) for s in setupnames] + else: + X = range(len(setupnames)) + + fig = plt.figure() + if not convert_x_to_digit: + plt.xticks(X, setupnames) + + plt.ylabel("average " + metric) + plt.grid("major") + # Y values are average metric values in that group and metric + for g in groups: + Y = [stat.avg_std(g, metric)[0] for stat in statistics_dict.values()] + plt.plot(X, Y, label=g) + + plt.legend() + return fig + + +def _flatten_extend(matrix): + flat_list = [] + for row in matrix: + flat_list.extend(row) + return flat_list + + +def plot_box( + data: dict[str, np.ndarray], + sort=True, + orientation="h", + # graph_name: str = "Structure", + score: str = "Dice-Score", + width=850, + height=1200, + yaxis_title=None, + xaxis_title=None, +): + graph_name: str = "Structure" + + if xaxis_title is None: + xaxis_title = score if orientation == "h" else graph_name + if yaxis_title is None: + yaxis_title = score if orientation != "h" else graph_name + + data = {e.replace("_", " "): v for e, v in data.items()} + df_data = pd.DataFrame( + { + graph_name: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]), + score: np.concatenate([*data.values()], 0), + } + ) + if sort: + df_by_spec_count = df_data.groupby(graph_name).mean() + df_by_spec_count = dict(df_by_spec_count[score].items()) + df_data["mean"] = df_data[graph_name].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data = df_data.sort_values(by="mean") + if orientation == "v": + fig = px.strip(df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation) + fig.update_traces(marker={"size": 5, "color": "#555555"}) + for e in data.keys(): + fig.add_trace(go.Box(y=df_data.query(f'{graph_name} == "{e}"')[score], name=e, orientation=orientation)) + else: + fig = px.strip(df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation) + fig.update_traces(marker={"size": 5, "color": "#555555"}) + for e in data.keys(): + fig.add_trace(go.Box(x=df_data.query(f'{graph_name} == "{e}"')[score], name=e, orientation=orientation, boxpoints=False)) + fig.update_layout( + autosize=False, + width=width, + height=height, + showlegend=False, + yaxis_title=yaxis_title, + xaxis_title=xaxis_title, + font={"family": "Arial"}, + ) + fig.update_traces(orientation=orientation) + return fig diff --git a/panoptica/utils/__init__.py b/panoptica/utils/__init__.py index 0b72f2a..2c0d43e 100644 --- a/panoptica/utils/__init__.py +++ b/panoptica/utils/__init__.py @@ -19,3 +19,4 @@ SegmentationClassGroups, LabelGroup, ) +from panoptica.utils.parallel_processing import NonDaemonicPool diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index 0fe85de..a865683 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -5,11 +5,27 @@ class EdgeCaseResult(_Enum_Compare): - INF = np.inf - NAN = np.nan - ZERO = 0.0 - ONE = 1.0 - NONE = None + INF = auto() # np.inf + NAN = auto() # np.nan + ZERO = auto() # 0.0 + ONE = auto() # 1.0 + NONE = auto() # None + + @property + def value(self): + return self() + + def __call__(self): + transfer_dict = { + EdgeCaseResult.INF.name: np.inf, + EdgeCaseResult.NAN.name: np.nan, + EdgeCaseResult.ZERO.name: 0.0, + EdgeCaseResult.ONE.name: 1.0, + EdgeCaseResult.NONE.name: None, + } + if self.name in transfer_dict: + return transfer_dict[self.name] + raise KeyError(f"No defined value for EdgeCaseResult {str(self)}") class EdgeCaseZeroTP(_Enum_Compare): @@ -41,26 +57,12 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( - empty_prediction_result - if empty_prediction_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( - empty_reference_result - if empty_reference_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( - no_instances_result if no_instances_result is not None else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( - normal if normal is not None else default_result - ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result - def __call__( - self, tp: int, num_pred_instances, num_ref_instances - ) -> tuple[bool, float | None]: + def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -131,9 +133,7 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[ - Metric, MetricZeroTPEdgeCaseHandling - ] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -160,9 +160,7 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError( - f"Metric {metric} encountered zero TP, but no edge handling available" - ) + raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") return self.__listmetric_zeroTP_handling[metric]( tp=tp, diff --git a/panoptica/utils/parallel_processing.py b/panoptica/utils/parallel_processing.py new file mode 100644 index 0000000..50e3691 --- /dev/null +++ b/panoptica/utils/parallel_processing.py @@ -0,0 +1,36 @@ +import multiprocessing.pool +from multiprocessing import Pool, Process +from typing import Callable + + +class NoDaemonProcess(Process): + def __init__( + self, + group: None = None, + target: Callable[..., object] | None = None, + name: str | None = None, + args=None, + kwargs=None, + *, + daemon: bool | None = None, + ) -> None: + if kwargs is None: + kwargs = {} + if args is None: + args = [] + super().__init__(None, target, name, args, kwargs, daemon=daemon) + + # make 'daemon' attribute always return False + def _get_daemon(self): + return False + + def _set_daemon(self, value): + pass + + daemon = property(_get_daemon, _set_daemon) + + +# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool +# because the latter is only a wrapper function, not a proper class. +class NonDaemonicPool(multiprocessing.pool.Pool): + Process = NoDaemonProcess diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 6ef4640..6020365 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -25,9 +25,7 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None - ) -> None: + def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: """Initializes a general Processing Pair Args: @@ -40,12 +38,8 @@ def __init__( self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple( - _unique_without_zeros(reference_arr) - ) # type:ignore - self._pred_labels: tuple[int, ...] = tuple( - _unique_without_zeros(prediction_arr) - ) # type:ignore + self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore + self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -62,41 +56,25 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - ( - print( - f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" - ) - if verbose - else None - ) + (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) self.is_cropped = True def uncrop_data(self, verbose: bool = False): if self.is_cropped == False: return - assert ( - self.uncropped_shape is not None - ), "Calling uncrop_data() without having cropped first" + assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - ( - print( - f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" - ) - if verbose - else None - ) + (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype( - type, int_type - ), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -175,9 +153,7 @@ def copy(self): ) # type:ignore -def _check_array_integrity( - prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None -): +def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): """ Check the integrity of two numpy arrays. @@ -199,12 +175,8 @@ def _check_array_integrity( assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert ( - prediction_arr.shape == reference_arr.shape - ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert ( - prediction_arr.dtype == reference_arr.dtype - ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -287,15 +259,11 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list( - [i for i in self._ref_labels if i not in self._pred_labels] - ) + missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list( - [i for i in self._pred_labels if i not in self._ref_labels] - ) + missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) self.missed_prediction_labels = missed_prediction_labels @property @@ -332,9 +300,7 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray - ) -> _ProcessingPair: + def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) @@ -343,9 +309,7 @@ def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} - def add_intermediate_arr_data( - self, processing_pair: _ProcessingPair, inputtype: InputType - ): + def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType): type_name = inputtype.name self.add_intermediate_data(type_name, processing_pair) @@ -355,36 +319,26 @@ def add_intermediate_data(self, key, value): @property def original_prediction_arr(self): - assert ( - self._original_input is not None - ), "Original prediction_arr is None, there are no intermediate steps" + assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps" return self._original_input.prediction_arr @property def original_reference_arr(self): - assert ( - self._original_input is not None - ), "Original reference_arr is None, there are no intermediate steps" + assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps" return self._original_input.reference_arr def prediction_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance( - procpair, _ProcessingPair - ), f"step {type_name} is not a processing pair, error" + assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" return procpair.prediction_arr def reference_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance( - procpair, _ProcessingPair - ), f"step {type_name} is not a processing pair, error" + assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" return procpair.reference_arr def __getitem__(self, key): - assert ( - key in self._intermediatesteps - ), f"key {key} not in intermediate steps, maybe the step was skipped?" + assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?" return self._intermediatesteps[key] diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index 2800187..9e945d2 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -149,6 +149,58 @@ def test_dsc_case_simple_underpredicted(self): self.assertEqual(dsc, 0.5714285714285714) +class Test_ASSD(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_st_case_simple_identical(self): + pred_arr, ref_arr = case_simple_identical() + st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 0.0) + + def test_st_case_simple_nooverlap(self): + pred_arr, ref_arr = case_simple_nooverlap() + st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 1.05) + + def test_st_case_simple_overpredicted(self): + pred_arr, ref_arr = case_simple_overpredicted() + st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 0.625) + + def test_st_case_simple_underpredicted(self): + pred_arr, ref_arr = case_simple_underpredicted() + st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 0.625) + + +class Test_clDSC(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_st_case_simple_identical(self): + pred_arr, ref_arr = case_simple_identical() + st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 1.0) + + def test_st_case_simple_nooverlap(self): + pred_arr, ref_arr = case_simple_nooverlap() + st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(np.isnan(st), True) + + def test_st_case_simple_overpredicted(self): + pred_arr, ref_arr = case_simple_overpredicted() + st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 1.0) + + def test_st_case_simple_underpredicted(self): + pred_arr, ref_arr = case_simple_underpredicted() + st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr) + self.assertEqual(st, 1.0) + + # class Test_ST(unittest.TestCase): # def setUp(self) -> None: # os.environ["PANOPTICA_CITATION_REMINDER"] = "False" diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index c843875..1038723 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -199,7 +199,31 @@ def test_pred_empty(self): self.assertEqual(result.fn, 1) self.assertEqual(result.sq, 0.0) self.assertEqual(result.pq, 0.0) + self.assertEqual(result.global_bin_dsc, 0.0) + self.assertEqual(result.sq_assd, np.inf) + + def test_no_TP_but_overlap(self): + a = np.zeros([50, 50], np.uint16) + b = a.copy() + a[20:40, 10:20] = 1 + b[20:25, 10:15] = 2 + + evaluator = Panoptica_Evaluator( + expected_input=InputType.SEMANTIC, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) + + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + print(result) + self.assertEqual(result.tp, 0) + self.assertEqual(result.fp, 1) + self.assertEqual(result.fn, 1) + self.assertEqual(result.sq, 0.0) + self.assertEqual(result.pq, 0.0) + self.assertAlmostEqual(result.global_bin_dsc, 0.22222222222) self.assertEqual(result.sq_assd, np.inf) + self.assertTrue(np.isnan(result.sq_rvd)) def test_ref_empty(self): a = np.zeros([50, 50], np.uint16) From 9f1288a53664024eea137887af7a96a09f029e6b Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 10:20:29 +0000 Subject: [PATCH 08/11] Autoformat with black --- examples/example_spine_semantic.py | 12 +++- examples/example_spine_statistics.py | 14 ++++- panoptica/_functionals.py | 15 ++++- panoptica/instance_approximator.py | 38 +++++++++--- panoptica/instance_evaluator.py | 18 ++++-- panoptica/panoptica_aggregator.py | 18 ++++-- panoptica/panoptica_evaluator.py | 56 +++++++++++++----- panoptica/panoptica_result.py | 38 +++++++++--- panoptica/panoptica_statistics.py | 59 +++++++++++++++---- panoptica/utils/edge_case_handling.py | 32 +++++++--- panoptica/utils/processing_pair.py | 84 +++++++++++++++++++++------ 11 files changed, 298 insertions(+), 86 deletions(-) diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 888f618..e2f45b1 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -26,7 +26,9 @@ def main(): with cProfile.Profile() as pr: - result, intermediate_steps_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] + result, intermediate_steps_data = evaluator.evaluate( + prediction_mask, reference_mask + )["ungrouped"] # To print the results, just call print print(result) @@ -35,8 +37,12 @@ def main(): intermediate_steps_data.original_prediction_arr # Input prediction array, untouched intermediate_steps_data.original_reference_arr # Input reference array, untouched - intermediate_steps_data.prediction_arr(InputType.MATCHED_INSTANCE) # Prediction array after instances have been matched - intermediate_steps_data.reference_arr(InputType.MATCHED_INSTANCE) # Reference array after instances have been matched + intermediate_steps_data.prediction_arr( + InputType.MATCHED_INSTANCE + ) # Prediction array after instances have been matched + intermediate_steps_data.reference_arr( + InputType.MATCHED_INSTANCE + ) # Reference array after instances have been matched pr.dump_stats(directory + "/semantic_example.log") return result, intermediate_steps_data diff --git a/examples/example_spine_statistics.py b/examples/example_spine_statistics.py index c7ca6aa..8073acd 100644 --- a/examples/example_spine_statistics.py +++ b/examples/example_spine_statistics.py @@ -41,11 +41,19 @@ for i in range(4): results = evaluator.evaluate(prediction_mask, reference_mask, f"sample{i}") elif parallel_opt == "joblib": - Parallel(n_jobs=4, backend="threading")(delayed(evaluator.evaluate)(prediction_mask, reference_mask) for i in range(4)) + Parallel(n_jobs=4, backend="threading")( + delayed(evaluator.evaluate)(prediction_mask, reference_mask) + for i in range(4) + ) elif parallel_opt == "future": with ProcessPoolExecutor() as executor: - futures = {executor.submit(evaluator.evaluate, prediction_mask, reference_mask) for i in range(4)} - for future in tqdm(as_completed(futures), total=len(futures), desc="Panoptica Evaluation"): + futures = { + executor.submit(evaluator.evaluate, prediction_mask, reference_mask) + for i in range(4) + } + for future in tqdm( + as_completed(futures), total=len(futures), desc="Panoptica Evaluation" + ): result = future.result() if result is not None: print("Done") diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index e3c81bc..7ea63a5 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -33,7 +33,11 @@ def _calc_overlapping_labels( # instance_pairs = [(reference_arr, prediction_arr, i, j) for i, j in overlapping_indices] # (ref, pred) - return [(int(i % (max_ref)), int(i // (max_ref))) for i in np.unique(overlap_arr) if i > max_ref] + return [ + (int(i % (max_ref)), int(i // (max_ref))) + for i in np.unique(overlap_arr) + if i > max_ref + ] def _calc_matching_metric_of_overlapping_labels( @@ -63,8 +67,13 @@ def _calc_matching_metric_of_overlapping_labels( with Pool() as pool: mm_values = pool.starmap(matching_metric.value, instance_pairs) - mm_pairs = [(i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values)] - mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing) + mm_pairs = [ + (i, (instance_pairs[idx][2], instance_pairs[idx][3])) + for idx, i in enumerate(mm_values) + ] + mm_pairs = sorted( + mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing + ) return mm_pairs diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index 78f0f7e..aff01cb 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -42,7 +42,9 @@ class InstanceApproximator(SupportsConfig, metaclass=ABCMeta): """ @abstractmethod - def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> UnmatchedInstancePair | MatchedInstancePair: + def _approximate_instances( + self, semantic_pair: SemanticPair, **kwargs + ) -> UnmatchedInstancePair | MatchedInstancePair: """ Abstract method to be implemented by subclasses for instance approximation. @@ -56,7 +58,9 @@ def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmat pass def _yaml_repr(cls, node) -> dict: - raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) return {} def approximate_instances( @@ -77,11 +81,19 @@ def approximate_instances( """ # Check validity pred_labels, ref_labels = semantic_pair._pred_labels, semantic_pair._ref_labels - pred_label_range = (np.min(pred_labels), np.max(pred_labels)) if len(pred_labels) > 0 else (0, 0) - ref_label_range = (np.min(ref_labels), np.max(ref_labels)) if len(ref_labels) > 0 else (0, 0) + pred_label_range = ( + (np.min(pred_labels), np.max(pred_labels)) + if len(pred_labels) > 0 + else (0, 0) + ) + ref_label_range = ( + (np.min(ref_labels), np.max(ref_labels)) if len(ref_labels) > 0 else (0, 0) + ) # min_value = min(np.min(pred_label_range[0]), np.min(ref_label_range[0])) - assert min_value >= 0, "There are negative values in the semantic maps. This is not allowed!" + assert ( + min_value >= 0 + ), "There are negative values in the semantic maps. This is not allowed!" # Set dtype to smalles fitting uint max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1])) dtype = _get_smallest_fitting_uint(max_value) @@ -121,7 +133,9 @@ def __init__(self, cca_backend: CCABackend | None = None) -> None: """ self.cca_backend = cca_backend - def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> UnmatchedInstancePair: + def _approximate_instances( + self, semantic_pair: SemanticPair, **kwargs + ) -> UnmatchedInstancePair: """ Approximate instances using the connected components algorithm. @@ -134,7 +148,9 @@ def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmat """ cca_backend = self.cca_backend if cca_backend is None: - cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy + cca_backend = ( + CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy + ) assert cca_backend is not None empty_prediction = len(semantic_pair._pred_labels) == 0 @@ -145,10 +161,14 @@ def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmat else (semantic_pair._prediction_arr, 0) ) reference_arr, n_reference_instance = ( - _connected_components(semantic_pair._reference_arr, cca_backend) if not empty_reference else (semantic_pair._reference_arr, 0) + _connected_components(semantic_pair._reference_arr, cca_backend) + if not empty_reference + else (semantic_pair._reference_arr, 0) ) - dtype = _get_smallest_fitting_uint(max(prediction_arr.max(), reference_arr.max())) + dtype = _get_smallest_fitting_uint( + max(prediction_arr.max(), reference_arr.max()) + ) return UnmatchedInstancePair( prediction_arr=prediction_arr.astype(dtype), diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index cdb3e77..8fc6999 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -24,7 +24,9 @@ def evaluate_matched_instance( """ if decision_metric is not None: - assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" + assert decision_metric.name in [ + v.name for v in eval_metrics + ], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -36,14 +38,22 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] + instance_pairs = [ + (reference_arr, prediction_arr, ref_idx, eval_metrics) + for ref_idx in ref_matched_labels + ] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) + metric_dicts: list[dict[Metric, float]] = pool.starmap( + _evaluate_instance, instance_pairs + ) # TODO if instance matcher already gives matching metric, adapt here! for metric_dict in metric_dicts: if decision_metric is None or ( - decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) + decision_threshold is not None + and decision_metric.score_beats_threshold( + metric_dict[decision_metric], decision_threshold + ) ): for k, v in metric_dict.items(): score_dict[k].append(v) diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 269b770..6460d6d 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -37,19 +37,27 @@ def __init__( self.__evaluation_metrics = panoptica_evaluator.resulting_metric_keys # uses tsv - assert output_file.parent.exists(), f"Directory {str(output_file.parent)} does not exist" + assert ( + output_file.parent.exists() + ), f"Directory {str(output_file.parent)} does not exist" out_file_path = str(output_file) if not out_file_path.endswith(".tsv"): out_file_path += ".tsv" - out_buffer_file: Path = Path(out_file_path).parent.joinpath("panoptica_aggregator_tmp.tsv") + out_buffer_file: Path = Path(out_file_path).parent.joinpath( + "panoptica_aggregator_tmp.tsv" + ) self.__output_buffer_file = out_buffer_file Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) self.__output_file = out_file_path - header = ["subject_name"] + [f"{g}-{m}" for g in self.__class_group_names for m in self.__evaluation_metrics] + header = ["subject_name"] + [ + f"{g}-{m}" + for g in self.__class_group_names + for m in self.__evaluation_metrics + ] header_hash = hash("+".join(header)) if not output_file.exists(): @@ -58,7 +66,9 @@ def __init__( else: header_list = _read_first_row(output_file) # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file - assert header_hash == hash("+".join(header_list)), "Hash of header not the same! You are using a different setup!" + assert header_hash == hash( + "+".join(header_list) + ), "Hash of header not the same! You are using a different setup!" if out_buffer_file.exists(): os.remove(out_buffer_file) diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index d2fe8cf..cb6c46a 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -73,9 +73,13 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -107,7 +111,9 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" + assert isinstance( + processing_pair, self.__expected_input.value + ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -127,8 +133,12 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -151,7 +161,9 @@ def segmentation_class_groups_names(self) -> list[str]: @property def resulting_metric_keys(self) -> list[str]: if self.__resulting_metric_keys is None: - dummy_input = MatchedInstancePair(np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8)) + dummy_input = MatchedInstancePair( + np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8) + ) _, res, _ = self._evaluate_group( group_name="", label_group=LabelGroup(1, single_instance=False), @@ -181,7 +193,9 @@ def _evaluate_group( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -258,12 +272,22 @@ def panoptic_evaluate( # Crops away unecessary space of zeroes input_pair.crop_data() - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | EvaluateInstancePair | PanopticaResult = input_pair.copy() + processing_pair: ( + SemanticPair + | UnmatchedInstancePair + | MatchedInstancePair + | EvaluateInstancePair + | PanopticaResult + ) = input_pair.copy() # First Phase: Instance Approximation if isinstance(processing_pair, SemanticPair): - intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.SEMANTIC) - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + intermediate_steps_data.add_intermediate_arr_data( + processing_pair.copy(), InputType.SEMANTIC + ) + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -273,7 +297,9 @@ def panoptic_evaluate( # Second Phase: Instance Matching if isinstance(processing_pair, UnmatchedInstancePair): - intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.UNMATCHED_INSTANCE) + intermediate_steps_data.add_intermediate_arr_data( + processing_pair.copy(), InputType.UNMATCHED_INSTANCE + ) processing_pair = _handle_zero_instances_cases( processing_pair, eval_metrics=instance_metrics, @@ -284,7 +310,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -294,7 +322,9 @@ def panoptic_evaluate( # Third Phase: Instance Evaluation if isinstance(processing_pair, MatchedInstancePair): - intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.MATCHED_INSTANCE) + intermediate_steps_data.add_intermediate_arr_data( + processing_pair.copy(), InputType.MATCHED_INSTANCE + ) processing_pair = _handle_zero_instances_cases( processing_pair, eval_metrics=instance_metrics, diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 1b7b60d..9386a5c 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -252,19 +252,25 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) + self._list_metrics[m] = Evaluation_List_Metric( + m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result + ) # even if not available, set the global vars default_value = None was_calculated = False if m in self._global_metrics and arrays_present: - default_value = self._calc_global_bin_metric(m, pred_binary, ref_binary, do_binarize=False) + default_value = self._calc_global_bin_metric( + m, pred_binary, ref_binary, do_binarize=False + ) was_calculated = True self._add_metric( f"global_bin_{m.name.lower()}", MetricType.GLOBAL, - lambda x: MetricCouldNotBeComputedException(f"Global Metric {m} not set"), + lambda x: MetricCouldNotBeComputedException( + f"Global Metric {m} not set" + ), long_name="Global Binary " + m.value.long_name, default_value=default_value, was_calculated=was_calculated, @@ -292,7 +298,9 @@ def _calc_global_bin_metric( prediction_empty = pred_binary.sum() == 0 reference_empty = ref_binary.sum() == 0 if prediction_empty or reference_empty: - is_edgecase, result = self._edge_case_handler.handle_zero_tp(metric, 0, int(prediction_empty), int(reference_empty)) + is_edgecase, result = self._edge_case_handler.handle_zero_tp( + metric, 0, int(prediction_empty), int(reference_empty) + ) if is_edgecase: return result @@ -378,7 +386,11 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } @property def evaluation_metrics(self): @@ -388,7 +400,9 @@ def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -404,7 +418,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -419,7 +435,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -545,7 +563,9 @@ def function_template(res: PanopticaResult): if metric not in res._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") if res.tp == 0: - is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances) + is_edgecase, result = res._edge_case_handler.handle_zero_tp( + metric, res.tp, res.num_pred_instances, res.num_ref_instances + ) if is_edgecase: return result pred_binary = res._prediction_arr.copy() diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index b21dd52..66394f4 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -44,7 +44,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -78,13 +80,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric) -> list[float]: """Returns the list of values for given group and metric @@ -115,7 +123,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_across_groups(self, metric): """Given metric, gives list of all values (even across groups!) Treat with care! @@ -132,7 +143,10 @@ def get_across_groups(self, metric): return values def get_summary_dict(self): - return {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } def get_summary(self, group, metric): # TODO maybe more here? range, stuff like that @@ -196,7 +210,9 @@ def make_curve_over_setups( groups = [groups] # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -261,18 +277,37 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(graph_name).mean() df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data["mean"] = df_data[graph_name].apply( + lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) + ) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip(df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): - fig.add_trace(go.Box(y=df_data.query(f'{graph_name} == "{e}"')[score], name=e, orientation=orientation)) + fig.add_trace( + go.Box( + y=df_data.query(f'{graph_name} == "{e}"')[score], + name=e, + orientation=orientation, + ) + ) else: - fig = px.strip(df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): - fig.add_trace(go.Box(x=df_data.query(f'{graph_name} == "{e}"')[score], name=e, orientation=orientation, boxpoints=False)) + fig.add_trace( + go.Box( + x=df_data.query(f'{graph_name} == "{e}"')[score], + name=e, + orientation=orientation, + boxpoints=False, + ) + ) fig.update_layout( autosize=False, width=width, diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index a865683..e7dd5d1 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -57,12 +57,26 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) - def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -133,7 +147,9 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[ + Metric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -160,7 +176,9 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) return self.__listmetric_zeroTP_handling[metric]( tp=tp, diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 6020365..6ef4640 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -25,7 +25,9 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: + def __init__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None + ) -> None: """Initializes a general Processing Pair Args: @@ -38,8 +40,12 @@ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore - self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore + self._ref_labels: tuple[int, ...] = tuple( + _unique_without_zeros(reference_arr) + ) # type:ignore + self._pred_labels: tuple[int, ...] = tuple( + _unique_without_zeros(prediction_arr) + ) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -56,25 +62,41 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) + ( + print( + f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" + ) + if verbose + else None + ) self.is_cropped = True def uncrop_data(self, verbose: bool = False): if self.is_cropped == False: return - assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" + assert ( + self.uncropped_shape is not None + ), "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) + ( + print( + f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" + ) + if verbose + else None + ) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype( + type, int_type + ), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -153,7 +175,9 @@ def copy(self): ) # type:ignore -def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): +def _check_array_integrity( + prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None +): """ Check the integrity of two numpy arrays. @@ -175,8 +199,12 @@ def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert ( + prediction_arr.shape == reference_arr.shape + ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert ( + prediction_arr.dtype == reference_arr.dtype + ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -259,11 +287,15 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) + missed_reference_labels = list( + [i for i in self._ref_labels if i not in self._pred_labels] + ) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) + missed_prediction_labels = list( + [i for i in self._pred_labels if i not in self._ref_labels] + ) self.missed_prediction_labels = missed_prediction_labels @property @@ -300,7 +332,9 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) @@ -309,7 +343,9 @@ def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} - def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType): + def add_intermediate_arr_data( + self, processing_pair: _ProcessingPair, inputtype: InputType + ): type_name = inputtype.name self.add_intermediate_data(type_name, processing_pair) @@ -319,26 +355,36 @@ def add_intermediate_data(self, key, value): @property def original_prediction_arr(self): - assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps" + assert ( + self._original_input is not None + ), "Original prediction_arr is None, there are no intermediate steps" return self._original_input.prediction_arr @property def original_reference_arr(self): - assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps" + assert ( + self._original_input is not None + ), "Original reference_arr is None, there are no intermediate steps" return self._original_input.reference_arr def prediction_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" + assert isinstance( + procpair, _ProcessingPair + ), f"step {type_name} is not a processing pair, error" return procpair.prediction_arr def reference_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" + assert isinstance( + procpair, _ProcessingPair + ), f"step {type_name} is not a processing pair, error" return procpair.reference_arr def __getitem__(self, key): - assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?" + assert ( + key in self._intermediatesteps + ), f"key {key} not in intermediate steps, maybe the step was skipped?" return self._intermediatesteps[key] From c49969320e0bd3e8b9668feec5799dc465b32c65 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:21:48 +0000 Subject: [PATCH 09/11] fixed import issue --- examples/example_spine_statistics.py | 3 ++- panoptica/__init__.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/example_spine_statistics.py b/examples/example_spine_statistics.py index c7ca6aa..d89e0ab 100644 --- a/examples/example_spine_statistics.py +++ b/examples/example_spine_statistics.py @@ -1,7 +1,8 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import Panoptica_Evaluator, Panoptica_Aggregator, make_curve_over_setups +from panoptica import Panoptica_Evaluator, Panoptica_Aggregator +from panoptica.panoptica_statistics import make_curve_over_setups from pathlib import Path from panoptica.utils import NonDaemonicPool from joblib import delayed, Parallel diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 41648f0..4589be4 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,7 +3,7 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.panoptica_statistics import Panoptica_Statistic, make_curve_over_setups +from panoptica.panoptica_statistics import Panoptica_Statistic from panoptica.panoptica_aggregator import Panoptica_Aggregator from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult From 03a25ea6f3b5f25a98d218c38d80b0b39d2e006b Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 28 Aug 2024 10:37:22 +0000 Subject: [PATCH 10/11] removed import assert --- panoptica/panoptica_statistics.py | 43 +++++++------------------------ 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 66394f4..1edc100 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -11,7 +11,6 @@ except Exception as e: print(e) print("OPTIONAL PACKAGE MISSING") - assert False class Panoptica_Statistic: @@ -44,9 +43,7 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert ( - header[0] == "subject_name" - ), "First column is not subject_names, something wrong with the file?" + assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -80,19 +77,13 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert ( - group in self.__groupnames - ), f"group {group} not existent, only got groups {self.__groupnames}" + assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert ( - metric in self.__metricnames - ), f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert ( - subjectname in self.__subj_names - ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric) -> list[float]: """Returns the list of values for given group and metric @@ -123,10 +114,7 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return { - g: {m: self.get(g, m)[sidx] for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} def get_across_groups(self, metric): """Given metric, gives list of all values (even across groups!) Treat with care! @@ -143,10 +131,7 @@ def get_across_groups(self, metric): return values def get_summary_dict(self): - return { - g: {m: self.get_summary(g, m) for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} def get_summary(self, group, metric): # TODO maybe more here? range, stuff like that @@ -210,9 +195,7 @@ def make_curve_over_setups( groups = [groups] # for setupname, stat in statistics_dict.items(): - assert ( - metric in stat.metricnames - ), f"metric {metric} not in statistic obj {setupname}" + assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -277,14 +260,10 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(graph_name).mean() df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply( - lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) - ) + df_data["mean"] = df_data[graph_name].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip( - df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation - ) + fig = px.strip(df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( @@ -295,9 +274,7 @@ def plot_box( ) ) else: - fig = px.strip( - df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation - ) + fig = px.strip(df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( From fad4d576acab4fae68ff2f65c1bc6bf94506b47b Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 10:38:53 +0000 Subject: [PATCH 11/11] Autoformat with black --- panoptica/panoptica_statistics.py | 42 +++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 1edc100..fd998be 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -43,7 +43,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -77,13 +79,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric) -> list[float]: """Returns the list of values for given group and metric @@ -114,7 +122,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_across_groups(self, metric): """Given metric, gives list of all values (even across groups!) Treat with care! @@ -131,7 +142,10 @@ def get_across_groups(self, metric): return values def get_summary_dict(self): - return {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } def get_summary(self, group, metric): # TODO maybe more here? range, stuff like that @@ -195,7 +209,9 @@ def make_curve_over_setups( groups = [groups] # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -260,10 +276,14 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(graph_name).mean() df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data["mean"] = df_data[graph_name].apply( + lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) + ) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip(df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( @@ -274,7 +294,9 @@ def plot_box( ) ) else: - fig = px.strip(df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace(