diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index e2f45b1..0ec88f6 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -21,6 +21,7 @@ instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), verbose=True, + log_times=True, ) diff --git a/examples/example_spine_statistics.py b/examples/example_spine_statistics.py index 47da69b..1c215c4 100644 --- a/examples/example_spine_statistics.py +++ b/examples/example_spine_statistics.py @@ -1,7 +1,14 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import Panoptica_Evaluator, Panoptica_Aggregator +from panoptica import ( + Panoptica_Evaluator, + Panoptica_Aggregator, + InputType, + NaiveThresholdMatching, + Metric, +) +from panoptica.utils import SegmentationClassGroups, LabelGroup from panoptica.panoptica_statistics import make_curve_over_setups from pathlib import Path from panoptica.utils import NonDaemonicPool @@ -21,8 +28,10 @@ evaluator = Panoptica_Aggregator( Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance"), Path(__file__).parent.joinpath("spine_example.tsv"), + log_times=True, ) +evaluator.panoptica_evaluator.set_log_group_times(True) if __name__ == "__main__": parallel_opt = "future" # none, pool, joblib, future diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 7ea63a5..3b6a6b1 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -2,7 +2,7 @@ from multiprocessing import Pool import numpy as np - +import math from panoptica.utils.constants import CCABackend from panoptica.utils.numpy_utils import _get_bbox_nd @@ -147,3 +147,13 @@ def _get_paired_crop( if combined.sum() == 0: combined += 1 return _get_bbox_nd(combined, px_dist=px_pad) + + +def _round_to_n(value: float | int, n_significant_digits: int = 2): + return ( + value + if value == 0 + else round( + value, -int(math.floor(math.log10(abs(value)))) + (n_significant_digits - 1) + ) + ) diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 8fc6999..caa96ea 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -3,6 +3,7 @@ from panoptica.metrics import Metric from panoptica.utils.processing_pair import MatchedInstancePair, EvaluateInstancePair +from panoptica._functionals import _get_paired_crop def evaluate_matched_instance( @@ -42,6 +43,8 @@ def evaluate_matched_instance( (reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels ] + + # metric_dicts: list[dict[Metric, float]] = [_evaluate_instance(*i) for i in instance_pairs] with Pool() as pool: metric_dicts: list[dict[Metric, float]] = pool.starmap( _evaluate_instance, instance_pairs @@ -89,6 +92,16 @@ def _evaluate_instance( """ ref_arr = reference_arr == ref_idx pred_arr = prediction_arr == ref_idx + + # Crop down for speedup + crop = _get_paired_crop( + pred_arr, + ref_arr, + ) + + ref_arr = ref_arr[crop] + pred_arr = pred_arr[crop] + result: dict[Metric, float] = {} if ref_arr.sum() == 0 or pred_arr.sum() == 0: return result diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 0e373d7..5fc8234 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -12,6 +12,8 @@ filelock = Lock() inevalfilelock = Lock() +COMPUTATION_TIME_KEY = "computation_time" + # class Panoptica_Aggregator: @@ -23,6 +25,7 @@ def __init__( self, panoptica_evaluator: Panoptica_Evaluator, output_file: Path | str, + log_times: bool = False, continue_file: bool = True, ): """ @@ -32,10 +35,11 @@ def __init__( """ self.__panoptica_evaluator = panoptica_evaluator self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names - self.__output_file = None - self.__output_buffer_file = None self.__evaluation_metrics = panoptica_evaluator.resulting_metric_keys + if log_times: + self.__evaluation_metrics.append(COMPUTATION_TIME_KEY) + if isinstance(output_file, str): output_file = Path(output_file) # uses tsv @@ -44,8 +48,16 @@ def __init__( ), f"Directory {str(output_file.parent)} does not exist" out_file_path = str(output_file) - if not out_file_path.endswith(".tsv"): - out_file_path += ".tsv" + + # extension + if "." in out_file_path: + # extension exists + extension = out_file_path.split(".")[-1] + assert ( + extension == "tsv" + ), f"You gave the extension {extension}, but currently only .tsv is supported. Either delete it or give .tsv as extension" + else: + out_file_path += ".tsv" # add extension out_buffer_file: Path = Path(out_file_path).parent.joinpath( "panoptica_aggregator_tmp.tsv" @@ -67,10 +79,15 @@ def __init__( _write_content(output_file, [header]) else: header_list = _read_first_row(output_file) - # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file - assert header_hash == hash( - "+".join(header_list) - ), "Hash of header not the same! You are using a different setup!" + if len(header_list) == 0: + # empty file + print("Output file given is empty, will start with header") + continue_file = True + else: + # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file + assert header_hash == hash( + "+".join(header_list) + ), "Hash of header not the same! You are using a different setup!" if out_buffer_file.exists(): os.remove(out_buffer_file) @@ -85,8 +102,8 @@ def __init__( atexit.register(self.__exist_handler) def __exist_handler(self): - if os.path.exists(self.__output_buffer_file): - os.remove(self.__output_buffer_file) + if self.__output_buffer_file is not None and self.__output_buffer_file.exists(): + os.remove(str(self.__output_buffer_file)) def make_statistic(self) -> Panoptica_Statistic: with filelock: @@ -140,6 +157,8 @@ def _save_one_subject(self, subject_name, result_grouped): for groupname in self.__class_group_names: result: PanopticaResult = result_grouped[groupname][0] result_dict = result.to_dict() + if result.computation_time is not None: + result_dict[COMPUTATION_TIME_KEY] = result.computation_time del result for e in self.__evaluation_metrics: @@ -153,7 +172,9 @@ def panoptica_evaluator(self): return self.__panoptica_evaluator -def _read_first_row(file: str): +def _read_first_row(file: str | Path): + if isinstance(file, Path): + file = str(file) # NOT THREAD SAFE BY ITSELF! with open(str(file), "r", encoding="utf8", newline="") as tsvfile: rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") @@ -167,8 +188,10 @@ def _read_first_row(file: str): return row -def _load_first_column_entries(file: str): +def _load_first_column_entries(file: str | Path): # NOT THREAD SAFE BY ITSELF! + if isinstance(file, Path): + file = str(file) with open(str(file), "r", encoding="utf8", newline="") as tsvfile: rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") @@ -184,7 +207,9 @@ def _load_first_column_entries(file: str): return id_list -def _write_content(file: str, content: list[list[str]]): +def _write_content(file: str | Path, content: list[list[str]]): + if isinstance(file, Path): + file = str(file) # NOT THREAD SAFE BY ITSELF! with open(str(file), "a", encoding="utf8", newline="") as tsvfile: writer = csv.writer(tsvfile, delimiter="\t", lineterminator="\n") diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 36ec399..9730501 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -19,9 +19,11 @@ ) import numpy as np from panoptica.utils.config import SupportsConfig -from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup - -NO_GROUP_KEY = "ungrouped" +from panoptica.utils.segmentation_class import ( + SegmentationClassGroups, + LabelGroup, + _NoSegmentationClassGroups, +) class Panoptica_Evaluator(SupportsConfig): @@ -42,6 +44,7 @@ def __init__( global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, + save_group_times: bool = False, log_times: bool = False, verbose: bool = False, ) -> None: @@ -70,7 +73,10 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold self.__resulting_metric_keys = None + self.__save_group_times = save_group_times + if segmentation_class_groups is None: + segmentation_class_groups = _NoSegmentationClassGroups() self.__segmentation_class_groups = segmentation_class_groups self.__edge_case_handler = ( @@ -96,6 +102,7 @@ def _yaml_repr(cls, node) -> dict: "global_metrics": node.__global_metrics, "decision_metric": node.__decision_metric, "decision_threshold": node.__decision_threshold, + "save_group_times": node.__save_group_times, "log_times": node.__log_times, "verbose": node.__verbose, } @@ -107,6 +114,7 @@ def evaluate( prediction_arr: np.ndarray, reference_arr: np.ndarray, result_all: bool = True, + save_group_times: bool | None = None, log_times: bool | None = None, verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]: @@ -115,24 +123,6 @@ def evaluate( processing_pair, self.__expected_input.value ), f"input not of expected type {self.__expected_input}" - if self.__segmentation_class_groups is None: - return { - NO_GROUP_KEY: panoptic_evaluate( - input_pair=processing_pair, - edge_case_handler=self.__edge_case_handler, - instance_approximator=self.__instance_approximator, - instance_matcher=self.__instance_matcher, - instance_metrics=self.__eval_metrics, - global_metrics=self.__global_metrics, - decision_metric=self.__decision_metric, - decision_threshold=self.__decision_threshold, - result_all=result_all, - log_times=self.__log_times if log_times is None else log_times, - verbose=True if verbose is None else verbose, - verbose_calc=self.__verbose if verbose is None else verbose, - ) - } - self.__segmentation_class_groups.has_defined_labels_for( processing_pair.prediction_arr, raise_error=True ) @@ -140,13 +130,18 @@ def evaluate( processing_pair.reference_arr, raise_error=True ) - result_grouped = {} + result_grouped: dict[str, tuple[PanopticaResult, IntermediateStepsData]] = {} for group_name, label_group in self.__segmentation_class_groups.items(): result_grouped[group_name] = self._evaluate_group( group_name, label_group, processing_pair, result_all, + save_group_times=( + self.__save_group_times + if save_group_times is None + else save_group_times + ), log_times=log_times, verbose=verbose, )[1:] @@ -154,10 +149,14 @@ def evaluate( @property def segmentation_class_groups_names(self) -> list[str]: - if self.__segmentation_class_groups is None: - return [NO_GROUP_KEY] return self.__segmentation_class_groups.keys() + def set_log_group_times(self, should_save: bool): + self.__save_group_times = should_save + + def _set_instance_approximator(self, instance_approximator: InstanceApproximator): + self.__instance_approximator = instance_approximator + def _set_instance_matcher(self, matcher: InstanceMatchingAlgorithm): self.__instance_matcher = matcher @@ -172,6 +171,7 @@ def resulting_metric_keys(self) -> list[str]: label_group=LabelGroup(1, single_instance=False), processing_pair=dummy_input, result_all=True, + save_group_times=False, log_times=False, verbose=False, ) @@ -187,8 +187,11 @@ def _evaluate_group( result_all: bool = True, verbose: bool | None = None, log_times: bool | None = None, + save_group_times: bool = False, ): assert isinstance(label_group, LabelGroup) + if self.__save_group_times: + start_time = perf_counter() prediction_arr_grouped = label_group(processing_pair.prediction_arr) reference_arr_grouped = label_group(processing_pair.reference_arr) @@ -219,6 +222,9 @@ def _evaluate_group( verbose=True if verbose is None else verbose, verbose_calc=self.__verbose if verbose is None else verbose, ) + if save_group_times: + duration = perf_counter() - start_time + result.computation_time = duration return group_name, result, intermediate_steps_data diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 9386a5c..2123981 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -16,6 +16,7 @@ class PanopticaResult(object): + def __init__( self, reference_arr: np.ndarray, @@ -26,6 +27,7 @@ def __init__( list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, global_metrics: list[Metric] = [], + computation_time: float | None = None, ): """Result object for Panoptica, contains all calculatable metrics @@ -38,13 +40,14 @@ def __init__( list_metrics (dict[Metric, list[float]]): dictionary containing the metrics for each TP edge_case_handler (EdgeCaseHandler): EdgeCaseHandler object that handles various forms of edge cases """ + self._evaluation_metrics: dict[str, Evaluation_Metric] = {} self._edge_case_handler = edge_case_handler empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._global_metrics: list[Metric] = global_metrics + self.computation_time = computation_time ###################### # Evaluation Metrics # ###################### - self._evaluation_metrics: dict[str, Evaluation_Metric] = {} # # region Already Calculated self.num_ref_instances: int @@ -433,19 +436,24 @@ def __getattribute__(self, __name: str) -> Any: pass else: raise e - if attr is None: - if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) - elif not self._evaluation_metrics[__name]._was_calculated: - value = self._calc_metric(__name) - setattr(self, __name, value) - if isinstance(value, MetricCouldNotBeComputedException): - raise value - return value - else: + if __name == "_evaluation_metrics": return attr + if ( + object.__getattribute__(self, "_evaluation_metrics") is not None + and __name in self._evaluation_metrics.keys() + ): + if attr is None: + if self._evaluation_metrics[__name]._error: + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) + elif not self._evaluation_metrics[__name]._was_calculated: + value = self._calc_metric(__name) + setattr(self, __name, value) + if isinstance(value, MetricCouldNotBeComputedException): + raise value + return value + return attr ######################### diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 8bde861..c37f00e 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -79,11 +79,42 @@ def _yaml_repr(cls, node): "single_instance": node.single_instance, } - # @classmethod - # def to_yaml(cls, representer, node): - # return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) - - # @classmethod - # def from_yaml(cls, constructor, node): - # data = constructor.construct_mapping(node, deep=True) - # return cls(**data) + +class _LabelGroupAny(LabelGroup): + def __init__(self) -> None: + pass + + @property + def value_labels(self) -> list[int]: + raise AssertionError("LabelGroupAny has no value_labels, it is all labels") + + @property + def single_instance(self) -> bool: + return False + + def __call__( + self, + array: np.ndarray, + set_to_binary: bool = False, + ) -> np.ndarray: + """Extracts the labels of this class + + Args: + array (np.ndarray): Array to extract the segmentation group labels from + set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + + Returns: + np.ndarray: Array containing only the labels of this segmentation group + """ + array = array.copy() + return array + + def __str__(self) -> str: + return f"LabelGroupAny" + + def __repr__(self) -> str: + return str(self) + + @classmethod + def _yaml_repr(cls, node): + return {} diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 96f380e..e968761 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -1,7 +1,9 @@ import numpy as np from pathlib import Path from panoptica.utils.config import SupportsConfig -from panoptica.utils.label_group import LabelGroup +from panoptica.utils.label_group import LabelGroup, _LabelGroupAny + +NO_GROUP_KEY = "ungrouped" class SegmentationClassGroups(SupportsConfig): @@ -96,3 +98,39 @@ def list_duplicates(seq): seen_twice = set(x for x in seq if x in seen or seen_add(x)) # turn the set into a list (as requested) return list(seen_twice) + + +class _NoSegmentationClassGroups(SegmentationClassGroups): + def __init__(self) -> None: + self.__group_dictionary = {NO_GROUP_KEY: _LabelGroupAny()} + + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): + return True + + def __str__(self) -> str: + text = "NoSegmentationClassGroups" + return text + + def __contains__(self, item): + return item in self.__group_dictionary + + def __getitem__(self, key): + return self.__group_dictionary[key] + + def __iter__(self): + yield from self.__group_dictionary + + def keys(self) -> list[str]: + return list(self.__group_dictionary.keys()) + + @property + def labels(self): + raise Exception( + "_NoSegmentationClassGroups has no explicit definition of labels" + ) + + @classmethod + def _yaml_repr(cls, node): + return {}