Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.5.0 dev #44

Merged
merged 7 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/data/create_nlst_metadata_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def make_metadata_dict(dataframe, pid, timepoint, series_id, use_timepoint = Fal
'slice_number': [slicenumber],
'pixel_spacing': pixel_spacing,
'slice_thickness': slice_thickness,
'img_position': img_posn,
'img_position': [img_posn],
'series_data': make_metadata_dict(image_data, pid, timepoint, series_id, use_timepoint_and_studyinstance = True)
}

Expand All @@ -135,6 +135,7 @@ def make_metadata_dict(dataframe, pid, timepoint, series_id, use_timepoint = Fal
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['paths'].append(path)
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['slice_location'].append(slicelocation)
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['slice_number'].append(slicenumber)
json_dataset[pt_idx]['accessions'][exam_idx]['image_series'][series_id]['img_position'].append(img_posn)
else:
exam_dict['image_series'] = {series_id: img_series_dict}
json_dataset[pt_idx]['accessions'].append(exam_dict)
Expand Down
14 changes: 5 additions & 9 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch
import numpy as np
import pickle

from sybil.serie import Serie
from sybil.models.sybil import SybilNet
from sybil.models.calibrator import SimpleClassifierGroup
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
from sybil.utils.metrics import get_survival_metrics
Expand Down Expand Up @@ -67,7 +67,7 @@
},
}

CHECKPOINT_URL = "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.0.3/sybil_checkpoints.zip"
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://www.dropbox.com/scl/fi/45rtadfdci0bj8dbpotmr/sybil_checkpoints_v1.5.0.zip?rlkey=n8n7pvhb89pjoxgvm90mtbtuk&dl=1")


class Prediction(NamedTuple):
Expand All @@ -91,7 +91,7 @@ def download_sybil(name, cache) -> Tuple[List[str], str]:
# Download models
model_files = NAME_TO_FILE[name]
checkpoints = model_files["checkpoint"]
download_calib_path = os.path.join(cache, f"{name}.p")
download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json")
have_all_files = os.path.exists(download_calib_path)

download_model_paths = []
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(
self.to(self.device)

if calibrator_path is not None:
self.calibrator = pickle.load(open(calibrator_path, "rb"))
self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path)
else:
self.calibrator = None

Expand Down Expand Up @@ -227,8 +227,6 @@ def _calibrate(self, scores: np.ndarray) -> np.ndarray:

Parameters
----------
calibrator: Optional[dict]
Dictionary of sklearn.calibration.CalibratedClassifierCV for each year, otherwise None.
scores: np.ndarray
risk scores as numpy array

Expand All @@ -242,9 +240,7 @@ def _calibrate(self, scores: np.ndarray) -> np.ndarray:
calibrated_scores = []
for YEAR in range(scores.shape[1]):
probs = scores[:, YEAR].reshape(-1, 1)
probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[
:, 1
]
probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1]
calibrated_scores.append(probs)

return np.stack(calibrated_scores, axis=1)
Expand Down
168 changes: 168 additions & 0 deletions sybil/models/calibrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import json
import os
from typing import List

import numpy as np

"""
Calibrator for Sybil prediction models.

We calibrate probabilities using isotonic regression.
Previously this was done with scikit-learn, here we use a custom implementation to avoid versioning issues.
"""


class SimpleClassifierGroup:
"""
A class to represent a calibrator for prediction models.
Behavior and coefficients are taken from the sklearn.calibration.CalibratedClassifierCV class.
Make a custom class to avoid sklearn versioning issues.
"""

def __init__(self, calibrators: List["SimpleIsotonicRegressor"]):
self.calibrators = calibrators

def predict_proba(self, X, expand=False):
"""
Predict class probabilities for X.

Parameters
----------
X : array-like of shape (n_probabilities,)
The input probabilities to recalibrate.
expand : bool, default=False
Whether to return the probabilities for each class separately.
This is intended for binary classification which can be done in 1D,
expand=True will return a 2D array with shape (n_probabilities, 2).

Returns
-------
proba : ndarray of shape (n_samples, n_classes)
The class probabilities of the input samples. Classes are ordered by
lexicographic order.
"""
proba = np.array([calibrator.transform(X) for calibrator in self.calibrators])
pos_prob = np.mean(proba, axis=0)
if expand and len(self.calibrators) == 1:
return np.array([1.-pos_prob, pos_prob])
else:
return pos_prob

def to_json(self):
return [calibrator.to_json() for calibrator in self.calibrators]

@classmethod
def from_json(cls, json_list):
return cls([SimpleIsotonicRegressor.from_json(json_dict) for json_dict in json_list])

@classmethod
def from_json_grouped(cls, json_path):
"""
We store calibrators in a diction of {year (str): [calibrators]}.
This is a convenience method to load that dictionary from a file path.
"""
json_dict = json.load(open(json_path, "r"))
output_dict = {key: cls.from_json(json_list) for key, json_list in json_dict.items()}
return output_dict


class SimpleIsotonicRegressor:
def __init__(self, coef, intercept, x0, y0, x_min=-np.inf, x_max=np.inf):
self.coef = coef
self.intercept = intercept
self.x0 = x0
self.y0 = y0
self.x_min = x_min
self.x_max = x_max

