From dbd4c45eae89a7a161efd3d7ea2aef05cca7c5e6 Mon Sep 17 00:00:00 2001 From: sungmanc Date: Wed, 19 Jul 2023 19:41:03 +0900 Subject: [PATCH 1/7] Fix h-labelissue --- .../adapters/mmcls/configurer.py | 5 ++++ .../adapters/mmcls/datasets/otx_datasets.py | 20 +++++++++++-- .../custom_hierarchical_linear_cls_head.py | 28 ++++++++++--------- ...custom_hierarchical_non_linear_cls_head.py | 25 +++++++++-------- .../classification/utils/cls_utils.py | 1 + 5 files changed, 51 insertions(+), 28 deletions(-) diff --git a/src/otx/algorithms/classification/adapters/mmcls/configurer.py b/src/otx/algorithms/classification/adapters/mmcls/configurer.py index d0ecbfb4e2c..87c91c97368 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/configurer.py +++ b/src/otx/algorithms/classification/adapters/mmcls/configurer.py @@ -132,6 +132,11 @@ def configure_model(self, cfg, ir_options): # noqa: C901 cfg.model.arch_type = cfg.model.type cfg.model.type = super_type + # Hierarchical + if cfg.model.hierarchical: + assert cfg.data.train.hierarchical_info == cfg.data.val.hierarchical_info == cfg.data.test.hierarchical_info + cfg.model.head.hierarchical_info = cfg.data.train.hierarchical_info + # OV-plugin ir_model_path = ir_options.get("ir_model_path") if ir_model_path: diff --git a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py index 70a4500d1b5..997d908df6d 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py +++ b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py @@ -309,9 +309,7 @@ def load_annotations(self): if item_labels: num_cls_heads = self.hierarchical_info["num_multiclass_heads"] - class_indices = [0] * ( - self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"] - ) + class_indices = [0] * (num_cls_heads + self.hierarchical_info["num_multilabel_classes"]) for j in range(num_cls_heads): class_indices[j] = -1 for otx_lbl in item_labels: @@ -329,6 +327,22 @@ def load_annotations(self): self.gt_labels.append(class_indices) self.gt_labels = np.array(self.gt_labels) + self._update_heads_information(num_cls_heads) + + def _update_heads_information(self, num_cls_heads: int): + """Update heads information to find the empty heads. + + If there are no annotations at a specific head, this should be filtered out to calculate loss correctly. + + Args: + num_cls_heads (int): the number of multi-class heads. + + """ + for head_idx in range(num_cls_heads): + labels_in_head = self.gt_labels[:, head_idx] # type: ignore[call-overload] + if max(labels_in_head) < 0: + self.hierarchical_info["empty_multiclass_head_indices"].append(head_idx) + @staticmethod def mean_top_k_accuracy(scores, labels, k=1): """Return mean of top-k accuracy.""" diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py index 5b3245a4f40..c2a7deaacc6 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py @@ -87,22 +87,24 @@ def forward_train(self, cls_score, gt_label, **kwargs): cls_score = self.fc(cls_score) losses = dict(loss=0.0) + num_empty_heads = len(self.hierarchical_info["empty_multiclass_head_indices"]) for i in range(self.hierarchical_info["num_multiclass_heads"]): - head_gt = gt_label[:, i] - head_logits = cls_score[ - :, - self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ - "head_idx_to_logits_range" - ][str(i)][1], - ] - valid_mask = head_gt >= 0 - head_gt = head_gt[valid_mask].long() - head_logits = head_logits[valid_mask, :] - multiclass_loss = self.loss(head_logits, head_gt) - losses["loss"] += multiclass_loss + if i not in self.hierarchical_info["empty_multiclass_head_indices"]: + head_gt = gt_label[:, i] + head_logits = cls_score[ + :, + self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ + "head_idx_to_logits_range" + ][str(i)][1], + ] + valid_mask = head_gt >= 0 + head_gt = head_gt[valid_mask].long() + head_logits = head_logits[valid_mask, :] + multiclass_loss = self.loss(head_logits, head_gt) + losses["loss"] += multiclass_loss if self.hierarchical_info["num_multiclass_heads"] > 1: - losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] + losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] - num_empty_heads if self.compute_multilabel_loss: head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :] diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py index 4b2691157e1..7fa529c99ba 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py @@ -118,18 +118,19 @@ def forward_train(self, cls_score, gt_label, **kwargs): losses = dict(loss=0.0) for i in range(self.hierarchical_info["num_multiclass_heads"]): - head_gt = gt_label[:, i] - head_logits = cls_score[ - :, - self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ - "head_idx_to_logits_range" - ][str(i)][1], - ] - valid_mask = head_gt >= 0 - head_gt = head_gt[valid_mask].long() - head_logits = head_logits[valid_mask, :] - multiclass_loss = self.loss(head_logits, head_gt) - losses["loss"] += multiclass_loss + if i not in self.hierarchical_info["empty_multiclass_head_indices"]: + head_gt = gt_label[:, i] + head_logits = cls_score[ + :, + self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[ + "head_idx_to_logits_range" + ][str(i)][1], + ] + valid_mask = head_gt >= 0 + head_gt = head_gt[valid_mask].long() + head_logits = head_logits[valid_mask, :] + multiclass_loss = self.loss(head_logits, head_gt) + losses["loss"] += multiclass_loss if self.hierarchical_info["num_multiclass_heads"] > 1: losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] diff --git a/src/otx/algorithms/classification/utils/cls_utils.py b/src/otx/algorithms/classification/utils/cls_utils.py index 8bb2b9630f2..23dc1ba1fa6 100644 --- a/src/otx/algorithms/classification/utils/cls_utils.py +++ b/src/otx/algorithms/classification/utils/cls_utils.py @@ -62,6 +62,7 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl "class_to_group_idx": class_to_idx, "all_groups": exclusive_groups + single_label_groups, "label_to_idx": label_to_idx, + "empty_multiclass_head_indices": [], } return mixed_cls_heads_info From 5f7e12644f7ec24963bdf0ba69bf7d8e86282cf1 Mon Sep 17 00:00:00 2001 From: sungmanc Date: Wed, 19 Jul 2023 20:28:17 +0900 Subject: [PATCH 2/7] Update unit tests --- .../adapters/mmcls/configurer.py | 2 +- .../adapters/mmcls/data/test_datasets.py | 24 ++++++++++++++++++- .../adapters/mmcls/test_configurer.py | 10 ++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/otx/algorithms/classification/adapters/mmcls/configurer.py b/src/otx/algorithms/classification/adapters/mmcls/configurer.py index 87c91c97368..fe4529679a9 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/configurer.py +++ b/src/otx/algorithms/classification/adapters/mmcls/configurer.py @@ -133,7 +133,7 @@ def configure_model(self, cfg, ir_options): # noqa: C901 cfg.model.type = super_type # Hierarchical - if cfg.model.hierarchical: + if cfg.model.get("hierarchical"): assert cfg.data.train.hierarchical_info == cfg.data.val.hierarchical_info == cfg.data.test.hierarchical_info cfg.model.head.hierarchical_info = cfg.data.train.hierarchical_info diff --git a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py index 41e6890e02d..5712d79753f 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py @@ -142,9 +142,31 @@ def test_metric_hierarchical_adapter(self): dataset = OTXHierarchicalClsDataset( otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info ) - results = np.zeros((len(dataset), dataset.num_classes)) metrics = dataset.evaluate(results) assert len(metrics) > 0 assert metrics["accuracy"] > 0 + + @e2e_pytest_unit + def test_hierarchical_with_empty_heads(self): + self.task_environment, self.dataset = init_environment( + self.hyper_parameters, self.model_template, False, True, self.dataset_len + ) + class_info = get_multihead_class_info(self.task_environment.label_schema) + dataset = OTXHierarchicalClsDataset( + otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info + ) + pseudo_gt_labels = [] + pseudo_head_idx = 0 + for label in dataset.gt_labels: + pseudo_gt_label = label + pseudo_gt_label[pseudo_head_idx] = -1 + pseudo_gt_labels.append(pseudo_gt_label) + pseudo_gt_labels = np.array(pseudo_gt_labels) + + from copy import deepcopy + pseudo_dataset = deepcopy(dataset) + pseudo_dataset.gt_labels = pseudo_gt_labels + pseudo_dataset._update_heads_information(dataset.hierarchical_info["num_multiclass_heads"]) + assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0 diff --git a/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py b/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py index 96e1efbf685..ca7c51813fe 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py @@ -22,6 +22,9 @@ def setup(self) -> None: self.configurer = ClassificationConfigurer() self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model.py")) self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "data_pipeline.py")) + + self.multilabel_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_multilabel.py")) + self.hierarchical_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py")) @e2e_pytest_unit def test_configure(self, mocker): @@ -118,6 +121,13 @@ def test_configure_model(self): self.configurer.configure_model(self.model_cfg, ir_options) assert self.model_cfg.model_task assert self.model_cfg.model.head.in_channels == 960 + + multilabel_model_cfg = self.multilabel_model_cfg + self.configurer.configure_model(multilabel_model_cfg, ir_options) + + h_label_model_cfg = self.hierarchical_model_cfg + self.configurer.configure_model(h_label_model_cfg, ir_options) + @e2e_pytest_unit def test_configure_model_not_classification_task(self): From cc16d53b8415fe52c342ac1a5562ac9b3f63928f Mon Sep 17 00:00:00 2001 From: sungmanc Date: Wed, 19 Jul 2023 20:28:49 +0900 Subject: [PATCH 3/7] Make black happy --- .../adapters/mmcls/data/test_datasets.py | 7 ++++--- .../classification/adapters/mmcls/test_configurer.py | 11 ++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py index 5712d79753f..9c002ea2801 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py @@ -164,9 +164,10 @@ def test_hierarchical_with_empty_heads(self): pseudo_gt_label[pseudo_head_idx] = -1 pseudo_gt_labels.append(pseudo_gt_label) pseudo_gt_labels = np.array(pseudo_gt_labels) - + from copy import deepcopy - pseudo_dataset = deepcopy(dataset) - pseudo_dataset.gt_labels = pseudo_gt_labels + + pseudo_dataset = deepcopy(dataset) + pseudo_dataset.gt_labels = pseudo_gt_labels pseudo_dataset._update_heads_information(dataset.hierarchical_info["num_multiclass_heads"]) assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0 diff --git a/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py b/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py index ca7c51813fe..ae058a4d56d 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/test_configurer.py @@ -22,9 +22,11 @@ def setup(self) -> None: self.configurer = ClassificationConfigurer() self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model.py")) self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "data_pipeline.py")) - + self.multilabel_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_multilabel.py")) - self.hierarchical_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py")) + self.hierarchical_model_cfg = MPAConfig.fromfile( + os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py") + ) @e2e_pytest_unit def test_configure(self, mocker): @@ -121,13 +123,12 @@ def test_configure_model(self): self.configurer.configure_model(self.model_cfg, ir_options) assert self.model_cfg.model_task assert self.model_cfg.model.head.in_channels == 960 - + multilabel_model_cfg = self.multilabel_model_cfg self.configurer.configure_model(multilabel_model_cfg, ir_options) - + h_label_model_cfg = self.hierarchical_model_cfg self.configurer.configure_model(h_label_model_cfg, ir_options) - @e2e_pytest_unit def test_configure_model_not_classification_task(self): From a84204215ff639d9afc4f60f3a2bbd59c570870d Mon Sep 17 00:00:00 2001 From: sungmanc Date: Wed, 19 Jul 2023 23:47:30 +0900 Subject: [PATCH 4/7] Fix unittests --- .../mmcls/models/heads/test_custom_hierarchical_cls_head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py index 11f6e100996..0d9f6035e1b 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py @@ -31,6 +31,7 @@ def setup(self, head_type) -> None: "num_multilabel_classes": 1, "head_idx_to_logits_range": {"0": (0, 2)}, "num_single_label_classes": 2, + "empty_multiclass_head_indices": [] } self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0) self.multilabel_loss = dict(type=AsymmetricLossWithIgnore.__name__, reduction="sum") From 7c1c83039d6f33445749a688f0d8b7cd3e394850 Mon Sep 17 00:00:00 2001 From: sungmanc Date: Wed, 19 Jul 2023 23:50:58 +0900 Subject: [PATCH 5/7] Make black happy --- .../mmcls/models/heads/test_custom_hierarchical_cls_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py index 0d9f6035e1b..a5dd0d3a058 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py @@ -31,7 +31,7 @@ def setup(self, head_type) -> None: "num_multilabel_classes": 1, "head_idx_to_logits_range": {"0": (0, 2)}, "num_single_label_classes": 2, - "empty_multiclass_head_indices": [] + "empty_multiclass_head_indices": [], } self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0) self.multilabel_loss = dict(type=AsymmetricLossWithIgnore.__name__, reduction="sum") From 64fc3436876c9e5581d2e02f11b264aba0a9ba68 Mon Sep 17 00:00:00 2001 From: sungmanc Date: Thu, 20 Jul 2023 10:16:11 +0900 Subject: [PATCH 6/7] Fix update heades information func --- .../adapters/mmcls/datasets/otx_datasets.py | 9 +++------ .../classification/adapters/mmcls/data/test_datasets.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py index 997d908df6d..7522be7ea33 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py +++ b/src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py @@ -327,17 +327,14 @@ def load_annotations(self): self.gt_labels.append(class_indices) self.gt_labels = np.array(self.gt_labels) - self._update_heads_information(num_cls_heads) + self._update_heads_information() - def _update_heads_information(self, num_cls_heads: int): + def _update_heads_information(self): """Update heads information to find the empty heads. If there are no annotations at a specific head, this should be filtered out to calculate loss correctly. - - Args: - num_cls_heads (int): the number of multi-class heads. - """ + num_cls_heads = self.hierarchical_info["num_multiclass_heads"] for head_idx in range(num_cls_heads): labels_in_head = self.gt_labels[:, head_idx] # type: ignore[call-overload] if max(labels_in_head) < 0: diff --git a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py index 9c002ea2801..b4719680125 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/data/test_datasets.py @@ -169,5 +169,5 @@ def test_hierarchical_with_empty_heads(self): pseudo_dataset = deepcopy(dataset) pseudo_dataset.gt_labels = pseudo_gt_labels - pseudo_dataset._update_heads_information(dataset.hierarchical_info["num_multiclass_heads"]) + pseudo_dataset._update_heads_information() assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0 From 2d785a2da162ff2e31699e7044a06b21c1c9f404 Mon Sep 17 00:00:00 2001 From: sungmanc Date: Thu, 20 Jul 2023 15:05:08 +0900 Subject: [PATCH 7/7] Update the logic: consider the loss per batch --- .../custom_hierarchical_linear_cls_head.py | 12 +++++---- ...custom_hierarchical_non_linear_cls_head.py | 11 +++++--- .../test_custom_hierarchical_cls_head.py | 26 +++++++++++++------ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py index c2a7deaacc6..6776756bb61 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_linear_cls_head.py @@ -87,7 +87,7 @@ def forward_train(self, cls_score, gt_label, **kwargs): cls_score = self.fc(cls_score) losses = dict(loss=0.0) - num_empty_heads = len(self.hierarchical_info["empty_multiclass_head_indices"]) + num_effective_heads_in_batch = 0 for i in range(self.hierarchical_info["num_multiclass_heads"]): if i not in self.hierarchical_info["empty_multiclass_head_indices"]: head_gt = gt_label[:, i] @@ -99,12 +99,14 @@ def forward_train(self, cls_score, gt_label, **kwargs): ] valid_mask = head_gt >= 0 head_gt = head_gt[valid_mask].long() - head_logits = head_logits[valid_mask, :] - multiclass_loss = self.loss(head_logits, head_gt) - losses["loss"] += multiclass_loss + if len(head_gt) > 0: + head_logits = head_logits[valid_mask, :] + multiclass_loss = self.loss(head_logits, head_gt) + losses["loss"] += multiclass_loss + num_effective_heads_in_batch += 1 if self.hierarchical_info["num_multiclass_heads"] > 1: - losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] - num_empty_heads + losses["loss"] /= num_effective_heads_in_batch if self.compute_multilabel_loss: head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :] diff --git a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py index 7fa529c99ba..5397818fbf3 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py +++ b/src/otx/algorithms/classification/adapters/mmcls/models/heads/custom_hierarchical_non_linear_cls_head.py @@ -117,6 +117,7 @@ def forward_train(self, cls_score, gt_label, **kwargs): cls_score = self.classifier(cls_score) losses = dict(loss=0.0) + num_effective_heads_in_batch = 0 for i in range(self.hierarchical_info["num_multiclass_heads"]): if i not in self.hierarchical_info["empty_multiclass_head_indices"]: head_gt = gt_label[:, i] @@ -128,12 +129,14 @@ def forward_train(self, cls_score, gt_label, **kwargs): ] valid_mask = head_gt >= 0 head_gt = head_gt[valid_mask].long() - head_logits = head_logits[valid_mask, :] - multiclass_loss = self.loss(head_logits, head_gt) - losses["loss"] += multiclass_loss + if len(head_gt) > 0: + head_logits = head_logits[valid_mask, :] + multiclass_loss = self.loss(head_logits, head_gt) + losses["loss"] += multiclass_loss + num_effective_heads_in_batch += 1 if self.hierarchical_info["num_multiclass_heads"] > 1: - losses["loss"] /= self.hierarchical_info["num_multiclass_heads"] + losses["loss"] /= num_effective_heads_in_batch if self.compute_multilabel_loss: head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :] diff --git a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py index a5dd0d3a058..8f8ec9b6550 100644 --- a/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py +++ b/tests/unit/algorithms/classification/adapters/mmcls/models/heads/test_custom_hierarchical_cls_head.py @@ -24,13 +24,13 @@ def head_type(self) -> None: @pytest.fixture(autouse=True) def setup(self, head_type) -> None: - self.num_classes = 3 - self.head_dim = 5 + self.num_classes = 6 + self.head_dim = 10 self.cls_heads_info = { - "num_multiclass_heads": 1, - "num_multilabel_classes": 1, - "head_idx_to_logits_range": {"0": (0, 2)}, - "num_single_label_classes": 2, + "num_multiclass_heads": 3, + "num_multilabel_classes": 0, + "head_idx_to_logits_range": {"0": (0, 2), "1": (2, 4), "2": (4, 6)}, + "num_single_label_classes": 6, "empty_multiclass_head_indices": [], } self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0) @@ -44,13 +44,23 @@ def setup(self, head_type) -> None: ) self.default_head.init_weights() self.default_input = torch.ones((2, self.head_dim)) - self.default_gt = torch.zeros((2, 2)) + self.default_gt = torch.zeros((2, 3)) @e2e_pytest_unit def test_forward(self) -> None: result = self.default_head.forward_train(self.default_input, self.default_gt) assert "loss" in result - assert result["loss"] >= 0 + assert result["loss"] >= 0 and not torch.isnan(result["loss"]) + + empty_head_gt_full = torch.tensor([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]) + result_include_empty_full = self.default_head.forward_train(self.default_input, empty_head_gt_full) + assert "loss" in result_include_empty_full + assert result_include_empty_full["loss"] >= 0 and not torch.isnan(result_include_empty_full["loss"]) + + empty_head_gt_partial = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]) + result_include_empty_partial = self.default_head.forward_train(self.default_input, empty_head_gt_partial) + assert "loss" in result_include_empty_partial + assert result_include_empty_partial["loss"] >= 0 and not torch.isnan(result_include_empty_partial["loss"]) @e2e_pytest_unit def test_simple_test(self) -> None: