Skip to content

Commit

Permalink
add a new direction loss for the velocities, proposed by a group of r…
Browse files Browse the repository at this point in the history
…esearchers out of Wuhan China
  • Loading branch information
lucidrains committed Nov 5, 2024
1 parent 598d08e commit dc6ba3d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,12 @@ trainer()
url = {https://api.semanticscholar.org/CorpusID:270878436}
}
```

```bibtex
@inproceedings{Yao2024FasterDiTTF,
title = {FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification},
author = {Jingfeng Yao and Wang Cheng and Wenyu Liu and Xinggang Wang},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273346237}
}
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rectified-flow-pytorch"
version = "0.1.10"
version = "0.1.11"
description = "Rectified Flow in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
21 changes: 21 additions & 0 deletions rectified_flow_pytorch/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ class MSELoss(Module):
def forward(self, pred, target, **kwargs):
return F.mse_loss(pred, target)

class MSEAndDirectionLoss(Module):
"""
Figure 7 - https://arxiv.org/abs/2410.10356
"""

def __init__(self, cosine_sim_dim: int = 1):
super().__init__()
assert cosine_sim_dim > 0, 'cannot be batch dimension'
self.cosine_sim_dim = cosine_sim_dim

def forward(self, pred, target, **kwargs):
mse_loss = F.mse_loss(pred, target)

direction_loss = (1. - F.cosine_similarity(pred, target, dim = self.cosine_sim_dim)).mean()

return mse_loss + direction_loss

# loss breakdown

LossBreakdown = namedtuple('LossBreakdown', ['total', 'main', 'data_match', 'velocity_match'])
Expand All @@ -135,6 +152,7 @@ def __init__(
predict: Literal['flow', 'noise'] = 'flow',
loss_fn: Literal[
'mse',
'mse_and_direction',
'pseudo_huber',
'pseudo_huber_with_lpips'
] | Module = 'mse',
Expand Down Expand Up @@ -179,6 +197,9 @@ def __init__(
if loss_fn == 'mse':
loss_fn = MSELoss()

elif loss_fn == 'mse_and_direction':
loss_fn = MSEAndDirectionLoss(**loss_fn_kwargs)

elif loss_fn == 'pseudo_huber':
assert predict == 'flow'

Expand Down

0 comments on commit dc6ba3d

Please sign in to comment.