Skip to content

Commit

Permalink
Merge pull request #121 from BrainLesion/101
Browse files Browse the repository at this point in the history
Aggregator Functionality
  • Loading branch information
Hendrik-code authored Aug 28, 2024
2 parents 743b33c + fad4d57 commit 319aa5d
Show file tree
Hide file tree
Showing 18 changed files with 889 additions and 48 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,8 @@ cython_debug/
.DS_Store

.vscode
*.png
*.tsv
*.csv
*.nii.gz
*.json
81 changes: 81 additions & 0 deletions examples/example_spine_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath

from panoptica import Panoptica_Evaluator, Panoptica_Aggregator
from panoptica.panoptica_statistics import make_curve_over_setups
from pathlib import Path
from panoptica.utils import NonDaemonicPool
from joblib import delayed, Parallel
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from multiprocessing import set_start_method


# set_start_method("fork")

directory = turbopath(__file__).parent

reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

evaluator = Panoptica_Aggregator(
Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance"),
Path(__file__).parent.joinpath("spine_example.tsv"),
)


if __name__ == "__main__":
parallel_opt = "None" # none, pool, joblib, future
#
parallel_opt = parallel_opt.lower()

if parallel_opt == "pool":
args = [
(prediction_mask, reference_mask, "sample1"),
(prediction_mask, reference_mask, "sample2"),
(prediction_mask, reference_mask, "sample3"),
(prediction_mask, reference_mask, "sample4"),
]
with NonDaemonicPool() as pool:
pool.starmap(evaluator.evaluate, args)
elif parallel_opt == "none":
for i in range(4):
results = evaluator.evaluate(prediction_mask, reference_mask, f"sample{i}")
elif parallel_opt == "joblib":
Parallel(n_jobs=4, backend="threading")(
delayed(evaluator.evaluate)(prediction_mask, reference_mask)
for i in range(4)
)
elif parallel_opt == "future":
with ProcessPoolExecutor() as executor:
futures = {
executor.submit(evaluator.evaluate, prediction_mask, reference_mask)
for i in range(4)
}
for future in tqdm(
as_completed(futures), total=len(futures), desc="Panoptica Evaluation"
):
result = future.result()
if result is not None:
print("Done")

panoptic_statistic = evaluator.make_statistic()
panoptic_statistic.print_summary()

fig = panoptic_statistic.get_summary_figure("sq_dsc", horizontal=True)
out_figure = str(Path(__file__).parent.joinpath("example_sq_dsc_figure.png"))
fig.write_image(out_figure)

fig2 = make_curve_over_setups(
{
"t0.5": panoptic_statistic,
"bad": panoptic_statistic,
"good classifier": panoptic_statistic,
2.0: panoptic_statistic,
},
groups=None,
metric="pq",
)

out_figure = str(Path(__file__).parent.joinpath("example_multiple_statistics.png"))
fig2.savefig(out_figure)
2 changes: 2 additions & 0 deletions panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.panoptica_statistics import Panoptica_Statistic
from panoptica.panoptica_aggregator import Panoptica_Aggregator
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.processing_pair import (
Expand Down
2 changes: 1 addition & 1 deletion panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _connected_components(
else:
raise NotImplementedError(cca_backend)

return cc_arr.astype(array.dtype), n_instances
return cc_arr, n_instances


def _get_paired_crop(
Expand Down
30 changes: 30 additions & 0 deletions panoptica/configs/panoptica_evaluator_BRATS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
!Panoptica_Evaluator
decision_metric: null
decision_threshold: null
edge_case_handler: !EdgeCaseHandler
empty_list_std: !EdgeCaseResult NAN
listmetric_zeroTP_handling:
!Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult INF}
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult NAN}
instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD]
global_metrics: [!Metric DSC, !Metric RVD, !Metric IOU]
expected_input: !InputType SEMANTIC
instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null}
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
matching_threshold: 0.5}
log_times: true
segmentation_class_groups: null
verbose: false
30 changes: 30 additions & 0 deletions panoptica/configs/panoptica_evaluator_ISLES.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
!Panoptica_Evaluator
decision_metric: null
decision_threshold: null
edge_case_handler: !EdgeCaseHandler
empty_list_std: !EdgeCaseResult NAN
listmetric_zeroTP_handling:
!Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult INF}
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult NAN}
instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD]
global_metrics: [!Metric DSC, !Metric RVD, !Metric IOU]
expected_input: !InputType SEMANTIC
instance_approximator: !ConnectedComponentsInstanceApproximator {cca_backend: null}
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
matching_threshold: 0.5}
log_times: true
segmentation_class_groups: null
verbose: false
9 changes: 7 additions & 2 deletions panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,14 @@ def _approximate_instances(
if not empty_reference
else (semantic_pair._reference_arr, 0)
)

dtype = _get_smallest_fitting_uint(
max(prediction_arr.max(), reference_arr.max())
)

return UnmatchedInstancePair(
prediction_arr=prediction_arr,
reference_arr=reference_arr,
prediction_arr=prediction_arr.astype(dtype),
reference_arr=reference_arr.astype(dtype),
n_prediction_instance=n_prediction_instance,
n_reference_instance=n_reference_instance,
)
Expand Down
1 change: 1 addition & 0 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def evaluate_matched_instance(
_evaluate_instance, instance_pairs
)

