Skip to content

Commit

Permalink
Check shape and remove deprecated APIs in scheduling_ddpm_flax.py (#7703
Browse files Browse the repository at this point in the history
)

`model_output.shape` may only have rank 1.

There are warnings related to use of random keys.

```
tests/schedulers/test_scheduler_flax.py: 13 warnings
  /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)

tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas
  /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    u = uniform(key, shape, dtype, lo, hi)  # type: ignore[arg-type]
```
  • Loading branch information
ppham27 authored May 8, 2024
1 parent d50baf0 commit f29b934
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ def step(
t = timestep

if key is None:
key = jax.random.PRNGKey(0)
key = jax.random.key(0)

if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
if (
len(model_output.shape) > 1
and model_output.shape[1] == sample.shape[1] * 2
and self.config.variance_type in ["learned", "learned_range"]
):
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
Expand Down Expand Up @@ -264,7 +268,7 @@ def step(

# 6. Add noise
def random_variance():
split_key = jax.random.split(key, num=1)
split_key = jax.random.split(key, num=1)[0]
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise

Expand Down

0 comments on commit f29b934

Please sign in to comment.