-
Notifications
You must be signed in to change notification settings - Fork 5
/
evaluator.py
75 lines (65 loc) · 3.06 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Optional
from wasabi import msg
import pytorch_lightning as pl
import torch.cuda
import typer
from omegaconf import OmegaConf
from callbacks import ConfidenceIntervalCallback
from clr_gat import CLRGAT
from dataloaders.dataloaders import UnlabelledDataModule
app = typer.Typer()
# pl.seed_everything(72)
@app.command()
def clrgat(dataset: str, ckpt_path: str, datapath: str, eval_ways: int, eval_shots: int, query_shots: int,
sup_finetune: str, config: Optional[str] = None, adapt: str = "ot", distance: str = "euclidean"):
if torch.cuda.is_available():
map_location = "cuda"
else:
map_location = "cpu"
msg.divider("Eval Setting")
msg.info(f"Eval Ways: {eval_ways}")
msg.info(f"Eval Shots: {eval_shots}")
if config is not None:
cfg = OmegaConf.load(config)
msg.divider("Model Setup")
with msg.loading("Loading model"):
# uncomment for older checkpoints
model = CLRGAT.load_from_checkpoint(checkpoint_path=ckpt_path,
mpnn_dev=map_location,
arch="conv4",
# out_planes=64,
# average_end=False,
label_cleansing_opts={
"use": False,
},
distance=distance,
use_hms=False,
use_projector=False,
projector_h_dim=2048,
projector_out_dim=256,
eval_ways=eval_ways,
sup_finetune=sup_finetune,
sup_finetune_epochs=25,
# map_location=map_location,
hparams_file=config)
model.mpnn_opts["adapt"] = adapt
msg.divider("Adaptation Type")
msg.info(f" Adaptation type in use: {model.mpnn_opts['adapt']}")
with msg.loading("Loading datamodule"):
datamodule = UnlabelledDataModule(dataset=dataset,
datapath=datapath,
split='test',
img_size_orig=(84, 84) if dataset == 'miniimagenet' else (28, 28),
img_size_crop=(60, 60),
eval_ways=eval_ways,
eval_support_shots=eval_shots,
eval_query_shots=query_shots)
msg.divider("Beginning testing")
trainer = pl.Trainer(
gpus=-1,
limit_test_batches=600,
callbacks=[ConfidenceIntervalCallback(log_to_wb=False)]
)
trainer.test(model=model, datamodule=datamodule)
if __name__ == '__main__':
app()