diff --git a/README.md b/README.md index 64e06cc6e..048206932 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ A summary can be found in the [Model Zoo](docs/en/model_zoo.md) page. * [x] [Rotated RepPoints-OBB](configs/rotated_reppoints/README.md) (ICCV'2019) * [x] [RoI Transformer](configs/roi_trans/README.md) (CVPR'2019) * [x] [Gliding Vertex](configs/gliding_vertex/README.md) (TPAMI'2020) +* [x] [CSL](configs/csl/README.md) (ECCV'2020) * [x] [R3Det](configs/r3det/README.md) (AAAI'2021) * [x] [S2A-Net](configs/s2anet/README.md) (TGRS'2021) * [x] [ReDet](configs/redet/README.md) (CVPR'2021) diff --git a/README_zh-CN.md b/README_zh-CN.md index 4e5de4a04..e6cd17f16 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -100,6 +100,7 @@ MMRotate 也提供了其他更详细的教程: * [x] [Rotated RepPoints-OBB](configs/rotated_reppoints/README.md) (ICCV'2019) * [x] [RoI Transformer](configs/roi_trans/README.md) (CVPR'2019) * [x] [Gliding Vertex](configs/gliding_vertex/README.md) (TPAMI'2020) +* [x] [CSL](configs/csl/README.md) (ECCV'2020) * [x] [R3Det](configs/r3det/README.md) (AAAI'2021) * [x] [S2A-Net](configs/s2anet/README.md) (TGRS'2021) * [x] [ReDet](configs/redet/README.md) (CVPR'2021) diff --git a/configs/csl/README.md b/configs/csl/README.md new file mode 100644 index 000000000..972f9dfbb --- /dev/null +++ b/configs/csl/README.md @@ -0,0 +1,43 @@ +# CSL +> [Arbitrary-Oriented Object Detection with Circular Smooth Label](https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40) + + +## Abstract + +
+ +
+ +Arbitrary-oriented object detection has recently attracted increasing attention in vision for their importance +in aerial imagery, scene text, and face etc. In this paper, we show that existing regression-based rotation detectors +suffer the problem of discontinuous boundaries, which is directly caused by angular periodicity or corner ordering. +By a careful study, we find the root cause is that the ideal predictions are beyond the defined range. We design a +new rotation detection baseline, to address the boundary problem by transforming angular prediction from a regression +problem to a classification task with little accuracy loss, whereby high-precision angle classification is devised in +contrast to previous works using coarse-granularity in rotation detection. We also propose a circular smooth label (CSL) +technique to handle the periodicity of the angle and increase the error tolerance to adjacent angles. We further +introduce four window functions in CSL and explore the effect of different window radius sizes on detection performance. +Extensive experiments and visual analysis on two large-scale public datasets for aerial images i.e. DOTA, HRSC2016, +as well as scene text dataset ICDAR2015 and MLT, show the effectiveness of our approach. + +## Results and models + +DOTA1.0 + +| Backbone | mAP | Angle | Window func. | Omega | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | +|:------------:|:----------:|:-----------:|:-----------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:-------------:| +| ResNet50 (1024,1024,200) | 68.42 | le90 | - | - | 1x | 3.38 | 17.8 | - | 2 | [rotated_retinanet_obb_r50_fpn_1x_dota_le90](./rotated_retinanet_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90-c0097bc4.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90_20220128_130740.log.json) +| ResNet50 (1024,1024,200) | 68.79 | le90 | - | - | 1x | 2.36 | 25.9 | - | 2 | [rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90](./rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90-01de71b5.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90_20220303_183714.log.json) +| ResNet50 (1024,1024,200) | 69.51 | le90 | Gaussian | 4 | 1x | 2.60 | 24.0 | - | 2 | [rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90](./rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90_20220321_010033.log.json) + + +## Citation +``` +@inproceedings{yang2020arbitrary, + title={Arbitrary-Oriented Object Detection with Circular Smooth Label}, + author={Yang, Xue and Yan, Junchi}, + booktitle={European Conference on Computer Vision}, + pages={677--694}, + year={2020} +} +``` diff --git a/configs/csl/metafile.yml b/configs/csl/metafile.yml new file mode 100644 index 000000000..46dbcda9c --- /dev/null +++ b/configs/csl/metafile.yml @@ -0,0 +1,27 @@ +Collections: +- Name: CSL + Metadata: + Training Data: DOTAv1.0 + Training Techniques: + - SGD with Momentum + - Weight Decay + Training Resources: 1x Quadro RTX 8000 + Architecture: + - ResNet + Paper: + URL: https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40 + Title: 'Arbitrary-Oriented Object Detection with Circular Smooth Label' + README: configs/csl/README.md + +Models: + - Name: rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90 + In Collection: csl + Config: configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py + Metadata: + Training Data: DOTAv1.0 + Results: + - Task: Oriented Object Detection + Dataset: DOTAv1.0 + Metrics: + mAP: 69.51 + Weights: https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth diff --git a/configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py b/configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py new file mode 100644 index 000000000..f7e8d06a8 --- /dev/null +++ b/configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py @@ -0,0 +1,22 @@ +_base_ = \ + ['../rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py'] + +angle_version = 'le90' +model = dict( + bbox_head=dict( + type='CSLRRetinaHead', + angle_coder=dict( + type='CSLCoder', + angle_version=angle_version, + omega=4, + window='gaussian', + radius=3), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0), + loss_angle=dict( + type='SmoothFocalLoss', gamma=2.0, alpha=0.25, loss_weight=0.8))) diff --git a/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py b/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py index 6b12f69c6..a033006ba 100644 --- a/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py +++ b/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py @@ -1,6 +1,6 @@ -_base_ = ['./redet_re50_fpn_1x_dota_le90.py'] +_base_ = ['./redet_re50_refpn_1x_dota_le90.py'] -data_root = '/cluster/home/it_stu198/main/datasets/split_ms_dota1_0/' +data_root = 'datasets/split_ms_dota1_0/' angle_version = 'le90' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 55f08ad41..4a4612f25 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -5,6 +5,7 @@ - [Rotated RepPoints-OBB](../../configs/rotated_reppoints/README.md) (ICCV'2019) - [RoI Transformer](../../configs/roi_trans/README.md) (CVPR'2019) - [Gliding Vertex](../../configs/gliding_vertex/README.md) (TPAMI'2020) +- [CSL](../../configs/csl/README.md) (ECCV'2020) - [R3Det](../../configs/r3det/README.md) (AAAI'2021) - [S2A-Net](../../configs/s2anet/README.md) (TGRS'2021) - [ReDet](../../configs/redet/README.md) (CVPR'2021) @@ -26,6 +27,7 @@ | ResNet50 (1024,1024,200) | 68.42 | le90 | 1x | 3.38 | 16.9 | - | 2 | [rotated_retinanet_obb_r50_fpn_1x_dota_le90](../../configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90-c0097bc4.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90_20220128_130740.log.json) | ResNet50 (1024,1024,200) | 68.79 | le90 | 1x | 2.36 | 22.4 | - | 2 | [rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90](../../configs/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90-01de71b5.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90_20220303_183714.log.json) | ResNet50 (1024,1024,200) | 69.49 | le135 | 1x | 4.05 | 8.6 | - | 2 | [g_reppoints_r50_fpn_1x_dota_le135](../../configs/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135-b840eed7.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135_20220202_233631.log.json) +| ResNet50 (1024,1024,200) | 69.51 | le90 | 1x | 4.40 | 24.0 | - | 2 | [rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90](../../configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90_20220321_010033.log.json) | ResNet50 (1024,1024,200) | 69.55 | oc | 1x | 3.39 | 15.5 | - | 2 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc](../../configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc-41fd7805.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc_20220120_152421.log.json) | ResNet50 (1024,1024,200) | 69.60 | le90 | 1x | 3.38 | 15.1 | - | 2 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90](../../configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90-03e02f75.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90_20220209_173225.log.json) | ResNet50 (1024,1024,200) | 69.63 | le135 | 1x | 3.45 | 16.1 | - | 2 | [cfa_r50_fpn_1x_dota_le135](../../configs/cfa/cfa_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135-aed1cbc6.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135_20220205_144859.log.json) @@ -57,4 +59,4 @@ - `MS` means multiple scale image split. - `RR` means random rotation. -The above models are trained with 1 * 1080Ti and inferred with 1 * 2080Ti. +The above models are trained with 1 * 1080Ti/2080Ti and inferred with 1 * 2080Ti. diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md index 4ebf2cc0d..314918cf7 100644 --- a/docs/zh_cn/model_zoo.md +++ b/docs/zh_cn/model_zoo.md @@ -5,6 +5,7 @@ - [Rotated RepPoints-OBB](../../configs/rotated_reppoints/README.md) (ICCV'2019) - [RoI Transformer](../../configs/roi_trans/README.md) (CVPR'2019) - [Gliding Vertex](../../configs/gliding_vertex/README.md) (TPAMI'2020) +- [CSL](../../configs/csl/README.md) (ECCV'2020) - [R3Det](../../configs/r3det/README.md) (AAAI'2021) - [S2A-Net](../../configs/s2anet/README.md) (TGRS'2021) - [ReDet](../../configs/redet/README.md) (CVPR'2021) @@ -26,6 +27,7 @@ | ResNet50 (1024,1024,200) | 68.42 | le90 | 1x | 3.38 | 16.9 | - | 2 | [rotated_retinanet_obb_r50_fpn_1x_dota_le90](../../configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90-c0097bc4.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90_20220128_130740.log.json) | ResNet50 (1024,1024,200) | 69.49 | le135 | 1x | 4.05 | 8.6 | - | 2 | [g_reppoints_r50_fpn_1x_dota_le135](../../configs/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135-b840eed7.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135_20220202_233631.log.json) | ResNet50 (1024,1024,200) | 68.79 | le90 | 1x | 2.36 | 22.4 | - | 2 | [rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90](../../configs/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90-01de71b5.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90_20220303_183714.log.json) +| ResNet50 (1024,1024,200) | 69.51 | le90 | 1x | 4.40 | 24.0 | - | 2 | [rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90](../../configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90_20220321_010033.log.json) | ResNet50 (1024,1024,200) | 69.55 | oc | 1x | 3.39 | 15.5 | - | 2 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc](../../configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc-41fd7805.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc_20220120_152421.log.json) | ResNet50 (1024,1024,200) | 69.60 | le90 | 1x | 3.38 | 15.1 | - | 2 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90](../../configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90-03e02f75.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90_20220209_173225.log.json) | ResNet50 (1024,1024,200) | 69.63 | le135 | 1x | 3.45 | 16.1 | - | 2 | [cfa_r50_fpn_1x_dota_le135](../../configs/cfa/cfa_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135-aed1cbc6.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135_20220205_144859.log.json) @@ -57,4 +59,4 @@ - `MS` 表示多尺度图像增强。 - `RR` 表示随机旋转增强。 -上述模型都是使用 1 * 1080ti 训练得到的,并且在 1 * 2080ti 上进行推理测试。 +上述模型都是使用 1 * 1080ti/2080ti 训练得到的,并且在 1 * 2080ti 上进行推理测试。 diff --git a/mmrotate/core/anchor/anchor_generator.py b/mmrotate/core/anchor/anchor_generator.py index dec8f31dc..bdb378368 100644 --- a/mmrotate/core/anchor/anchor_generator.py +++ b/mmrotate/core/anchor/anchor_generator.py @@ -8,7 +8,10 @@ @ROTATED_ANCHOR_GENERATORS.register_module() class RotatedAnchorGenerator(AnchorGenerator): - """Standard rotate anchor generator for 2D anchor-based detectors.""" + """Fake rotate anchor generator for 2D anchor-based detectors. + + Horizontal bounding box represented by (x,y,w,h,theta). + """ def single_level_grid_priors(self, featmap_size, @@ -34,6 +37,11 @@ def single_level_grid_priors(self, anchors = super(RotatedAnchorGenerator, self).single_level_grid_priors( featmap_size, level_idx, dtype=dtype, device=device) + # The correct usage is: + # from ..bbox.transforms import hbb2obb + # anchors = hbb2obb(anchors, self.angle_version) + # instead of rudely setting the angle to all 0. + # However, the experiment shows that the performance has decreased. num_anchors = anchors.size(0) xy = (anchors[:, 2:] + anchors[:, :2]) / 2 wh = anchors[:, 2:] - anchors[:, :2] diff --git a/mmrotate/core/bbox/coder/__init__.py b/mmrotate/core/bbox/coder/__init__.py index 24ed61710..cfd7fb1a1 100644 --- a/mmrotate/core/bbox/coder/__init__.py +++ b/mmrotate/core/bbox/coder/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .angle_coder import CSLCoder from .delta_midpointoffset_rbbox_coder import MidpointOffsetCoder from .delta_xywha_hbbox_coder import DeltaXYWHAHBBoxCoder from .delta_xywha_rbbox_coder import DeltaXYWHAOBBoxCoder @@ -6,5 +7,5 @@ __all__ = [ 'DeltaXYWHAOBBoxCoder', 'DeltaXYWHAHBBoxCoder', 'MidpointOffsetCoder', - 'GVFixCoder', 'GVRatioCoder' + 'GVFixCoder', 'GVRatioCoder', 'CSLCoder' ] diff --git a/mmrotate/core/bbox/coder/angle_coder.py b/mmrotate/core/bbox/coder/angle_coder.py new file mode 100644 index 000000000..b84000a5e --- /dev/null +++ b/mmrotate/core/bbox/coder/angle_coder.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from mmdet.core.bbox.coder.base_bbox_coder import BaseBBoxCoder + +from ..builder import ROTATED_BBOX_CODERS + + +@ROTATED_BBOX_CODERS.register_module() +class CSLCoder(BaseBBoxCoder): + """Circular Smooth Label Coder. + + `Circular Smooth Label (CSL) + `_ . + + Args: + angle_version (str): Angle definition. + omega (float, optional): Angle discretization granularity. + Default: 1. + window (str, optional): Window function. Default: gaussian. + radius (int/float): window radius, int type for + ['triangle', 'rect', 'pulse'], float type for + ['gaussian']. Default: 6. + """ + + def __init__(self, angle_version, omega=1, window='gaussian', radius=6): + super().__init__() + self.angle_version = angle_version + assert angle_version in ['oc', 'le90', 'le135'] + assert window in ['gaussian', 'triangle', 'rect', 'pulse'] + self.angle_range = 90 if angle_version == 'oc' else 180 + self.angle_offset_dict = {'oc': 0, 'le90': 90, 'le135': 45} + self.angle_offset = self.angle_offset_dict[angle_version] + self.omega = omega + self.window = window + self.radius = radius + self.coding_len = int(self.angle_range // omega) + + def encode(self, angle_targets): + """Circular Smooth Label Encoder. + + Args: + angle_targets (Tensor): Angle offset for each scale level + Has shape (num_anchors * H * W, 1) + + Returns: + list[Tensor]: The csl encoding of angle offset for each + scale level. Has shape (num_anchors * H * W, coding_len) + """ + + # radius to degree + angle_targets_deg = angle_targets * (180 / math.pi) + # empty label + smooth_label = torch.zeros_like(angle_targets).repeat( + 1, self.coding_len) + angle_targets_deg = (angle_targets_deg + + self.angle_offset) / self.omega + # Float to Int + angle_targets_long = angle_targets_deg.long() + + if self.window == 'pulse': + radius_range = angle_targets_long % self.coding_len + smooth_value = 1.0 + elif self.window == 'rect': + base_radius_range = torch.arange( + -self.radius, self.radius, device=angle_targets_long.device) + radius_range = (base_radius_range + + angle_targets_long) % self.coding_len + smooth_value = 1.0 + elif self.window == 'triangle': + base_radius_range = torch.arange( + -self.radius, self.radius, device=angle_targets_long.device) + radius_range = (base_radius_range + + angle_targets_long) % self.coding_len + smooth_value = 1.0 - torch.abs( + (1 / self.radius) * base_radius_range) + + elif self.window == 'gaussian': + base_radius_range = torch.arange( + -self.angle_range // 2, + self.angle_range // 2, + device=angle_targets_long.device) + + radius_range = (base_radius_range + + angle_targets_long) % self.coding_len + smooth_value = torch.exp(-torch.pow(base_radius_range, 2) / + (2 * self.radius**2)) + + else: + raise NotImplementedError + + if isinstance(smooth_value, torch.Tensor): + smooth_value = smooth_value.unsqueeze(0).repeat( + smooth_label.size(0), 1) + + return smooth_label.scatter(1, radius_range, smooth_value) + + def decode(self, angle_preds): + """Circular Smooth Label Decoder. + + Args: + angle_preds (Tensor): The csl encoding of angle offset + for each scale level. + Has shape (num_anchors * H * W, coding_len) + + Returns: + list[Tensor]: Angle offset for each scale level. + Has shape (num_anchors * H * W, 1) + """ + angle_cls_inds = torch.argmax(angle_preds, dim=1) + angle_pred = ((angle_cls_inds + 0.5) * + self.omega) % self.angle_range - self.angle_offset + return angle_pred * (math.pi / 180) diff --git a/mmrotate/core/bbox/transforms.py b/mmrotate/core/bbox/transforms.py index 239eb2e7a..21bdfd3dc 100644 --- a/mmrotate/core/bbox/transforms.py +++ b/mmrotate/core/bbox/transforms.py @@ -573,10 +573,10 @@ def hbb2obb_oc(hbboxes): Returns: obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle] """ - x = (hbboxes[:, 0::4] + hbboxes[:, 2::4]) * 0.5 - y = (hbboxes[:, 1::4] + hbboxes[:, 3::4]) * 0.5 - w = hbboxes[:, 2::4] - hbboxes[:, 0::4] - h = hbboxes[:, 3::4] - hbboxes[:, 1::4] + x = (hbboxes[..., 0] + hbboxes[..., 2]) * 0.5 + y = (hbboxes[..., 1] + hbboxes[..., 3]) * 0.5 + w = hbboxes[..., 2] - hbboxes[..., 0] + h = hbboxes[..., 3] - hbboxes[..., 1] theta = x.new_zeros(*x.shape) rbboxes = torch.stack([x, y, h, w, theta + np.pi / 2], dim=-1) return rbboxes diff --git a/mmrotate/datasets/pipelines/transforms.py b/mmrotate/datasets/pipelines/transforms.py index b04484f47..3a4510844 100644 --- a/mmrotate/datasets/pipelines/transforms.py +++ b/mmrotate/datasets/pipelines/transforms.py @@ -84,7 +84,7 @@ def bbox_flip(self, bboxes, img_shape, direction): if self.version == 'oc': rotated_flag = (bboxes[:, 4] != np.pi / 2) flipped[rotated_flag, 4] = np.pi / 2 - bboxes[rotated_flag, 4] - flipped[rotated_flag, 2] = bboxes[rotated_flag, 3], + flipped[rotated_flag, 2] = bboxes[rotated_flag, 3] flipped[rotated_flag, 3] = bboxes[rotated_flag, 2] else: flipped[:, 4] = norm_angle(np.pi - bboxes[:, 4], self.version) diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py index 3d28b1afa..0adcbdb8a 100644 --- a/mmrotate/models/dense_heads/__init__.py +++ b/mmrotate/models/dense_heads/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .csl_rotated_retina_head import CSLRRetinaHead from .kfiou_odm_refine_head import KFIoUODMRefineHead from .kfiou_rotate_retina_head import KFIoURRetinaHead from .kfiou_rotate_retina_refine_head import KFIoURRetinaRefineHead @@ -15,5 +16,5 @@ 'RotatedAnchorHead', 'RotatedRetinaHead', 'RotatedRPNHead', 'OrientedRPNHead', 'RotatedRetinaRefineHead', 'ODMRefineHead', 'KFIoURRetinaHead', 'KFIoURRetinaRefineHead', 'KFIoUODMRefineHead', - 'RotatedRepPointsHead', 'SAMRepPointsHead' + 'RotatedRepPointsHead', 'SAMRepPointsHead', 'CSLRRetinaHead' ] diff --git a/mmrotate/models/dense_heads/csl_rotated_retina_head.py b/mmrotate/models/dense_heads/csl_rotated_retina_head.py new file mode 100644 index 000000000..2cc19fc88 --- /dev/null +++ b/mmrotate/models/dense_heads/csl_rotated_retina_head.py @@ -0,0 +1,579 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +from mmcv.runner import force_fp32 +from mmdet.core import images_to_levels, multi_apply, unmap + +from mmrotate.core import build_bbox_coder, multiclass_nms_rotated +from ... import obb2hbb, rotated_anchor_inside_flags +from ..builder import ROTATED_HEADS, build_loss +from .rotated_retina_head import RotatedRetinaHead + + +@ROTATED_HEADS.register_module() +class CSLRRetinaHead(RotatedRetinaHead): + """Rotational Anchor-based refine head. + + Args: + use_encoded_angle (bool): Decide whether to use encoded angle or + gt angle as target. Default: True. + shield_reg_angle (bool): Decide whether to shield the angle loss from + reg branch. Default: False. + angle_coder (dict): Config of angle coder. + loss_angle (dict): Config of angle classification loss. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + use_encoded_angle=True, + shield_reg_angle=False, + angle_coder=dict( + type='CSLCoder', + angle_version='le90', + omega=1, + window='gaussian', + radius=6), + loss_angle=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=[ + dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01), + dict( + type='Normal', + name='retina_angle_cls', + std=0.01, + bias_prob=0.01), + ]), + **kwargs): + self.angle_coder = build_bbox_coder(angle_coder) + self.coding_len = self.angle_coder.coding_len + super(CSLRRetinaHead, self).__init__(**kwargs, init_cfg=init_cfg) + self.shield_reg_angle = shield_reg_angle + self.loss_angle = build_loss(loss_angle) + self.use_encoded_angle = use_encoded_angle + + def _init_layers(self): + """Initialize layers of the head.""" + super(CSLRRetinaHead, self)._init_layers() + self.retina_angle_cls = nn.Conv2d( + self.feat_channels, + self.num_anchors * self.coding_len, + 3, + padding=1) + + def forward_single(self, x): + """Forward feature of a single scale level. + + Args: + x (torch.Tensor): Features of a single scale level. + + Returns: + tuple: + cls_score (torch.Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + bbox_pred (torch.Tensor): Box energies / deltas for a single + scale level, the channels number is num_anchors * 5. + angle_cls (torch.Tensor): Angle for a single scale level + the channels number is num_anchors * coding_len. + """ + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_pred = self.retina_reg(reg_feat) + angle_cls = self.retina_angle_cls(reg_feat) + return cls_score, bbox_pred, angle_cls + + def loss_single(self, cls_score, bbox_pred, angle_cls, anchors, labels, + label_weights, bbox_targets, bbox_weights, angle_targets, + angle_weights, num_total_samples): + """Compute loss of a single scale level. + + Args: + cls_score (torch.Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (torch.Tensor): Box energies / deltas for each scale + level with shape (N, num_anchors * 5, H, W). + anchors (torch.Tensor): Box reference for each scale level with + shape (N, num_total_anchors, 5). + labels (torch.Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (torch.Tensor): Label weights of each anchor with + shape (N, num_total_anchors) + bbox_targets (torch.Tensor): BBox regression targets of each anchor + weight shape (N, num_total_anchors, 5). + bbox_weights (torch.Tensor): BBox regression loss weights of each + anchor with shape (N, num_total_anchors, 5). + angle_targets (torch.Tensor): Angle classification targets of + each anchor weight shape (N, num_total_anchors, coding_len). + angle_weights (torch.Tensor): Angle classification loss weights + of each anchor with shape (N, num_total_anchors, 1). + num_total_samples (int): If sampling, num total samples equal to + the number of total anchors; Otherwise, it is the number of + positive anchors. + + Returns: + loss_cls (torch.Tensor): cls. loss for each scale level. + loss_bbox (torch.Tensor): reg. loss for each scale level. + loss_angle (torch.Tensor): angle cls. loss for each scale level. + """ + # Classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + loss_cls = self.loss_cls( + cls_score, labels, label_weights, avg_factor=num_total_samples) + # Regression loss + bbox_targets = bbox_targets.reshape(-1, 5) + bbox_weights = bbox_weights.reshape(-1, 5) + # Shield angle in reg. branch + if self.shield_reg_angle: + bbox_weights[:, -1] = 0. + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5) + if self.reg_decoded_bbox: + anchors = anchors.reshape(-1, 5) + bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) + + loss_bbox = self.loss_bbox( + bbox_pred, + bbox_targets, + bbox_weights, + avg_factor=num_total_samples) + + angle_cls = angle_cls.permute(0, 2, 3, 1).reshape(-1, self.coding_len) + angle_targets = angle_targets.reshape(-1, self.coding_len) + angle_weights = angle_weights.reshape(-1, 1) + + loss_angle = self.loss_angle( + angle_cls, + angle_targets, + weight=angle_weights, + avg_factor=num_total_samples) + + return loss_cls, loss_bbox, loss_angle + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'angle_clses')) + def loss(self, + cls_scores, + bbox_preds, + angle_clses, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 5, H, W) + angle_clses (list[Tensor]): Box angles for each scale + level with shape (N, num_anchors * coding_len, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 5) in [cx, cy, w, h, a] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. Default: None + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.anchor_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg, angel_target_list, + angel_weight_list) = cls_reg_targets + num_total_samples = ( + num_total_pos + num_total_neg if self.sampling else num_total_pos) + + # Anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # Concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i, _ in enumerate(anchor_list): + concat_anchor_list.append(torch.cat(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + + losses_cls, losses_bbox, losses_angle = multi_apply( + self.loss_single, + cls_scores, + bbox_preds, + angle_clses, + all_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + angel_target_list, + angel_weight_list, + num_total_samples=num_total_samples) + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_angle=losses_angle) + + def _get_targets_single(self, + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in a + single image. + + Args: + flat_anchors (torch.Tensor): Multi-level anchors of the image, + which are concatenated into a single tensor of shape + (num_anchors, 5) + valid_flags (torch.Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + gt_bboxes (torch.Tensor): Ground truth bboxes of the image, + shape (num_gts, 5). + img_meta (dict): Meta info of the image. + gt_bboxes_ignore (torch.Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 5). + img_meta (dict): Meta info of the image. + gt_labels (torch.Tensor): Ground truth labels of each box, + shape (num_gts,). + label_channels (int): Channel of label. Default: 1. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Default: True. + + Returns: + tuple: + labels_list (list[Tensor]): Labels of each level + label_weights_list (list[Tensor]): Label weights of each level + bbox_targets_list (list[Tensor]): BBox targets of each level + bbox_weights_list (list[Tensor]): BBox weights of each level + angle_targets_list (list[Tensor]): Angle targets of each level + angle_weights_list (list[Tensor]): Angle weights of each level + num_total_pos (int): Number of positive samples in all images + num_total_neg (int): Number of negative samples in all images + """ + inside_flags = rotated_anchor_inside_flags( + flat_anchors, valid_flags, img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 9 + # Assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + + if self.assign_by_circumhbbox is not None: + gt_bboxes_assign = obb2hbb(gt_bboxes, self.assign_by_circumhbbox) + assign_result = self.assigner.assign( + anchors, gt_bboxes_assign, gt_bboxes_ignore, + None if self.sampling else gt_labels) + else: + assign_result = self.assigner.assign( + anchors, gt_bboxes, gt_bboxes_ignore, + None if self.sampling else gt_labels) + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + angle_targets = torch.zeros_like(bbox_targets[:, 4:5]) + angle_weights = torch.zeros_like(bbox_targets[:, 4:5]) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) + else: + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + if self.use_encoded_angle: + # Get encoded angle as target + angle_targets[pos_inds, :] = pos_bbox_targets[:, 4:5] + else: + # Get gt angle as target + angle_targets[pos_inds, :] = \ + sampling_result.pos_gt_bboxes[:, 4:5] + # Angle encoder + angle_targets = self.angle_coder.encode(angle_targets) + angle_weights[pos_inds, :] = 1.0 + + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # Map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + angle_targets = unmap(angle_targets, num_total_anchors, + inside_flags) + angle_weights = unmap(angle_weights, num_total_anchors, + inside_flags) + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, sampling_result, angle_targets, angle_weights) + + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + angle_cls_list, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False, + with_nms=True): + """Transform outputs for a single batch item into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores for a single scale level + Has shape (num_anchors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas for a single + scale level with shape (num_anchors * 5, H, W). + angle_cls_list (list[Tensor]): Angle deltas for a single + scale level with shape (num_anchors * coding_len, H, W). + mlvl_anchors (list[Tensor]): Box reference for a single scale level + with shape (num_total_anchors, 5). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + are bounding box positions (cx, cy, w, h, a) and the + 6-th column is a score between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + mlvl_bboxes = [] + mlvl_scores = [] + for cls_score, bbox_pred, angle_cls, anchors in zip( + cls_score_list, bbox_pred_list, angle_cls_list, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1) + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5) + + angle_cls = angle_cls.permute(1, 2, 0).reshape( + -1, self.coding_len).sigmoid() + + nms_pre = cfg.get('nms_pre', -1) + if scores.shape[0] > nms_pre > 0: + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores, _ = scores.max(dim=1) + else: + # Remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + max_scores, _ = scores[:, :-1].max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + angle_cls = angle_cls[topk_inds, :] + + # Angle decoder + angle_pred = self.angle_coder.decode(angle_cls) + + if self.use_encoded_angle: + bbox_pred[..., -1] = angle_pred + bboxes = self.bbox_coder.decode( + anchors, bbox_pred, max_shape=img_shape) + else: + bboxes = self.bbox_coder.decode( + anchors, bbox_pred, max_shape=img_shape) + bboxes[..., -1] = angle_pred + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + # Angle should not be rescaled + mlvl_bboxes[:, :4] = mlvl_bboxes[:, :4] / mlvl_bboxes.new_tensor( + scale_factor) + mlvl_scores = torch.cat(mlvl_scores) + if self.use_sigmoid_cls: + # Add a dummy background class to the backend when using sigmoid + # Remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + + if with_nms: + det_bboxes, det_labels = multiclass_nms_rotated( + mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, + cfg.max_per_img) + return det_bboxes, det_labels + else: + return mlvl_bboxes, mlvl_scores + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'angle_clses')) + def get_bboxes(self, + cls_scores, + bbox_preds, + angle_clses, + img_metas, + cfg=None, + rescale=False, + with_nms=True): + """Transform network output for a batch into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 5, H, W) + angle_clses (list[Tensor]): Box angles for each scale + level with shape (N, num_anchors * coding_len, H, W) + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 6) tensor, where the first 5 columns + are bounding box positions (cx, cy, w, h, a) and the + 6-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class label of the + corresponding box. + + Example: + >>> import mmcv + >>> self = AnchorHead( + >>> num_classes=9, + >>> in_channels=1, + >>> anchor_generator=dict( + >>> type='AnchorGenerator', + >>> scales=[8], + >>> ratios=[0.5, 1.0, 2.0], + >>> strides=[4,])) + >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}] + >>> cfg = mmcv.Config(dict( + >>> score_thr=0.00, + >>> nms=dict(type='nms', iou_thr=1.0), + >>> max_per_img=10)) + >>> feat = torch.rand(1, 1, 3, 3) + >>> cls_score, bbox_pred = self.forward_single(feat) + >>> # Note the input lists are over different levels, not images + >>> cls_scores, bbox_preds = [cls_score], [bbox_pred] + >>> result_list = self.get_bboxes(cls_scores, bbox_preds, + >>> img_metas, cfg) + >>> det_bboxes, det_labels = result_list[0] + >>> assert len(result_list) == 1 + >>> assert det_bboxes.shape[1] == 5 + >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img + """ + assert len(cls_scores) == len(bbox_preds) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_priors( + featmap_sizes, device=device) + + result_list = [] + for img_id, _ in enumerate(img_metas): + cls_score_list = [ + cls_scores[i][img_id].detach() for i in range(num_levels) + ] + bbox_pred_list = [ + bbox_preds[i][img_id].detach() for i in range(num_levels) + ] + angle_cls_list = [ + angle_clses[i][img_id].detach() for i in range(num_levels) + ] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + if with_nms: + # Some heads don't support with_nms argument + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + angle_cls_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale) + else: + proposals = self._get_bboxes_single(cls_score_list, + bbox_pred_list, + angle_cls_list, + mlvl_anchors, img_shape, + scale_factor, cfg, rescale, + with_nms) + result_list.append(proposals) + return result_list diff --git a/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py b/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py index 19c56f512..a1448b42c 100644 --- a/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py +++ b/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py @@ -83,7 +83,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, positive anchors. Returns: - dict[str, Tensor]: A dictionary of loss components. + loss_cls (torch.Tensor): cls. loss for each scale level. + loss_bbox (torch.Tensor): reg. loss for each scale level. """ # classification loss labels = labels.reshape(-1) diff --git a/mmrotate/models/dense_heads/oriented_rpn_head.py b/mmrotate/models/dense_heads/oriented_rpn_head.py index ca35c0e31..823a8ed6c 100644 --- a/mmrotate/models/dense_heads/oriented_rpn_head.py +++ b/mmrotate/models/dense_heads/oriented_rpn_head.py @@ -155,7 +155,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, positive anchors. Returns: - dict[str, Tensor]: A dictionary of loss components. + loss_cls (torch.Tensor): cls. loss for each scale level. + loss_bbox (torch.Tensor): reg. loss for each scale level. """ # classification loss labels = labels.reshape(-1) diff --git a/mmrotate/models/dense_heads/rotated_anchor_head.py b/mmrotate/models/dense_heads/rotated_anchor_head.py index 3dad51205..0b7f46bd0 100644 --- a/mmrotate/models/dense_heads/rotated_anchor_head.py +++ b/mmrotate/models/dense_heads/rotated_anchor_head.py @@ -187,7 +187,7 @@ def _get_targets_single(self, Args: flat_anchors (torch.Tensor): Multi-level anchors of the image, which are concatenated into a single tensor of shape - (num_anchors ,4) + (num_anchors, 5) valid_flags (torch.Tensor): Multi level valid flags of the image, which are concatenated into a single tensor of shape (num_anchors,). @@ -296,7 +296,7 @@ def get_targets(self, anchor_list (list[list[Tensor]]): Multi level anchors of each image. The outer list indicates images, and the inner list corresponds to feature levels of the image. Each element of - the inner list is a tensor of shape (num_anchors, 4). + the inner list is a tensor of shape (num_anchors, 5). valid_flag_list (list[list[Tensor]]): Multi level valid flags of each image. The outer list indicates images, and the inner list corresponds to feature levels of the image. Each element of @@ -405,7 +405,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, positive anchors. Returns: - dict[str, Tensor]: A dictionary of loss components. + loss_cls (torch.Tensor): cls. loss for each scale level. + loss_bbox (torch.Tensor): reg. loss for each scale level. """ # classification loss labels = labels.reshape(-1) diff --git a/mmrotate/models/dense_heads/rotated_rpn_head.py b/mmrotate/models/dense_heads/rotated_rpn_head.py index 632acbb1e..eb6606867 100644 --- a/mmrotate/models/dense_heads/rotated_rpn_head.py +++ b/mmrotate/models/dense_heads/rotated_rpn_head.py @@ -274,7 +274,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, positive anchors. Returns: - dict[str, Tensor]: A dictionary of loss components. + loss_cls (torch.Tensor): cls. loss for each scale level. + loss_bbox (torch.Tensor): reg. loss for each scale level. """ # classification loss labels = labels.reshape(-1) diff --git a/mmrotate/models/losses/__init__.py b/mmrotate/models/losses/__init__.py index b7050fbde..8594630e3 100644 --- a/mmrotate/models/losses/__init__.py +++ b/mmrotate/models/losses/__init__.py @@ -4,8 +4,9 @@ from .gaussian_dist_loss_v1 import GDLoss_v1 from .kf_iou_loss import KFLoss from .kld_reppoints_loss import KLDRepPointsLoss +from .smooth_focal_loss import SmoothFocalLoss __all__ = [ 'GDLoss', 'GDLoss_v1', 'KFLoss', 'ConvexGIoULoss', 'BCConvexGIoULoss', - 'KLDRepPointsLoss' + 'KLDRepPointsLoss', 'SmoothFocalLoss' ] diff --git a/mmrotate/models/losses/gaussian_dist_loss.py b/mmrotate/models/losses/gaussian_dist_loss.py index 1976f0c46..638d1a870 100644 --- a/mmrotate/models/losses/gaussian_dist_loss.py +++ b/mmrotate/models/losses/gaussian_dist_loss.py @@ -91,6 +91,31 @@ def postprocess(distance, fun='log1p', tau=1.0): @weighted_loss def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True): """Gaussian Wasserstein distance loss. + Derivation and simplification: + Given any positive-definite symmetrical 2*2 matrix Z: + :math:`Tr(Z^{1/2}) = λ_1^{1/2} + λ_2^{1/2}` + where :math:`λ_1` and :math:`λ_2` are the eigen values of Z + Meanwhile we have: + :math:`Tr(Z) = λ_1 + λ_2` + + :math:`det(Z) = λ_1 * λ_2` + Combination with following formula: + :math:`(λ_1^{1/2}+λ_2^{1/2})^2 = λ_1+λ_2+2 *(λ_1 * λ_2)^{1/2}` + Yield: + :math:`Tr(Z^{1/2}) = (Tr(Z) + 2 * (det(Z))^{1/2})^{1/2}` + For gwd loss the frustrating coupling part is: + :math:`Tr((Σ_p^{1/2} * Σ_t * Σp^{1/2})^{1/2})` + Assuming :math:`Z = Σ_p^{1/2} * Σ_t * Σ_p^{1/2}` then: + :math:`Tr(Z) = Tr(Σ_p^{1/2} * Σ_t * Σ_p^{1/2}) + = Tr(Σ_p^{1/2} * Σ_p^{1/2} * Σ_t) + = Tr(Σ_p * Σ_t)` + :math:`det(Z) = det(Σ_p^{1/2} * Σ_t * Σ_p^{1/2}) + = det(Σ_p^{1/2}) * det(Σ_t) * det(Σ_p^{1/2}) + = det(Σ_p * Σ_t)` + and thus we can rewrite the coupling part as: + :math:`Tr(Z^{1/2}) = (Tr(Z) + 2 * (det(Z))^{1/2})^{1/2}` + :math:`Tr((Σ_p^{1/2} * Σ_t * Σ_p^{1/2})^{1/2}) + = (Tr(Σ_p * Σ_t) + 2 * (det(Σ_p * Σ_t))^{1/2})^{1/2}` Args: pred (torch.Tensor): Predicted bboxes. @@ -102,6 +127,7 @@ def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True): Returns: loss (torch.Tensor) + """ xy_p, Sigma_p = pred xy_t, Sigma_t = target diff --git a/mmrotate/models/losses/smooth_focal_loss.py b/mmrotate/models/losses/smooth_focal_loss.py new file mode 100644 index 000000000..f05c9a42c --- /dev/null +++ b/mmrotate/models/losses/smooth_focal_loss.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models import weight_reduce_loss + +from ..builder import ROTATED_LOSSES + + +def smooth_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + """Smooth Focal Loss proposed in Circular Smooth Label (CSL). + + `Circular Smooth Label (CSL) + `_ . + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: The calculated loss + """ + + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@ROTATED_LOSSES.register_module() +class SmoothFocalLoss(nn.Module): + + def __init__(self, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0): + """Smooth Focal Loss. + + Args: + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + + Returns: + loss (torch.Tensor) + """ + super(SmoothFocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + loss_cls = self.loss_weight * smooth_focal_loss( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + + return loss_cls