diff --git a/README.md b/README.md index f0293b1..e53dddc 100644 --- a/README.md +++ b/README.md @@ -55,3 +55,12 @@ sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) # url = {https://api.semanticscholar.org/CorpusID:273532030} } ``` + +```bibtex +@inproceedings{Yao2024FasterDiTTF, + title = {FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification}, + author = {Jingfeng Yao and Wang Cheng and Wenyu Liu and Xinggang Wang}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273346237} +} +``` diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 612dcbd..ef07b8f 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -64,6 +64,11 @@ def softclamp(t, value): return (t / value).tanh() * value +# losses + +def direction_loss(pred, target, dim = -1): + return 0.5 * (1. - F.cosine_similarity(pred, target, dim = dim)) + # attention class Attention(Module): @@ -265,6 +270,7 @@ def __init__( ff_kwargs: dict = dict(), lm_loss_weight = 1., flow_loss_weight = 1., + direction_loss_weight = 0., odeint_kwargs: dict = dict( atol = 1e-5, rtol = 1e-5, @@ -340,10 +346,15 @@ def __init__( self.lm_loss_weight = lm_loss_weight self.flow_loss_weight = flow_loss_weight + self.has_direction_loss = direction_loss_weight > 0. + self.direction_loss_weight = direction_loss_weight + # sampling related self.odeint_fn = partial(odeint, **odeint_kwargs) + self.register_buffer('zero', torch.tensor(0.), persistent = False) + @property def device(self): return next(self.parameters()).device @@ -541,6 +552,13 @@ def forward( flow_loss = F.mse_loss(flow, pred_actions_flow) + # maybe direction loss + + dir_loss = self.zero + + if self.has_direction_loss: + dir_loss = direction_loss(flow, pred_actions_flow) + # language cross entropy loss language_logits = self.state_to_logits(tokens) @@ -550,14 +568,19 @@ def forward( labels ) + # loss breakdonw + + loss_breakdown = (language_loss, flow_loss, dir_loss) + # total loss and return breakdown total_loss = ( language_loss * self.lm_loss_weight + - flow_loss * self.flow_loss_weight + flow_loss * self.flow_loss_weight + + dir_loss * self.direction_loss_weight ) - return total_loss, (language_loss, flow_loss) + return total_loss, loss_breakdown # fun diff --git a/pyproject.toml b/pyproject.toml index 4261e47..4cb54c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.1" +version = "0.0.2" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }