Skip to content

Commit

Permalink
Merge pull request #129 from BrainLesion/speedup
Browse files Browse the repository at this point in the history
Speedup
  • Loading branch information
Hendrik-code authored Sep 5, 2024
2 parents 7d0b12a + b524f8b commit ceaf057
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 61 deletions.
1 change: 1 addition & 0 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
verbose=True,
log_times=True,
)


Expand Down
11 changes: 10 additions & 1 deletion examples/example_spine_statistics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath

from panoptica import Panoptica_Evaluator, Panoptica_Aggregator
from panoptica import (
Panoptica_Evaluator,
Panoptica_Aggregator,
InputType,
NaiveThresholdMatching,
Metric,
)
from panoptica.utils import SegmentationClassGroups, LabelGroup
from panoptica.panoptica_statistics import make_curve_over_setups
from pathlib import Path
from panoptica.utils import NonDaemonicPool
Expand All @@ -21,8 +28,10 @@
evaluator = Panoptica_Aggregator(
Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance"),
Path(__file__).parent.joinpath("spine_example.tsv"),
log_times=True,
)

evaluator.panoptica_evaluator.set_log_group_times(True)

if __name__ == "__main__":
parallel_opt = "future" # none, pool, joblib, future
Expand Down
12 changes: 11 additions & 1 deletion panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from multiprocessing import Pool

import numpy as np

import math
from panoptica.utils.constants import CCABackend
from panoptica.utils.numpy_utils import _get_bbox_nd