# TODO if instance matcher already gives matching metric, adapt here!
for metric_dict in metric_dicts:
if decision_metric is None or (
decision_threshold is not None
Expand Down
1 change: 1 addition & 0 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _match_instances(
and not self._allow_many_to_one
):
continue # -> doesnt make speed difference
# TODO always go in here, but add the matching score to the pair (so evaluation over multiple thresholds becomes easy)
if self._matching_metric.score_beats_threshold(
matching_score, self._matching_threshold
):
Expand Down
184 changes: 184 additions & 0 deletions panoptica/panoptica_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import numpy as np
from panoptica.panoptica_statistics import Panoptica_Statistic
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
from pathlib import Path
from multiprocessing import Lock, set_start_method
import csv
import os
import atexit

set_start_method("fork")
filelock = Lock()
inevalfilelock = Lock()


#
class Panoptica_Aggregator:
# internal_list_lock = Lock()
#
"""Aggregator that calls evaluations and saves the resulting metrics per sample. Can be used to create statistics, ..."""

def __init__(
self,
panoptica_evaluator: Panoptica_Evaluator,
output_file: Path,
continue_file: bool = True,
):
"""
Args:
panoptica_evaluator (Panoptica_Evaluator): The Panoptica_Evaluator used for the pipeline.
output_file (Path | None, optional): If given, will stream the sample results into this file. If the file is existent, will append results if not already there. Defaults to None.
"""
self.__panoptica_evaluator = panoptica_evaluator
self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names
self.__output_file = None
self.__output_buffer_file = None
self.__evaluation_metrics = panoptica_evaluator.resulting_metric_keys

# uses tsv
assert (
output_file.parent.exists()
), f"Directory {str(output_file.parent)} does not exist"

out_file_path = str(output_file)
if not out_file_path.endswith(".tsv"):
out_file_path += ".tsv"

out_buffer_file: Path = Path(out_file_path).parent.joinpath(
"panoptica_aggregator_tmp.tsv"
)
self.__output_buffer_file = out_buffer_file

Path(out_file_path).parent.mkdir(parents=True, exist_ok=True)
self.__output_file = out_file_path

header = ["subject_name"] + [
f"{g}-{m}"
for g in self.__class_group_names
for m in self.__evaluation_metrics
]
header_hash = hash("+".join(header))

if not output_file.exists():
# write header
_write_content(output_file, [header])
else:
header_list = _read_first_row(output_file)
# TODO should also hash panoptica_evaluator just to make sure! and then save into header of file
assert header_hash == hash(
"+".join(header_list)
), "Hash of header not the same! You are using a different setup!"

if out_buffer_file.exists():
os.remove(out_buffer_file)
open(out_buffer_file, "a").close()

if continue_file:
with inevalfilelock:
with filelock:
id_list = _load_first_column_entries(self.__output_file)
_write_content(self.__output_buffer_file, [[s] for s in id_list])

atexit.register(self.__exist_handler)

def __exist_handler(self):
os.remove(self.__output_buffer_file)

def make_statistic(self) -> Panoptica_Statistic:
with filelock:
obj = Panoptica_Statistic.from_file(self.__output_file)
return obj

def evaluate(
self,
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
subject_name: str,
):
"""Evaluates one case
Args:
prediction_arr (np.ndarray): Prediction array
reference_arr (np.ndarray): reference array
subject_name (str | None, optional): Unique name of the sample. If none, will give it a name based on count. Defaults to None.
skip_already_existent (bool): If true, will skip subjects which were already evaluated instead of crashing. Defaults to False.
verbose (bool | None, optional): Verbose. Defaults to None.
"""
# Read tmp file to see which sample names are blocked
with inevalfilelock:
id_list = _load_first_column_entries(self.__output_buffer_file)

if subject_name in id_list:
print(
f"Subject '{subject_name}' evaluated or in process {self.__output_file}, do not add duplicates to your evaluation!",
flush=True,
)
return
_write_content(self.__output_buffer_file, [[subject_name]])

# Run Evaluation (allowed in parallel)
res = self.__panoptica_evaluator.evaluate(
prediction_arr,
reference_arr,
result_all=True,
verbose=False,
log_times=False,
)

# Add to file
self._save_one_subject(subject_name, res)

def _save_one_subject(self, subject_name, result_grouped):
with filelock:
#
content = [subject_name]
for groupname in self.__class_group_names:
result: PanopticaResult = result_grouped[groupname][0]
result_dict = result.to_dict()
del result

for e in self.__evaluation_metrics:
mvalue = result_dict[e] if e in result_dict else ""
content.append(mvalue)
_write_content(self.__output_file, [content])
print(f"Saved entry {subject_name} into {str(self.__output_file)}")


def _read_first_row(file: str):
# NOT THREAD SAFE BY ITSELF!
with open(str(file), "r", encoding="utf8", newline="") as tsvfile:
rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n")

rows = [row for row in rd]
if len(rows) == 0:
row = []
else:
row = rows[0]

return row


def _load_first_column_entries(file: str):
# NOT THREAD SAFE BY ITSELF!
with open(str(file), "r", encoding="utf8", newline="") as tsvfile:
rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n")

rows = [row for row in rd]
if len(rows) == 0:
id_list = []
else:
id_list = list([row[0] for row in rows])

n_id = len(id_list)
assert n_id == len(list(set(id_list))), "file has duplicate entries!"

return id_list


def _write_content(file: str, content: list[list[str]]):
# NOT THREAD SAFE BY ITSELF!
with open(str(file), "a", encoding="utf8", newline="") as tsvfile:
writer = csv.writer(tsvfile, delimiter="\t", lineterminator="\n")
for c in content:
writer.writerow(c)
Loading

0 comments on commit 319aa5d

Please sign in to comment.