Skip to content

Commit

Permalink
add lvis
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid committed Dec 14, 2023
1 parent 3dbd62e commit cae8cb3
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@
)

val_evaluator_Flickr30k = dict(
type='Flickr30kMetric',
ann_file=data_root+'mdetr_annotations/final_flickr_separateGT_val.json'
type='Flickr30kMetric'
)

test_evaluator_Flickr30k = dict(
type='Flickr30kMetric',
ann_file=data_root+'mdetr_annotations/final_flickr_separateGT_test.json'
type='Flickr30kMetric'
)

# ----------Config---------- #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
sampler=dict(_delete_=True, type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
_delete_=True,
type='RepeatDataset',
times=10,
dataset=dict(
_delete_=True,
type='CocoDataset',
data_root=data_root,
metainfo=metainfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@
)

val_evaluator_Flickr30k = dict(
type='Flickr30kMetric',
ann_file=data_root+'mdetr_annotations/final_flickr_separateGT_val.json'
type='Flickr30kMetric'
)

test_evaluator_Flickr30k = dict(
type='Flickr30kMetric',
ann_file=data_root+'mdetr_annotations/final_flickr_separateGT_test.json'
type='Flickr30kMetric'
)

# ----------Config---------- #
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
_base_ = '../grounding_dino_swin-t_pretrain_obj365.py'

data_root = 'data/coco/'

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
dict(
type='RandomSamplingNegPos',
tokenizer_name=_base_.lang_model_name,
num_sample_negative=85,
# change this
label_map_file='data/coco/annotations/lvis_v1_label_map.json',
max_tokens=256),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction', 'text',
'custom_entities', 'tokens_positive', 'dataset_mode'))
]

train_dataloader = dict(
dataset=dict(_delete_=True,
type='ClassBalancedDataset',
oversample_thr=1e-3,
dataset=dict(
type='ODVGDataset',
data_root=data_root,
need_text=False,
label_map_file='annotations/lvis_v1_label_map.json',
ann_file='annotations/lvis_v1_train_od.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline)))

val_dataloader = dict(
dataset=dict(
data_root=data_root,
type='LVISV1Dataset',
ann_file='annotations/lvis_v1_minival_inserted_image_name.json',
data_prefix=dict(img='')))
test_dataloader = val_dataloader

val_evaluator = dict(
_delete_=True,
type='LVISFixedAPMetric',
ann_file=data_root +
'annotations/lvis_v1_minival_inserted_image_name.json')
test_evaluator = val_evaluator

optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'backbone': dict(lr_mult=0.1),
# 'language_model': dict(lr_mult=0),
}))

# learning policy
max_epochs = 12
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[11],
gamma=0.1)
]
train_cfg = dict(max_epochs=max_epochs, val_interval=1)

default_hooks = dict(checkpoint=dict(max_keep_ckpts=1, save_best='auto'))

load_from = ''
98 changes: 98 additions & 0 deletions tools/dataset_converters/lvis2odvg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import json
import os.path

import jsonlines
from lvis import LVIS
from tqdm import tqdm


key_list_lvis = [i for i in range(1203)]
val_list_lvis = [i for i in range(1, 1204)]


def dump_lvis_label_map(args):
with open(args.input, 'r') as f:
j = json.load(f)
o_dict = {}
for category in j['categories']:
index = str(int(category['id']) - 1)
name = category['name']
o_dict[index] = name
if args.output is None:
output = os.path.dirname(args.input) + '/lvis_v1_label_map.json'
else:
output = os.path.dirname(args.output) + '/lvis_v1_label_map.json'
with open(output, 'w') as f:
json.dump(o_dict, f)


def lvis2odvg(args):
lvis = LVIS(args.input)
cats = lvis.load_cats(lvis.get_cat_ids())
nms = {cat['id']: cat['name'] for cat in cats}
metas = []
if args.output is None:
out_path = args.input[:-5] + '_od.json'
else:
out_path = args.output

key_list = key_list_lvis
val_list = val_list_lvis
dump_lvis_label_map(args)

for img_id, img_info in tqdm(lvis.imgs.items()):
file_name = img_info['coco_url'].replace(
'http://images.cocodataset.org/', '')
ann_ids = lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = lvis.load_anns(ann_ids)
instance_list = []
for ann in raw_ann_info:
if ann.get('ignore', False):
print(f'invalid ignore box of {ann}')
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
print(f'invalid wh box of {ann}')
continue
if ann['area'] <= 0 or w < 1 or h < 1:
print(f'invalid area box of {ann}, w={img_info["width"]}, h={img_info["height"]}')
continue

if ann.get('iscrowd', False):
print(f'invalid iscrowd box of {ann}')
continue

bbox_xyxy = [x1, y1, x1 + w, y1 + h]
label = ann['category_id']
category = nms[label]
ind = val_list.index(label)
label_trans = key_list[ind]
instance_list.append({
'bbox': bbox_xyxy,
'label': label_trans,
'category': category
})
metas.append({
'filename': file_name,
'height': img_info['height'],
'width': img_info['width'],
'detection': {
'instances': instance_list
}
})

with jsonlines.open(out_path, mode='w') as writer:
writer.write_all(metas)

print('save to {}'.format(out_path))


if __name__ == '__main__':
parser = argparse.ArgumentParser('lvis to odvg format.', add_help=True)
parser.add_argument('input', type=str, help='input list name')
parser.add_argument("--output", "-o", type=str, help='input list name')
args = parser.parse_args()
lvis2odvg(args)

0 comments on commit cae8cb3

Please sign in to comment.