Skip to content

Commit

Permalink
add support for mamba2
Browse files Browse the repository at this point in the history
  • Loading branch information
MzeroMiko committed Jun 13, 2024
1 parent f7331c3 commit 9dfcfc9
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions classification/models/vmamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,8 @@ def vmamba_base_s1l20(channel_first=True):


if __name__ == "__main__":
model_ref = vmamba_tiny_s1l8()

model = VSSM(
depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
patch_size=4, in_chans=3, num_classes=1000,
Expand All @@ -1753,8 +1755,34 @@ def vmamba_base_s1l20(channel_first=True):
print(parameter_count(model)[""])
print(model.flops()) # wrong
model.cuda().train()
inp = torch.randn((128,3, 224, 224))
model(inp).sum().backward()
model_ref.cuda().train()

def bench(model):
import time
inp = torch.randn((128, 3, 224, 224)).cuda()
for _ in range(30):
model(inp)
tim = time.time()
torch.cuda.synchronize()
for _ in range(30):
model(inp)
torch.cuda.synchronize()
tim1 = time.time() - tim

for _ in range(30):
model(inp).sum().backward()
tim = time.time()
torch.cuda.synchronize()
for _ in range(30):
model(inp).sum().backward()
torch.cuda.synchronize()
tim2 = time.time() - tim

return tim1 / 30, tim2 / 30

print(bench(model_ref))
print(bench(model))

breakpoint()


0 comments on commit 9dfcfc9

Please sign in to comment.