diff --git a/README.md b/README.md
index 64e06cc6e..048206932 100644
--- a/README.md
+++ b/README.md
@@ -104,6 +104,7 @@ A summary can be found in the [Model Zoo](docs/en/model_zoo.md) page.
* [x] [Rotated RepPoints-OBB](configs/rotated_reppoints/README.md) (ICCV'2019)
* [x] [RoI Transformer](configs/roi_trans/README.md) (CVPR'2019)
* [x] [Gliding Vertex](configs/gliding_vertex/README.md) (TPAMI'2020)
+* [x] [CSL](configs/csl/README.md) (ECCV'2020)
* [x] [R3Det](configs/r3det/README.md) (AAAI'2021)
* [x] [S2A-Net](configs/s2anet/README.md) (TGRS'2021)
* [x] [ReDet](configs/redet/README.md) (CVPR'2021)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 4e5de4a04..e6cd17f16 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -100,6 +100,7 @@ MMRotate 也提供了其他更详细的教程:
* [x] [Rotated RepPoints-OBB](configs/rotated_reppoints/README.md) (ICCV'2019)
* [x] [RoI Transformer](configs/roi_trans/README.md) (CVPR'2019)
* [x] [Gliding Vertex](configs/gliding_vertex/README.md) (TPAMI'2020)
+* [x] [CSL](configs/csl/README.md) (ECCV'2020)
* [x] [R3Det](configs/r3det/README.md) (AAAI'2021)
* [x] [S2A-Net](configs/s2anet/README.md) (TGRS'2021)
* [x] [ReDet](configs/redet/README.md) (CVPR'2021)
diff --git a/configs/csl/README.md b/configs/csl/README.md
new file mode 100644
index 000000000..972f9dfbb
--- /dev/null
+++ b/configs/csl/README.md
@@ -0,0 +1,43 @@
+# CSL
+> [Arbitrary-Oriented Object Detection with Circular Smooth Label](https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40)
+
+
+## Abstract
+
+
+

+
+
+Arbitrary-oriented object detection has recently attracted increasing attention in vision for their importance
+in aerial imagery, scene text, and face etc. In this paper, we show that existing regression-based rotation detectors
+suffer the problem of discontinuous boundaries, which is directly caused by angular periodicity or corner ordering.
+By a careful study, we find the root cause is that the ideal predictions are beyond the defined range. We design a
+new rotation detection baseline, to address the boundary problem by transforming angular prediction from a regression
+problem to a classification task with little accuracy loss, whereby high-precision angle classification is devised in
+contrast to previous works using coarse-granularity in rotation detection. We also propose a circular smooth label (CSL)
+technique to handle the periodicity of the angle and increase the error tolerance to adjacent angles. We further
+introduce four window functions in CSL and explore the effect of different window radius sizes on detection performance.
+Extensive experiments and visual analysis on two large-scale public datasets for aerial images i.e. DOTA, HRSC2016,
+as well as scene text dataset ICDAR2015 and MLT, show the effectiveness of our approach.
+
+## Results and models
+
+DOTA1.0
+
+| Backbone | mAP | Angle | Window func. | Omega | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download |
+|:------------:|:----------:|:-----------:|:-----------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:-------------:|
+| ResNet50 (1024,1024,200) | 68.42 | le90 | - | - | 1x | 3.38 | 17.8 | - | 2 | [rotated_retinanet_obb_r50_fpn_1x_dota_le90](./rotated_retinanet_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90-c0097bc4.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90_20220128_130740.log.json)
+| ResNet50 (1024,1024,200) | 68.79 | le90 | - | - | 1x | 2.36 | 25.9 | - | 2 | [rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90](./rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90-01de71b5.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90_20220303_183714.log.json)
+| ResNet50 (1024,1024,200) | 69.51 | le90 | Gaussian | 4 | 1x | 2.60 | 24.0 | - | 2 | [rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90](./rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90_20220321_010033.log.json)
+
+
+## Citation
+```
+@inproceedings{yang2020arbitrary,
+ title={Arbitrary-Oriented Object Detection with Circular Smooth Label},
+ author={Yang, Xue and Yan, Junchi},
+ booktitle={European Conference on Computer Vision},
+ pages={677--694},
+ year={2020}
+}
+```
diff --git a/configs/csl/metafile.yml b/configs/csl/metafile.yml
new file mode 100644
index 000000000..46dbcda9c
--- /dev/null
+++ b/configs/csl/metafile.yml
@@ -0,0 +1,27 @@
+Collections:
+- Name: CSL
+ Metadata:
+ Training Data: DOTAv1.0
+ Training Techniques:
+ - SGD with Momentum
+ - Weight Decay
+ Training Resources: 1x Quadro RTX 8000
+ Architecture:
+ - ResNet
+ Paper:
+ URL: https://link.springer.com/chapter/10.1007/978-3-030-58598-3_40
+ Title: 'Arbitrary-Oriented Object Detection with Circular Smooth Label'
+ README: configs/csl/README.md
+
+Models:
+ - Name: rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90
+ In Collection: csl
+ Config: configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py
+ Metadata:
+ Training Data: DOTAv1.0
+ Results:
+ - Task: Oriented Object Detection
+ Dataset: DOTAv1.0
+ Metrics:
+ mAP: 69.51
+ Weights: https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth
diff --git a/configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py b/configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py
new file mode 100644
index 000000000..f7e8d06a8
--- /dev/null
+++ b/configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py
@@ -0,0 +1,22 @@
+_base_ = \
+ ['../rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py']
+
+angle_version = 'le90'
+model = dict(
+ bbox_head=dict(
+ type='CSLRRetinaHead',
+ angle_coder=dict(
+ type='CSLCoder',
+ angle_version=angle_version,
+ omega=4,
+ window='gaussian',
+ radius=3),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0),
+ loss_angle=dict(
+ type='SmoothFocalLoss', gamma=2.0, alpha=0.25, loss_weight=0.8)))
diff --git a/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py b/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py
index 6b12f69c6..a033006ba 100644
--- a/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py
+++ b/configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py
@@ -1,6 +1,6 @@
-_base_ = ['./redet_re50_fpn_1x_dota_le90.py']
+_base_ = ['./redet_re50_refpn_1x_dota_le90.py']
-data_root = '/cluster/home/it_stu198/main/datasets/split_ms_dota1_0/'
+data_root = 'datasets/split_ms_dota1_0/'
angle_version = 'le90'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md
index 55f08ad41..4a4612f25 100644
--- a/docs/en/model_zoo.md
+++ b/docs/en/model_zoo.md
@@ -5,6 +5,7 @@
- [Rotated RepPoints-OBB](../../configs/rotated_reppoints/README.md) (ICCV'2019)
- [RoI Transformer](../../configs/roi_trans/README.md) (CVPR'2019)
- [Gliding Vertex](../../configs/gliding_vertex/README.md) (TPAMI'2020)
+- [CSL](../../configs/csl/README.md) (ECCV'2020)
- [R3Det](../../configs/r3det/README.md) (AAAI'2021)
- [S2A-Net](../../configs/s2anet/README.md) (TGRS'2021)
- [ReDet](../../configs/redet/README.md) (CVPR'2021)
@@ -26,6 +27,7 @@
| ResNet50 (1024,1024,200) | 68.42 | le90 | 1x | 3.38 | 16.9 | - | 2 | [rotated_retinanet_obb_r50_fpn_1x_dota_le90](../../configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90-c0097bc4.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90_20220128_130740.log.json)
| ResNet50 (1024,1024,200) | 68.79 | le90 | 1x | 2.36 | 22.4 | - | 2 | [rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90](../../configs/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90-01de71b5.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90_20220303_183714.log.json)
| ResNet50 (1024,1024,200) | 69.49 | le135 | 1x | 4.05 | 8.6 | - | 2 | [g_reppoints_r50_fpn_1x_dota_le135](../../configs/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135-b840eed7.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135_20220202_233631.log.json)
+| ResNet50 (1024,1024,200) | 69.51 | le90 | 1x | 4.40 | 24.0 | - | 2 | [rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90](../../configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90_20220321_010033.log.json)
| ResNet50 (1024,1024,200) | 69.55 | oc | 1x | 3.39 | 15.5 | - | 2 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc](../../configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc-41fd7805.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc_20220120_152421.log.json)
| ResNet50 (1024,1024,200) | 69.60 | le90 | 1x | 3.38 | 15.1 | - | 2 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90](../../configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90-03e02f75.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90_20220209_173225.log.json)
| ResNet50 (1024,1024,200) | 69.63 | le135 | 1x | 3.45 | 16.1 | - | 2 | [cfa_r50_fpn_1x_dota_le135](../../configs/cfa/cfa_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135-aed1cbc6.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135_20220205_144859.log.json)
@@ -57,4 +59,4 @@
- `MS` means multiple scale image split.
- `RR` means random rotation.
-The above models are trained with 1 * 1080Ti and inferred with 1 * 2080Ti.
+The above models are trained with 1 * 1080Ti/2080Ti and inferred with 1 * 2080Ti.
diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md
index 4ebf2cc0d..314918cf7 100644
--- a/docs/zh_cn/model_zoo.md
+++ b/docs/zh_cn/model_zoo.md
@@ -5,6 +5,7 @@
- [Rotated RepPoints-OBB](../../configs/rotated_reppoints/README.md) (ICCV'2019)
- [RoI Transformer](../../configs/roi_trans/README.md) (CVPR'2019)
- [Gliding Vertex](../../configs/gliding_vertex/README.md) (TPAMI'2020)
+- [CSL](../../configs/csl/README.md) (ECCV'2020)
- [R3Det](../../configs/r3det/README.md) (AAAI'2021)
- [S2A-Net](../../configs/s2anet/README.md) (TGRS'2021)
- [ReDet](../../configs/redet/README.md) (CVPR'2021)
@@ -26,6 +27,7 @@
| ResNet50 (1024,1024,200) | 68.42 | le90 | 1x | 3.38 | 16.9 | - | 2 | [rotated_retinanet_obb_r50_fpn_1x_dota_le90](../../configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90-c0097bc4.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le90/rotated_retinanet_obb_r50_fpn_1x_dota_le90_20220128_130740.log.json)
| ResNet50 (1024,1024,200) | 69.49 | le135 | 1x | 4.05 | 8.6 | - | 2 | [g_reppoints_r50_fpn_1x_dota_le135](../../configs/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135-b840eed7.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/g_reppoints/g_reppoints_r50_fpn_1x_dota_le135/g_reppoints_r50_fpn_1x_dota_le135_20220202_233631.log.json)
| ResNet50 (1024,1024,200) | 68.79 | le90 | 1x | 2.36 | 22.4 | - | 2 | [rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90](../../configs/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90-01de71b5.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/rotated_retinanet/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90_20220303_183714.log.json)
+| ResNet50 (1024,1024,200) | 69.51 | le90 | 1x | 4.40 | 24.0 | - | 2 | [rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90](../../configs/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90-b4271aed.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/csl/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90/rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90_20220321_010033.log.json)
| ResNet50 (1024,1024,200) | 69.55 | oc | 1x | 3.39 | 15.5 | - | 2 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc](../../configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc-41fd7805.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc/rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc_20220120_152421.log.json)
| ResNet50 (1024,1024,200) | 69.60 | le90 | 1x | 3.38 | 15.1 | - | 2 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90](../../configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90-03e02f75.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90/rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90_20220209_173225.log.json)
| ResNet50 (1024,1024,200) | 69.63 | le135 | 1x | 3.45 | 16.1 | - | 2 | [cfa_r50_fpn_1x_dota_le135](../../configs/cfa/cfa_r50_fpn_1x_dota_le135.py) | [model](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135-aed1cbc6.pth) | [log](https://download.openmmlab.com/mmrotate/v0.1.0/cfa/cfa_r50_fpn_1x_dota_le135/cfa_r50_fpn_1x_dota_le135_20220205_144859.log.json)
@@ -57,4 +59,4 @@
- `MS` 表示多尺度图像增强。
- `RR` 表示随机旋转增强。
-上述模型都是使用 1 * 1080ti 训练得到的,并且在 1 * 2080ti 上进行推理测试。
+上述模型都是使用 1 * 1080ti/2080ti 训练得到的,并且在 1 * 2080ti 上进行推理测试。
diff --git a/mmrotate/core/anchor/anchor_generator.py b/mmrotate/core/anchor/anchor_generator.py
index dec8f31dc..bdb378368 100644
--- a/mmrotate/core/anchor/anchor_generator.py
+++ b/mmrotate/core/anchor/anchor_generator.py
@@ -8,7 +8,10 @@
@ROTATED_ANCHOR_GENERATORS.register_module()
class RotatedAnchorGenerator(AnchorGenerator):
- """Standard rotate anchor generator for 2D anchor-based detectors."""
+ """Fake rotate anchor generator for 2D anchor-based detectors.
+
+ Horizontal bounding box represented by (x,y,w,h,theta).
+ """
def single_level_grid_priors(self,
featmap_size,
@@ -34,6 +37,11 @@ def single_level_grid_priors(self,
anchors = super(RotatedAnchorGenerator, self).single_level_grid_priors(
featmap_size, level_idx, dtype=dtype, device=device)
+ # The correct usage is:
+ # from ..bbox.transforms import hbb2obb
+ # anchors = hbb2obb(anchors, self.angle_version)
+ # instead of rudely setting the angle to all 0.
+ # However, the experiment shows that the performance has decreased.
num_anchors = anchors.size(0)
xy = (anchors[:, 2:] + anchors[:, :2]) / 2
wh = anchors[:, 2:] - anchors[:, :2]
diff --git a/mmrotate/core/bbox/coder/__init__.py b/mmrotate/core/bbox/coder/__init__.py
index 24ed61710..cfd7fb1a1 100644
--- a/mmrotate/core/bbox/coder/__init__.py
+++ b/mmrotate/core/bbox/coder/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .angle_coder import CSLCoder
from .delta_midpointoffset_rbbox_coder import MidpointOffsetCoder
from .delta_xywha_hbbox_coder import DeltaXYWHAHBBoxCoder
from .delta_xywha_rbbox_coder import DeltaXYWHAOBBoxCoder
@@ -6,5 +7,5 @@
__all__ = [
'DeltaXYWHAOBBoxCoder', 'DeltaXYWHAHBBoxCoder', 'MidpointOffsetCoder',
- 'GVFixCoder', 'GVRatioCoder'
+ 'GVFixCoder', 'GVRatioCoder', 'CSLCoder'
]
diff --git a/mmrotate/core/bbox/coder/angle_coder.py b/mmrotate/core/bbox/coder/angle_coder.py
new file mode 100644
index 000000000..b84000a5e
--- /dev/null
+++ b/mmrotate/core/bbox/coder/angle_coder.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+from mmdet.core.bbox.coder.base_bbox_coder import BaseBBoxCoder
+
+from ..builder import ROTATED_BBOX_CODERS
+
+
+@ROTATED_BBOX_CODERS.register_module()
+class CSLCoder(BaseBBoxCoder):
+ """Circular Smooth Label Coder.
+
+ `Circular Smooth Label (CSL)
+ `_ .
+
+ Args:
+ angle_version (str): Angle definition.
+ omega (float, optional): Angle discretization granularity.
+ Default: 1.
+ window (str, optional): Window function. Default: gaussian.
+ radius (int/float): window radius, int type for
+ ['triangle', 'rect', 'pulse'], float type for
+ ['gaussian']. Default: 6.
+ """
+
+ def __init__(self, angle_version, omega=1, window='gaussian', radius=6):
+ super().__init__()
+ self.angle_version = angle_version
+ assert angle_version in ['oc', 'le90', 'le135']
+ assert window in ['gaussian', 'triangle', 'rect', 'pulse']
+ self.angle_range = 90 if angle_version == 'oc' else 180
+ self.angle_offset_dict = {'oc': 0, 'le90': 90, 'le135': 45}
+ self.angle_offset = self.angle_offset_dict[angle_version]
+ self.omega = omega
+ self.window = window
+ self.radius = radius
+ self.coding_len = int(self.angle_range // omega)
+
+ def encode(self, angle_targets):
+ """Circular Smooth Label Encoder.
+
+ Args:
+ angle_targets (Tensor): Angle offset for each scale level
+ Has shape (num_anchors * H * W, 1)
+
+ Returns:
+ list[Tensor]: The csl encoding of angle offset for each
+ scale level. Has shape (num_anchors * H * W, coding_len)
+ """
+
+ # radius to degree
+ angle_targets_deg = angle_targets * (180 / math.pi)
+ # empty label
+ smooth_label = torch.zeros_like(angle_targets).repeat(
+ 1, self.coding_len)
+ angle_targets_deg = (angle_targets_deg +
+ self.angle_offset) / self.omega
+ # Float to Int
+ angle_targets_long = angle_targets_deg.long()
+
+ if self.window == 'pulse':
+ radius_range = angle_targets_long % self.coding_len
+ smooth_value = 1.0
+ elif self.window == 'rect':
+ base_radius_range = torch.arange(
+ -self.radius, self.radius, device=angle_targets_long.device)
+ radius_range = (base_radius_range +
+ angle_targets_long) % self.coding_len
+ smooth_value = 1.0
+ elif self.window == 'triangle':
+ base_radius_range = torch.arange(
+ -self.radius, self.radius, device=angle_targets_long.device)
+ radius_range = (base_radius_range +
+ angle_targets_long) % self.coding_len
+ smooth_value = 1.0 - torch.abs(
+ (1 / self.radius) * base_radius_range)
+
+ elif self.window == 'gaussian':
+ base_radius_range = torch.arange(
+ -self.angle_range // 2,
+ self.angle_range // 2,
+ device=angle_targets_long.device)
+
+ radius_range = (base_radius_range +
+ angle_targets_long) % self.coding_len
+ smooth_value = torch.exp(-torch.pow(base_radius_range, 2) /
+ (2 * self.radius**2))
+
+ else:
+ raise NotImplementedError
+
+ if isinstance(smooth_value, torch.Tensor):
+ smooth_value = smooth_value.unsqueeze(0).repeat(
+ smooth_label.size(0), 1)
+
+ return smooth_label.scatter(1, radius_range, smooth_value)
+
+ def decode(self, angle_preds):
+ """Circular Smooth Label Decoder.
+
+ Args:
+ angle_preds (Tensor): The csl encoding of angle offset
+ for each scale level.
+ Has shape (num_anchors * H * W, coding_len)
+
+ Returns:
+ list[Tensor]: Angle offset for each scale level.
+ Has shape (num_anchors * H * W, 1)
+ """
+ angle_cls_inds = torch.argmax(angle_preds, dim=1)
+ angle_pred = ((angle_cls_inds + 0.5) *
+ self.omega) % self.angle_range - self.angle_offset
+ return angle_pred * (math.pi / 180)
diff --git a/mmrotate/core/bbox/transforms.py b/mmrotate/core/bbox/transforms.py
index 239eb2e7a..21bdfd3dc 100644
--- a/mmrotate/core/bbox/transforms.py
+++ b/mmrotate/core/bbox/transforms.py
@@ -573,10 +573,10 @@ def hbb2obb_oc(hbboxes):
Returns:
obbs (torch.Tensor): [x_ctr,y_ctr,w,h,angle]
"""
- x = (hbboxes[:, 0::4] + hbboxes[:, 2::4]) * 0.5
- y = (hbboxes[:, 1::4] + hbboxes[:, 3::4]) * 0.5
- w = hbboxes[:, 2::4] - hbboxes[:, 0::4]
- h = hbboxes[:, 3::4] - hbboxes[:, 1::4]
+ x = (hbboxes[..., 0] + hbboxes[..., 2]) * 0.5
+ y = (hbboxes[..., 1] + hbboxes[..., 3]) * 0.5
+ w = hbboxes[..., 2] - hbboxes[..., 0]
+ h = hbboxes[..., 3] - hbboxes[..., 1]
theta = x.new_zeros(*x.shape)
rbboxes = torch.stack([x, y, h, w, theta + np.pi / 2], dim=-1)
return rbboxes
diff --git a/mmrotate/datasets/pipelines/transforms.py b/mmrotate/datasets/pipelines/transforms.py
index b04484f47..3a4510844 100644
--- a/mmrotate/datasets/pipelines/transforms.py
+++ b/mmrotate/datasets/pipelines/transforms.py
@@ -84,7 +84,7 @@ def bbox_flip(self, bboxes, img_shape, direction):
if self.version == 'oc':
rotated_flag = (bboxes[:, 4] != np.pi / 2)
flipped[rotated_flag, 4] = np.pi / 2 - bboxes[rotated_flag, 4]
- flipped[rotated_flag, 2] = bboxes[rotated_flag, 3],
+ flipped[rotated_flag, 2] = bboxes[rotated_flag, 3]
flipped[rotated_flag, 3] = bboxes[rotated_flag, 2]
else:
flipped[:, 4] = norm_angle(np.pi - bboxes[:, 4], self.version)
diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py
index 3d28b1afa..0adcbdb8a 100644
--- a/mmrotate/models/dense_heads/__init__.py
+++ b/mmrotate/models/dense_heads/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .csl_rotated_retina_head import CSLRRetinaHead
from .kfiou_odm_refine_head import KFIoUODMRefineHead
from .kfiou_rotate_retina_head import KFIoURRetinaHead
from .kfiou_rotate_retina_refine_head import KFIoURRetinaRefineHead
@@ -15,5 +16,5 @@
'RotatedAnchorHead', 'RotatedRetinaHead', 'RotatedRPNHead',
'OrientedRPNHead', 'RotatedRetinaRefineHead', 'ODMRefineHead',
'KFIoURRetinaHead', 'KFIoURRetinaRefineHead', 'KFIoUODMRefineHead',
- 'RotatedRepPointsHead', 'SAMRepPointsHead'
+ 'RotatedRepPointsHead', 'SAMRepPointsHead', 'CSLRRetinaHead'
]
diff --git a/mmrotate/models/dense_heads/csl_rotated_retina_head.py b/mmrotate/models/dense_heads/csl_rotated_retina_head.py
new file mode 100644
index 000000000..2cc19fc88
--- /dev/null
+++ b/mmrotate/models/dense_heads/csl_rotated_retina_head.py
@@ -0,0 +1,579 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+import torch.nn as nn
+from mmcv.runner import force_fp32
+from mmdet.core import images_to_levels, multi_apply, unmap
+
+from mmrotate.core import build_bbox_coder, multiclass_nms_rotated
+from ... import obb2hbb, rotated_anchor_inside_flags
+from ..builder import ROTATED_HEADS, build_loss
+from .rotated_retina_head import RotatedRetinaHead
+
+
+@ROTATED_HEADS.register_module()
+class CSLRRetinaHead(RotatedRetinaHead):
+ """Rotational Anchor-based refine head.
+
+ Args:
+ use_encoded_angle (bool): Decide whether to use encoded angle or
+ gt angle as target. Default: True.
+ shield_reg_angle (bool): Decide whether to shield the angle loss from
+ reg branch. Default: False.
+ angle_coder (dict): Config of angle coder.
+ loss_angle (dict): Config of angle classification loss.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ use_encoded_angle=True,
+ shield_reg_angle=False,
+ angle_coder=dict(
+ type='CSLCoder',
+ angle_version='le90',
+ omega=1,
+ window='gaussian',
+ radius=6),
+ loss_angle=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=[
+ dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01),
+ dict(
+ type='Normal',
+ name='retina_angle_cls',
+ std=0.01,
+ bias_prob=0.01),
+ ]),
+ **kwargs):
+ self.angle_coder = build_bbox_coder(angle_coder)
+ self.coding_len = self.angle_coder.coding_len
+ super(CSLRRetinaHead, self).__init__(**kwargs, init_cfg=init_cfg)
+ self.shield_reg_angle = shield_reg_angle
+ self.loss_angle = build_loss(loss_angle)
+ self.use_encoded_angle = use_encoded_angle
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ super(CSLRRetinaHead, self)._init_layers()
+ self.retina_angle_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.coding_len,
+ 3,
+ padding=1)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (torch.Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (torch.Tensor): Cls scores for a single scale level
+ the channels number is num_anchors * num_classes.
+ bbox_pred (torch.Tensor): Box energies / deltas for a single
+ scale level, the channels number is num_anchors * 5.
+ angle_cls (torch.Tensor): Angle for a single scale level
+ the channels number is num_anchors * coding_len.
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_pred = self.retina_reg(reg_feat)
+ angle_cls = self.retina_angle_cls(reg_feat)
+ return cls_score, bbox_pred, angle_cls
+
+ def loss_single(self, cls_score, bbox_pred, angle_cls, anchors, labels,
+ label_weights, bbox_targets, bbox_weights, angle_targets,
+ angle_weights, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ cls_score (torch.Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (torch.Tensor): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W).
+ anchors (torch.Tensor): Box reference for each scale level with
+ shape (N, num_total_anchors, 5).
+ labels (torch.Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (torch.Tensor): Label weights of each anchor with
+ shape (N, num_total_anchors)
+ bbox_targets (torch.Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 5).
+ bbox_weights (torch.Tensor): BBox regression loss weights of each
+ anchor with shape (N, num_total_anchors, 5).
+ angle_targets (torch.Tensor): Angle classification targets of
+ each anchor weight shape (N, num_total_anchors, coding_len).
+ angle_weights (torch.Tensor): Angle classification loss weights
+ of each anchor with shape (N, num_total_anchors, 1).
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ loss_cls (torch.Tensor): cls. loss for each scale level.
+ loss_bbox (torch.Tensor): reg. loss for each scale level.
+ loss_angle (torch.Tensor): angle cls. loss for each scale level.
+ """
+ # Classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # Regression loss
+ bbox_targets = bbox_targets.reshape(-1, 5)
+ bbox_weights = bbox_weights.reshape(-1, 5)
+ # Shield angle in reg. branch
+ if self.shield_reg_angle:
+ bbox_weights[:, -1] = 0.
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5)
+ if self.reg_decoded_bbox:
+ anchors = anchors.reshape(-1, 5)
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+
+ loss_bbox = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+
+ angle_cls = angle_cls.permute(0, 2, 3, 1).reshape(-1, self.coding_len)
+ angle_targets = angle_targets.reshape(-1, self.coding_len)
+ angle_weights = angle_weights.reshape(-1, 1)
+
+ loss_angle = self.loss_angle(
+ angle_cls,
+ angle_targets,
+ weight=angle_weights,
+ avg_factor=num_total_samples)
+
+ return loss_cls, loss_bbox, loss_angle
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'angle_clses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ angle_clses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W)
+ angle_clses (list[Tensor]): Box angles for each scale
+ level with shape (N, num_anchors * coding_len, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 5) in [cx, cy, w, h, a] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, angel_target_list,
+ angel_weight_list) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # Anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # Concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i, _ in enumerate(anchor_list):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ losses_cls, losses_bbox, losses_angle = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ angle_clses,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ angel_target_list,
+ angel_weight_list,
+ num_total_samples=num_total_samples)
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_angle=losses_angle)
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ flat_anchors (torch.Tensor): Multi-level anchors of the image,
+ which are concatenated into a single tensor of shape
+ (num_anchors, 5)
+ valid_flags (torch.Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ gt_bboxes (torch.Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 5).
+ img_meta (dict): Meta info of the image.
+ gt_bboxes_ignore (torch.Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 5).
+ img_meta (dict): Meta info of the image.
+ gt_labels (torch.Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ label_channels (int): Channel of label. Default: 1.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors. Default: True.
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level
+ label_weights_list (list[Tensor]): Label weights of each level
+ bbox_targets_list (list[Tensor]): BBox targets of each level
+ bbox_weights_list (list[Tensor]): BBox weights of each level
+ angle_targets_list (list[Tensor]): Angle targets of each level
+ angle_weights_list (list[Tensor]): Angle weights of each level
+ num_total_pos (int): Number of positive samples in all images
+ num_total_neg (int): Number of negative samples in all images
+ """
+ inside_flags = rotated_anchor_inside_flags(
+ flat_anchors, valid_flags, img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 9
+ # Assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ if self.assign_by_circumhbbox is not None:
+ gt_bboxes_assign = obb2hbb(gt_bboxes, self.assign_by_circumhbbox)
+ assign_result = self.assigner.assign(
+ anchors, gt_bboxes_assign, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ else:
+ assign_result = self.assigner.assign(
+ anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ angle_targets = torch.zeros_like(bbox_targets[:, 4:5])
+ angle_weights = torch.zeros_like(bbox_targets[:, 4:5])
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+
+ if self.use_encoded_angle:
+ # Get encoded angle as target
+ angle_targets[pos_inds, :] = pos_bbox_targets[:, 4:5]
+ else:
+ # Get gt angle as target
+ angle_targets[pos_inds, :] = \
+ sampling_result.pos_gt_bboxes[:, 4:5]
+ # Angle encoder
+ angle_targets = self.angle_coder.encode(angle_targets)
+ angle_weights[pos_inds, :] = 1.0
+
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # Map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags,
+ fill=self.num_classes) # fill bg label
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+ angle_targets = unmap(angle_targets, num_total_anchors,
+ inside_flags)
+ angle_weights = unmap(angle_weights, num_total_anchors,
+ inside_flags)
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds, sampling_result, angle_targets, angle_weights)
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ angle_cls_list,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores for a single scale level
+ Has shape (num_anchors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas for a single
+ scale level with shape (num_anchors * 5, H, W).
+ angle_cls_list (list[Tensor]): Angle deltas for a single
+ scale level with shape (num_anchors * coding_len, H, W).
+ mlvl_anchors (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_anchors, 5).
+ img_shape (tuple[int]): Shape of the input image,
+ (height, width, 3).
+ scale_factor (ndarray): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
+ are bounding box positions (cx, cy, w, h, a) and the
+ 6-th column is a score between 0 and 1.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, angle_cls, anchors in zip(
+ cls_score_list, bbox_pred_list, angle_cls_list, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
+
+ angle_cls = angle_cls.permute(1, 2, 0).reshape(
+ -1, self.coding_len).sigmoid()
+
+ nms_pre = cfg.get('nms_pre', -1)
+ if scores.shape[0] > nms_pre > 0:
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ # Remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ angle_cls = angle_cls[topk_inds, :]
+
+ # Angle decoder
+ angle_pred = self.angle_coder.decode(angle_cls)
+
+ if self.use_encoded_angle:
+ bbox_pred[..., -1] = angle_pred
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shape)
+ else:
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shape)
+ bboxes[..., -1] = angle_pred
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ # Angle should not be rescaled
+ mlvl_bboxes[:, :4] = mlvl_bboxes[:, :4] / mlvl_bboxes.new_tensor(
+ scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # Remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+
+ if with_nms:
+ det_bboxes, det_labels = multiclass_nms_rotated(
+ mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ return det_bboxes, det_labels
+ else:
+ return mlvl_bboxes, mlvl_scores
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'angle_clses'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ angle_clses,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W)
+ angle_clses (list[Tensor]): Box angles for each scale
+ level with shape (N, num_anchors * coding_len, H, W)
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 6) tensor, where the first 5 columns
+ are bounding box positions (cx, cy, w, h, a) and the
+ 6-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of the
+ corresponding box.
+
+ Example:
+ >>> import mmcv
+ >>> self = AnchorHead(
+ >>> num_classes=9,
+ >>> in_channels=1,
+ >>> anchor_generator=dict(
+ >>> type='AnchorGenerator',
+ >>> scales=[8],
+ >>> ratios=[0.5, 1.0, 2.0],
+ >>> strides=[4,]))
+ >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
+ >>> cfg = mmcv.Config(dict(
+ >>> score_thr=0.00,
+ >>> nms=dict(type='nms', iou_thr=1.0),
+ >>> max_per_img=10))
+ >>> feat = torch.rand(1, 1, 3, 3)
+ >>> cls_score, bbox_pred = self.forward_single(feat)
+ >>> # Note the input lists are over different levels, not images
+ >>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
+ >>> result_list = self.get_bboxes(cls_scores, bbox_preds,
+ >>> img_metas, cfg)
+ >>> det_bboxes, det_labels = result_list[0]
+ >>> assert len(result_list) == 1
+ >>> assert det_bboxes.shape[1] == 5
+ >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+
+ result_list = []
+ for img_id, _ in enumerate(img_metas):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ angle_cls_list = [
+ angle_clses[i][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ if with_nms:
+ # Some heads don't support with_nms argument
+ proposals = self._get_bboxes_single(cls_score_list,
+ bbox_pred_list,
+ angle_cls_list,
+ mlvl_anchors, img_shape,
+ scale_factor, cfg, rescale)
+ else:
+ proposals = self._get_bboxes_single(cls_score_list,
+ bbox_pred_list,
+ angle_cls_list,
+ mlvl_anchors, img_shape,
+ scale_factor, cfg, rescale,
+ with_nms)
+ result_list.append(proposals)
+ return result_list
diff --git a/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py b/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py
index 19c56f512..a1448b42c 100644
--- a/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py
+++ b/mmrotate/models/dense_heads/kfiou_rotate_retina_head.py
@@ -83,7 +83,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
positive anchors.
Returns:
- dict[str, Tensor]: A dictionary of loss components.
+ loss_cls (torch.Tensor): cls. loss for each scale level.
+ loss_bbox (torch.Tensor): reg. loss for each scale level.
"""
# classification loss
labels = labels.reshape(-1)
diff --git a/mmrotate/models/dense_heads/oriented_rpn_head.py b/mmrotate/models/dense_heads/oriented_rpn_head.py
index ca35c0e31..823a8ed6c 100644
--- a/mmrotate/models/dense_heads/oriented_rpn_head.py
+++ b/mmrotate/models/dense_heads/oriented_rpn_head.py
@@ -155,7 +155,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
positive anchors.
Returns:
- dict[str, Tensor]: A dictionary of loss components.
+ loss_cls (torch.Tensor): cls. loss for each scale level.
+ loss_bbox (torch.Tensor): reg. loss for each scale level.
"""
# classification loss
labels = labels.reshape(-1)
diff --git a/mmrotate/models/dense_heads/rotated_anchor_head.py b/mmrotate/models/dense_heads/rotated_anchor_head.py
index 3dad51205..0b7f46bd0 100644
--- a/mmrotate/models/dense_heads/rotated_anchor_head.py
+++ b/mmrotate/models/dense_heads/rotated_anchor_head.py
@@ -187,7 +187,7 @@ def _get_targets_single(self,
Args:
flat_anchors (torch.Tensor): Multi-level anchors of the image,
which are concatenated into a single tensor of shape
- (num_anchors ,4)
+ (num_anchors, 5)
valid_flags (torch.Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
@@ -296,7 +296,7 @@ def get_targets(self,
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
- the inner list is a tensor of shape (num_anchors, 4).
+ the inner list is a tensor of shape (num_anchors, 5).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
@@ -405,7 +405,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
positive anchors.
Returns:
- dict[str, Tensor]: A dictionary of loss components.
+ loss_cls (torch.Tensor): cls. loss for each scale level.
+ loss_bbox (torch.Tensor): reg. loss for each scale level.
"""
# classification loss
labels = labels.reshape(-1)
diff --git a/mmrotate/models/dense_heads/rotated_rpn_head.py b/mmrotate/models/dense_heads/rotated_rpn_head.py
index 632acbb1e..eb6606867 100644
--- a/mmrotate/models/dense_heads/rotated_rpn_head.py
+++ b/mmrotate/models/dense_heads/rotated_rpn_head.py
@@ -274,7 +274,8 @@ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
positive anchors.
Returns:
- dict[str, Tensor]: A dictionary of loss components.
+ loss_cls (torch.Tensor): cls. loss for each scale level.
+ loss_bbox (torch.Tensor): reg. loss for each scale level.
"""
# classification loss
labels = labels.reshape(-1)
diff --git a/mmrotate/models/losses/__init__.py b/mmrotate/models/losses/__init__.py
index b7050fbde..8594630e3 100644
--- a/mmrotate/models/losses/__init__.py
+++ b/mmrotate/models/losses/__init__.py
@@ -4,8 +4,9 @@
from .gaussian_dist_loss_v1 import GDLoss_v1
from .kf_iou_loss import KFLoss
from .kld_reppoints_loss import KLDRepPointsLoss
+from .smooth_focal_loss import SmoothFocalLoss
__all__ = [
'GDLoss', 'GDLoss_v1', 'KFLoss', 'ConvexGIoULoss', 'BCConvexGIoULoss',
- 'KLDRepPointsLoss'
+ 'KLDRepPointsLoss', 'SmoothFocalLoss'
]
diff --git a/mmrotate/models/losses/gaussian_dist_loss.py b/mmrotate/models/losses/gaussian_dist_loss.py
index 1976f0c46..638d1a870 100644
--- a/mmrotate/models/losses/gaussian_dist_loss.py
+++ b/mmrotate/models/losses/gaussian_dist_loss.py
@@ -91,6 +91,31 @@ def postprocess(distance, fun='log1p', tau=1.0):
@weighted_loss
def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True):
"""Gaussian Wasserstein distance loss.
+ Derivation and simplification:
+ Given any positive-definite symmetrical 2*2 matrix Z:
+ :math:`Tr(Z^{1/2}) = λ_1^{1/2} + λ_2^{1/2}`
+ where :math:`λ_1` and :math:`λ_2` are the eigen values of Z
+ Meanwhile we have:
+ :math:`Tr(Z) = λ_1 + λ_2`
+
+ :math:`det(Z) = λ_1 * λ_2`
+ Combination with following formula:
+ :math:`(λ_1^{1/2}+λ_2^{1/2})^2 = λ_1+λ_2+2 *(λ_1 * λ_2)^{1/2}`
+ Yield:
+ :math:`Tr(Z^{1/2}) = (Tr(Z) + 2 * (det(Z))^{1/2})^{1/2}`
+ For gwd loss the frustrating coupling part is:
+ :math:`Tr((Σ_p^{1/2} * Σ_t * Σp^{1/2})^{1/2})`
+ Assuming :math:`Z = Σ_p^{1/2} * Σ_t * Σ_p^{1/2}` then:
+ :math:`Tr(Z) = Tr(Σ_p^{1/2} * Σ_t * Σ_p^{1/2})
+ = Tr(Σ_p^{1/2} * Σ_p^{1/2} * Σ_t)
+ = Tr(Σ_p * Σ_t)`
+ :math:`det(Z) = det(Σ_p^{1/2} * Σ_t * Σ_p^{1/2})
+ = det(Σ_p^{1/2}) * det(Σ_t) * det(Σ_p^{1/2})
+ = det(Σ_p * Σ_t)`
+ and thus we can rewrite the coupling part as:
+ :math:`Tr(Z^{1/2}) = (Tr(Z) + 2 * (det(Z))^{1/2})^{1/2}`
+ :math:`Tr((Σ_p^{1/2} * Σ_t * Σ_p^{1/2})^{1/2})
+ = (Tr(Σ_p * Σ_t) + 2 * (det(Σ_p * Σ_t))^{1/2})^{1/2}`
Args:
pred (torch.Tensor): Predicted bboxes.
@@ -102,6 +127,7 @@ def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True):
Returns:
loss (torch.Tensor)
+
"""
xy_p, Sigma_p = pred
xy_t, Sigma_t = target
diff --git a/mmrotate/models/losses/smooth_focal_loss.py b/mmrotate/models/losses/smooth_focal_loss.py
new file mode 100644
index 000000000..f05c9a42c
--- /dev/null
+++ b/mmrotate/models/losses/smooth_focal_loss.py
@@ -0,0 +1,129 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmdet.models import weight_reduce_loss
+
+from ..builder import ROTATED_LOSSES
+
+
+def smooth_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ """Smooth Focal Loss proposed in Circular Smooth Label (CSL).
+
+ `Circular Smooth Label (CSL)
+ `_ .
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) *
+ (1 - target)) * pt.pow(gamma)
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, reduction='none') * focal_weight
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@ROTATED_LOSSES.register_module()
+class SmoothFocalLoss(nn.Module):
+
+ def __init__(self,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ loss_weight=1.0):
+ """Smooth Focal Loss.
+
+ Args:
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+
+ Returns:
+ loss (torch.Tensor)
+ """
+ super(SmoothFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+
+ loss_cls = self.loss_weight * smooth_focal_loss(
+ pred,
+ target,
+ weight,
+ gamma=self.gamma,
+ alpha=self.alpha,
+ reduction=reduction,
+ avg_factor=avg_factor)
+
+ return loss_cls