Skip to content

Commit

Permalink
take first measure for straightening flow issue
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 17, 2024
1 parent 84735d9 commit a52717b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,15 @@ sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) #
}
```

```bibtex
@article{Li2024ImmiscibleDA,
title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment},
author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.12303},
url = {https://api.semanticscholar.org/CorpusID:270562607}
}
```

[*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4)
20 changes: 20 additions & 0 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from torchdiffeq import odeint

from scipy.optimize import linear_sum_assignment

from rotary_embedding_torch import (
RotaryEmbedding,
apply_rotary_emb
Expand Down Expand Up @@ -117,6 +119,15 @@ def pad_at_dim(
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)

# flow related

def noise_assignment(data, noise):
device = data.device
data, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (data, noise))
dist = torch.cdist(data, noise)
_, assign = linear_sum_assignment(dist.cpu())
return torch.from_numpy(assign).to(device)

# losses

def direction_loss(pred, target, dim = -1):
Expand Down Expand Up @@ -413,6 +424,7 @@ def __init__(
lm_loss_weight = 1.,
flow_loss_weight = 1.,
direction_loss_weight = 0.,
immiscible_flow = False, # https://arxiv.org/abs/2406.12303
odeint_kwargs: dict = dict(
atol = 1e-5,
rtol = 1e-5,
Expand Down Expand Up @@ -495,6 +507,10 @@ def __init__(

self.lm_pad_id = lm_pad_id

# flow related

self.immiscible_flow = immiscible_flow

# loss related

self.lm_loss_weight = lm_loss_weight
Expand Down Expand Up @@ -615,6 +631,10 @@ def forward(
if not return_actions_flow:
noise = torch.randn_like(actions)

if self.immiscible_flow:
assignment = noise_assignment(actions, noise)
noise = noise[assignment]

flow = actions - noise
padded_times = rearrange(times, 'b -> b 1 1')

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"einops>=0.8.0",
"jaxtyping",
"rotary-embedding-torch>=0.8.5",
'scipy',
"torch>=2.5",
'torchdiffeq',
"tqdm"
Expand Down

0 comments on commit a52717b

Please sign in to comment.