forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_corrosion.py
217 lines (174 loc) · 7.91 KB
/
train_corrosion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common libraries
import numpy as np
import os, json, cv2, random
import pycocotools
import skimage.draw
from PIL import Image, ImageDraw
from progress.bar import Bar
import datetime
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper
import detectron2.utils.comm as comm
import torch
import time
import logging
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer, launch, default_argument_parser
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog, build_detection_test_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.structures import BoxMode
from tools.darwin import *
categories = ["Corrosion"]
# from https://medium.com/@apofeniaco/training-on-detectron2-with-a-validation-set-and-plot-loss-on-it-to-avoid-overfitting-6449418fbf4e
class LossEvalHook(HookBase):
def __init__(self, eval_period, model, data_loader):
self._model = model
self._period = eval_period
self._data_loader = data_loader
def _do_loss_eval(self):
# Copying inference_on_dataset from evaluator.py
total = len(self._data_loader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
losses = []
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
loss_batch = self._get_loss(inputs)
losses.append(loss_batch)
mean_loss = np.mean(losses)
self.trainer.storage.put_scalar('validation_loss', mean_loss)
comm.synchronize()
return losses
def _get_loss(self, data):
# How loss is calculated on train_loop
metrics_dict = self._model(data)
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
return total_losses_reduced
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
self.trainer.storage.put_scalars(timetest=12)
def setup(args):
#set the number of GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# REGISTER DATASETS
dataset_directory = "/home/ndserv05/Documents/Data/Corrosion"
# register training and validation datasets with detectron
for d in ["train", "val"]:
# get_darwin_dataset(dataset_directory, d)
DatasetCatalog.register("corrosion_" + d, lambda d=d: get_darwin_dataset(dataset_directory, d, categories))
MetadataCatalog.get("corrosion_" + d).set(thing_classes=categories)
# number of epochs to train
EPOCHS = 60
NUM_GPU = 2
# get size of train and val datasets
TRAIN_SIZE = len(DatasetCatalog.get("corrosion_train"))
VAL_SIZE = len(DatasetCatalog.get("corrosion_val"))
# CONFIGURATION
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
cfg.OUTPUT_DIR = "./output/" + "Corrosion_" + "{:%Y%m%dT%H%M}".format(datetime.datetime.now())
cfg.INPUT.MASK_FORMAT = "bitmask"
cfg.DATASETS.TRAIN = ("corrosion_train",)
cfg.DATASETS.TEST = ()
cfg.TEST.EVAL_PERIOD = 887 # eval period should be one epoch, which is the number of images in training set divided by num_gpu*IMS_PER_BATCH
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml") # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
cfg.SOLVER.MAX_ITER = int(TRAIN_SIZE/(NUM_GPU*cfg.SOLVER.IMS_PER_BATCH)*EPOCHS) # one iteration is 4 images so one epoch is around 887 iterations.
# cfg.SOLVER.STEPS = [] # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (corrosion)
# print(cfg.SOLVER.MAX_ITER)
return cfg
# TRAINER
# need to subclass in order to implement the build_evaluator() function
class myTrainer(DefaultTrainer):
# @classmethod
# def build_evaluator(cls, cfg, dataset):
# # the dataset is *not* in COCO format but this is handled by the evaluator
# return COCOEvaluator(dataset, ("bbox", "segm"), False, output_dir=cfg.OUTPUT_DIR)
def build_hooks(self):
hooks = super().build_hooks()
hooks.insert(-1,LossEvalHook(
self.cfg.TEST.EVAL_PERIOD,
self.model,
build_detection_test_loader(
self.cfg,
"corrosion_val",
DatasetMapper(self.cfg,True)
)
))
return hooks
def main(args):
cfg = setup(args)
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = myTrainer(cfg)
# trainer.build_evaluator(cfg, "corrosion_val")# this is not necessary
trainer.build_hooks()
trainer.resume_or_load(resume=False)
return trainer.train()
if __name__ == '__main__':
launch(
main,
num_gpus_per_machine=2,
num_machines=1,
machine_rank=0,
dist_url="auto",
args=({},)
)
# INFERENCE
# cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3 # set a custom testing threshold
# predictor = DefaultPredictor(cfg)
# # TEST INFERENCE
# dataset_dicts = get_darwin_dataset(dataset_directory, 'val')
# for d in dataset_dicts:
# im = cv2.imread(d["file_name"])
# # im = Image.open(d["file_name"])
# outputs = predictor(im) # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
# print(outputs)
# v = Visualizer(im, metadata=corrosion_metadata, scale=1)
# out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
# # cv2_imshow(out.get_image()[:, :, ::-1])
# Image.fromarray(out.get_image()[:, :, ::-1]).save(str(d['image_id']) + '.jpg')
# # EVALUATE MODEL
# # In[9]:
# evaluator = COCOEvaluator("corrosion_val", ("bbox", "segm"), False, output_dir=cfg.OUTPUT_DIR)
# val_loader = build_detection_test_loader(cfg, "corrosion_val")
# print(inference_on_dataset(trainer.model, val_loader, evaluator))
# # another equivalent way to evaluate the model is to use `trainer.test`