Skip to content

Commit

Permalink
complete reward token classifier free guidance combined with a recent…
Browse files Browse the repository at this point in the history
… disney research improvement
  • Loading branch information
lucidrains committed Nov 18, 2024
1 parent ecdfb38 commit d70530c
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 9 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,13 @@ sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) #
}
```

```bibtex
@inproceedings{Sadat2024EliminatingOA,
title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273098845}
}
```

[*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4)
95 changes: 87 additions & 8 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn.functional as F
from torch import pi, nn, tensor, is_tensor
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten

from torchdiffeq import odeint

Expand Down Expand Up @@ -91,6 +92,9 @@ def default(v, d):

# tensor helpers

def l2norm(t, dim = -1):
return F.normalize(t, dim = dim)

def softclamp(t, value):
if value <= 0.:
return t
Expand All @@ -110,6 +114,36 @@ def inverse(out, inv_pattern = None):

return packed, inverse

def pack_one_with_inverse(t, pattern):
packed, inverse = pack_with_inverse([t], pattern)

def inverse_one(out, inv_pattern = None):
out, = inverse(out, inv_pattern)
return out

return packed, inverse_one

def tree_flatten_with_inverse(input):
out, tree_spec = tree_flatten(input)

def inverse(output):
return tree_unflatten(output, tree_spec)

return out, inverse

def project(x, y):
x, inverse = pack_one_with_inverse(x, 'b *')
y, _ = pack_one_with_inverse(y, 'b *')

dtype = x.dtype
x, y = x.double(), y.double()
unit = l2norm(y, dim = -1)

parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = x - parallel

return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)

def pad_at_dim(
t,
pad: tuple[int, int],
Expand Down Expand Up @@ -556,7 +590,11 @@ def sample(
trajectory_length: int,
reward_tokens = None,
steps = 18,
show_pbar = True
show_pbar = True,
cond_scale = 0.,
remove_parallel_component = True,
keep_parallel_frac = 0.,
cache_kv = True
):
batch_size = token_ids.shape[0]

Expand All @@ -568,22 +606,31 @@ def sample(
# ode step function

cached_state_kv = None
null_cached_state_kv = None

def ode_fn(timestep, denoised_actions):
nonlocal cached_state_kv
nonlocal null_cached_state_kv

flow, cached_state_kv = self.forward(
flow, (new_cached_state_kv, new_null_cached_state_kv) = self.forward_with_reward_cfg(
images,
token_ids,
joint_states,
denoised_actions,
times = timestep,
reward_tokens = reward_tokens,
cached_state_keys_values = cached_state_kv,
cached_state_keys_values = (cached_state_kv, null_cached_state_kv),
cond_scale = cond_scale,
remove_parallel_component = remove_parallel_component,
keep_parallel_frac = keep_parallel_frac,
return_actions_flow = True,
return_state_keys_values = True
)

if cache_kv:
cached_state_kv = new_cached_state_kv
null_cached_state_kv = new_null_cached_state_kv

pbar.update(1)

return flow
Expand All @@ -608,24 +655,56 @@ def ode_fn(timestep, denoised_actions):

return sampled_actions

@torch.no_grad()
def forward_with_reward_cfg(
self,
*args,
reward_tokens: Float['b d'] | None = None,
cond_scale = 1.,
cached_state_keys_values = (None, None),
cond_scale = 0.,
remove_parallel_component = False,
keep_parallel_frac = 0.,
return_state_keys_values = True,

**kwargs
):
assert return_state_keys_values, 'cached key values must be turned on'

with_reward_cache, without_reward_cache = cached_state_keys_values

out = self.forward(
maybe_reward_out = self.forward(
*args,
reward_tokens = reward_tokens,
cached_state_keys_values = with_reward_cache,
return_state_keys_values = return_state_keys_values,
**kwargs
)

if not exists(reward_tokens) or cond_scale == 1.:
return out
action_flow_with_reward, with_reward_cache_kv = maybe_reward_out

raise NotImplementedError
if not exists(reward_tokens) or cond_scale == 0.:
return action_flow_with_reward, (with_reward_cache_kv, None)

no_reward_out = self.forward(
*args,
cached_state_keys_values = without_reward_cache,
return_state_keys_values = return_state_keys_values,
**kwargs
)

action_flow_without_reward, without_reward_cache_kv = no_reward_out

update = action_flow_with_reward - action_flow_without_reward

if remove_parallel_component:
# from https://arxiv.org/abs/2410.02416

update_parallel, update_orthog = project(update, action_flow_with_reward)
update = update_orthog + update_parallel * keep_parallel_frac

flow_with_reward_cfg = action_flow_with_reward + cond_scale * update

return flow_with_reward_cfg, (with_reward_cache_kv, without_reward_cache_kv)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.17"
version = "0.0.18"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit d70530c

Please sign in to comment.