Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def configure_model(self, cfg, ir_options): # noqa: C901
cfg.model.arch_type = cfg.model.type
cfg.model.type = super_type

# Hierarchical
if cfg.model.get("hierarchical"):
assert cfg.data.train.hierarchical_info == cfg.data.val.hierarchical_info == cfg.data.test.hierarchical_info
cfg.model.head.hierarchical_info = cfg.data.train.hierarchical_info

# OV-plugin
ir_model_path = ir_options.get("ir_model_path")
if ir_model_path:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,7 @@ def load_annotations(self):
if item_labels:
num_cls_heads = self.hierarchical_info["num_multiclass_heads"]

class_indices = [0] * (
self.hierarchical_info["num_multiclass_heads"] + self.hierarchical_info["num_multilabel_classes"]
)
class_indices = [0] * (num_cls_heads + self.hierarchical_info["num_multilabel_classes"])
for j in range(num_cls_heads):
class_indices[j] = -1
for otx_lbl in item_labels:
Expand All @@ -329,6 +327,19 @@ def load_annotations(self):
self.gt_labels.append(class_indices)
self.gt_labels = np.array(self.gt_labels)

self._update_heads_information()

def _update_heads_information(self):
"""Update heads information to find the empty heads.

If there are no annotations at a specific head, this should be filtered out to calculate loss correctly.
"""
num_cls_heads = self.hierarchical_info["num_multiclass_heads"]
for head_idx in range(num_cls_heads):
labels_in_head = self.gt_labels[:, head_idx] # type: ignore[call-overload]
if max(labels_in_head) < 0:
self.hierarchical_info["empty_multiclass_head_indices"].append(head_idx)

