v0.9.96
New Features
Thanks to @mlopezantequera for adding the following features!
Testers: allow any combination of query and reference sets (#250)
To evaluate different combinations of query and reference sets, use the splits_to_eval
argument for tester.test()
.
For example, let's say your dataset_dict
has two keys: "dataset_a"
and "train"
.
- The default
splits_to_eval = None
is equivalent to:
splits_to_eval = [('dataset_a', ['dataset_a']), ('train', ['train'])]
dataset_a
as the query, andtrain
as the reference:
splits_to_eval = [('dataset_a', ['train'])]
dataset_a
as the query, anddataset_a
+train
as the reference:
splits_to_eval = [('dataset_a', ['dataset_a', 'train'])]
Then pass splits_to_eval
to tester.test
:
tester.test(dataset_dict, epoch, model, splits_to_eval = splits_to_eval)
Note that this new feature makes the old reference_set
init argument obsolete, so reference_set
has been removed.
AccuracyCalculator: allow arbitrary label comparion functions (#254)
AccuracyCalculator now has an optional init argument, label_comparison_fn
, which is a function that compares two numpy arrays of labels and returns a boolean array. The default is numpy.equal
. If a custom function is used, then you must exclude clustering based metrics ("NMI" and "AMI"). The following is an example of a custom function for two-dimensional labels. It returns True
if the 0th column matches, and the 1st column does not match:
def example_label_comparison_fn(x, y):
return (x[:, 0] == y[:, 0]) & (x[:, 1] != y[:, 1])
AccuracyCalculator(exclude=("NMI", "AMI"),
label_comparison_fn=example_label_comparison_fn)
Other Changes
- BaseTrainer and BaseTester now take in an optional
dtype
argument. This is the type that the dataset output will be converted to, e.g.torch.float16
. If set to the default value ofNone
, then no type casting will be done. - Removed
self.dim_reduced_embeddings
from BaseTester and the associated code inHookContainer
, due to lack of use. tester.test()
now returnsall_accuracies
, whereas before, it returned nothing and you'd have to accessall_accuracies
either through theend_of_testing_hook
or by accessingtester.all_accuracies
.tester.embeddings_and_labels
is deleted at the end oftester.test()
to free up memory.