Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[Unreleased\]

### New features
- Add PseudoLabeling transform for unlabeled dataset
(<https://github.com/openvinotoolkit/datumaro/pull/1594>)

### Enhancements
- Raise an appropriate error when exporting a datumaro dataset if its subset name contains path separators.
(<https://github.com/openvinotoolkit/datumaro/pull/1615>)

### Bug fixes

## Q3 2024 Release 1.9.0
## \[Q3 2024 Release 1.9.0\]
### New features
- Add a new CLI command: datum format
(<https://github.com/openvinotoolkit/datumaro/pull/1570>)
Expand Down
31 changes: 31 additions & 0 deletions docs/source/docs/command-reference/context_free/transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Basic dataset item manipulations:
- [`remove_annotations`](#remove_annotations) - Removes annotations
- [`remove_attributes`](#remove_attributes) - Removes attributes
- [`astype_annotations`](#astype_annotations) - Convert annotation type
- [`pseudo_labeling`](#pseudo_labeling) - Generate pseudo labels for unlabeled data

Subset manipulations:
- [`random_split`](#random_split) - Splits dataset into subsets
Expand Down Expand Up @@ -838,3 +839,33 @@ correct [-h] [-r REPORT_PATH]
Optional arguments:
- `-h`, `--help` (flag) - Show this help message and exit
- `-r`, `--reports` (str) - A validation report from a 'validate' CLI (default=validation_reports.json)

#### `pseudo_labeling`

Assigns pseudo-labels to items in a dataset based on their similarity to predefined labels. This class is useful for semi-supervised learning when dealing with missing or uncertain labels.

The process includes:

- Similarity Computation: Uses hashing techniques to compute the similarity between items and predefined labels.
- Pseudo-Label Assignment: Assigns the most similar label as a pseudo-label to each item.

Attributes:

- `extractor` (IDataset) - Provides access to dataset items and their annotations.
- `labels` (Optional[List[str]]) - List of predefined labels for pseudo-labeling. Defaults to all available labels if not provided.
- `explorer` (Optional[Explorer]) - Computes hash keys for items and labels. If not provided, a new Explorer is created.

Usage:
```console
pseudo_labeling [-h] [--labels LABELS]
```

Optional arguments:
- `-h`, `--help` (flag) - Show this help message and exit
- `--labels` (str) - Comma-separated list of label names for pseudo-labeling

Examples:
- Assign pseudo-labels based on predefined labels
```console
datum transform -t pseudo_labeling -- --labels 'label1,label2'
```
5 changes: 5 additions & 0 deletions src/datumaro/plugins/specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,11 @@
"plugin_name": "remove_annotations",
"plugin_type": "Transform"
},
{
"import_path": "datumaro.plugins.transforms.PseudoLabeling",
"plugin_name": "pseudo_labeling",
"plugin_type": "Transform"
},
{
"import_path": "datumaro.plugins.transforms.RemoveAttributes",
"plugin_name": "remove_attributes",
Expand Down
64 changes: 64 additions & 0 deletions src/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pandas.api.types import CategoricalDtype

import datumaro.util.mask_tools as mask_tools
from datumaro.components.algorithms.hash_key_inference.explorer import Explorer
from datumaro.components.algorithms.hash_key_inference.hashkey_util import calculate_hamming
from datumaro.components.annotation import (
AnnotationType,
Bbox,
Expand All @@ -40,6 +42,7 @@
TabularCategories,
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetInfo, DatasetItem, IDataset
from datumaro.components.errors import (
AnnotationTypeError,
Expand Down Expand Up @@ -2004,3 +2007,64 @@ def transform_item(self, item):
refined_annotations.append(ann)

return self.wrap_item(item, media=refined_media, annotations=refined_annotations)


class PseudoLabeling(ItemTransform):
"""
A class used to assign pseudo-labels to items in a dataset based on
their similarity to predefined labels.|n
|n
This class leverages hashing techniques to compute the similarity
between dataset items and a set of predefined labels.|n
It assigns the most similar label as a pseudo-label to each item.
This is particularly useful in semi-supervised
learning scenarios where some labels are missing or uncertain.|n
|n
Attributes:|n
- extractor : IDataset|n
The dataset extractor that provides access to dataset items and their annotations.|n
- labels : Optional[List[str]]|n
A list of label names to be used for pseudo-labeling.
If not provided, all available labels in the dataset will be used.|n
- explorer : Optional[Explorer]|n
An optional Explorer object used to compute hash keys for items and labels.
If not provided, a new Explorer will be created.|n
"""

def __init__(
self,
extractor: IDataset,
labels: Optional[List[str]] = None,
explorer: Optional[Explorer] = None,
):
super().__init__(extractor)

self._categories = self._extractor.categories()
self._labels = labels
self._explorer = explorer
self._label_indices = self._categories[AnnotationType.label]._indices

if not self._labels:
self._labels = list(self._label_indices.keys())
if not self._explorer:
self._explorer = Explorer(Dataset.from_iterable(list(self._extractor)))

label_hashkeys = [
np.unpackbits(self._explorer._get_hash_key_from_text_query(label).hash_key, axis=-1)
for label in self._labels
]
self._label_hashkeys = np.stack(label_hashkeys, axis=0)

def categories(self):
return self._categories

def transform_item(self, item: DatasetItem):
hashkey_ = np.unpackbits(self._explorer._get_hash_key_from_item_query(item).hash_key)
logits = calculate_hamming(hashkey_, self._label_hashkeys)
inverse_distances = 1.0 / (logits + 1e-6)
probs = inverse_distances / np.sum(inverse_distances)
ind = np.argsort(probs)[::-1]

pseudo = np.array(self._labels)[ind][0]
pseudo_annotation = [Label(label=self._label_indices[pseudo])]
return self.wrap_item(item, annotations=pseudo_annotation)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the original annotations are replaced with the pseudo annotation. Is it intended? Then how about add this notice to the document?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I assumed the case that each item does not have any annotation. so the original annotations is not existed. If my assumption is not proper, please let me know.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Then, we can enhance this feature when the assumption is changed.

56 changes: 56 additions & 0 deletions tests/unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import logging as log
import os
import os.path as osp
import random
from unittest import TestCase
Expand All @@ -14,6 +15,7 @@

import datumaro.plugins.transforms as transforms
import datumaro.util.mask_tools as mask_tools
from datumaro.components.algorithms.hash_key_inference.explorer import Explorer
from datumaro.components.annotation import (
AnnotationType,
Bbox,
Expand Down Expand Up @@ -1673,3 +1675,57 @@ def test_transform_clean_after_astype_ann(self):
result_item = result.__getitem__(i)
self.assertEqual(expected_item.annotations, result_item.annotations)
self.assertEqual(expected_item.media, result_item.media)


class PseudoLabelingTest(TestCase):
def setUp(self):
self.data_path = get_test_asset_path("explore_dataset")
self.categories = ["bird", "cat", "dog", "monkey"]
self.source = Dataset.from_iterable(
[
DatasetItem(
id=0,
media=Image.from_file(
path=os.path.join(self.data_path, "dog", "ILSVRC2012_val_00001698.JPEG")
),
),
DatasetItem(
id=1,
media=Image.from_file(
path=os.path.join(self.data_path, "cat", "ILSVRC2012_val_00004894.JPEG")
),
),
],
categories=self.categories,
)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_pseudolabeling_with_labels(self):
dataset = self.source
labels = self.categories
explorer = Explorer(dataset)
result = dataset.transform("pseudo_labeling", labels=labels, explorer=explorer)

label_indices = dataset.categories()[AnnotationType.label]._indices
for item, expected in zip(result, ["dog", "cat"]):
self.assertEqual(item.annotations[0].label, label_indices[expected])

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_pseudolabeling_without_labels(self):
dataset = self.source
explorer = Explorer(dataset)
result = dataset.transform("pseudo_labeling", explorer=explorer)

label_indices = dataset.categories()[AnnotationType.label]._indices
for item, expected in zip(result, ["dog", "cat"]):
self.assertEqual(item.annotations[0].label, label_indices[expected])

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_pseudolabeling_without_explorer(self):
dataset = self.source
labels = self.categories
result = dataset.transform("pseudo_labeling", labels=labels)

label_indices = dataset.categories()[AnnotationType.label]._indices
for item, expected in zip(result, ["dog", "cat"]):
self.assertEqual(item.annotations[0].label, label_indices[expected])