diff --git a/configs/_base_/datasets/dota_coco.py b/configs/_base_/datasets/dota_coco.py index f7c66bbc9..a1f0d3888 100644 --- a/configs/_base_/datasets/dota_coco.py +++ b/configs/_base_/datasets/dota_coco.py @@ -43,7 +43,7 @@ ] metainfo = dict( - CLASSES=('plane', 'baseball-diamond', 'bridge', 'ground-track-field', + classes=('plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter')) diff --git a/configs/_base_/datasets/hrsid.py b/configs/_base_/datasets/hrsid.py index 61c739dcd..3b21f7988 100644 --- a/configs/_base_/datasets/hrsid.py +++ b/configs/_base_/datasets/hrsid.py @@ -42,7 +42,7 @@ 'scale_factor')) ] -metainfo = dict(CLASSES=('ship', )) +metainfo = dict(classes=('ship', )) train_dataloader = dict( batch_size=2, diff --git a/configs/_base_/datasets/rsdd.py b/configs/_base_/datasets/rsdd.py index a79c48079..3fc1c6a3c 100644 --- a/configs/_base_/datasets/rsdd.py +++ b/configs/_base_/datasets/rsdd.py @@ -42,7 +42,7 @@ 'scale_factor')) ] -metainfo = dict(CLASSES=('ship', )) +metainfo = dict(classes=('ship', )) train_dataloader = dict( batch_size=2, diff --git a/configs/_base_/datasets/srsdd.py b/configs/_base_/datasets/srsdd.py index df0f03026..52cdab65b 100644 --- a/configs/_base_/datasets/srsdd.py +++ b/configs/_base_/datasets/srsdd.py @@ -43,7 +43,7 @@ ] metainfo = dict( - CLASSES=('Container', 'Dredger', 'LawEnforce', 'Cell-Container', 'ore-oil', + classes=('Container', 'Dredger', 'LawEnforce', 'Cell-Container', 'ore-oil', 'Fishing')) train_dataloader = dict( diff --git a/configs/_base_/datasets/ssdd.py b/configs/_base_/datasets/ssdd.py index cc618b5a2..c1fb5cd57 100644 --- a/configs/_base_/datasets/ssdd.py +++ b/configs/_base_/datasets/ssdd.py @@ -42,7 +42,7 @@ 'scale_factor')) ] -metainfo = dict(CLASSES=('ship', )) +metainfo = dict(classes=('ship', )) train_dataloader = dict( batch_size=2, diff --git a/mmrotate/datasets/dior.py b/mmrotate/datasets/dior.py index 77904d037..c87957727 100755 --- a/mmrotate/datasets/dior.py +++ b/mmrotate/datasets/dior.py @@ -26,14 +26,14 @@ class DIORDataset(BaseDataset): """ METAINFO = { - 'CLASSES': + 'classes': ('airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge', 'chimney', 'expressway-service-area', 'expressway-toll-station', 'dam', 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle', 'windmill'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42), @@ -59,11 +59,11 @@ def load_data_list(self) -> List[dict]: Returns: list[dict]: Annotation info from XML file. """ - assert self._metainfo.get('CLASSES', None) is not None, \ - 'CLASSES in `DIORDataset` can not be None.' + assert self._metainfo.get('classes', None) is not None, \ + 'classes in `DIORDataset` can not be None.' self.cat2label = { cat: i - for i, cat in enumerate(self.metainfo['CLASSES']) + for i, cat in enumerate(self.metainfo['classes']) } data_list = [] diff --git a/mmrotate/datasets/dota.py b/mmrotate/datasets/dota.py index db96df3cf..c32e1d8e7 100644 --- a/mmrotate/datasets/dota.py +++ b/mmrotate/datasets/dota.py @@ -26,13 +26,13 @@ class DOTADataset(BaseDataset): """ METAINFO = { - 'CLASSES': + 'classes': ('plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), (138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255), (255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139), (255, 255, 0), @@ -53,7 +53,7 @@ def load_data_list(self) -> List[dict]: List[dict]: A list of annotation. """ # noqa: E501 cls_map = {c: i - for i, c in enumerate(self.metainfo['CLASSES']) + for i, c in enumerate(self.metainfo['classes']) } # in mmdet v2.0 label is 0-based data_list = [] if self.ann_file == '': @@ -153,13 +153,13 @@ class DOTAv15Dataset(DOTADataset): """ METAINFO = { - 'CLASSES': + 'classes': ('plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter', 'container-crane'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), (138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255), (255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139), (255, 255, 0), @@ -177,14 +177,14 @@ class DOTAv2Dataset(DOTADataset): """ METAINFO = { - 'CLASSES': + 'classes': ('plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter', 'container-crane', 'airport', 'helipad'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), (138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255), (255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139), (255, 255, 0), diff --git a/mmrotate/datasets/hrsc.py b/mmrotate/datasets/hrsc.py index 8cd1c60b6..b102a6fba 100644 --- a/mmrotate/datasets/hrsc.py +++ b/mmrotate/datasets/hrsc.py @@ -35,7 +35,7 @@ class HRSCDataset(BaseDataset): """ METAINFO = { - 'CLASSES': + 'classes': ('ship', 'aircraft carrier', 'warcraft', 'merchant ship', 'Nimitz', 'Enterprise', 'Arleigh Burke', 'WhidbeyIsland', 'Perry', 'Sanantonio', 'Ticonderoga', 'Kitty Hawk', 'Kuznetsov', 'Abukuma', 'Austen', @@ -43,8 +43,8 @@ class HRSCDataset(BaseDataset): 'Hovercraft', 'yacht', 'CntShip(_|.--.--|_]=', 'Cruise', 'submarine', 'lute', 'Medical', 'Car carrier(======|', 'Ford-class', 'Midway-class', 'Invincible-class'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': + # palette is a list of color tuples, which is used for visualization. + 'palette': [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), @@ -53,8 +53,8 @@ class HRSCDataset(BaseDataset): (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), (92, 0, 73)], - # CLASSES_ID is a tuple, which is used for ``self.catid2label`` - 'CLASSES_ID': + # classes_id is a tuple, which is used for ``self.catid2label`` + 'classes_id': ('01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '22', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33') @@ -84,16 +84,16 @@ def load_data_list(self) -> List[dict]: Returns: list[dict]: Annotation info from XML file. """ - assert self._metainfo.get('CLASSES', None) is not None, \ - 'CLASSES in `HRSCDataset` can not be None.' + assert self._metainfo.get('classes', None) is not None, \ + 'classes in `HRSCDataset` can not be None.' if self.classwise: self.catid2label = { ('1' + '0' * 6 + cls_id): i - for i, cls_id in enumerate(self._metainfo['CLASSES_ID']) + for i, cls_id in enumerate(self._metainfo['classes_id']) } else: - self._metainfo['CLASSES'] = ('ship', ) - self._metainfo['PALETTE'] = [ + self._metainfo['classes'] = ('ship', ) + self._metainfo['palette'] = [ (220, 20, 60), ] diff --git a/mmrotate/visualization/palette.py b/mmrotate/visualization/palette.py index 557357270..59e5aa864 100644 --- a/mmrotate/visualization/palette.py +++ b/mmrotate/visualization/palette.py @@ -31,13 +31,13 @@ def get_palette(palette: Union[List[tuple], str, tuple], dataset_palette = [tuple(c) for c in palette] elif palette == 'dota': from mmrotate.datasets import DOTADataset - dataset_palette = DOTADataset.METAINFO['PALETTE'] + dataset_palette = DOTADataset.METAINFO['palette'] elif palette == 'sar': from mmrotate.datasets import SARDataset - dataset_palette = SARDataset.METAINFO['PALETTE'] + dataset_palette = SARDataset.METAINFO['palette'] elif palette == 'hrsc': from mmrotate.datasets import HRSCDataset - dataset_palette = HRSCDataset.METAINFO['PALETTE'] + dataset_palette = HRSCDataset.METAINFO['palette'] elif is_str(palette): dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes else: diff --git a/tests/test_datasets/test_dior.py b/tests/test_datasets/test_dior.py index dd9c043bf..2cdc48e2e 100755 --- a/tests/test_datasets/test_dior.py +++ b/tests/test_datasets/test_dior.py @@ -27,4 +27,4 @@ def test_dior(self): 'tests/data/dior/Annotations/Oriented Bounding Boxes/00001.xml') self.assertEqual(len(data_list[0]['instances']), 1) self.assertEqual(dataset.get_cat_ids(0), [9]) - self.assertEqual(len(dataset._metainfo['CLASSES']), 20) + self.assertEqual(len(dataset._metainfo['classes']), 20) diff --git a/tests/test_datasets/test_hrsc.py b/tests/test_datasets/test_hrsc.py index 43d408cbe..b88a55594 100644 --- a/tests/test_datasets/test_hrsc.py +++ b/tests/test_datasets/test_hrsc.py @@ -28,7 +28,7 @@ def test_hrsc(self): 'tests/data/hrsc/FullDataSet/Annotations/100000006.xml') self.assertEqual(len(data_list[0]['instances']), 1) self.assertEqual(dataset.get_cat_ids(0), [0]) - self.assertEqual(dataset._metainfo['CLASSES'], ('ship', )) + self.assertEqual(dataset._metainfo['classes'], ('ship', )) def test_hrsc_classwise(self): dataset = HRSCDataset( @@ -52,4 +52,4 @@ def test_hrsc_classwise(self): 'tests/data/hrsc/FullDataSet/Annotations/100000006.xml') self.assertEqual(len(data_list[0]['instances']), 1) self.assertEqual(dataset.get_cat_ids(0), [12]) - self.assertEqual(len(dataset._metainfo['CLASSES']), 31) + self.assertEqual(len(dataset._metainfo['classes']), 31) diff --git a/tests/test_visualization/test_palette.py b/tests/test_visualization/test_palette.py index cfafec149..73eb6c2f4 100644 --- a/tests/test_visualization/test_palette.py +++ b/tests/test_visualization/test_palette.py @@ -24,8 +24,8 @@ def test_palette(): assert color == (255, 0, 0) # test dataset str - palette = get_palette('dota', len(DOTADataset.METAINFO['CLASSES'])) - assert len(palette) == len(DOTADataset.METAINFO['CLASSES']) + palette = get_palette('dota', len(DOTADataset.METAINFO['classes'])) + assert len(palette) == len(DOTADataset.METAINFO['classes']) assert palette[0] == (165, 42, 42) # test random diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index e8020f2db..7c71bd906 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -77,7 +77,7 @@ def calculate_confusion_matrix(dataset, tp_iou_thr (float|optional): IoU threshold to be considered as matched. Default: 0.5. """ - num_classes = len(dataset.metainfo['CLASSES']) + num_classes = len(dataset.metainfo['classes']) confusion_matrix = np.zeros(shape=[num_classes + 1, num_classes + 1]) assert len(dataset) == len(results) prog_bar = ProgressBar(len(results)) @@ -256,7 +256,7 @@ def main(): args.tp_iou_thr) plot_confusion_matrix( confusion_matrix, - dataset.metainfo['CLASSES'] + ('background', ), + dataset.metainfo['classes'] + ('background', ), save_dir=args.save_dir, color_theme=args.color_theme, show=args.show)