Skip to content

v0.9.96

Compare
Choose a tag to compare
@KevinMusgrave KevinMusgrave released this 12 Jan 14:49
· 710 commits to master since this release
62d6ad9

New Features

Thanks to @mlopezantequera for adding the following features!

Testers: allow any combination of query and reference sets (#250)

To evaluate different combinations of query and reference sets, use the splits_to_eval argument for tester.test().

For example, let's say your dataset_dict has two keys: "dataset_a" and "train".

  • The default splits_to_eval = None is equivalent to:
splits_to_eval = [('dataset_a', ['dataset_a']), ('train', ['train'])]
  • dataset_a as the query, and train as the reference:
splits_to_eval = [('dataset_a', ['train'])]
  • dataset_a as the query, and dataset_a + train as the reference:
splits_to_eval = [('dataset_a', ['dataset_a', 'train'])]

Then pass splits_to_eval to tester.test:

tester.test(dataset_dict, epoch, model, splits_to_eval = splits_to_eval)

Note that this new feature makes the old reference_set init argument obsolete, so reference_set has been removed.

AccuracyCalculator: allow arbitrary label comparion functions (#254)

AccuracyCalculator now has an optional init argument, label_comparison_fn, which is a function that compares two numpy arrays of labels and returns a boolean array. The default is numpy.equal. If a custom function is used, then you must exclude clustering based metrics ("NMI" and "AMI"). The following is an example of a custom function for two-dimensional labels. It returns True if the 0th column matches, and the 1st column does not match:

def example_label_comparison_fn(x, y):
    return (x[:, 0] == y[:, 0]) & (x[:, 1] != y[:, 1])

AccuracyCalculator(exclude=("NMI", "AMI"), 
                    label_comparison_fn=example_label_comparison_fn)

Other Changes

  • BaseTrainer and BaseTester now take in an optional dtype argument. This is the type that the dataset output will be converted to, e.g. torch.float16. If set to the default value of None, then no type casting will be done.
  • Removed self.dim_reduced_embeddings from BaseTester and the associated code in HookContainer, due to lack of use.
  • tester.test() now returns all_accuracies, whereas before, it returned nothing and you'd have to access all_accuracies either through the end_of_testing_hook or by accessing tester.all_accuracies.
  • tester.embeddings_and_labels is deleted at the end of tester.test() to free up memory.