Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 259301552
  • Loading branch information
obachem committed Jul 22, 2019
1 parent 9a6fee7 commit ad704d2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 23 deletions.
56 changes: 33 additions & 23 deletions disentanglement_lib/evaluation/metrics/factor_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import division
from __future__ import print_function
from absl import logging
from disentanglement_lib.evaluation.metrics import utils
import numpy as np
from six.moves import range
import gin.tf
Expand Down Expand Up @@ -67,9 +68,10 @@ def compute_factor_vae(ground_truth_data,
return scores_dict

logging.info("Generating training set.")
training_votes = _generate_training_batch(
ground_truth_data, representation_function, batch_size, num_train,
random_state, global_variances, active_dims)
training_votes = _generate_training_batch(ground_truth_data,
representation_function, batch_size,
num_train, random_state,
global_variances, active_dims)
classifier = np.argmax(training_votes, axis=0)
other_index = np.arange(training_votes.shape[1])

Expand All @@ -79,31 +81,33 @@ def compute_factor_vae(ground_truth_data,
logging.info("Training set accuracy: %.2g", train_accuracy)

logging.info("Generating evaluation set.")
eval_votes = _generate_training_batch(
ground_truth_data, representation_function, batch_size, num_eval,
random_state, global_variances, active_dims)
eval_votes = _generate_training_batch(ground_truth_data,
representation_function, batch_size,
num_eval, random_state,
global_variances, active_dims)

logging.info("Evaluate evaluation set accuracy.")
eval_accuracy = np.sum(
eval_votes[classifier, other_index]) * 1. / np.sum(eval_votes)
eval_accuracy = np.sum(eval_votes[classifier,
other_index]) * 1. / np.sum(eval_votes)
logging.info("Evaluation set accuracy: %.2g", eval_accuracy)
scores_dict["train_accuracy"] = train_accuracy
scores_dict["eval_accuracy"] = eval_accuracy
scores_dict["num_active_dims"] = len(active_dims)
return scores_dict


@gin.configurable(
"prune_dims",
blacklist=["variances"])
@gin.configurable("prune_dims", blacklist=["variances"])
def _prune_dims(variances, threshold=0.):
"""Mask for dimensions collapsed to the prior."""
scale_z = np.sqrt(variances)
return scale_z >= threshold


def _compute_variances(ground_truth_data, representation_function, batch_size,
random_state):
def _compute_variances(ground_truth_data,
representation_function,
batch_size,
random_state,
eval_batch_size=64):
"""Computes the variance for each dimension of the representation.
Args:
Expand All @@ -112,12 +116,18 @@ def _compute_variances(ground_truth_data, representation_function, batch_size,
outputs a representation.
batch_size: Number of points to be used to compute the variances.
random_state: Numpy random state used for randomness.
eval_batch_size: Batch size used to eval representation.
Returns:
Vector with the variance of each dimension.
"""
observations = ground_truth_data.sample_observations(batch_size, random_state)
representations = representation_function(observations)
representations = utils.obtain_representation(observations,
representation_function,
eval_batch_size)
representations = np.transpose(representations)
assert representations.shape[0] == batch_size
assert representations.shape[1] == ground_truth_data.num_factors
return np.var(representations, axis=0, ddof=1)


Expand Down Expand Up @@ -151,9 +161,8 @@ def _generate_training_sample(ground_truth_data, representation_function,
factors, random_state)
representations = representation_function(observations)
local_variances = np.var(representations, axis=0, ddof=1)
argmin = np.argmin(
local_variances[active_dims] / global_variances[active_dims]
)
argmin = np.argmin(local_variances[active_dims] /
global_variances[active_dims])
return factor_index, argmin


Expand All @@ -176,12 +185,13 @@ def _generate_training_batch(ground_truth_data, representation_function,
Returns:
(num_factors, dim_representation)-sized numpy array with votes.
"""
votes = np.zeros(
(ground_truth_data.num_factors, global_variances.shape[0]),
dtype=np.int64)
votes = np.zeros((ground_truth_data.num_factors, global_variances.shape[0]),
dtype=np.int64)
for _ in range(num_points):
factor_index, argmin = _generate_training_sample(
ground_truth_data, representation_function, batch_size, random_state,
global_variances, active_dims)
factor_index, argmin = _generate_training_sample(ground_truth_data,
representation_function,
batch_size, random_state,
global_variances,
active_dims)
votes[factor_index, argmin] += 1
return votes
25 changes: 25 additions & 0 deletions disentanglement_lib/evaluation/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ def split_train_test(observations, train_percentage):
return observations_train, observations_test


def obtain_representation(observations, representation_function, batch_size):
""""Obtain representations from observations.
Args:
observations: Observations for which we compute the representation.
representation_function: Function that takes observation as input and
outputs a representation.
batch_size: Batch size to compute the representation.
Returns:
representations: Codes (num_codes, num_points)-Numpy array.
"""
representations = None
num_points = observations.shape[0]
i = 0
while i < num_points:
num_points_iter = min(num_points - i, batch_size)
current_observations = observations[i:i + num_points_iter]
if i == 0:
representations = representation_function(current_observations)
else:
representations = np.vstack((representations,
representation_function(
current_observations)))
i += num_points_iter
return np.transpose(representations)


def discrete_mutual_info(mus, ys):
Expand Down

0 comments on commit ad704d2

Please sign in to comment.