def transform(self, X):
T = X
T = T @ self.coef + self.intercept
T = np.clip(T, self.x_min, self.x_max)
return np.interp(T, self.x0, self.y0)

@classmethod
def from_classifier(cls, classifer: "_CalibratedClassifier"):
assert len(classifer.calibrators) == 1, "Only one calibrator per classifier is supported."
calibrator = classifer.calibrators[0]
return cls(classifer.base_estimator.coef_, classifer.base_estimator.intercept_,
calibrator.f_.x, calibrator.f_.y, calibrator.X_min_, calibrator.X_max_)

def to_json(self):
return {
"coef": self.coef.tolist(),
"intercept": self.intercept.tolist(),
"x0": self.x0.tolist(),
"y0": self.y0.tolist(),
"x_min": self.x_min,
"x_max": self.x_max
}

@classmethod
def from_json(cls, json_dict):
return cls(
np.array(json_dict["coef"]),
np.array(json_dict["intercept"]),
np.array(json_dict["x0"]),
np.array(json_dict["y0"]),
json_dict["x_min"],
json_dict["x_max"]
)

def __repr__(self):
return f"SimpleIsotonicRegressor(x={self.x0}, y={self.y0})"


def export_calibrator(input_path, output_path):
import pickle
import sklearn
sk_cal_dict = pickle.load(open(input_path, "rb"))
simple_cal_dict = dict()
for key, cal in sk_cal_dict.items():
calibrators = [SimpleIsotonicRegressor.from_classifier(classifier) for classifier in cal.calibrated_classifiers_]
simple_cal_dict[key] = SimpleClassifierGroup(calibrators).to_json()

json.dump(simple_cal_dict, open(output_path, "w"), indent=2)


def export_by_name(base_dir, model_name, overwrite=False):
sk_input_path = os.path.expanduser(f"{base_dir}/{model_name}.p")
simple_output_path = os.path.expanduser(f"{base_dir}/{model_name}_simple_calibrator.json")

version = "1.4.0"
scores_output_path = f"{base_dir}/{model_name}_v{version}_calibrations.json"

if overwrite or not os.path.exists(simple_output_path):
run_test_calibrations(sk_input_path, scores_output_path)

if overwrite or not os.path.exists(simple_output_path):
export_calibrator(sk_input_path, simple_output_path)


def export_all_default_calibrators(base_dir="~/.sybil", overwrite=False):
base_dir = os.path.expanduser(base_dir)
model_names = ["sybil_1", "sybil_2", "sybil_3", "sybil_4", "sybil_5", "sybil_ensemble"]
for model_name in model_names:
export_by_name(base_dir, model_name, overwrite=overwrite)


def run_test_calibrations(sk_input_path, scores_output_path, overwrite=False):
"""
For regression testing. Output calibrated probabilities for a range of input probabilities.
"""
import pickle
sk_cal_dict = pickle.load(open(sk_input_path, "rb"))

test_probs = np.arange(0, 1, 0.001).reshape(-1, 1)

output_dict = {"x": test_probs.flatten().tolist()}
for key, model in sk_cal_dict.items():
output_dict[key] = model.predict_proba(test_probs)[:, -1].flatten().tolist()

if overwrite or not os.path.exists(scores_output_path):
with open(scores_output_path, "w") as f:
json.dump(output_dict, f, indent=2)


if __name__ == "__main__":
export_all_default_calibrators(overwrite=False)
13 changes: 13 additions & 0 deletions sybil/models/sybil.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import torchvision
from sybil.models.cumulative_probability_layer import Cumulative_Probability_Layer
Expand Down Expand Up @@ -29,6 +30,7 @@ def forward(self, x, batch=None):
pool_output = self.aggregate_and_classify(x)
output["activ"] = x
output.update(pool_output)
output["prob"] = pool_output["logit"].sigmoid()

return output

Expand All @@ -41,6 +43,17 @@ def aggregate_and_classify(self, x):

return pool_output

@staticmethod
def load(path):
checkpoint = torch.load(path, map_location="cpu")
args = checkpoint["args"]
model = SybilNet(args)

# Remove 'model' from param names
state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
model.load_state_dict(state_dict) # type: ignore
return model


class RiskFactorPredictor(SybilNet):
def __init__(self, args):
Expand Down
3 changes: 2 additions & 1 deletion sybil/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _get_parser():
help="Generate images with attention overlap. Sets --return-attentions (if not already set).",
)


parser.add_argument(
"--file-type",
default="auto",
Expand Down Expand Up @@ -90,6 +89,8 @@ def predict(
):
logger = sybil.utils.logging_utils.get_logger()

return_attentions |= write_attention_images

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
input_files = [x for x in input_files if os.path.isfile(x)]
Expand Down
2 changes: 2 additions & 0 deletions sybil/serie.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import List, Optional, NamedTuple, Literal
from argparse import Namespace

Expand Down Expand Up @@ -137,6 +138,7 @@ def get_raw_images(self) -> List[np.ndarray]:
images = [i["input"] for i in input_dicts]
return images

@functools.lru_cache
def get_volume(self) -> torch.Tensor:
"""
Load loaded 3D CT volume
Expand Down
Loading
Loading