Skip to content

Commit

Permalink
docstring fix, fix labels of confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-taylor committed Jan 25, 2024
1 parent 421bbdd commit b564531
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions ISLP/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"""

from os.path import join as pjoin
import pandas as pd, numpy as np
from importlib.resources import (as_file,
files)
import pandas as pd, numpy as np
from sklearn.metrics import confusion_matrix as _confusion_matrix
from sklearn.metrics._classification import unique_labels

# data originally saved via: [sm.datasets.get_rdataset(n, 'ISLR').data.to_csv('../ISLP/data/%s.csv' % n, index=False) for n in ['Carseats', 'College', 'Credit', 'Default', 'Hitters', 'Auto', 'OJ', 'Portfolio', 'Smarket', 'Wage', 'Weekly', 'Caravan']]

Expand Down Expand Up @@ -42,7 +44,15 @@ def _make_categorical(dataset):
}
_index = {'Auto':'name'}

_datasets = sorted(list(_unordered.keys()) +
list(_ordered.keys()) +
['NCI60',
'Khan',
'Bikeshare',
'NYSE'])

def load_data(dataset):

if dataset == 'NCI60':
with as_file(files('ISLP').joinpath('data', 'NCI60data.npy')) as features:
X = np.load(features)
Expand Down Expand Up @@ -103,19 +113,46 @@ def load_data(dataset):
return df.set_index('date')
else:
return _make_categorical(dataset)
load_data.__doc__ = f"""
Load dataset from ISLP package.
from sklearn.metrics import confusion_matrix as _confusion_matrix
Choices are: {_datasets}
Parameters
----------
dataset: str
Returns
-------
data: array-like or dict
Either a `pd.DataFrame` representing the dataset or a dictionary
containing different parts of the dataset.
"""

def confusion_table(predicted_labels,
true_labels):
true_labels,
labels=None):
"""
Return a data frame version of confusion
matrix with rows given by predicted label
and columns the truth.
Parameters
----------
predicted_labels: array-like
These will form rows of confusion matrix.
true_labels: array-like
These will form columns of confusion matrix.
"""

labels = sorted(np.unique(list(true_labels) +
list(predicted_labels)))
if labels is None:
labels = unique_labels(true_labels,
predicted_labels)
C = _confusion_matrix(true_labels,
predicted_labels,
labels=labels)
Expand Down

0 comments on commit b564531

Please sign in to comment.