Skip to content

Commit

Permalink
✅🐛Fixed test_ema
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Jan 5, 2025
1 parent 8cfe21c commit 3c50b82
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions tests/test_learn/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,8 @@ def test_ema(self) -> None:
gt = p1.data
gt = decay * gt + (1.0 - decay) * p2.data
gt = decay * gt + (1.0 - decay) * p3.data
for p in [p1, p1.data]:
if not isinstance(p, nn.Parameter):
p = p.clone()
else:
p = nn.Parameter(p.data.clone())
for i, p in enumerate([p1, p1.data]):
p = p.clone()
ema = cflearn.EMA(decay, [("test", p)])
p.data = p2.data
ema()
Expand All @@ -156,7 +153,7 @@ def test_ema(self) -> None:
ema.eval()
ema.train()
ema.eval()
torch.testing.assert_close(p.data, gt.data)
torch.testing.assert_close(p.data, gt.data if i == 0 else p3.data)
ema.train()
ema.eval()
ema.train()
Expand All @@ -169,7 +166,7 @@ def test_ema(self) -> None:
ema.train()
ema.eval()
ema.eval()
torch.testing.assert_close(p.data, gt.data)
torch.testing.assert_close(p.data, gt.data if i == 0 else p3.data)
ema.train()
ema.train()
ema.eval()
Expand All @@ -178,7 +175,7 @@ def test_ema(self) -> None:
ema.train()
torch.testing.assert_close(p.data, p3.data)
with cflearn.eval_context(ema):
torch.testing.assert_close(p.data, gt.data)
torch.testing.assert_close(p.data, gt.data if i == 0 else p3.data)
torch.testing.assert_close(p.data, p3.data)
str(ema)
ema = cflearn.EMA(decay, [("test", p)], use_num_updates=True)
Expand Down

0 comments on commit 3c50b82

Please sign in to comment.