pytorch-retinanet-master

所属分类:人工智能/神经网络/深度学习
开发工具:Python
文件大小:967KB
下载次数:3
上传日期:2020-06-30 22:56:37
上 传 者1192483306
说明:  基于python与pytorch的目标检测算法RetinaNet,可以训练自己的数据集
(RetinaNet, a target detection algorithm based on python and pytorch , you can train your own data set)

文件列表:
LICENSE (11357, 2019-12-31)
coco_validation.py (1252, 2019-12-31)
images (0, 2019-12-31)
images\1.jpg (107574, 2019-12-31)
images\3.jpg (151222, 2019-12-31)
images\4.jpg (151935, 2019-12-31)
images\5.jpg (166511, 2019-12-31)
images\6.jpg (114637, 2019-12-31)
images\7.jpg (187220, 2019-12-31)
images\8.jpg (86040, 2019-12-31)
retinanet (0, 2019-12-31)
retinanet\__init__.py (0, 2019-12-31)
retinanet\anchors.py (4095, 2019-12-31)
retinanet\coco_eval.py (2598, 2019-12-31)
retinanet\csv_eval.py (8989, 2019-12-31)
retinanet\dataloader.py (15108, 2019-12-31)
retinanet\losses.py (5342, 2019-12-31)
retinanet\model.py (11976, 2019-12-31)
retinanet\oid_dataset.py (9572, 2019-12-31)
retinanet\utils.py (4105, 2019-12-31)
train.py (6274, 2019-12-31)
visualize.py (3065, 2019-12-31)

# pytorch-retinanet ![img3](https://github.com/yhenon/pytorch-retinanet/blob/master/images/3.jpg) ![img5](https://github.com/yhenon/pytorch-retinanet/blob/master/images/5.jpg) Pytorch implementation of RetinaNet object detection as described in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollr. This implementation is primarily designed to be easy to read and simple to modify. ## Results Currently, this repo achieves 33.5% mAP at 600px resolution with a Resnet-50 backbone. The published result is 34.0% mAP. The difference is likely due to the use of Adam optimizer instead of SGD with weight decay. ## Installation 1) Clone this repo 2) Install the required packages: ``` apt-get install tk-dev python-tk ``` 3) Install the python packages: ``` pip install pandas pip install pycocotools pip install opencv-python pip install requests ``` ## Training The network can be trained using the `train.py` script. Currently, two dataloaders are available: COCO and CSV. For training on coco, use ``` python train.py --dataset coco --coco_path ../coco --depth 50 ``` For training using a custom dataset, with annotations in CSV format (see below), use ``` python train.py --dataset csv --csv_train --csv_classes --csv_val ``` Note that the --csv_val argument is optional, in which case no validation will be performed. ## Pre-trained model A pre-trained model is available at: - https://drive.google.com/open?id=1yLmjq3JtXi841yXWBxst0coAgR26MNBS (this is a pytorch state dict) The state dict model can be loaded using: ``` retinanet = model.resnet50(num_classes=dataset_train.num_classes(),) retinanet.load_state_dict(torch.load(PATH_TO_WEIGHTS)) ``` ## Validation Run `coco_validation.py` to validate the code on the COCO dataset. With the above model, run: `python coco_validation.py --coco_path ~/path/to/coco --model_path /path/to/model/coco_resnet_50_map_0_335_state_dict.pt` This produces the following results: ``` Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.335 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.499 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.357 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.167 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.369 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.466 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.282 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.429 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.458 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.255 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.508 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.597 ``` ## Visualization To visualize the network detection, use `visualize.py`: ``` python visualize.py --dataset coco --coco_path ../coco --model ``` This will visualize bounding boxes on the validation set. To visualise with a CSV dataset, use: ``` python visualize.py --dataset csv --csv_classes --csv_val --model ``` ## Model The retinanet model uses a resnet backbone. You can set the depth of the resnet model using the --depth argument. Depth must be one of 18, 34, 50, 101 or 152. Note that deeper models are more accurate but are slower and use more memory. ## CSV datasets The `CSVGenerator` provides an easy way to define your own datasets. It uses two CSV files: one file containing annotations and one file containing a class name to ID mapping. ### Annotations format The CSV file with annotations should contain one annotation per line. Images with multiple bounding boxes should use one row per bounding box. Note that indexing for pixel values starts at 0. The expected format of each line is: ``` path/to/image.jpg,x1,y1,x2,y2,class_name ``` Some images may not contain any labeled objects. To add these images to the dataset as negative examples, add an annotation where `x1`, `y1`, `x2`, `y2` and `class_name` are all empty: ``` path/to/image.jpg,,,,, ``` A full example: ``` /data/imgs/img_001.jpg,837,346,***1,456,cow /data/imgs/img_002.jpg,215,312,279,391,cat /data/imgs/img_002.jpg,22,5,89,84,bird /data/imgs/img_003.jpg,,,,, ``` This defines a dataset with 3 images. `img_001.jpg` contains a cow. `img_002.jpg` contains a cat and a bird. `img_003.jpg` contains no interesting objects/animals. ### Class mapping format The class name to ID mapping file should contain one mapping per line. Each line should use the following format: ``` class_name,id ``` Indexing for classes starts at 0. Do not include a background class as it is implicit. For example: ``` cow,0 cat,1 bird,2 ``` ## Acknowledgements - Significant amounts of code are borrowed from the [keras retinanet implementation](https://github.com/fizyr/keras-retinanet) - The NMS module used is from the [pytorch faster-rcnn implementation](https://github.com/ruotianluo/pytorch-faster-rcnn) ## Examples ![img1](https://github.com/yhenon/pytorch-retinanet/blob/master/images/1.jpg) ![img2](https://github.com/yhenon/pytorch-retinanet/blob/master/images/2.jpg) ![img4](https://github.com/yhenon/pytorch-retinanet/blob/master/images/4.jpg) ![img6](https://github.com/yhenon/pytorch-retinanet/blob/master/images/6.jpg) ![img7](https://github.com/yhenon/pytorch-retinanet/blob/master/images/7.jpg) ![img8](https://github.com/yhenon/pytorch-retinanet/blob/master/images/8.jpg)

近期下载者

相关文件


收藏者