-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcbm.py
88 lines (70 loc) · 3.39 KB
/
cbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import json
import torch
import data_utils
class CBM_model(torch.nn.Module):
def __init__(self, backbone_name, W_c, W_g, b_g, proj_mean, proj_std, device="cuda"):
super().__init__()
model, _ = data_utils.get_target_model(backbone_name, device)
#remove final fully connected layer
if "clip" in backbone_name:
self.backbone = model
elif "cub" in backbone_name:
self.backbone = lambda x: model.features(x)
else:
self.backbone = torch.nn.Sequential(*list(model.children())[:-1])
self.proj_layer = torch.nn.Linear(in_features=W_c.shape[1], out_features=W_c.shape[0], bias=False).to(device)
self.proj_layer.load_state_dict({"weight":W_c})
self.proj_mean = proj_mean
self.proj_std = proj_std
self.final = torch.nn.Linear(in_features = W_g.shape[1], out_features=W_g.shape[0]).to(device)
self.final.load_state_dict({"weight":W_g, "bias":b_g})
self.concepts = None
def forward(self, x):
x = self.backbone(x)
x = torch.flatten(x, 1)
x = self.proj_layer(x)
proj_c = (x-self.proj_mean)/self.proj_std
x = self.final(proj_c)
return x, proj_c
class standard_model(torch.nn.Module):
def __init__(self, backbone_name, W_g, b_g, proj_mean, proj_std, device="cuda"):
super().__init__()
model, _ = data_utils.get_target_model(backbone_name, device)
#remove final fully connected layer
if "clip" in backbone_name:
self.backbone = model
elif "cub" in backbone_name:
self.backbone = lambda x: model.features(x)
else:
self.backbone = torch.nn.Sequential(*list(model.children())[:-1])
self.proj_mean = proj_mean
self.proj_std = proj_std
self.final = torch.nn.Linear(in_features = W_g.shape[1], out_features=W_g.shape[0]).to(device)
self.final.load_state_dict({"weight":W_g, "bias":b_g})
self.concepts = None
def forward(self, x):
x = self.backbone(x)
x = torch.flatten(x, 1)
proj_c = (x-self.proj_mean)/self.proj_std
x = self.final(proj_c)
return x, proj_c
def load_cbm(load_dir, device):
with open(os.path.join(load_dir ,"args.txt"), 'r') as f:
args = json.load(f)
W_c = torch.load(os.path.join(load_dir ,"W_c.pt"), map_location=device)
W_g = torch.load(os.path.join(load_dir, "W_g.pt"), map_location=device)
b_g = torch.load(os.path.join(load_dir, "b_g.pt"), map_location=device)
proj_mean = torch.load(os.path.join(load_dir, "proj_mean.pt"), map_location=device)
proj_std = torch.load(os.path.join(load_dir, "proj_std.pt"), map_location=device)
model = CBM_model(args['backbone'], W_c, W_g, b_g, proj_mean, proj_std, device)
return model
def load_std(load_dir, device):
with open(os.path.join(load_dir ,"args.txt"), 'r') as f:
args = json.load(f)
W_g = torch.load(os.path.join(load_dir, "W_g.pt"), map_location=device)
b_g = torch.load(os.path.join(load_dir, "b_g.pt"), map_location=device)
proj_mean = torch.load(os.path.join(load_dir, "proj_mean.pt"), map_location=device)
proj_std = torch.load(os.path.join(load_dir, "proj_std.pt"), map_location=device)
model = standard_model(args['backbone'], W_g, b_g, proj_mean, proj_std, device)
return model