Skip to content

Commit

Permalink
Restructure of VAE models
Browse files Browse the repository at this point in the history
  • Loading branch information
crstngc committed Jan 26, 2025
1 parent 3105bdf commit e2973d1
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 169 deletions.
36 changes: 34 additions & 2 deletions scico/flax/autoencoders/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def initialize(key: KeyArray, model: ModuleDef, ishape: Shape) -> Tuple[PyTree,
that no batch dimension is included.
Returns:
Initial model parameters (including `batch_stats`).
Initial model parameters (including `batch_stats` if applicable).
"""
input_shape = (1, ishape[0], ishape[1], model.channels)
key, model_key = jax.random.split(key)
Expand All @@ -49,6 +49,35 @@ def init(*args):
return variables["params"]


def initialize_class_conditional(
key: KeyArray, model: ModuleDef, ishape: Shape, num_classes: int
) -> Tuple[PyTree, ...]:
"""Initialize Flax model conditioned on class labels.
Args:
key: A PRNGKey used as the random key.
model: Flax model to train.
ishape: Shape of signal (image) to process by `model`. Make sure
that no batch dimension is included.
num_classes: Number of classes in the dataset.
Returns:
Initial model parameters (including `batch_stats` if applicable).
"""
input_shape = (1, ishape[0], ishape[1], model.channels)
key, model_key = jax.random.split(key)
key, call_key = jax.random.split(key)

@jax.jit
def init(*args):
return model.init(*args)

fakex = jnp.ones(input_shape, model.dtype) # Expected input shape
fakec = jnp.zeros((1, num_classes)) # Expected class specification
variables = init({"params": model_key}, fakex, call_key, fakec)
return variables["params"]


def create_vae_train_state(
key: KeyArray,
config: ConfigDict,
Expand Down Expand Up @@ -77,7 +106,10 @@ def create_vae_train_state(
"""
batch_stats = None
if variables0 is None:
aux = initialize(key, model, ishape)
if model.cond_width > 0: # Class conditional model constructed
aux = initialize_class_conditional(key, model, ishape, config["num_classes"])
else:
aux = initialize(key, model, ishape)
if isinstance(aux, tuple):
params, batch_stats = aux
else:
Expand Down
59 changes: 51 additions & 8 deletions scico/flax/autoencoders/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def loss_fn(params: PyTree, x: ArrayLike, key: ArrayLike):
def train_step_vae_class_conditional(
state: TrainState,
batch: ArrayLike,
batch_c: ArrayLike,
num_classes: int,
key: ArrayLike,
kl_weight: float,
Expand All @@ -143,10 +142,8 @@ def train_step_vae_class_conditional(
Args:
state: Flax train state which includes the model apply function,
the model parameters and an Optax optimizer.
batch: Sharded and batched training data. Only output data is
passed.
batch_c: Sharded and batched training conditional data associated
to class of samples.
batch: Sharded and batched training data. Data as well as class
labels are passed.
num_classes: Number of classes in dataset.
key: Key for random generation.
kl_weight: Weight of the KL divergence term in the total training loss.
Expand All @@ -168,7 +165,6 @@ def train_step_vae_class_conditional(

def loss_fn(params: PyTree, x: ArrayLike, c: ArrayLike, key: ArrayLike):
"""Loss function used for training."""
reduce_dims = list(range(1, len(x.shape)))
c = jax.nn.one_hot(c, num_classes).squeeze() # one hot encode the class index
output, mean, logvar = state.apply_fn(
{
Expand All @@ -178,7 +174,7 @@ def loss_fn(params: PyTree, x: ArrayLike, c: ArrayLike, key: ArrayLike):
key,
c,
)

reduce_dims = list(range(1, len(x.shape)))
mse_loss = criterion(output, x).sum(axis=reduce_dims).mean()
# KL loss term to keep encoder output close to standard
# normal distribution.
Expand All @@ -195,7 +191,7 @@ def loss_fn(params: PyTree, x: ArrayLike, c: ArrayLike, key: ArrayLike):
lr = learning_rate_fn(step)

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
aux, grads = grad_fn(state.params, batch_["image"], batch_c, step_key)
aux, grads = grad_fn(state.params, batch["image"], batch["label"], step_key)
losses = aux[1]
# Re-use same axis_name as in call to pmap
grads = lax.pmean(grads, axis_name="batch")
Expand Down Expand Up @@ -230,6 +226,8 @@ def eval_step_vae(
and the model parameters.
batch: Sharded and batched training data.
criterion: Loss function.
key: Key for random generation.
kl_weight: Weight of the KL divergence term in the total training loss.
Returns:
Current diagnostic statistics.
Expand All @@ -250,6 +248,51 @@ def eval_step_vae(
return metrics


def eval_step_vae_class_conditional(
state: TrainState,
batch: ArrayLike,
num_classes: int,
criterion: Callable,
key: ArrayLike,
kl_weight: float,
**kwargs,
) -> VAEMetricsDict:
"""Evaluate current model state using class
conditional information.
Assumes sharded batched data. This function is intended to be used
via :class:`~scico.flax.BasicFlaxTrainer` or
:meth:`~scico.flax.only_evaluate`, not directly.
Args:
state: Flax train state which includes the model apply function
and the model parameters.
batch: Sharded and batched training data.
num_classes: Number of classes in dataset.
criterion: Loss function.
key: Key for random generation.
kl_weight: Weight of the KL divergence term in the total training loss.
Returns:
Current diagnostic statistics.
"""
variables = {
"params": state.params,
}
key, step_key = jax.random.split(key)
c = jax.nn.one_hot(batch["label"], num_classes).squeeze() # one hot encode the class index
output, mean, logvar = state.apply_fn(variables, batch["image"], step_key, c)
reduce_dims = list(range(1, len(batch["image"].shape)))
mse_loss = criterion(output, batch["image"]).sum(axis=reduce_dims).mean()
# KL loss term to keep encoder output close to standard
# normal distribution.
reduce_dims = list(range(1, len(mean.shape)))
kl_loss = jnp.mean(-0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar), axis=reduce_dims))
loss = mse_loss + kl_weight * kl_loss
metrics: VAEMetricsDict = {"loss": loss, "mse": mse_loss, "kl": kl_loss}
return metrics


def build_sample_fn(model: Callable, params: PyTree):
"""Function to generate samples from model.
Expand Down
Loading

0 comments on commit e2973d1

Please sign in to comment.