-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
93 lines (79 loc) · 2.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import os.path as osp
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from infection.datasets import SegDataModule, TransformersDataModule
from infection.models import SegModel
from infection.callbacks import WandbCallback, VisualizerCallback
from lightning import seed_everything
import hydra
from omegaconf import DictConfig, OmegaConf
seed_everything(2023)
@hydra.main(version_base=None)
def main(args: DictConfig):
if args.model.model_name not in ['maskformer', 'mask2former']:
datamodule = SegDataModule(
train_img_dir=args.data.train_img_dir,
train_ann_dir=args.data.train_ann_dir,
val_img_dir=args.data.val_img_dir,
val_ann_dir=args.data.val_ann_dir,
batch_size=args.data.batch_size,
image_size=args.data.image_size,
use_mosaic=args.data.use_mosaic,
)
else:
datamodule = TransformersDataModule(
model_name=args.model.model_name,
train_img_dir=args.data.train_img_dir,
train_ann_dir=args.data.train_ann_dir,
val_img_dir=args.data.val_img_dir,
val_ann_dir=args.data.val_ann_dir,
batch_size=args.data.batch_size,
image_size=args.data.image_size,
use_mosaic=args.data.use_mosaic,
)
model = SegModel(
model_name=args.model.model_name,
loss_configs=args.loss,
optimizer_configs=args.optimizer,
)
if args.model.get('pretrained', None) is not None:
checkpoint = torch.load(args.model.pretrained)
model.load_state_dict(checkpoint['state_dict'])
if args.model.get('freeze_backbone', False):
model.freeze_backbone()
print("Freeze backbone")
else:
model.unfreeze()
checkpoint_callback = ModelCheckpoint(
dirpath=osp.join(args.trainer.save_dir, 'checkpoints'),
monitor="val_high_vegetation_IoU",
mode="max",
filename="best",
save_top_k=1,
save_last=True
)
os.makedirs(args.trainer.save_dir, exist_ok=True)
# wandb_callback = WandbCallback(
# username = args.logger.wandb.username,
# project_name=args.logger.wandb.project_name,
# group_name=args.logger.wandb.group_name,
# save_dir=args.trainer.save_dir,
# config_dict=args
# )
visualizer_callback = VisualizerCallback(
save_dir=args.trainer.save_dir,
)
# Save configs
with open(osp.join(args.trainer.save_dir, "pipeline.yaml"), "w") as f:
OmegaConf.save(config=args, f=f)
trainer = pl.Trainer(
max_epochs=args.trainer.max_epochs,
callbacks=[visualizer_callback, checkpoint_callback], #
log_every_n_steps=args.trainer.log_every_n_steps,
default_root_dir=args.trainer.save_dir,
)
trainer.fit(model, datamodule=datamodule)
if __name__ == "__main__":
main()