diff --git a/README.md b/README.md index 13480ffd0..89d256587 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ A summary can be found in the [Model Zoo](docs/en/model_zoo.md) page. - [x] [Rotated RetinaNet-OBB/HBB](configs/rotated_retinanet/README.md) (ICCV'2017) - [x] [Rotated FasterRCNN-OBB](configs/rotated_faster_rcnn/README.md) (TPAMI'2017) - [x] [Rotated RepPoints-OBB](configs/rotated_reppoints/README.md) (ICCV'2019) +- [x] [Rotated FCOS](configs/rotated_fcos/README.md) (ICCV'2019) - [x] [RoI Transformer](configs/roi_trans/README.md) (CVPR'2019) - [x] [Gliding Vertex](configs/gliding_vertex/README.md) (TPAMI'2020) - [x] [Rotated ATSS-OBB](configs/rotated_atss/README.md) (CVPR'2020) diff --git a/README_zh-CN.md b/README_zh-CN.md index 0eee7c116..6c7910824 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -95,6 +95,7 @@ MMRotate 也提供了其他更详细的教程: - [x] [Rotated RetinaNet-OBB/HBB](configs/rotated_retinanet/README.md) (ICCV'2017) - [x] [Rotated FasterRCNN-OBB](configs/rotated_faster_rcnn/README.md) (TPAMI'2017) - [x] [Rotated RepPoints-OBB](configs/rotated_reppoints/README.md) (ICCV'2019) +- [x] [Rotated FCOS](configs/rotated_fcos/README.md) (ICCV'2019) - [x] [RoI Transformer](configs/roi_trans/README.md) (CVPR'2019) - [x] [Gliding Vertex](configs/gliding_vertex/README.md) (TPAMI'2020) - [x] [Rotated ATSS-OBB](configs/rotated_atss/README.md) (CVPR'2020) diff --git a/configs/rotated_fcos/README.md b/configs/rotated_fcos/README.md new file mode 100644 index 000000000..010b36de2 --- /dev/null +++ b/configs/rotated_fcos/README.md @@ -0,0 +1,54 @@ +# Rotated FCOS + +> [FCOS: Fully Convolutional One-Stage Object Detection](https://arxiv.org/abs/1904.01355) + + + +## Abstract + +
+ +
+ +We propose a fully convolutional one-stage object detector (FCOS) to solve object detection in a per-pixel prediction +fashion, analogue to semantic segmentation. Almost all state-of-the-art object detectors such as RetinaNet, SSD, YOLOv3, +and Faster R-CNN rely on pre-defined anchor boxes. In contrast, our proposed detector FCOS is anchor box free, as well +as proposal free. By eliminating the predefined set of anchor boxes, FCOS completely avoids the complicated computation +related to anchor boxes such as calculating overlapping during training. More importantly, we also avoid all +hyper-parameters related to anchor boxes, which are often very sensitive to the final detection performance. With the +only post-processing non-maximum suppression (NMS), FCOS with ResNeXt-64x4d-101 achieves 44.7% in AP with single-model +and single-scale testing, surpassing previous one-stage detectors with the advantage of being much simpler. For the +first time, we demonstrate a much simpler and flexible detection framework achieving improved detection accuracy. We +hope that the proposed FCOS framework can serve as a simple and strong alternative for many other instance-level tasks. + +## Results and Models + +DOTA1.0 + +| Backbone | mAP | Angle | Separate Angle | Tricks | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | +| :----------------------: | :---: | :---: | :------------: | :----: | :-----: | :------: | :------------: | :-: | :--------: | :---------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ResNet50 (1024,1024,200) | 70.70 | le90 | Y | Y | 1x | 4.18 | 26.4 | - | 2 | [rotated_fcos_sep_angle_r50_fpn_1x_dota_le90](./rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90_20220409_023250.log.json) | +| ResNet50 (1024,1024,200) | 71.28 | le90 | N | Y | 1x | 4.18 | 25.9 | - | 2 | [rotated_fcos_r50_fpn_1x_dota_le90](./rotated_fcos_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90_20220413_163526.log.json) | +| ResNet50 (1024,1024,200) | 71.76 | le90 | Y | Y | 1x | 4.23 | 25.7 | - | 2 | [rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90](./rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90_20220409_080616.log.json) | +| ResNet50 (1024,1024,200) | 71.89 | le90 | N | Y | 1x | 4.18 | 26.2 | - | 2 | [rotated_fcos_kld_r50_fpn_1x_dota_le90](./rotated_fcos_kld_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90_20220409_202939.log.json) | + +**Notes:** + +- `MS` means multiple scale image split. +- `RR` means random rotation. +- `Rotated IoU Loss` need mmcv version 1.5.0 or above. +- `Separate Angle` means angle loss is calculated separately. + At this time bbox loss uses horizontal bbox loss such as `IoULoss`, `GIoULoss`. +- Tricks means setting `norm_on_bbox`, `centerness_on_reg`, `center_sampling` as `True`. +- Inf time was tested on a single RTX3090. + +## Citation + +``` +@article{tian2019fcos, + title={FCOS: Fully Convolutional One-Stage Object Detection}, + author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong}, + journal={arXiv preprint arXiv:1904.01355}, + year={2019} +} +``` diff --git a/configs/rotated_fcos/metafile.yml b/configs/rotated_fcos/metafile.yml new file mode 100644 index 000000000..d2cf2db9c --- /dev/null +++ b/configs/rotated_fcos/metafile.yml @@ -0,0 +1,63 @@ +Collections: +- Name: rotated_fcos + Metadata: + Training Data: DOTAv1.0 + Training Techniques: + - SGD with Momentum + - Weight Decay + Training Resources: 1x Tesla V100 + Architecture: + - ResNet + Paper: + URL: https://arxiv.org/abs/1904.01355 + Title: 'FCOS: Fully Convolutional One-Stage Object Detection' + README: configs/rotated_fcos/README.md + +Models: + - Name: rotated_fcos_sep_angle_r50_fpn_1x_dota_le90 + In Collection: rotated_fcos + Config: configs/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py + Metadata: + Training Data: DOTAv1.0 + Results: + - Task: Oriented Object Detection + Dataset: DOTAv1.0 + Metrics: + mAP: 70.70 + Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth + + - Name: rotated_fcos_r50_fpn_1x_dota_le90 + In Collection: rotated_fcos + Config: configs/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90.py + Metadata: + Training Data: DOTAv1.0 + Results: + - Task: Oriented Object Detection + Dataset: DOTAv1.0 + Metrics: + mAP: 71.28 + Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth + + - Name: rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90 + In Collection: rotated_fcos + Config: configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py + Metadata: + Training Data: DOTAv1.0 + Results: + - Task: Oriented Object Detection + Dataset: DOTAv1.0 + Metrics: + mAP: 71.76 + Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth + + - Name: rotated_fcos_kld_r50_fpn_1x_dota_le90 + In Collection: rotated_fcos + Config: configs/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90.py + Metadata: + Training Data: DOTAv1.0 + Results: + - Task: Oriented Object Detection + Dataset: DOTAv1.0 + Metrics: + mAP: 71.89 + Weights: https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth diff --git a/configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py b/configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py new file mode 100644 index 000000000..0d0ad2b1c --- /dev/null +++ b/configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py @@ -0,0 +1,30 @@ +_base_ = 'rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py' +angle_version = 'le90' + +# model settings +model = dict( + bbox_head=dict( + type='CSLRFCOSHead', + center_sampling=True, + center_sample_radius=1.5, + norm_on_bbox=True, + centerness_on_reg=True, + separate_angle=True, + scale_angle=False, + angle_coder=dict( + type='CSLCoder', + angle_version=angle_version, + omega=1, + window='gaussian', + radius=1), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_angle=dict( + type='SmoothFocalLoss', gamma=2.0, alpha=0.25, loss_weight=0.2)), ) diff --git a/configs/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90.py b/configs/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90.py new file mode 100644 index 000000000..a69757f00 --- /dev/null +++ b/configs/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90.py @@ -0,0 +1,11 @@ +_base_ = 'rotated_fcos_r50_fpn_1x_dota_le90.py' + +model = dict( + bbox_head=dict( + loss_bbox=dict( + _delete_=True, + type='GDLoss_v1', + loss_type='kld', + fun='log1p', + tau=1, + loss_weight=1.0)), ) diff --git a/configs/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90.py b/configs/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90.py new file mode 100644 index 000000000..17ec673fe --- /dev/null +++ b/configs/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90.py @@ -0,0 +1,81 @@ +_base_ = [ + '../_base_/datasets/dotav1.py', '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] +angle_version = 'le90' + +# model settings +model = dict( + type='RotatedFCOS', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + zero_init_residual=False, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='RotatedFCOSHead', + num_classes=15, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + center_sampling=True, + center_sample_radius=1.5, + norm_on_bbox=True, + centerness_on_reg=True, + separate_angle=False, + scale_angle=True, + bbox_coder=dict( + type='DistanceAnglePointCoder', angle_version=angle_version), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='RotatedIoULoss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=None, + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RResize', img_scale=(1024, 1024)), + dict( + type='RRandomFlip', + flip_ratio=[0.25, 0.25, 0.25], + direction=['horizontal', 'vertical', 'diagonal'], + version=angle_version), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +data = dict( + train=dict(pipeline=train_pipeline, version=angle_version), + val=dict(version=angle_version), + test=dict(version=angle_version)) diff --git a/configs/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py b/configs/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py new file mode 100644 index 000000000..fbaecae0d --- /dev/null +++ b/configs/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py @@ -0,0 +1,83 @@ +_base_ = [ + '../_base_/datasets/dotav1.py', '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] +angle_version = 'le90' + +# model settings +model = dict( + type='RotatedFCOS', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + zero_init_residual=False, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='RotatedFCOSHead', + num_classes=15, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + center_sampling=True, + center_sample_radius=1.5, + norm_on_bbox=True, + centerness_on_reg=True, + separate_angle=True, + scale_angle=True, + bbox_coder=dict( + type='DistanceAnglePointCoder', angle_version=angle_version), + h_bbox_coder=dict(type='DistancePointBBoxCoder'), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.0), + loss_angle=dict(type='L1Loss', loss_weight=0.2), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=None, + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RResize', img_scale=(1024, 1024)), + dict( + type='RRandomFlip', + flip_ratio=[0.25, 0.25, 0.25], + direction=['horizontal', 'vertical', 'diagonal'], + version=angle_version), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +data = dict( + train=dict(pipeline=train_pipeline, version=angle_version), + val=dict(version=angle_version), + test=dict(version=angle_version)) diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index cbdefe79f..da14d730e 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -3,6 +3,7 @@ - [Rotated RetinaNet-OBB/HBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_retinanet/README.md) (ICCV'2017) - [Rotated FasterRCNN-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_faster_rcnn/README.md) (TPAMI'2017) - [Rotated RepPoints-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_reppoints/README.md) (ICCV'2019) +- [Rotated FCOS](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_fcos/README.md) (ICCV'2019) - [RoI Transformer](https://github.com/open-mmlab/mmrotate/tree/main/configs/roi_trans/README.md) (CVPR'2019) - [Gliding Vertex](https://github.com/open-mmlab/mmrotate/tree/main/configs/gliding_vertex/README.md) (TPAMI'2020) - [Rotated ATSS-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_atss/README.md) (CVPR'2020) @@ -40,7 +41,11 @@ | ResNet50 (1024,1024,200) | 69.94 | oc | 1x | 3.39 | 15.6 | - | 2 | [rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc](../../configs/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc-49c1f937.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc_20220125_201832.log.json) | | ResNet50 (1024,1024,200) | 70.18 | oc | 1x | 3.23 | 15.6 | - | 2 | [r3det_tiny_r50_fpn_1x_dota_oc](../../configs/r3det/r3det_tiny_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/r3det/r3det_tiny_r50_fpn_1x_dota_oc/r3det_tiny_r50_fpn_1x_dota_oc-c98a616c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/r3det/r3det_tiny_r50_fpn_1x_dota_oc/r3det_tiny_r50_fpn_1x_dota_oc_20220209_171624.log.json) | | ResNet50 (1024,1024,200) | 70.64 | le90 | 1x | 3.12 | 18.2 | - | 2 | [rotated_atss_obb_r50_fpn_1x_dota_le90](../../configs/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le90/rotated_atss_obb_r50_fpn_1x_dota_le90-e029ca06.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le90/rotated_atss_obb_r50_fpn_1x_dota_le90_20220402_002048.log.json) | +| ResNet50 (1024,1024,200) | 70.70 | le90 | 1x | 4.18 | | - | 2 | [rotated_fcos_sep_angle_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90_20220409_023250.log.json) | +| ResNet50 (1024,1024,200) | 71.28 | le90 | 1x | 4.18 | | - | 2 | [rotated_fcos_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90_20220413_163526.log.json) | +| ResNet50 (1024,1024,200) | 71.76 | le90 | 1x | 4.23 | | - | 2 | [rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90_20220409_080616.log.json) | | ResNet50 (1024,1024,200) | 71.83 | oc | 1x | 3.54 | 12.4 | - | 2 | [r3det_kld_r50_fpn_1x_dota_oc](../../configs/kld/r3det_kld_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_kld_r50_fpn_1x_dota_oc/r3det_kld_r50_fpn_1x_dota_oc-31866226.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_kld_r50_fpn_1x_dota_oc/r3det_kld_r50_fpn_1x_dota_oc_20220210_114049.log.json) | +| ResNet50 (1024,1024,200) | 71.89 | le90 | 1x | 4.18 | | - | 2 | [rotated_fcos_kld_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90_20220409_202939.log.json) | | ResNet50 (1024,1024,200) | 72.29 | le135 | 1x | 3.19 | 18.8 | - | 2 | [rotated_atss_obb_r50_fpn_1x_dota_le135](../../configs/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le135/rotated_atss_obb_r50_fpn_1x_dota_le135-eab7bc12.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le135/rotated_atss_obb_r50_fpn_1x_dota_le135_20220402_002138.log.json) | | ResNet50 (1024,1024,200) | 72.68 | oc | 1x | 3.62 | 12.2 | - | 2 | [r3det_kfiou_ln_r50_fpn_1x_dota_oc](../../configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc/r3det_kfiou_ln_r50_fpn_1x_dota_oc-8e7f049d.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc/r3det_kfiou_ln_r50_fpn_1x_dota_oc_20220123_074507.log.json) | | ResNet50 (1024,1024,200) | 72.76 | oc | 1x | 3.44 | 14.0 | - | 2 | [r3det_tiny_kld_r50_fpn_1x_dota_oc](../../configs/kld/r3det_tiny_kld_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_tiny_kld_r50_fpn_1x_dota_oc/r3det_tiny_kld_r50_fpn_1x_dota_oc-589e142a.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_tiny_kld_r50_fpn_1x_dota_oc/r3det_tiny_kld_r50_fpn_1x_dota_oc_20220209_172917.log.json) | diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md index d9062c3f0..d7d5c2d37 100644 --- a/docs/zh_cn/model_zoo.md +++ b/docs/zh_cn/model_zoo.md @@ -1,8 +1,12 @@ ## 基准和模型库 -- [Rotated RetinaNet-OBB/HBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_retinanet/README.md) (ICCV'2017) -- [Rotated FasterRCNN-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_faster_rcnn/README.md) (TPAMI'2017) -- [Rotated RepPoints-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_reppoints/README.md) (ICCV'2019) +- [Rotated RetinaNet-OBB/HBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_retinanet/README.md) ( + ICCV'2017) +- [Rotated FasterRCNN-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_faster_rcnn/README.md) ( + TPAMI'2017) +- [Rotated RepPoints-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_reppoints/README.md) ( + ICCV'2019) +- [Rotated FCOS](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_fcos/README.md) (ICCV'2019) - [RoI Transformer](https://github.com/open-mmlab/mmrotate/tree/main/configs/roi_trans/README.md) (CVPR'2019) - [Gliding Vertex](https://github.com/open-mmlab/mmrotate/tree/main/configs/gliding_vertex/README.md) (TPAMI'2020) - [Rotated ATSS-OBB](https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_atss/README.md) (CVPR'2020) @@ -40,7 +44,11 @@ | ResNet50 (1024,1024,200) | 69.94 | oc | 1x | 3.39 | 15.6 | - | 2 | [rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc](../../configs/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc-49c1f937.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc_20220125_201832.log.json) | | ResNet50 (1024,1024,200) | 70.18 | oc | 1x | 3.23 | 15.6 | - | 2 | [r3det_tiny_r50_fpn_1x_dota_oc](../../configs/r3det/r3det_tiny_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/r3det/r3det_tiny_r50_fpn_1x_dota_oc/r3det_tiny_r50_fpn_1x_dota_oc-c98a616c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/r3det/r3det_tiny_r50_fpn_1x_dota_oc/r3det_tiny_r50_fpn_1x_dota_oc_20220209_171624.log.json) | | ResNet50 (1024,1024,200) | 70.64 | le90 | 1x | 3.12 | 18.2 | - | 2 | [rotated_atss_obb_r50_fpn_1x_dota_le90](../../configs/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le90/rotated_atss_obb_r50_fpn_1x_dota_le90-e029ca06.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le90/rotated_atss_obb_r50_fpn_1x_dota_le90_20220402_002048.log.json) | +| ResNet50 (1024,1024,200) | 70.70 | le90 | 1x | 4.18 | | - | 2 | [rotated_fcos_sep_angle_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90-0be71a0c.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90/rotated_fcos_sep_angle_r50_fpn_1x_dota_le90_20220409_023250.log.json) | +| ResNet50 (1024,1024,200) | 71.28 | le90 | 1x | 4.18 | | - | 2 | [rotated_fcos_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90-d87568ed.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_r50_fpn_1x_dota_le90/rotated_fcos_r50_fpn_1x_dota_le90_20220413_163526.log.json) | +| ResNet50 (1024,1024,200) | 71.76 | le90 | 1x | 4.23 | | - | 2 | [rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90-4e044ad2.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90/rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90_20220409_080616.log.json) | | ResNet50 (1024,1024,200) | 71.83 | oc | 1x | 3.54 | 12.4 | - | 2 | [r3det_kld_r50_fpn_1x_dota_oc](../../configs/kld/r3det_kld_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_kld_r50_fpn_1x_dota_oc/r3det_kld_r50_fpn_1x_dota_oc-31866226.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_kld_r50_fpn_1x_dota_oc/r3det_kld_r50_fpn_1x_dota_oc_20220210_114049.log.json) | +| ResNet50 (1024,1024,200) | 71.89 | le90 | 1x | 4.18 | | - | 2 | [rotated_fcos_kld_r50_fpn_1x_dota_le90](../../configs/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90-ecafdb2b.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_fcos/rotated_fcos_kld_r50_fpn_1x_dota_le90/rotated_fcos_kld_r50_fpn_1x_dota_le90_20220409_202939.log.json) | | ResNet50 (1024,1024,200) | 72.29 | le135 | 1x | 3.19 | 18.8 | - | 2 | [rotated_atss_obb_r50_fpn_1x_dota_le135](../../configs/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le135/rotated_atss_obb_r50_fpn_1x_dota_le135-eab7bc12.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_atss/rotated_atss_obb_r50_fpn_1x_dota_le135/rotated_atss_obb_r50_fpn_1x_dota_le135_20220402_002138.log.json) | | ResNet50 (1024,1024,200) | 72.68 | oc | 1x | 3.62 | 12.2 | - | 2 | [r3det_kfiou_ln_r50_fpn_1x_dota_oc](../../configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc/r3det_kfiou_ln_r50_fpn_1x_dota_oc-8e7f049d.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc/r3det_kfiou_ln_r50_fpn_1x_dota_oc_20220123_074507.log.json) | | ResNet50 (1024,1024,200) | 72.76 | oc | 1x | 3.44 | 14.0 | - | 2 | [r3det_tiny_kld_r50_fpn_1x_dota_oc](../../configs/kld/r3det_tiny_kld_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_tiny_kld_r50_fpn_1x_dota_oc/r3det_tiny_kld_r50_fpn_1x_dota_oc-589e142a.pth) \| [log](https://download.openmmlab.com/mmrotate/v0.1.0/kld/r3det_tiny_kld_r50_fpn_1x_dota_oc/r3det_tiny_kld_r50_fpn_1x_dota_oc_20220209_172917.log.json) | diff --git a/mmrotate/core/bbox/coder/__init__.py b/mmrotate/core/bbox/coder/__init__.py index cfd7fb1a1..4111880be 100644 --- a/mmrotate/core/bbox/coder/__init__.py +++ b/mmrotate/core/bbox/coder/__init__.py @@ -3,9 +3,10 @@ from .delta_midpointoffset_rbbox_coder import MidpointOffsetCoder from .delta_xywha_hbbox_coder import DeltaXYWHAHBBoxCoder from .delta_xywha_rbbox_coder import DeltaXYWHAOBBoxCoder +from .distance_angle_point_coder import DistanceAnglePointCoder from .gliding_vertex_coder import GVFixCoder, GVRatioCoder __all__ = [ 'DeltaXYWHAOBBoxCoder', 'DeltaXYWHAHBBoxCoder', 'MidpointOffsetCoder', - 'GVFixCoder', 'GVRatioCoder', 'CSLCoder' + 'GVFixCoder', 'GVRatioCoder', 'CSLCoder', 'DistanceAnglePointCoder' ] diff --git a/mmrotate/core/bbox/coder/distance_angle_point_coder.py b/mmrotate/core/bbox/coder/distance_angle_point_coder.py new file mode 100644 index 000000000..3bca158ad --- /dev/null +++ b/mmrotate/core/bbox/coder/distance_angle_point_coder.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmdet.core import BaseBBoxCoder + +from mmrotate.core.bbox.transforms import norm_angle +from ..builder import BBOX_CODERS + + +@BBOX_CODERS.register_module() +class DistanceAnglePointCoder(BaseBBoxCoder): + """Distance Angle Point BBox coder. + + This coder encodes gt bboxes (x, y, w, h, angle) into (top, bottom, left, + right, angle) and decode it back to the original. + + Args: + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. + """ + + def __init__(self, clip_border=True, angle_version='oc'): + super(BaseBBoxCoder, self).__init__() + self.clip_border = clip_border + self.angle_version = angle_version + + def encode(self, points, gt_bboxes, max_dis=None, eps=0.1): + """Encode bounding box to distances. + + Args: + points (Tensor): Shape (N, 2), The format is [x, y]. + gt_bboxes (Tensor): Shape (N, 5), The format is "xywha" + max_dis (float): Upper bound of the distance. Default None. + eps (float): a small value to ensure target < max_dis, instead <=. + Default 0.1. + + Returns: + Tensor: Box transformation deltas. The shape is (N, 5). + """ + assert points.size(0) == gt_bboxes.size(0) + assert points.size(-1) == 2 + assert gt_bboxes.size(-1) == 5 + return self.obb2distance(points, gt_bboxes, max_dis, eps) + + def decode(self, points, pred_bboxes, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (B, N, 2) or (N, 2). + pred_bboxes (Tensor): Distance from the given point to 4 + boundaries and angle (left, top, right, bottom, angle). + Shape (B, N, 5) or (N, 5) + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]], + and the length of max_shape should also be B. + Default None. + Returns: + Tensor: Boxes with shape (N, 5) or (B, N, 5) + """ + assert points.size(0) == pred_bboxes.size(0) + assert points.size(-1) == 2 + assert pred_bboxes.size(-1) == 5 + if self.clip_border is False: + max_shape = None + return self.distance2obb(points, pred_bboxes, max_shape, + self.angle_version) + + def obb2distance(self, points, distance, max_dis=None, eps=None): + ctr, wh, angle = torch.split(distance, [2, 2, 1], dim=1) + + cos_angle, sin_angle = torch.cos(angle), torch.sin(angle) + rot_matrix = torch.cat([cos_angle, sin_angle, -sin_angle, cos_angle], + dim=1).reshape(-1, 2, 2) + + offset = points - ctr + offset = torch.matmul(rot_matrix, offset[..., None]) + offset = offset.squeeze(-1) + + w, h = wh[..., 0], wh[..., 1] + offset_x, offset_y = offset[..., 0], offset[..., 1] + left = w / 2 + offset_x + right = w / 2 - offset_x + top = h / 2 + offset_y + bottom = h / 2 - offset_y + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return torch.stack((left, top, right, bottom, angle.squeeze(-1)), -1) + + def distance2obb(self, + points, + distance, + max_shape=None, + angle_version='oc'): + distance, angle = distance.split([4, 1], dim=1) + + cos_angle, sin_angle = torch.cos(angle), torch.sin(angle) + rot_matrix = torch.cat([cos_angle, -sin_angle, sin_angle, cos_angle], + dim=1).reshape(-1, 2, 2) + + wh = distance[:, :2] + distance[:, 2:] + offset_t = (distance[:, 2:] - distance[:, :2]) / 2 + offset_t = offset_t.unsqueeze(2) + offset = torch.bmm(rot_matrix, offset_t).squeeze(2) + ctr = points + offset + + angle_regular = norm_angle(angle, angle_version) + return torch.cat([ctr, wh, angle_regular], dim=-1) diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py index 014a9a94f..6d20fda89 100644 --- a/mmrotate/models/dense_heads/__init__.py +++ b/mmrotate/models/dense_heads/__init__.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .csl_rotated_fcos_head import CSLRFCOSHead from .csl_rotated_retina_head import CSLRRetinaHead from .kfiou_odm_refine_head import KFIoUODMRefineHead from .kfiou_rotate_retina_head import KFIoURRetinaHead from .kfiou_rotate_retina_refine_head import KFIoURRetinaRefineHead from .odm_refine_head import ODMRefineHead from .oriented_rpn_head import OrientedRPNHead +from .rotated_anchor_free_head import RotatedAnchorFreeHead from .rotated_anchor_head import RotatedAnchorHead from .rotated_atss_head import RotatedATSSHead +from .rotated_fcos_head import RotatedFCOSHead from .rotated_reppoints_head import RotatedRepPointsHead from .rotated_retina_head import RotatedRetinaHead from .rotated_retina_refine_head import RotatedRetinaRefineHead @@ -18,5 +21,6 @@ 'OrientedRPNHead', 'RotatedRetinaRefineHead', 'ODMRefineHead', 'KFIoURRetinaHead', 'KFIoURRetinaRefineHead', 'KFIoUODMRefineHead', 'RotatedRepPointsHead', 'SAMRepPointsHead', 'CSLRRetinaHead', - 'RotatedATSSHead' + 'RotatedATSSHead', 'RotatedAnchorFreeHead', 'RotatedFCOSHead', + 'CSLRFCOSHead' ] diff --git a/mmrotate/models/dense_heads/csl_rotated_fcos_head.py b/mmrotate/models/dense_heads/csl_rotated_fcos_head.py new file mode 100644 index 000000000..a29f5d1e4 --- /dev/null +++ b/mmrotate/models/dense_heads/csl_rotated_fcos_head.py @@ -0,0 +1,334 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +from mmcv.cnn import Scale +from mmcv.runner import force_fp32 +from mmdet.core import reduce_mean + +from mmrotate.core import build_bbox_coder, multiclass_nms_rotated +from ..builder import ROTATED_HEADS +from .rotated_anchor_free_head import RotatedAnchorFreeHead +from .rotated_fcos_head import RotatedFCOSHead + +INF = 1e8 + + +@ROTATED_HEADS.register_module() +class CSLRFCOSHead(RotatedFCOSHead): + """Use `Circular Smooth Label (CSL) + + `_ . + in `FCOS `_. + + Args: + separate_angle (bool): If true, angle prediction is separated from + bbox regression loss. In CSL only support True. Default: True. + scale_angle (bool): If true, add scale to angle pred branch. + In CSL only support False. Default: False. + angle_coder (dict): Config of angle coder. + """ # noqa: E501 + + def __init__(self, + separate_angle=True, + scale_angle=False, + angle_coder=dict( + type='CSLCoder', + angle_version='le90', + omega=1, + window='gaussian', + radius=6), + **kwargs): + self.angle_coder = build_bbox_coder(angle_coder) + assert separate_angle, 'Only support separate angle in CSL' + assert scale_angle is False, 'Only support no scale angle in CSL' + self.coding_len = self.angle_coder.coding_len + super().__init__( + separate_angle=separate_angle, scale_angle=scale_angle, **kwargs) + + def _init_layers(self): + """Initialize layers of the head.""" + RotatedAnchorFreeHead._init_layers(self) + self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) + self.conv_angle = nn.Conv2d( + self.feat_channels, self.coding_len, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centernesses')) + def loss(self, + cls_scores, + bbox_preds, + angle_preds, + centernesses, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute loss of the head. + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * 4. + angle_preds (list[Tensor]): Box angle for each scale level, \ + each is a 4D-tensor, the channel number is num_points * 1. + centernesses (list[Tensor]): centerness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) \ + == len(angle_preds) == len(centernesses) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + labels, bbox_targets, angle_targets = self.get_targets( + all_level_points, gt_bboxes, gt_labels) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores, bbox_preds and centerness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_angle_preds = [ + angle_pred.permute(0, 2, 3, 1).reshape(-1, self.coding_len) + for angle_pred in angle_preds + ] + flatten_centerness = [ + centerness.permute(0, 2, 3, 1).reshape(-1) + for centerness in centernesses + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_angle_preds = torch.cat(flatten_angle_preds) + flatten_centerness = torch.cat(flatten_centerness) + flatten_labels = torch.cat(labels) + flatten_bbox_targets = torch.cat(bbox_targets) + flatten_angle_targets = torch.cat(angle_targets) + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((flatten_labels >= 0) + & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) + num_pos = torch.tensor( + len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) + num_pos = max(reduce_mean(num_pos), 1.0) + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, avg_factor=num_pos) + + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_angle_preds = flatten_angle_preds[pos_inds] + pos_centerness = flatten_centerness[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_angle_targets = flatten_angle_targets[pos_inds] + pos_centerness_targets = self.centerness_target(pos_bbox_targets) + # centerness weighted iou loss + centerness_denorm = max( + reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) + + if len(pos_inds) > 0: + pos_points = flatten_points[pos_inds] + if self.seprate_angle: + bbox_coder = self.h_bbox_coder + else: + bbox_coder = self.bbox_coder + pos_bbox_preds = torch.cat([pos_bbox_preds, pos_angle_preds], + dim=-1) + pos_bbox_targets = torch.cat( + [pos_bbox_targets, pos_angle_targets], dim=-1) + pos_decoded_bbox_preds = bbox_coder.decode(pos_points, + pos_bbox_preds) + pos_decoded_target_preds = bbox_coder.decode( + pos_points, pos_bbox_targets) + loss_bbox = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + weight=pos_centerness_targets, + avg_factor=centerness_denorm) + if self.seprate_angle: + loss_angle = self.loss_angle( + pos_angle_preds, pos_angle_targets, avg_factor=num_pos) + loss_centerness = self.loss_centerness( + pos_centerness, pos_centerness_targets, avg_factor=num_pos) + else: + loss_bbox = pos_bbox_preds.sum() + loss_centerness = pos_centerness.sum() + if self.seprate_angle: + loss_angle = pos_angle_preds.sum() + + if self.seprate_angle: + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_angle=loss_angle, + loss_centerness=loss_centerness) + else: + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_centerness=loss_centerness) + + def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges, + num_points_per_lvl): + """Compute regression, classification and angle targets for a single + image.""" + num_points = points.size(0) + num_gts = gt_labels.size(0) + if num_gts == 0: + return gt_labels.new_full((num_points,), self.num_classes), \ + gt_bboxes.new_zeros((num_points, 4)), \ + gt_bboxes.new_zeros((num_points, self.coding_len)) + + labels, bbox_targets, angle_targets = \ + super(CSLRFCOSHead, self)._get_target_single(gt_bboxes, + gt_labels, + points, + regress_ranges, + num_points_per_lvl) + angle_targets = self.angle_coder.encode(angle_targets) + + return labels, bbox_targets, angle_targets + + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + angle_preds, + centernesses, + mlvl_points, + img_shape, + scale_factor, + cfg, + rescale=False): + """Transform outputs for a single batch item into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for a single scale level + Has shape (num_points * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for a single scale + level with shape (num_points * 4, H, W). + angle_preds (list[Tensor]): Box angle for a single scale level \ + with shape (N, num_points * 1, H, W). + centernesses (list[Tensor]): Centerness for a single scale level + with shape (num_points * 1, H, W). + mlvl_points (list[Tensor]): Box reference for a single scale level + with shape (num_total_points, 4). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arrange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + + Returns: + Tensor: Labeled boxes in shape (n, 6), where the first 5 columns + are bounding box positions (x, y, w, h, angle) and the + 6-th column is a score between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_centerness = [] + for cls_score, bbox_pred, angle_pred, centerness, points in zip( + cls_scores, bbox_preds, angle_preds, centernesses, + mlvl_points): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() + + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + angle_pred = angle_pred.permute(1, 2, + 0).reshape(-1, self.coding_len) + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = (scores * centerness[:, None]).max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + points = points[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + angle_pred = angle_pred[topk_inds, :] + scores = scores[topk_inds, :] + centerness = centerness[topk_inds] + + angle_pred = self.angle_coder.decode(angle_pred).unsqueeze(-1) + bbox_pred = torch.cat([bbox_pred, angle_pred], dim=-1) + bboxes = self.bbox_coder.decode( + points, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_centerness.append(centerness) + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + scale_factor = mlvl_bboxes.new_tensor(scale_factor) + mlvl_bboxes[..., :4] = mlvl_bboxes[..., :4] / scale_factor + mlvl_scores = torch.cat(mlvl_scores) + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + mlvl_centerness = torch.cat(mlvl_centerness) + det_bboxes, det_labels = multiclass_nms_rotated( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_centerness) + return det_bboxes, det_labels + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centerness')) + def refine_bboxes(self, cls_scores, bbox_preds, angle_preds, centernesses): + """This function will be used in S2ANet, whose num_anchors=1.""" + num_levels = len(cls_scores) + assert num_levels == len(bbox_preds) + num_imgs = cls_scores[0].size(0) + for i in range(num_levels): + assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) + + # device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_points = self.prior_generator.grid_priors(featmap_sizes, + bbox_preds[0].dtype, + bbox_preds[0].device) + bboxes_list = [[] for _ in range(num_imgs)] + + for lvl in range(num_levels): + bbox_pred = bbox_preds[lvl] + angle_pred = angle_preds[lvl] + bbox_pred = bbox_pred.permute(0, 2, 3, 1) + bbox_pred = bbox_pred.reshape(num_imgs, -1, 4) + angle_pred = angle_pred.permute(0, 2, 3, 1) + angle_pred = angle_pred.reshape(num_imgs, -1, self.coding_len) + angle_pred = self.angle_coder.decode(angle_pred) + bbox_pred = torch.cat([bbox_pred, angle_pred], dim=-1) + + points = mlvl_points[lvl] + + for img_id in range(num_imgs): + bbox_pred_i = bbox_pred[img_id] + decode_bbox_i = self.bbox_coder.decode(points, bbox_pred_i) + bboxes_list[img_id].append(decode_bbox_i.detach()) + + return bboxes_list diff --git a/mmrotate/models/dense_heads/rotated_anchor_free_head.py b/mmrotate/models/dense_heads/rotated_anchor_free_head.py new file mode 100644 index 000000000..ee0138970 --- /dev/null +++ b/mmrotate/models/dense_heads/rotated_anchor_free_head.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmdet.core.anchor.point_generator import MlvlPointGenerator +from mmdet.models.dense_heads import AnchorFreeHead + +from mmrotate.core import build_bbox_coder +from ..builder import ROTATED_HEADS, build_loss + + +@ROTATED_HEADS.register_module() +class RotatedAnchorFreeHead(AnchorFreeHead): + """Rotated Anchor-free head (Rotated FCOS, etc.). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + stacked_convs (int): Number of stacking convs of the head. + strides (tuple): Downsample factor of each feature map. + dcn_on_last_conv (bool): If true, use dcn in the last layer of + towers. Default: False. + conv_bias (bool | str): If specified as `auto`, it will be decided by + the norm_cfg. Bias of conv will be set as True if `norm_cfg` is + None, otherwise False. Default: "auto". + loss_cls (dict): Config of classification loss. + loss_bbox (dict): Config of localization loss. + bbox_coder (dict): Config of bbox coder. Defaults + 'DistancePointBBoxCoder'. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + train_cfg (dict): Training config of anchor head. + test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + num_classes, + in_channels, + feat_channels=256, + stacked_convs=4, + strides=(4, 8, 16, 32, 64), + dcn_on_last_conv=False, + conv_bias='auto', + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + bbox_coder=dict(type='DistancePointBBoxCoder'), + conv_cfg=None, + norm_cfg=None, + train_cfg=None, + test_cfg=None, + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='conv_cls', + std=0.01, + bias_prob=0.01))): + super(AnchorFreeHead, self).__init__(init_cfg) + self.num_classes = num_classes + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + if self.use_sigmoid_cls: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.strides = strides + self.dcn_on_last_conv = dcn_on_last_conv + assert conv_bias == 'auto' or isinstance(conv_bias, bool) + self.conv_bias = conv_bias + self.loss_cls = build_loss(loss_cls) + self.loss_bbox = build_loss(loss_bbox) + self.bbox_coder = build_bbox_coder(bbox_coder) + + self.prior_generator = MlvlPointGenerator(strides) + + # In order to keep a more general interface and be consistent with + # anchor_head. We can think of point like one anchor + self.num_base_priors = self.prior_generator.num_base_priors[0] + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.fp16_enabled = False + + self._init_layers() diff --git a/mmrotate/models/dense_heads/rotated_fcos_head.py b/mmrotate/models/dense_heads/rotated_fcos_head.py new file mode 100644 index 000000000..acf9196c0 --- /dev/null +++ b/mmrotate/models/dense_heads/rotated_fcos_head.py @@ -0,0 +1,667 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +from mmcv.cnn import Scale +from mmcv.runner import force_fp32 +from mmdet.core import multi_apply, reduce_mean + +from mmrotate.core import build_bbox_coder, multiclass_nms_rotated +from ..builder import ROTATED_HEADS, build_loss +from .rotated_anchor_free_head import RotatedAnchorFreeHead + +INF = 1e8 + + +@ROTATED_HEADS.register_module() +class RotatedFCOSHead(RotatedAnchorFreeHead): + """Anchor-free head used in `FCOS `_. + The FCOS head does not use anchor boxes. Instead bounding boxes are + predicted at each pixel and a centerness measure is used to suppress + low-quality predictions. + Here norm_on_bbox, centerness_on_reg, dcn_on_last_conv are training + tricks used in official repo, which will bring remarkable mAP gains + of up to 4.9. Please see https://github.com/tianzhi0549/FCOS for + more detail. + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + strides (list[int] | list[tuple[int, int]]): Strides of points + in multiple feature levels. Default: (4, 8, 16, 32, 64). + regress_ranges (tuple[tuple[int, int]]): Regress range of multiple + level points. + center_sampling (bool): If true, use center sampling. Default: False. + center_sample_radius (float): Radius of center sampling. Default: 1.5. + norm_on_bbox (bool): If true, normalize the regression targets + with FPN strides. Default: False. + centerness_on_reg (bool): If true, position centerness on the + regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042. + Default: False. + separate_angle (bool): If true, angle prediction is separated from + bbox regression loss. Default: False. + scale_angle (bool): If true, add scale to angle pred branch. Default: True. + h_bbox_coder (dict): Config of horzional bbox coder, only used when seprate_angle is True. + conv_bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias of conv will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". + loss_cls (dict): Config of classification loss. + loss_bbox (dict): Config of localization loss. + loss_angle (dict): Config of angle loss, only used when seprate_angle is True. + loss_centerness (dict): Config of centerness loss. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True). + init_cfg (dict or list[dict], optional): Initialization config dict. + Example: + >>> self = RotatedFCOSHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred, angle_pred, centerness = self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ # noqa: E501 + + def __init__(self, + num_classes, + in_channels, + regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), + (512, INF)), + center_sampling=False, + center_sample_radius=1.5, + norm_on_bbox=False, + centerness_on_reg=False, + separate_angle=False, + scale_angle=True, + h_bbox_coder=dict(type='DistancePointBBoxCoder'), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_angle=dict(type='L1Loss', loss_weight=1.0), + loss_centerness=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='conv_cls', + std=0.01, + bias_prob=0.01)), + **kwargs): + self.regress_ranges = regress_ranges + self.center_sampling = center_sampling + self.center_sample_radius = center_sample_radius + self.norm_on_bbox = norm_on_bbox + self.centerness_on_reg = centerness_on_reg + self.seprate_angle = separate_angle + self.is_scale_angle = scale_angle + super().__init__( + num_classes, + in_channels, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + norm_cfg=norm_cfg, + init_cfg=init_cfg, + **kwargs) + self.loss_centerness = build_loss(loss_centerness) + if self.seprate_angle: + self.loss_angle = build_loss(loss_angle) + self.h_bbox_coder = build_bbox_coder(h_bbox_coder) + # Angle predict length + + def _init_layers(self): + """Initialize layers of the head.""" + super()._init_layers() + self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) + self.conv_angle = nn.Conv2d(self.feat_channels, 1, 3, padding=1) + self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) + if self.is_scale_angle: + self.scale_angle = Scale(1.0) + + def forward(self, feats): + """Forward features from the upstream network. + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + Returns: + tuple: + cls_scores (list[Tensor]): Box scores for each scale level, \ + each is a 4D-tensor, the channel number is \ + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each \ + scale level, each is a 4D-tensor, the channel number is \ + num_points * 4. + angle_preds (list[Tensor]): Box angle for each scale level, \ + each is a 4D-tensor, the channel number is num_points * 1. + centernesses (list[Tensor]): centerness for each scale level, \ + each is a 4D-tensor, the channel number is num_points * 1. + """ + return multi_apply(self.forward_single, feats, self.scales, + self.strides) + + def forward_single(self, x, scale, stride): + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps, only + used to normalize the bbox prediction when self.norm_on_bbox + is True. + Returns: + tuple: scores for each class, bbox predictions, angle predictions \ + and centerness predictions of input feature maps. + """ + cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x) + if self.centerness_on_reg: + centerness = self.conv_centerness(reg_feat) + else: + centerness = self.conv_centerness(cls_feat) + # scale the bbox_pred of different level + # float to avoid overflow when enabling FP16 + bbox_pred = scale(bbox_pred).float() + if self.norm_on_bbox: + # bbox_pred needed for gradient computation has been modified + # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace + # F.relu(bbox_pred) with bbox_pred.clamp(min=0) + bbox_pred = bbox_pred.clamp(min=0) + if not self.training: + bbox_pred *= stride + else: + bbox_pred = bbox_pred.exp() + angle_pred = self.conv_angle(reg_feat) + if self.is_scale_angle: + angle_pred = self.scale_angle(angle_pred).float() + return cls_score, bbox_pred, angle_pred, centerness + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centernesses')) + def loss(self, + cls_scores, + bbox_preds, + angle_preds, + centernesses, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute loss of the head. + Args: + cls_scores (list[Tensor]): Box scores for each scale level, + each is a 4D-tensor, the channel number is + num_points * num_classes. + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level, each is a 4D-tensor, the channel number is + num_points * 4. + angle_preds (list[Tensor]): Box angle for each scale level, \ + each is a 4D-tensor, the channel number is num_points * 1. + centernesses (list[Tensor]): centerness for each scale level, each + is a 4D-tensor, the channel number is num_points * 1. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == len(bbox_preds) \ + == len(angle_preds) == len(centernesses) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + all_level_points = self.prior_generator.grid_priors( + featmap_sizes, + dtype=bbox_preds[0].dtype, + device=bbox_preds[0].device) + labels, bbox_targets, angle_targets = self.get_targets( + all_level_points, gt_bboxes, gt_labels) + + num_imgs = cls_scores[0].size(0) + # flatten cls_scores, bbox_preds and centerness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + for bbox_pred in bbox_preds + ] + flatten_angle_preds = [ + angle_pred.permute(0, 2, 3, 1).reshape(-1, 1) + for angle_pred in angle_preds + ] + flatten_centerness = [ + centerness.permute(0, 2, 3, 1).reshape(-1) + for centerness in centernesses + ] + flatten_cls_scores = torch.cat(flatten_cls_scores) + flatten_bbox_preds = torch.cat(flatten_bbox_preds) + flatten_angle_preds = torch.cat(flatten_angle_preds) + flatten_centerness = torch.cat(flatten_centerness) + flatten_labels = torch.cat(labels) + flatten_bbox_targets = torch.cat(bbox_targets) + flatten_angle_targets = torch.cat(angle_targets) + # repeat points to align with bbox_preds + flatten_points = torch.cat( + [points.repeat(num_imgs, 1) for points in all_level_points]) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((flatten_labels >= 0) + & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) + num_pos = torch.tensor( + len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) + num_pos = max(reduce_mean(num_pos), 1.0) + loss_cls = self.loss_cls( + flatten_cls_scores, flatten_labels, avg_factor=num_pos) + + pos_bbox_preds = flatten_bbox_preds[pos_inds] + pos_angle_preds = flatten_angle_preds[pos_inds] + pos_centerness = flatten_centerness[pos_inds] + pos_bbox_targets = flatten_bbox_targets[pos_inds] + pos_angle_targets = flatten_angle_targets[pos_inds] + pos_centerness_targets = self.centerness_target(pos_bbox_targets) + # centerness weighted iou loss + centerness_denorm = max( + reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) + + if len(pos_inds) > 0: + pos_points = flatten_points[pos_inds] + if self.seprate_angle: + bbox_coder = self.h_bbox_coder + else: + bbox_coder = self.bbox_coder + pos_bbox_preds = torch.cat([pos_bbox_preds, pos_angle_preds], + dim=-1) + pos_bbox_targets = torch.cat( + [pos_bbox_targets, pos_angle_targets], dim=-1) + pos_decoded_bbox_preds = bbox_coder.decode(pos_points, + pos_bbox_preds) + pos_decoded_target_preds = bbox_coder.decode( + pos_points, pos_bbox_targets) + loss_bbox = self.loss_bbox( + pos_decoded_bbox_preds, + pos_decoded_target_preds, + weight=pos_centerness_targets, + avg_factor=centerness_denorm) + if self.seprate_angle: + loss_angle = self.loss_angle( + pos_angle_preds, pos_angle_targets, avg_factor=num_pos) + loss_centerness = self.loss_centerness( + pos_centerness, pos_centerness_targets, avg_factor=num_pos) + else: + loss_bbox = pos_bbox_preds.sum() + loss_centerness = pos_centerness.sum() + if self.seprate_angle: + loss_angle = pos_angle_preds.sum() + + if self.seprate_angle: + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_angle=loss_angle, + loss_centerness=loss_centerness) + else: + return dict( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_centerness=loss_centerness) + + def get_targets(self, points, gt_bboxes_list, gt_labels_list): + """Compute regression, classification and centerness targets for points + in multiple images. + + Args: + points (list[Tensor]): Points of each fpn level, each has shape + (num_points, 2). + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, + each has shape (num_gt, 4). + gt_labels_list (list[Tensor]): Ground truth labels of each box, + each has shape (num_gt,). + Returns: + tuple: + concat_lvl_labels (list[Tensor]): Labels of each level. \ + concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ + level. + concat_lvl_angle_targets (list[Tensor]): Angle targets of \ + each level. + """ + assert len(points) == len(self.regress_ranges) + num_levels = len(points) + # expand regress ranges to align with points + expanded_regress_ranges = [ + points[i].new_tensor(self.regress_ranges[i])[None].expand_as( + points[i]) for i in range(num_levels) + ] + # concat all levels points and regress ranges + concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) + concat_points = torch.cat(points, dim=0) + + # the number of points per img, per lvl + num_points = [center.size(0) for center in points] + + # get labels and bbox_targets of each image + labels_list, bbox_targets_list, angle_targets_list = multi_apply( + self._get_target_single, + gt_bboxes_list, + gt_labels_list, + points=concat_points, + regress_ranges=concat_regress_ranges, + num_points_per_lvl=num_points) + + # split to per img, per level + labels_list = [labels.split(num_points, 0) for labels in labels_list] + bbox_targets_list = [ + bbox_targets.split(num_points, 0) + for bbox_targets in bbox_targets_list + ] + angle_targets_list = [ + angle_targets.split(num_points, 0) + for angle_targets in angle_targets_list + ] + + # concat per level image + concat_lvl_labels = [] + concat_lvl_bbox_targets = [] + concat_lvl_angle_targets = [] + for i in range(num_levels): + concat_lvl_labels.append( + torch.cat([labels[i] for labels in labels_list])) + bbox_targets = torch.cat( + [bbox_targets[i] for bbox_targets in bbox_targets_list]) + angle_targets = torch.cat( + [angle_targets[i] for angle_targets in angle_targets_list]) + if self.norm_on_bbox: + bbox_targets = bbox_targets / self.strides[i] + concat_lvl_bbox_targets.append(bbox_targets) + concat_lvl_angle_targets.append(angle_targets) + return (concat_lvl_labels, concat_lvl_bbox_targets, + concat_lvl_angle_targets) + + def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges, + num_points_per_lvl): + """Compute regression, classification and angle targets for a single + image.""" + num_points = points.size(0) + num_gts = gt_labels.size(0) + if num_gts == 0: + return gt_labels.new_full((num_points,), self.num_classes), \ + gt_bboxes.new_zeros((num_points, 4)), \ + gt_bboxes.new_zeros((num_points, 1)) + + areas = gt_bboxes[:, 2] * gt_bboxes[:, 3] + # TODO: figure out why these two are different + # areas = areas[None].expand(num_points, num_gts) + areas = areas[None].repeat(num_points, 1) + regress_ranges = regress_ranges[:, None, :].expand( + num_points, num_gts, 2) + points = points[:, None, :].expand(num_points, num_gts, 2) + gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 5) + gt_ctr, gt_wh, gt_angle = torch.split(gt_bboxes, [2, 2, 1], dim=2) + + cos_angle, sin_angle = torch.cos(gt_angle), torch.sin(gt_angle) + rot_matrix = torch.cat([cos_angle, sin_angle, -sin_angle, cos_angle], + dim=-1).reshape(num_points, num_gts, 2, 2) + offset = points - gt_ctr + offset = torch.matmul(rot_matrix, offset[..., None]) + offset = offset.squeeze(-1) + + w, h = gt_wh[..., 0], gt_wh[..., 1] + offset_x, offset_y = offset[..., 0], offset[..., 1] + left = w / 2 + offset_x + right = w / 2 - offset_x + top = h / 2 + offset_y + bottom = h / 2 - offset_y + bbox_targets = torch.stack((left, top, right, bottom), -1) + + # condition1: inside a gt bbox + inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 + if self.center_sampling: + # condition1: inside a `center bbox` + radius = self.center_sample_radius + stride = offset.new_zeros(offset.shape) + + # project the points on current lvl back to the `original` sizes + lvl_begin = 0 + for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): + lvl_end = lvl_begin + num_points_lvl + stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius + lvl_begin = lvl_end + + inside_center_bbox_mask = (abs(offset) < stride).all(dim=-1) + inside_gt_bbox_mask = torch.logical_and(inside_center_bbox_mask, + inside_gt_bbox_mask) + + # condition2: limit the regression range for each location + max_regress_distance = bbox_targets.max(-1)[0] + inside_regress_range = ( + (max_regress_distance >= regress_ranges[..., 0]) + & (max_regress_distance <= regress_ranges[..., 1])) + + # if there are still more than one objects for a location, + # we choose the one with minimal area + areas[inside_gt_bbox_mask == 0] = INF + areas[inside_regress_range == 0] = INF + min_area, min_area_inds = areas.min(dim=1) + + labels = gt_labels[min_area_inds] + labels[min_area == INF] = self.num_classes # set as BG + bbox_targets = bbox_targets[range(num_points), min_area_inds] + angle_targets = gt_angle[range(num_points), min_area_inds] + + return labels, bbox_targets, angle_targets + + def centerness_target(self, pos_bbox_targets): + """Compute centerness targets. + + Args: + pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape + (num_pos, 4) + Returns: + Tensor: Centerness target. + """ + # only calculate pos centerness targets, otherwise there may be nan + left_right = pos_bbox_targets[:, [0, 2]] + top_bottom = pos_bbox_targets[:, [1, 3]] + if len(left_right) == 0: + centerness_targets = left_right[..., 0] + else: + centerness_targets = ( + left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + return torch.sqrt(centerness_targets) + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centernesses')) + def get_bboxes(self, + cls_scores, + bbox_preds, + angle_preds, + centernesses, + img_metas, + cfg=None, + rescale=None): + """Transform network output for a batch into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_points * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_points * 4, H, W) + angle_preds (list[Tensor]): Box angle for each scale level \ + with shape (N, num_points * 1, H, W) + centernesses (list[Tensor]): Centerness for each scale level with + shape (N, num_points * 1, H, W) + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used + rescale (bool): If True, return boxes in original image space + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 6) tensor, where the first 5 columns + are bounding box positions (x, y, w, h, angle) and the 6-th + column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of the + corresponding box. + """ + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + + mlvl_points = self.prior_generator.grid_priors(featmap_sizes, + bbox_preds[0].dtype, + bbox_preds[0].device) + result_list = [] + for img_id in range(len(img_metas)): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + angle_pred_list = [ + angle_preds[i][img_id].detach() for i in range(num_levels) + ] + centerness_pred_list = [ + centernesses[i][img_id].detach() for i in range(num_levels) + ] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + det_bboxes = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + angle_pred_list, + centerness_pred_list, + mlvl_points, img_shape, + scale_factor, cfg, rescale) + result_list.append(det_bboxes) + return result_list + + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + angle_preds, + centernesses, + mlvl_points, + img_shape, + scale_factor, + cfg, + rescale=False): + """Transform outputs for a single batch item into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for a single scale level + Has shape (num_points * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for a single scale + level with shape (num_points * 4, H, W). + angle_preds (list[Tensor]): Box angle for a single scale level \ + with shape (N, num_points * 1, H, W). + centernesses (list[Tensor]): Centerness for a single scale level + with shape (num_points * 1, H, W). + mlvl_points (list[Tensor]): Box reference for a single scale level + with shape (num_total_points, 4). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arrange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + + Returns: + Tensor: Labeled boxes in shape (n, 6), where the first 5 columns + are bounding box positions (x, y, w, h, angle) and the + 6-th column is a score between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_centerness = [] + for cls_score, bbox_pred, angle_pred, centerness, points in zip( + cls_scores, bbox_preds, angle_preds, centernesses, + mlvl_points): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() + + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + angle_pred = angle_pred.permute(1, 2, 0).reshape(-1, 1) + bbox_pred = torch.cat([bbox_pred, angle_pred], dim=1) + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = (scores * centerness[:, None]).max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + points = points[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + centerness = centerness[topk_inds] + bboxes = self.bbox_coder.decode( + points, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_centerness.append(centerness) + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + scale_factor = mlvl_bboxes.new_tensor(scale_factor) + mlvl_bboxes[..., :4] = mlvl_bboxes[..., :4] / scale_factor + mlvl_scores = torch.cat(mlvl_scores) + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + mlvl_centerness = torch.cat(mlvl_centerness) + det_bboxes, det_labels = multiclass_nms_rotated( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_centerness) + return det_bboxes, det_labels + + @force_fp32( + apply_to=('cls_scores', 'bbox_preds', 'angle_preds', 'centerness')) + def refine_bboxes(self, cls_scores, bbox_preds, angle_preds, centernesses): + """This function will be used in S2ANet, whose num_anchors=1.""" + num_levels = len(cls_scores) + assert num_levels == len(bbox_preds) + num_imgs = cls_scores[0].size(0) + for i in range(num_levels): + assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) + + # device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_points = self.prior_generator.grid_priors(featmap_sizes, + bbox_preds[0].dtype, + bbox_preds[0].device) + bboxes_list = [[] for _ in range(num_imgs)] + + for lvl in range(num_levels): + bbox_pred = bbox_preds[lvl] + angle_pred = angle_preds[lvl] + bbox_pred = bbox_pred.permute(0, 2, 3, 1) + bbox_pred = bbox_pred.reshape(num_imgs, -1, 4) + angle_pred = angle_pred.permute(0, 2, 3, 1) + angle_pred = angle_pred.reshape(num_imgs, -1, 1) + bbox_pred = torch.cat([bbox_pred, angle_pred], dim=-1) + + points = mlvl_points[lvl] + + for img_id in range(num_imgs): + bbox_pred_i = bbox_pred[img_id] + decode_bbox_i = self.bbox_coder.decode(points, bbox_pred_i) + bboxes_list[img_id].append(decode_bbox_i.detach()) + + return bboxes_list diff --git a/mmrotate/models/detectors/__init__.py b/mmrotate/models/detectors/__init__.py index 433e89df5..536f5f3bf 100644 --- a/mmrotate/models/detectors/__init__.py +++ b/mmrotate/models/detectors/__init__.py @@ -6,6 +6,7 @@ from .redet import ReDet from .roi_transformer import RoITransformer from .rotate_faster_rcnn import RotatedFasterRCNN +from .rotated_fcos import RotatedFCOS from .rotated_reppoints import RotatedRepPoints from .rotated_retinanet import RotatedRetinaNet from .s2anet import S2ANet @@ -16,5 +17,5 @@ 'RotatedRetinaNet', 'RotatedFasterRCNN', 'OrientedRCNN', 'RoITransformer', 'GlidingVertex', 'ReDet', 'R3Det', 'S2ANet', 'RotatedRepPoints', 'RotatedBaseDetector', 'RotatedTwoStageDetector', - 'RotatedSingleStageDetector' + 'RotatedSingleStageDetector', 'RotatedFCOS' ] diff --git a/mmrotate/models/detectors/rotated_fcos.py b/mmrotate/models/detectors/rotated_fcos.py new file mode 100644 index 000000000..e240d86a5 --- /dev/null +++ b/mmrotate/models/detectors/rotated_fcos.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..builder import ROTATED_DETECTORS +from .single_stage import RotatedSingleStageDetector + + +@ROTATED_DETECTORS.register_module() +class RotatedFCOS(RotatedSingleStageDetector): + """Implementation of Rotated `FCOS.`__ + + __ https://arxiv.org/abs/1904.01355 + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(RotatedFCOS, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained, init_cfg) diff --git a/mmrotate/models/losses/__init__.py b/mmrotate/models/losses/__init__.py index 8594630e3..adefd8bf0 100644 --- a/mmrotate/models/losses/__init__.py +++ b/mmrotate/models/losses/__init__.py @@ -4,9 +4,10 @@ from .gaussian_dist_loss_v1 import GDLoss_v1 from .kf_iou_loss import KFLoss from .kld_reppoints_loss import KLDRepPointsLoss +from .rotated_iou_loss import RotatedIoULoss from .smooth_focal_loss import SmoothFocalLoss __all__ = [ 'GDLoss', 'GDLoss_v1', 'KFLoss', 'ConvexGIoULoss', 'BCConvexGIoULoss', - 'KLDRepPointsLoss', 'SmoothFocalLoss' + 'KLDRepPointsLoss', 'SmoothFocalLoss', 'RotatedIoULoss' ] diff --git a/mmrotate/models/losses/rotated_iou_loss.py b/mmrotate/models/losses/rotated_iou_loss.py new file mode 100644 index 000000000..911da617e --- /dev/null +++ b/mmrotate/models/losses/rotated_iou_loss.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +from mmdet.models.losses.utils import weighted_loss + +from ..builder import ROTATED_LOSSES + +try: + from mmcv.ops import diff_iou_rotated_2d +except: # noqa: E722 + diff_iou_rotated_2d = None + + +@weighted_loss +def rotated_iou_loss(pred, target, linear=False, mode='log', eps=1e-6): + """Rotated IoU loss. + + Computing the IoU loss between a set of predicted rbboxes and target + rbboxes. + The loss is calculated as negative log of IoU. + + Args: + pred (torch.Tensor): Predicted bboxes of format (x, y, h, w, angle), + shape (n, 5). + target (torch.Tensor): Corresponding gt bboxes, shape (n, 5). + linear (bool, optional): If True, use linear scale of loss instead of + log scale. Default: False. + mode (str): Loss scaling mode, including "linear", "square", and "log". + Default: 'log' + eps (float): Eps to avoid log(0). + Return: + torch.Tensor: Loss tensor. + """ + assert mode in ['linear', 'square', 'log'] + if linear: + mode = 'linear' + warnings.warn( + 'DeprecationWarning: Setting "linear=True" in ' + 'poly_iou_loss is deprecated, please use "mode=`linear`" ' + 'instead.') + + if diff_iou_rotated_2d is None: + raise ImportError('Please install mmcv-full >= 1.5.0.') + + ious = diff_iou_rotated_2d(pred.unsqueeze(0), target.unsqueeze(0)) + ious = ious.squeeze(0).clamp(min=eps) + + if mode == 'linear': + loss = 1 - ious + elif mode == 'square': + loss = 1 - ious**2 + elif mode == 'log': + loss = -ious.log() + else: + raise NotImplementedError + return loss + + +@ROTATED_LOSSES.register_module() +class RotatedIoULoss(nn.Module): + """RotatedIoULoss. + + Computing the IoU loss between a set of predicted rbboxes and + target rbboxes. + Args: + linear (bool): If True, use linear scale of loss else determined + by mode. Default: False. + eps (float): Eps to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + mode (str): Loss scaling mode, including "linear", "square", and "log". + Default: 'log' + """ + + def __init__(self, + linear=False, + eps=1e-6, + reduction='mean', + loss_weight=1.0, + mode='log'): + super(RotatedIoULoss, self).__init__() + assert mode in ['linear', 'square', 'log'] + if linear: + mode = 'linear' + warnings.warn('DeprecationWarning: Setting "linear=True" in ' + 'IOULoss is deprecated, please use "mode=`linear`" ' + 'instead.') + self.mode = mode + self.linear = linear + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if (weight is not None) and (not torch.any(weight > 0)) and ( + reduction != 'none'): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 5) to (n,) to match the + # iou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + loss = self.loss_weight * rotated_iou_loss( + pred, + target, + weight, + mode=self.mode, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index 049e976fd..6df24953d 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import mmcv import pytest import torch +from mmrotate import digit_version from mmrotate.models.losses import (BCConvexGIoULoss, ConvexGIoULoss, GDLoss, - GDLoss_v1, KFLoss, KLDRepPointsLoss) + GDLoss_v1, KFLoss, KLDRepPointsLoss, + RotatedIoULoss) @pytest.mark.skipif( @@ -129,3 +132,44 @@ def test_kfiou_regression_losses(): targets_decode=targets_decode, avg_factor=10) assert isinstance(loss, torch.Tensor) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +@pytest.mark.skipif( + digit_version(mmcv.__version__) <= digit_version('1.5.0'), + reason='requires mmcv>=1.5.0') +def test_rotated_iou_losses(): + """Tests convex regression losses.""" + pred = torch.rand((10, 5)).cuda() + target = torch.rand((10, 5)).cuda() + weight = torch.rand((10, )).cuda() + + # Test loss mode + loss = RotatedIoULoss(linear=True)(pred, target) + assert isinstance(loss, torch.Tensor) + + loss = RotatedIoULoss(mode='linear')(pred, target) + assert isinstance(loss, torch.Tensor) + + loss = RotatedIoULoss(mode='log')(pred, target) + assert isinstance(loss, torch.Tensor) + + loss = RotatedIoULoss(mode='square')(pred, target) + assert isinstance(loss, torch.Tensor) + + # Test loss forward + loss = RotatedIoULoss()(pred, target) + assert isinstance(loss, torch.Tensor) + + # Test loss forward with weight + loss = RotatedIoULoss()(pred, target, weight) + assert isinstance(loss, torch.Tensor) + + # Test loss forward with reduction_override + loss = RotatedIoULoss()(pred, target, reduction_override='mean') + assert isinstance(loss, torch.Tensor) + + # Test loss forward with avg_factor + loss = RotatedIoULoss()(pred, target, avg_factor=10) + assert isinstance(loss, torch.Tensor)