@staticmethod
def mean_top_k_accuracy(scores, labels, k=1):
"""Return mean of top-k accuracy."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,26 @@ def forward_train(self, cls_score, gt_label, **kwargs):
cls_score = self.fc(cls_score)

losses = dict(loss=0.0)
num_effective_heads_in_batch = 0
for i in range(self.hierarchical_info["num_multiclass_heads"]):
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
if i not in self.hierarchical_info["empty_multiclass_head_indices"]:
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
if len(head_gt) > 0:
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
num_effective_heads_in_batch += 1

if self.hierarchical_info["num_multiclass_heads"] > 1:
losses["loss"] /= self.hierarchical_info["num_multiclass_heads"]
losses["loss"] /= num_effective_heads_in_batch

if self.compute_multilabel_loss:
head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,26 @@ def forward_train(self, cls_score, gt_label, **kwargs):
cls_score = self.classifier(cls_score)

losses = dict(loss=0.0)
num_effective_heads_in_batch = 0
for i in range(self.hierarchical_info["num_multiclass_heads"]):
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
if i not in self.hierarchical_info["empty_multiclass_head_indices"]:
head_gt = gt_label[:, i]
head_logits = cls_score[
:,
self.hierarchical_info["head_idx_to_logits_range"][str(i)][0] : self.hierarchical_info[
"head_idx_to_logits_range"
][str(i)][1],
]
valid_mask = head_gt >= 0
head_gt = head_gt[valid_mask].long()
if len(head_gt) > 0:
head_logits = head_logits[valid_mask, :]
multiclass_loss = self.loss(head_logits, head_gt)
losses["loss"] += multiclass_loss
num_effective_heads_in_batch += 1

if self.hierarchical_info["num_multiclass_heads"] > 1:
losses["loss"] /= self.hierarchical_info["num_multiclass_heads"]
losses["loss"] /= num_effective_heads_in_batch

if self.compute_multilabel_loss:
head_gt = gt_label[:, self.hierarchical_info["num_multiclass_heads"] :]
Expand Down
1 change: 1 addition & 0 deletions src/otx/algorithms/classification/utils/cls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disabl
"class_to_group_idx": class_to_idx,
"all_groups": exclusive_groups + single_label_groups,
"label_to_idx": label_to_idx,
"empty_multiclass_head_indices": [],
}
return mixed_cls_heads_info

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,32 @@ def test_metric_hierarchical_adapter(self):
dataset = OTXHierarchicalClsDataset(
otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info
)

results = np.zeros((len(dataset), dataset.num_classes))
metrics = dataset.evaluate(results)

assert len(metrics) > 0
assert metrics["accuracy"] > 0

@e2e_pytest_unit
def test_hierarchical_with_empty_heads(self):
self.task_environment, self.dataset = init_environment(
self.hyper_parameters, self.model_template, False, True, self.dataset_len
)
class_info = get_multihead_class_info(self.task_environment.label_schema)
dataset = OTXHierarchicalClsDataset(
otx_dataset=self.dataset, labels=self.dataset.get_labels(), hierarchical_info=class_info
)
pseudo_gt_labels = []
pseudo_head_idx = 0
for label in dataset.gt_labels:
pseudo_gt_label = label
pseudo_gt_label[pseudo_head_idx] = -1
pseudo_gt_labels.append(pseudo_gt_label)
pseudo_gt_labels = np.array(pseudo_gt_labels)

from copy import deepcopy

pseudo_dataset = deepcopy(dataset)
pseudo_dataset.gt_labels = pseudo_gt_labels
pseudo_dataset._update_heads_information()
assert pseudo_dataset.hierarchical_info["empty_multiclass_head_indices"][pseudo_head_idx] == 0
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ def head_type(self) -> None:

@pytest.fixture(autouse=True)
def setup(self, head_type) -> None:
self.num_classes = 3
self.head_dim = 5
self.num_classes = 6
self.head_dim = 10
self.cls_heads_info = {
"num_multiclass_heads": 1,
"num_multilabel_classes": 1,
"head_idx_to_logits_range": {"0": (0, 2)},
"num_single_label_classes": 2,
"num_multiclass_heads": 3,
"num_multilabel_classes": 0,
"head_idx_to_logits_range": {"0": (0, 2), "1": (2, 4), "2": (4, 6)},
"num_single_label_classes": 6,
"empty_multiclass_head_indices": [],
}
self.loss = dict(type="CrossEntropyLoss", use_sigmoid=False, reduction="mean", loss_weight=1.0)
self.multilabel_loss = dict(type=AsymmetricLossWithIgnore.__name__, reduction="sum")
Expand All @@ -43,13 +44,23 @@ def setup(self, head_type) -> None:
)
self.default_head.init_weights()
self.default_input = torch.ones((2, self.head_dim))
self.default_gt = torch.zeros((2, 2))
self.default_gt = torch.zeros((2, 3))

@e2e_pytest_unit
def test_forward(self) -> None:
result = self.default_head.forward_train(self.default_input, self.default_gt)
assert "loss" in result
assert result["loss"] >= 0
assert result["loss"] >= 0 and not torch.isnan(result["loss"])

empty_head_gt_full = torch.tensor([[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
result_include_empty_full = self.default_head.forward_train(self.default_input, empty_head_gt_full)
assert "loss" in result_include_empty_full
assert result_include_empty_full["loss"] >= 0 and not torch.isnan(result_include_empty_full["loss"])

empty_head_gt_partial = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]])
result_include_empty_partial = self.default_head.forward_train(self.default_input, empty_head_gt_partial)
assert "loss" in result_include_empty_partial
assert result_include_empty_partial["loss"] >= 0 and not torch.isnan(result_include_empty_partial["loss"])

@e2e_pytest_unit
def test_simple_test(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def setup(self) -> None:
self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model.py"))
self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "data_pipeline.py"))

self.multilabel_model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_multilabel.py"))
self.hierarchical_model_cfg = MPAConfig.fromfile(
os.path.join(DEFAULT_CLS_TEMPLATE_DIR, "model_hierarchical.py")
)

@e2e_pytest_unit
def test_configure(self, mocker):
mock_cfg_base = mocker.patch.object(ClassificationConfigurer, "configure_base")
Expand Down Expand Up @@ -119,6 +124,12 @@ def test_configure_model(self):
assert self.model_cfg.model_task
assert self.model_cfg.model.head.in_channels == 960

multilabel_model_cfg = self.multilabel_model_cfg
self.configurer.configure_model(multilabel_model_cfg, ir_options)

h_label_model_cfg = self.hierarchical_model_cfg
self.configurer.configure_model(h_label_model_cfg, ir_options)

@e2e_pytest_unit
def test_configure_model_not_classification_task(self):
ir_options = {"ir_model_path": {"ir_weight_path": "", "ir_weight_init": ""}}
Expand Down