Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
conansherry committed Nov 30, 2019
2 parents 7b1e756 + 5d09871 commit 72c935d
Show file tree
Hide file tree
Showing 65 changed files with 2,319 additions and 143 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ dist/
.idea

# project dirs
/detectron2/model_zoo/configs
/datasets
/models
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ FROM nvidia/cuda:10.1-cudnn7-devel
# To use this Dockerfile:
# 1. `nvidia-docker build -t detectron2:v0 .`
# 2. `nvidia-docker run -it --name detectron2 detectron2:v0`
#
# To enable GUI support (Linux):
# 1. Grant the container temporary access to your x server (will be reverted at reboot of your host):
# `xhost +local:`docker inspect --format='{{ .Config.Hostname }}' detectron2``
# 2. `nvidia-docker run -it --name detectron2 --env="DISPLAY" --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" detectron2:v0`


ENV DEBIAN_FRONTEND noninteractive
Expand Down
6 changes: 4 additions & 2 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@ For more advanced tutorials, refer to our [documentation](https://detectron2.rea
```
python demo/demo.py --config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
--input input1.jpg input2.jpg \
[--other-options]
--opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
```
The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
This command will run the inference and show visualizations in an OpenCV window.

For details of the command line arguments, see `demo.py -h`. Some common ones are:
* To run __on your webcam__, replace `--input files` with `--webcam`.
* To run __on a video__, replace `--input files` with `--video-input video.mp4`.
* To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
* To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.


### Train a Standard Model
### Use Detectron2 in Command Line

We provide a script in "tools/train_net.py", that is made to train
all the configs provided in detectron2.
Expand All @@ -46,7 +48,7 @@ python tools/train_net.py --num-gpus 8 \
--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
```

The configs are made for 8-GPU training. To train on 1 GPU, use:
The configs are made for 8-GPU training. To train on 1 GPU, change the batch size with:
```
python tools/train_net.py \
--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
Expand Down
13 changes: 9 additions & 4 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The [Dockerfile](https://github.com/facebookresearch/detectron2/blob/master/Dock
also installs detectron2 with a few simple commands.

### Requirements
- Linux or macOS
- Python >= 3.6
- PyTorch 1.3
- [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
Expand All @@ -16,19 +17,19 @@ also installs detectron2 with a few simple commands.
- VS2019/CUDA10.1


### Build detectron2
### Build Detectron2

After having the above dependencies, run:
```
git clone git@github.com:facebookresearch/detectron2.git
git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
python setup.py build develop
# or if you are on macOS
# MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build develop
# or, as an alternative to `setup.py`, do
# pip install .
# pip install [--editable] .
```
Note: you may need to rebuild detectron2 after reinstalling a different build of PyTorch.

Expand Down Expand Up @@ -61,4 +62,8 @@ Note: you may need to rebuild detectron2 after reinstalling a different build of
```
print valid outputs at the time you build detectron2.

+ "invalid device function": you build detectron2 with one version of CUDA but run it with a different version.
+ "invalid device function" or "no kernel image is available for execution": two possibilities:
* You build detectron2 with one version of CUDA but run it with a different version.
* Detectron2 is not built with the correct compute compability for the GPU model.
The compute compability defaults to match the GPU found on the machine during building,
and can be controlled by `TORCH_CUDA_ARCH_LIST` environment variable during installation.
2 changes: 1 addition & 1 deletion datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ coco/
panoptic_{train,val}2017.json
panoptic_{train,val}2017/
# png annotations
panoptic_stuff_{train,val}2017/ # generated by the script mentioned below
panoptic_stuff_{train,val}2017/ # generated by the script mentioned below
```

Install panopticapi by:
Expand Down
2 changes: 1 addition & 1 deletion demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_parser():
)
parser.add_argument(
"--opts",
help="Modify model config options using the command-line",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
Expand Down
2 changes: 1 addition & 1 deletion detectron2/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# File:


from . import model_zoo as _UNUSED # register the handler
from . import catalog as _UNUSED # register the handler
from .detection_checkpoint import DetectionCheckpointer
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer

Expand Down
File renamed without changes.
16 changes: 10 additions & 6 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"

# If the WEIGHT starts with a catalog://, like :R-50, the code will look for
# the path in ModelCatalog. Else, it will use it as the specified absolute
# path
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
# to be loaded to the model. You can find available models in the model zoo.
_C.MODEL.WEIGHTS = ""

# Values to be used for image normalization (BGR order)
# Values to be used for image normalization (BGR order).
# To train on images of different number of channels, just set different mean & std.
# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
# When using pre-trained models in Detectron1 or any MSRA models,
Expand Down Expand Up @@ -75,6 +75,8 @@
# with BGR being the one exception. One can set image format to BGR, we will
# internally use RGB for conversion and flip the channels over
_C.INPUT.FORMAT = "BGR"
# The ground truth mask format that the model will use.
# Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"


Expand Down Expand Up @@ -547,8 +549,10 @@
# Set seed to positive to use a fixed seed. Note that a fixed seed does not
# guarantee fully deterministic behavior.
_C.SEED = -1
# Benchmark different cudnn algorithms. It has large overhead for about 10k
# iterations. It usually hurts total time, but can benefit for certain models.
# Benchmark different cudnn algorithms.
# If input images have very different sizes, this option will have large overhead
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
_C.CUDNN_BENCHMARK = False

# global config is for quick hack purposes.
Expand Down
25 changes: 15 additions & 10 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def load_proposals_into_dataset(dataset_dicts, proposal_file):
"""
Load precomputed object proposals into the dataset.
The proposal file should be a pickled dict with the following keys:
- "ids": list[int] or list[str], the image ids
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
corresponding to the boxes.
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
Args:
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
proposal_file (str): file path of pre-computed proposals, in pkl format.
Expand All @@ -122,19 +129,17 @@ def load_proposals_into_dataset(dataset_dicts, proposal_file):
if key in proposals:
proposals[rename_keys[key]] = proposals.pop(key)

# Remove proposals whose ids are not in dataset
img_ids = set({entry["image_id"] for entry in dataset_dicts})
keep = [i for i, id in enumerate(proposals["ids"]) if id in img_ids]
# Sort proposals by ids following the image order in dataset
keep = sorted(keep)
for key in ["boxes", "ids", "objectness_logits"]:
proposals[key] = [proposals[key][i] for i in keep]
# Fetch the indexes of all proposals that are in the dataset
# Convert image_id to str since they could be int.
img_ids = set({str(record["image_id"]) for record in dataset_dicts})
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}

# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS

for i, record in enumerate(dataset_dicts):
# Sanity check that these proposals are for the correct image id
assert record["image_id"] == proposals["ids"][i]
for record in dataset_dicts:
# Get the index of the proposal
i = id_to_index[str(record["image_id"])]

boxes = proposals["boxes"][i]
objectness_logits = proposals["objectness_logits"][i]
Expand Down
3 changes: 3 additions & 0 deletions detectron2/data/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def register(name, func):
func (callable): a callable which takes no arguments and returns a list of dicts.
"""
assert callable(func), "You must register a function with `DatasetCatalog.register`!"
assert name not in DatasetCatalog._REGISTERED, "Dataset '{}' is already registered!".format(
name
)
DatasetCatalog._REGISTERED[name] = func

@staticmethod
Expand Down
19 changes: 10 additions & 9 deletions detectron2/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
import json
import numpy as np
import imagesize

from PIL import Image

Expand Down Expand Up @@ -262,12 +263,11 @@ def file2id(folder_path, file_path):

dataset_dicts = []
for (img_path, gt_path) in zip(input_files, gt_files):
local_path = PathManager.get_local_path(gt_path)
w, h = imagesize.get(local_path)
record = {}
record["file_name"] = img_path
record["sem_seg_file_name"] = gt_path
with PathManager.open(gt_path, "rb") as f:
img = Image.open(f)
w, h = img.size
record["height"] = h
record["width"] = w
dataset_dicts.append(record)
Expand Down Expand Up @@ -303,9 +303,9 @@ def convert_to_coco_dict(dataset_name):
coco_images = []
coco_annotations = []

for image_dict in dataset_dicts:
for image_id, image_dict in enumerate(dataset_dicts):
coco_image = {
"id": image_dict["image_id"],
"id": image_dict.get("image_id", image_id),
"width": image_dict["width"],
"height": image_dict["height"],
"file_name": image_dict["file_name"],
Expand All @@ -331,10 +331,11 @@ def convert_to_coco_dict(dataset_name):
area = polygons.area()[0].item()
else:
# Computing areas using bounding boxes
area = Boxes([bbox]).area()[0].item()
bbox_xy = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
area = Boxes([bbox_xy]).area()[0].item()

if "keypoints" in annotation:
keypoints = annotation["keypoints"] # list[int]
keypoints = annotation["keypoints"] # list[int]
for idx, v in enumerate(keypoints):
if idx % 3 != 2:
# COCO's segmentation coordinates are floating points in [0, H or W],
Expand All @@ -351,8 +352,8 @@ def convert_to_coco_dict(dataset_name):
# linking annotations to images
# "id" field must start with 1
coco_annotation["id"] = len(coco_annotations) + 1
coco_annotation["image_id"] = image_dict["image_id"]
coco_annotation["bbox"] = bbox
coco_annotation["image_id"] = coco_image["id"]
coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
coco_annotation["area"] = area
coco_annotation["category_id"] = annotation["category_id"]
coco_annotation["iscrowd"] = annotation.get("iscrowd", 0)
Expand Down
2 changes: 1 addition & 1 deletion detectron2/data/datasets/register_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def register_coco_panoptic_separated(
Args:
name (str): the name that identifies a dataset,
e.g. "coco_2017_train_panoptic"
metadata (str): extra metadata associated with this dataset.
metadata (dict): extra metadata associated with this dataset.
image_root (str): directory which contains all the images
panoptic_root (str): directory which contains panoptic annotation images
panoptic_json (str): path to the json panoptic annotation file
Expand Down
15 changes: 13 additions & 2 deletions detectron2/data/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def read_image(file_name, format=None):
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)

image = ImageOps.exif_transpose(image)
# capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
try:
image = ImageOps.exif_transpose(image)
except Exception:
pass

if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
Expand All @@ -73,7 +77,13 @@ def check_image_size(dataset_dict, image):
expected_wh = (dataset_dict["width"], dataset_dict["height"])
if not image_wh == expected_wh:
raise SizeMismatchError(
"mismatch (W,H), got {}, expect {}".format(image_wh, expected_wh)
"Mismatched (W,H){}, got {}, expect {}".format(
" for image " + dataset_dict["file_name"]
if "file_name" in dataset_dict
else "",
image_wh,
expected_wh,
)
)

# To ensure bbox always remap to original image size
Expand All @@ -82,6 +92,7 @@ def check_image_size(dataset_dict, image):
if "height" not in dataset_dict:
dataset_dict["height"] = image.shape[0]


def transform_proposals(dataset_dict, image_shape, transforms, min_box_side_len, proposal_topk):
"""
Apply transformations to the proposals in dataset_dict, if any.
Expand Down
41 changes: 35 additions & 6 deletions detectron2/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,20 @@ class DefaultPredictor:
The predictor takes an BGR image, resizes it to the specified resolution,
runs the model and produces a dict of predictions.
This predictor takes care of model loading and input preprocessing for you.
If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
.. code-block:: python
pred = DefaultPredictor(cfg)
outputs = pred(inputs)
"""

def __init__(self, cfg):
Expand Down Expand Up @@ -210,6 +221,14 @@ class DefaultTrainer(SimpleTrainer):
scheduler:
checkpointer (DetectionCheckpointer):
cfg (CfgNode):
Examples:
.. code-block:: python
trainer = DefaultTrainer(cfg)
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
trainer.train()
"""

def __init__(self, cfg):
Expand Down Expand Up @@ -415,7 +434,10 @@ def build_evaluator(cls, cfg, dataset_name):
It is not implemented by default.
"""
raise NotImplementedError
raise NotImplementedError(
"Please either implement `build_evaluator()` in subclasses, or pass "
"your evaluator as arguments to `DefaultTrainer.test()`."
)

@classmethod
def test(cls, cfg, model, evaluators=None):
Expand Down Expand Up @@ -443,11 +465,18 @@ def test(cls, cfg, model, evaluators=None):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
evaluator = (
evaluators[idx]
if evaluators is not None
else cls.build_evaluator(cfg, dataset_name)
)
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
Expand Down
Loading

0 comments on commit 72c935d

Please sign in to comment.