Skip to content

Commit

Permalink
Merge pull request #78 from shinning0821/main
Browse files Browse the repository at this point in the history
The update of MobileSAM and other code optimizations.
  • Loading branch information
shinning0821 authored Jan 14, 2024
2 parents 882249a + d5279ad commit 80c325b
Show file tree
Hide file tree
Showing 300 changed files with 37,432 additions and 145 deletions.
Binary file modified __pycache__/cfg.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/function.cpython-37.pyc
Binary file not shown.
Binary file modified __pycache__/utils.cpython-37.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-net', type=str, default='sam', help='net type')
parser.add_argument('-baseline', type=str, default='unet', help='baseline net type')
parser.add_argument('-encoder', type=str, default='default', help='encoder type')
parser.add_argument('-seg_net', type=str, default='transunet', help='net type')
parser.add_argument('-mod', type=str, default='sam_adpt', help='mod type:seg,cls,val_ad')
parser.add_argument('-exp_name', default='msa_test_isic', type=str, help='net type')
Expand Down
15 changes: 9 additions & 6 deletions function.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,20 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)

'''Train'''
if args.net == 'sam' or args.net == 'efficient_sam':
if args.mod == 'sam_adpt':
for n, value in net.image_encoder.named_parameters():
if "Adapter" not in n:
value.requires_grad = False
else:
value.requires_grad = True
else:
for n, value in net.image_encoder.named_parameters():
value.requires_grad = True

imge= net.image_encoder(imgs)

with torch.no_grad():
if args.net == 'sam':
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
points=pt,
boxes=None,
Expand All @@ -146,7 +150,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
labels=labels_torch,
)

if args.net == 'sam':
if args.net == 'sam' or args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
Expand Down Expand Up @@ -276,8 +280,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
'''test'''
with torch.no_grad():
imge= net.image_encoder(imgs)

if args.net == 'sam':
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
points=pt,
boxes=None,
Expand All @@ -290,7 +293,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
labels=labels_torch,
)

if args.net == 'sam':
if args.net == 'sam' or args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
Expand Down
3 changes: 2 additions & 1 deletion models/ImageEncoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .adapter_block import AdapterBlock
from .tinyvit.tiny_vit import TinyViT
from .vit import AdapterBlock, Block
Binary file modified models/ImageEncoder/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
213 changes: 213 additions & 0 deletions models/ImageEncoder/tinyvit/adapter_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...common import Adapter
from .utils import Conv2d_BN, DropPath, Mlp


class Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_heads=8,
attn_ratio=4,
resolution=(14, 14),
):
super().__init__()
# (h, w)
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2

self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)

points = list(itertools.product(
range(resolution[0]), range(resolution[1])))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs',
torch.LongTensor(idxs).view(N, N),
persistent=False)

@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
# self.register_buffer('ab',
# self.attention_biases[:, self.attention_bias_idxs],
# persistent=False)
def forward(self, x): # x (B,N,C)
B, N, _ = x.shape

# Normalization
x = self.norm(x)

qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -
1).split([self.key_dim, self.key_dim, self.d], dim=3)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)

attn = (
(q @ k.transpose(-2, -1)) * self.scale
+
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab)
)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x

class TinyViTAdapterBlock(nn.Module):
r""" TinyViT Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int, int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between
Attention and MLP. Default: 3
activation: the activation function. Default: nn.GELU
"""

def __init__(self, args, dim, input_resolution, num_heads, window_size=7,
mlp_ratio=4., drop=0., drop_path=0.,
local_conv_size=3,
activation=nn.GELU,
):
super().__init__()
self.args = args,
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
self.mlp_ratio = mlp_ratio

self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()

assert dim % num_heads == 0, 'dim must be divisible by num_heads'
head_dim = dim // num_heads

window_resolution = (window_size, window_size)
self.attn = Attention(dim, head_dim, num_heads,
attn_ratio=1, resolution=window_resolution)

mlp_hidden_dim = int(dim * mlp_ratio)
mlp_activation = activation
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=mlp_activation, drop=drop)

self.MLP_Adapter = Adapter(dim, skip_connect=False) # MLP-adapter, no skip connection
self.Space_Adapter = Adapter(dim) # with skip connection
self.Depth_Adapter = Adapter(dim, skip_connect=False) # no skip connection

pad = local_conv_size // 2
self.local_conv = Conv2d_BN(
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)

def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size - H %
self.window_size) % self.window_size
pad_r = (self.window_size - W %
self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0

if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))

pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
B * nH * nW, self.window_size * self.window_size, C)

## 3d branch
if self.args[0].thd:
from einops import rearrange
hh, ww = x.shape[1], x.shape[2]
depth = self.args.chunk
xd = rearrange(x, '(b d) h w c -> (b h w) d c ', d=depth)
# xd = rearrange(xd, '(b d) n c -> (b n) d c', d=self.in_chans)
xd = self.norm1(xd)
dh, _ = closest_numbers(depth)
xd = rearrange(xd, 'bhw (dh dw) c -> bhw dh dw c', dh= dh)
xd = self.Depth_Adapter(self.attn(xd))
xd = rearrange(xd, '(b n) dh dw c ->(b dh dw) n c', n= hh * ww )

x = self.attn(x)
x = self.Space_Adapter(x)

if self.args[0].thd:
xd = rearrange(xd, 'b (hh ww) c -> b hh ww c', hh= hh )
x = x + xd

# window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size,
C).transpose(2, 3).reshape(B, pH, pW, C)

if padding:
x = x[:, :H, :W].contiguous()

x = x.view(B, L, C)

x = res_x + self.drop_path(x)

x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)

x = x + self.drop_path(self.mlp(x)) + 0.5 * self.MLP_Adapter(x)
return x

def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"

def closest_numbers(target):
a = int(target ** 0.5)
b = a + 1
while True:
if a * b == target:
return (a, b)
elif a * b < target:
b += 1
else:
a -= 1
Loading

0 comments on commit 80c325b

Please sign in to comment.