Skip to content

Commit

Permalink
initial upload
Browse files Browse the repository at this point in the history
  • Loading branch information
rnepal2 authored Sep 20, 2021
1 parent 8137a8a commit f39d52c
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 0 deletions.
117 changes: 117 additions & 0 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import math
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score
# pytorch
import torch

# Custom Trainer class
# Train and make prediction with the GNN models
class Trainer:
def __init__(self, model, optimizer, train_loader, valid_loader):
self.model = model
self.optimizer = optimizer
self.train_loader = train_loader
self.valid_loader = valid_loader

# training model
def train_one_epoch(self, epoch):
# set model on training mode
self.model.train()

t_targets = []; p_targets = []; losses = []
tqdm_iter = tqdm(self.train_loader, total=len(self.train_loader))
for i, data in enumerate(tqdm_iter):

tqdm_iter.set_description(f"Epoch {epoch}")
self.optimizer.zero_grad()
outputs, loss = self.model(data, data.edge_index, data.batch)
targets = data.y
loss.backward()
self.optimizer.step()

y_true = self.process_output(targets) # for one batch
y_proba = self.process_output(outputs.flatten()) # for one batch

auc = roc_auc_score(y_true, y_proba)
# continuous loss/auc update
tqdm_iter.set_postfix(train_loss=round(loss.item(), 2), train_auc=round(auc, 2),
valid_loss=None, valid_auc=None)

losses.append(loss.item())
t_targets.extend(list(y_true))
p_targets.extend(list(y_proba))

epoch_auc = roc_auc_score(t_targets, p_targets)
epoch_loss = sum(losses)/len(losses)
return epoch_loss, epoch_auc, tqdm_iter


def process_output(self, out):
out = out.cpu().detach().numpy()
return out


def validate_one_epoch(self, progress):

progress_tracker = progress["tracker"]
train_loss = progress["loss"]
train_auc = progress["auc"]

# model in eval model
self.model.eval()

t_targets = []; p_targets = []; losses = []
for data in self.valid_loader:

outputs, loss = self.model(data, data.edge_index, data.batch)
outputs, targets = outputs.flatten(), data.y

y_proba = self.process_output(outputs) # for one batch
y_true = self.process_output(targets) # for one batch

t_targets.extend(list(y_true))
p_targets.extend(list(y_proba))
losses.append(loss.item())

epoch_auc = roc_auc_score(t_targets, p_targets)
epoch_loss = sum(losses)/len(losses)
progress_tracker.set_postfix(train_loss=round(train_loss, 2), train_auc=round(train_auc, 2),
valid_loss=round(epoch_loss, 2), valid_auc=round(epoch_auc, 2))
progress_tracker.close()
return epoch_loss, epoch_auc

# runs the training and validation trainer for n_epochs
def run(self, n_epochs=10):

train_scores = []; train_losses = []
valid_scores = []; valid_losses = []
for e in range(1, n_epochs+1):
lt, at, progress_tracker = self.train_one_epoch(e)

train_losses.append(lt)
train_scores.append(at)

# validate this epoch
progress = {"tracker": progress_tracker, "loss": lt, "auc": at}
lv, av = self.validate_one_epoch(progress) # pass training progress tracker to validation func
valid_losses.append(lv)
valid_scores.append(av)

return (train_losses, train_scores), (valid_losses, valid_scores)


def predict(self, test_loader):
# set model on evaluation mode
self.model.eval()
predictions = []
tqdm_iter = tqdm(test_loader, total=len(test_loader))
for data in tqdm_iter:
tqdm_iter.set_description(f"Making prediction")
with torch.no_grad():
o, _ = self.model(data, data.edge_index, data.batch)
o = self.process_output(o.flatten())
predictions.extend(list(o))
tqdm_iter.set_postfix(stage="test dataloader")
tqdm_iter.close()
return np.array(predictions)
98 changes: 98 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.metrics import roc_auc_score, roc_curve, precision_score, recall_score


def optimal_cutoff(target, predicted):
""" Find the optimal probability cutoff point for classification
----------
target: true labels
predicted: positive probability predicted by the model.
i.e. model.prdict_proba(X_test)[:, 1], NOT 0/1 prediction array
Returns
-------
cut-off value
"""
fpr, tpr, threshold = roc_curve(target, predicted)
i = np.arange(len(tpr))
roc = pd.DataFrame({'tf' : pd.Series(tpr-(1-fpr), index=i), 'threshold' : pd.Series(threshold, index=i)})
roc_t = roc.iloc[(roc.tf-0).abs().argsort()[:1]]

return round(list(roc_t['threshold'])[0], 2)

def plot_confusion_matrix(y_true, y_pred):
# confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred)
data = conf_matrix.transpose()

_, ax = plt.subplots()
ax.matshow(data, cmap="Blues")
# printing exact numbers
for (i, j), z in np.ndenumerate(data):
ax.text(j, i, '{}'.format(z), ha='center', va='center')
# axis formatting
plt.xticks([])
plt.yticks([])
plt.title("True label\n 0 {} 1\n".format(" "*18), fontsize=14)
plt.ylabel("Predicted label\n 1 {} 0".format(" "*18), fontsize=14)

def draw_roc_curve(y_true, y_proba):
'''
y_true: 0/1 true labels for test set
y_proba: model.predict_proba[:, 1] or probabilities of predictions
Return:
ROC curve with appropriate labels and legend
'''
fpr, tpr, _ = roc_curve(y_true, y_proba)

_, ax = plt.subplots()

ax.plot(fpr, tpr, color='r');
ax.plot([0, 1], [0, 1], color='y', linestyle='--')
ax.fill_between(fpr, tpr, label=f"AUC: {round(roc_auc_score(y_true, y_proba), 3)}")
ax.set_aspect(0.90)
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_xlim(-0.02, 1.02);
ax.set_ylim(-0.02, 1.02);
plt.legend()
plt.show()


def summerize_results(y_true, y_pred):
'''
Takes the true labels and the predicted probabilities
and prints some performance metrics.
'''
print("\n=========================")
print(" RESULTS")
print("=========================")

print("Accuracy: ", accuracy_score(y_true, y_pred).round(2))
conf_matrix = confusion_matrix(y_true, y_pred)
sensitivity = round(conf_matrix[1, 1]/(conf_matrix[1, 1] + conf_matrix[1, 0]), 2)
specificity = round(conf_matrix[0, 0]/(conf_matrix[0, 0] + conf_matrix[0, 1]), 2)

ppv = round(conf_matrix[1, 1]/(conf_matrix[1, 1] + conf_matrix[0, 1]), 2)
npv = round(conf_matrix[0, 0]/(conf_matrix[0, 0] + conf_matrix[1, 0]), 2)

print("-------------------------")
print("sensitivity: ", sensitivity)
print("specificity: ", specificity)

print("-------------------------")

print("positive predictive value: ", ppv)
print("negative predictive value: ", npv)

print("-------------------------")
print("precision: ", precision_score(y_true, y_pred).round(2))
print("recall: ", recall_score(y_true, y_pred).round(2))


0 comments on commit f39d52c

Please sign in to comment.