From ad704d243f3e5f6dea7e148abed85d9653137898 Mon Sep 17 00:00:00 2001 From: Olivier Bachem Date: Mon, 22 Jul 2019 13:40:32 +0200 Subject: [PATCH] internal change PiperOrigin-RevId: 259301552 --- .../evaluation/metrics/factor_vae.py | 56 +++++++++++-------- .../evaluation/metrics/utils.py | 25 +++++++++ 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/disentanglement_lib/evaluation/metrics/factor_vae.py b/disentanglement_lib/evaluation/metrics/factor_vae.py index 2a8f274b..d920b4bf 100644 --- a/disentanglement_lib/evaluation/metrics/factor_vae.py +++ b/disentanglement_lib/evaluation/metrics/factor_vae.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function from absl import logging +from disentanglement_lib.evaluation.metrics import utils import numpy as np from six.moves import range import gin.tf @@ -67,9 +68,10 @@ def compute_factor_vae(ground_truth_data, return scores_dict logging.info("Generating training set.") - training_votes = _generate_training_batch( - ground_truth_data, representation_function, batch_size, num_train, - random_state, global_variances, active_dims) + training_votes = _generate_training_batch(ground_truth_data, + representation_function, batch_size, + num_train, random_state, + global_variances, active_dims) classifier = np.argmax(training_votes, axis=0) other_index = np.arange(training_votes.shape[1]) @@ -79,13 +81,14 @@ def compute_factor_vae(ground_truth_data, logging.info("Training set accuracy: %.2g", train_accuracy) logging.info("Generating evaluation set.") - eval_votes = _generate_training_batch( - ground_truth_data, representation_function, batch_size, num_eval, - random_state, global_variances, active_dims) + eval_votes = _generate_training_batch(ground_truth_data, + representation_function, batch_size, + num_eval, random_state, + global_variances, active_dims) logging.info("Evaluate evaluation set accuracy.") - eval_accuracy = np.sum( - eval_votes[classifier, other_index]) * 1. / np.sum(eval_votes) + eval_accuracy = np.sum(eval_votes[classifier, + other_index]) * 1. / np.sum(eval_votes) logging.info("Evaluation set accuracy: %.2g", eval_accuracy) scores_dict["train_accuracy"] = train_accuracy scores_dict["eval_accuracy"] = eval_accuracy @@ -93,17 +96,18 @@ def compute_factor_vae(ground_truth_data, return scores_dict -@gin.configurable( - "prune_dims", - blacklist=["variances"]) +@gin.configurable("prune_dims", blacklist=["variances"]) def _prune_dims(variances, threshold=0.): """Mask for dimensions collapsed to the prior.""" scale_z = np.sqrt(variances) return scale_z >= threshold -def _compute_variances(ground_truth_data, representation_function, batch_size, - random_state): +def _compute_variances(ground_truth_data, + representation_function, + batch_size, + random_state, + eval_batch_size=64): """Computes the variance for each dimension of the representation. Args: @@ -112,12 +116,18 @@ def _compute_variances(ground_truth_data, representation_function, batch_size, outputs a representation. batch_size: Number of points to be used to compute the variances. random_state: Numpy random state used for randomness. + eval_batch_size: Batch size used to eval representation. Returns: Vector with the variance of each dimension. """ observations = ground_truth_data.sample_observations(batch_size, random_state) - representations = representation_function(observations) + representations = utils.obtain_representation(observations, + representation_function, + eval_batch_size) + representations = np.transpose(representations) + assert representations.shape[0] == batch_size + assert representations.shape[1] == ground_truth_data.num_factors return np.var(representations, axis=0, ddof=1) @@ -151,9 +161,8 @@ def _generate_training_sample(ground_truth_data, representation_function, factors, random_state) representations = representation_function(observations) local_variances = np.var(representations, axis=0, ddof=1) - argmin = np.argmin( - local_variances[active_dims] / global_variances[active_dims] - ) + argmin = np.argmin(local_variances[active_dims] / + global_variances[active_dims]) return factor_index, argmin @@ -176,12 +185,13 @@ def _generate_training_batch(ground_truth_data, representation_function, Returns: (num_factors, dim_representation)-sized numpy array with votes. """ - votes = np.zeros( - (ground_truth_data.num_factors, global_variances.shape[0]), - dtype=np.int64) + votes = np.zeros((ground_truth_data.num_factors, global_variances.shape[0]), + dtype=np.int64) for _ in range(num_points): - factor_index, argmin = _generate_training_sample( - ground_truth_data, representation_function, batch_size, random_state, - global_variances, active_dims) + factor_index, argmin = _generate_training_sample(ground_truth_data, + representation_function, + batch_size, random_state, + global_variances, + active_dims) votes[factor_index, argmin] += 1 return votes diff --git a/disentanglement_lib/evaluation/metrics/utils.py b/disentanglement_lib/evaluation/metrics/utils.py index 406acf1a..d2e12c2c 100644 --- a/disentanglement_lib/evaluation/metrics/utils.py +++ b/disentanglement_lib/evaluation/metrics/utils.py @@ -85,6 +85,31 @@ def split_train_test(observations, train_percentage): return observations_train, observations_test +def obtain_representation(observations, representation_function, batch_size): + """"Obtain representations from observations. + + Args: + observations: Observations for which we compute the representation. + representation_function: Function that takes observation as input and + outputs a representation. + batch_size: Batch size to compute the representation. + Returns: + representations: Codes (num_codes, num_points)-Numpy array. + """ + representations = None + num_points = observations.shape[0] + i = 0 + while i < num_points: + num_points_iter = min(num_points - i, batch_size) + current_observations = observations[i:i + num_points_iter] + if i == 0: + representations = representation_function(current_observations) + else: + representations = np.vstack((representations, + representation_function( + current_observations))) + i += num_points_iter + return np.transpose(representations) def discrete_mutual_info(mus, ys):