Expand Down Expand Up @@ -147,3 +147,13 @@ def _get_paired_crop(
if combined.sum() == 0:
combined += 1
return _get_bbox_nd(combined, px_dist=px_pad)


def _round_to_n(value: float | int, n_significant_digits: int = 2):
return (
value
if value == 0
else round(
value, -int(math.floor(math.log10(abs(value)))) + (n_significant_digits - 1)
)
)
13 changes: 13 additions & 0 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from panoptica.metrics import Metric
from panoptica.utils.processing_pair import MatchedInstancePair, EvaluateInstancePair
from panoptica._functionals import _get_paired_crop


def evaluate_matched_instance(
Expand Down Expand Up @@ -42,6 +43,8 @@ def evaluate_matched_instance(
(reference_arr, prediction_arr, ref_idx, eval_metrics)
for ref_idx in ref_matched_labels
]

# metric_dicts: list[dict[Metric, float]] = [_evaluate_instance(*i) for i in instance_pairs]
with Pool() as pool:
metric_dicts: list[dict[Metric, float]] = pool.starmap(
_evaluate_instance, instance_pairs
Expand Down Expand Up @@ -89,6 +92,16 @@ def _evaluate_instance(
"""
ref_arr = reference_arr == ref_idx
pred_arr = prediction_arr == ref_idx

# Crop down for speedup
crop = _get_paired_crop(
pred_arr,
ref_arr,
)

ref_arr = ref_arr[crop]
pred_arr = pred_arr[crop]

result: dict[Metric, float] = {}
if ref_arr.sum() == 0 or pred_arr.sum() == 0:
return result
Expand Down
51 changes: 38 additions & 13 deletions panoptica/panoptica_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
filelock = Lock()
inevalfilelock = Lock()

COMPUTATION_TIME_KEY = "computation_time"


#
class Panoptica_Aggregator:
Expand All @@ -23,6 +25,7 @@ def __init__(
self,
panoptica_evaluator: Panoptica_Evaluator,
output_file: Path | str,
log_times: bool = False,
continue_file: bool = True,
):
"""
Expand All @@ -32,10 +35,11 @@ def __init__(
"""
self.__panoptica_evaluator = panoptica_evaluator
self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names
self.__output_file = None
self.__output_buffer_file = None
self.__evaluation_metrics = panoptica_evaluator.resulting_metric_keys

if log_times:
self.__evaluation_metrics.append(COMPUTATION_TIME_KEY)

if isinstance(output_file, str):
output_file = Path(output_file)
# uses tsv
Expand All @@ -44,8 +48,16 @@ def __init__(
), f"Directory {str(output_file.parent)} does not exist"

out_file_path = str(output_file)
if not out_file_path.endswith(".tsv"):
out_file_path += ".tsv"

# extension
if "." in out_file_path:
# extension exists
extension = out_file_path.split(".")[-1]
assert (
extension == "tsv"
), f"You gave the extension {extension}, but currently only .tsv is supported. Either delete it or give .tsv as extension"
else:
out_file_path += ".tsv" # add extension

out_buffer_file: Path = Path(out_file_path).parent.joinpath(
"panoptica_aggregator_tmp.tsv"
Expand All @@ -67,10 +79,15 @@ def __init__(
_write_content(output_file, [header])
else:
header_list = _read_first_row(output_file)
# TODO should also hash panoptica_evaluator just to make sure! and then save into header of file
assert header_hash == hash(
"+".join(header_list)
), "Hash of header not the same! You are using a different setup!"
if len(header_list) == 0:
# empty file
print("Output file given is empty, will start with header")
continue_file = True
else:
# TODO should also hash panoptica_evaluator just to make sure! and then save into header of file
assert header_hash == hash(
"+".join(header_list)
), "Hash of header not the same! You are using a different setup!"

if out_buffer_file.exists():
os.remove(out_buffer_file)
Expand All @@ -85,8 +102,8 @@ def __init__(
atexit.register(self.__exist_handler)

def __exist_handler(self):
if os.path.exists(self.__output_buffer_file):
os.remove(self.__output_buffer_file)
if self.__output_buffer_file is not None and self.__output_buffer_file.exists():
os.remove(str(self.__output_buffer_file))

def make_statistic(self) -> Panoptica_Statistic:
with filelock:
Expand Down Expand Up @@ -140,6 +157,8 @@ def _save_one_subject(self, subject_name, result_grouped):
for groupname in self.__class_group_names:
result: PanopticaResult = result_grouped[groupname][0]
result_dict = result.to_dict()
if result.computation_time is not None:
result_dict[COMPUTATION_TIME_KEY] = result.computation_time
del result

for e in self.__evaluation_metrics:
Expand All @@ -153,7 +172,9 @@ def panoptica_evaluator(self):
return self.__panoptica_evaluator


def _read_first_row(file: str):
def _read_first_row(file: str | Path):
if isinstance(file, Path):
file = str(file)
# NOT THREAD SAFE BY ITSELF!
with open(str(file), "r", encoding="utf8", newline="") as tsvfile:
rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n")
Expand All @@ -167,8 +188,10 @@ def _read_first_row(file: str):
return row


def _load_first_column_entries(file: str):
def _load_first_column_entries(file: str | Path):
# NOT THREAD SAFE BY ITSELF!
if isinstance(file, Path):
file = str(file)
with open(str(file), "r", encoding="utf8", newline="") as tsvfile:
rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n")

Expand All @@ -184,7 +207,9 @@ def _load_first_column_entries(file: str):
return id_list


def _write_content(file: str, content: list[list[str]]):
def _write_content(file: str | Path, content: list[list[str]]):
if isinstance(file, Path):
file = str(file)
# NOT THREAD SAFE BY ITSELF!
with open(str(file), "a", encoding="utf8", newline="") as tsvfile:
writer = csv.writer(tsvfile, delimiter="\t", lineterminator="\n")
Expand Down
54 changes: 30 additions & 24 deletions panoptica/panoptica_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
)
import numpy as np
from panoptica.utils.config import SupportsConfig
from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup

NO_GROUP_KEY = "ungrouped"
from panoptica.utils.segmentation_class import (
SegmentationClassGroups,
LabelGroup,
_NoSegmentationClassGroups,
)


class Panoptica_Evaluator(SupportsConfig):
Expand All @@ -42,6 +44,7 @@ def __init__(
global_metrics: list[Metric] = [Metric.DSC],
decision_metric: Metric | None = None,
decision_threshold: float | None = None,
save_group_times: bool = False,
log_times: bool = False,
verbose: bool = False,
) -> None:
Expand Down Expand Up @@ -70,7 +73,10 @@ def __init__(
self.__decision_metric = decision_metric
self.__decision_threshold = decision_threshold
self.__resulting_metric_keys = None
self.__save_group_times = save_group_times

if segmentation_class_groups is None:
segmentation_class_groups = _NoSegmentationClassGroups()
self.__segmentation_class_groups = segmentation_class_groups

self.__edge_case_handler = (
Expand All @@ -96,6 +102,7 @@ def _yaml_repr(cls, node) -> dict:
"global_metrics": node.__global_metrics,
"decision_metric": node.__decision_metric,
"decision_threshold": node.__decision_threshold,
"save_group_times": node.__save_group_times,
"log_times": node.__log_times,
"verbose": node.__verbose,
}
Expand All @@ -107,6 +114,7 @@ def evaluate(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
result_all: bool = True,
save_group_times: bool | None = None,
log_times: bool | None = None,
verbose: bool | None = None,
) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]:
Expand All @@ -115,49 +123,40 @@ def evaluate(
processing_pair, self.__expected_input.value
), f"input not of expected type {self.__expected_input}"

if self.__segmentation_class_groups is None:
return {
NO_GROUP_KEY: panoptic_evaluate(
input_pair=processing_pair,
edge_case_handler=self.__edge_case_handler,
instance_approximator=self.__instance_approximator,
instance_matcher=self.__instance_matcher,
instance_metrics=self.__eval_metrics,
global_metrics=self.__global_metrics,
decision_metric=self.__decision_metric,
decision_threshold=self.__decision_threshold,
result_all=result_all,
log_times=self.__log_times if log_times is None else log_times,
verbose=True if verbose is None else verbose,
verbose_calc=self.__verbose if verbose is None else verbose,
)
}

self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.prediction_arr, raise_error=True
)
self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.reference_arr, raise_error=True
)

result_grouped = {}
result_grouped: dict[str, tuple[PanopticaResult, IntermediateStepsData]] = {}
for group_name, label_group in self.__segmentation_class_groups.items():
result_grouped[group_name] = self._evaluate_group(
group_name,
label_group,
processing_pair,
result_all,
save_group_times=(
self.__save_group_times
if save_group_times is None
else save_group_times
),
log_times=log_times,
verbose=verbose,
)[1:]
return result_grouped

@property
def segmentation_class_groups_names(self) -> list[str]:
if self.__segmentation_class_groups is None:
return [NO_GROUP_KEY]
return self.__segmentation_class_groups.keys()

def set_log_group_times(self, should_save: bool):
self.__save_group_times = should_save

def _set_instance_approximator(self, instance_approximator: InstanceApproximator):
self.__instance_approximator = instance_approximator

def _set_instance_matcher(self, matcher: InstanceMatchingAlgorithm):
self.__instance_matcher = matcher

Expand All @@ -172,6 +171,7 @@ def resulting_metric_keys(self) -> list[str]:
label_group=LabelGroup(1, single_instance=False),
processing_pair=dummy_input,
result_all=True,
save_group_times=False,
log_times=False,
verbose=False,
)
Expand All @@ -187,8 +187,11 @@ def _evaluate_group(
result_all: bool = True,
verbose: bool | None = None,
log_times: bool | None = None,
save_group_times: bool = False,
):
assert isinstance(label_group, LabelGroup)
if self.__save_group_times:
start_time = perf_counter()

prediction_arr_grouped = label_group(processing_pair.prediction_arr)
reference_arr_grouped = label_group(processing_pair.reference_arr)
Expand Down Expand Up @@ -219,6 +222,9 @@ def _evaluate_group(
verbose=True if verbose is None else verbose,
verbose_calc=self.__verbose if verbose is None else verbose,
)
if save_group_times:
duration = perf_counter() - start_time
result.computation_time = duration
return group_name, result, intermediate_steps_data


Expand Down
Loading

0 comments on commit ceaf057

Please sign in to comment.