diff --git a/README.md b/README.md index 944094c..bbf0a45 100644 --- a/README.md +++ b/README.md @@ -85,4 +85,13 @@ sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) # } ``` +```bibtex +@inproceedings{Sadat2024EliminatingOA, + title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models}, + author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273098845} +} +``` + [*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 47277d4..5b33e7d 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from torch import pi, nn, tensor, is_tensor from torch.nn import Module, ModuleList +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from torchdiffeq import odeint @@ -91,6 +92,9 @@ def default(v, d): # tensor helpers +def l2norm(t, dim = -1): + return F.normalize(t, dim = dim) + def softclamp(t, value): if value <= 0.: return t @@ -110,6 +114,36 @@ def inverse(out, inv_pattern = None): return packed, inverse +def pack_one_with_inverse(t, pattern): + packed, inverse = pack_with_inverse([t], pattern) + + def inverse_one(out, inv_pattern = None): + out, = inverse(out, inv_pattern) + return out + + return packed, inverse_one + +def tree_flatten_with_inverse(input): + out, tree_spec = tree_flatten(input) + + def inverse(output): + return tree_unflatten(output, tree_spec) + + return out, inverse + +def project(x, y): + x, inverse = pack_one_with_inverse(x, 'b *') + y, _ = pack_one_with_inverse(y, 'b *') + + dtype = x.dtype + x, y = x.double(), y.double() + unit = l2norm(y, dim = -1) + + parallel = (x * unit).sum(dim = -1, keepdim = True) * unit + orthogonal = x - parallel + + return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) + def pad_at_dim( t, pad: tuple[int, int], @@ -556,7 +590,11 @@ def sample( trajectory_length: int, reward_tokens = None, steps = 18, - show_pbar = True + show_pbar = True, + cond_scale = 0., + remove_parallel_component = True, + keep_parallel_frac = 0., + cache_kv = True ): batch_size = token_ids.shape[0] @@ -568,22 +606,31 @@ def sample( # ode step function cached_state_kv = None + null_cached_state_kv = None def ode_fn(timestep, denoised_actions): nonlocal cached_state_kv + nonlocal null_cached_state_kv - flow, cached_state_kv = self.forward( + flow, (new_cached_state_kv, new_null_cached_state_kv) = self.forward_with_reward_cfg( images, token_ids, joint_states, denoised_actions, times = timestep, reward_tokens = reward_tokens, - cached_state_keys_values = cached_state_kv, + cached_state_keys_values = (cached_state_kv, null_cached_state_kv), + cond_scale = cond_scale, + remove_parallel_component = remove_parallel_component, + keep_parallel_frac = keep_parallel_frac, return_actions_flow = True, return_state_keys_values = True ) + if cache_kv: + cached_state_kv = new_cached_state_kv + null_cached_state_kv = new_null_cached_state_kv + pbar.update(1) return flow @@ -608,24 +655,56 @@ def ode_fn(timestep, denoised_actions): return sampled_actions + @torch.no_grad() def forward_with_reward_cfg( self, *args, reward_tokens: Float['b d'] | None = None, - cond_scale = 1., + cached_state_keys_values = (None, None), + cond_scale = 0., + remove_parallel_component = False, + keep_parallel_frac = 0., + return_state_keys_values = True, + **kwargs ): + assert return_state_keys_values, 'cached key values must be turned on' + + with_reward_cache, without_reward_cache = cached_state_keys_values - out = self.forward( + maybe_reward_out = self.forward( *args, reward_tokens = reward_tokens, + cached_state_keys_values = with_reward_cache, + return_state_keys_values = return_state_keys_values, **kwargs ) - if not exists(reward_tokens) or cond_scale == 1.: - return out + action_flow_with_reward, with_reward_cache_kv = maybe_reward_out - raise NotImplementedError + if not exists(reward_tokens) or cond_scale == 0.: + return action_flow_with_reward, (with_reward_cache_kv, None) + + no_reward_out = self.forward( + *args, + cached_state_keys_values = without_reward_cache, + return_state_keys_values = return_state_keys_values, + **kwargs + ) + + action_flow_without_reward, without_reward_cache_kv = no_reward_out + + update = action_flow_with_reward - action_flow_without_reward + + if remove_parallel_component: + # from https://arxiv.org/abs/2410.02416 + + update_parallel, update_orthog = project(update, action_flow_with_reward) + update = update_orthog + update_parallel * keep_parallel_frac + + flow_with_reward_cfg = action_flow_with_reward + cond_scale * update + + return flow_with_reward_cfg, (with_reward_cache_kv, without_reward_cache_kv) def forward( self, diff --git a/pyproject.toml b/pyproject.toml index e58be61..0ac4772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.17" +version = "0.0.18" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }