Skip to content

Commit 072c8a8

Browse files
authored
Set label name with parents to avoid duplicates for AstypeAnnotations (#1492)
- Except `nan` for label - Set label name with parents to avoid duplicate name for `AstypeAnnotations` - Add unit test when tabular dataset included missing value - Add unit test for AstypeAnnotations when label value is nan
1 parent 62ec011 commit 072c8a8

6 files changed

Lines changed: 146 additions & 9 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
(<https://github.com/openvinotoolkit/datumaro/pull/1471>)
2424
- Add ExtractedMask and update importers who can use it to use it
2525
(<https://github.com/openvinotoolkit/datumaro/pull/1480>)
26+
- Set label name with parents to avoid duplicates for AstypeAnnotations
27+
(<https://github.com/openvinotoolkit/datumaro/pull/1492>)
2628

2729
### Bug fixes
2830
- Split the video directory into subsets to avoid overwriting

src/datumaro/plugins/data_formats/tabular.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ def _parse(
115115
target_dtype = table.dtype(target_)
116116
if target_dtype in [int, float, pd.api.types.CategoricalDtype()]:
117117
# 'int' can be categorical, but we don't know this unless user gives information.
118-
labels = set(table.features(target_, unique=True))
118+
labels = set(
119+
[
120+
feature
121+
for feature in table.features(target_, unique=True)
122+
if not pd.isna(feature)
123+
]
124+
)
119125
if category is None:
120126
categories.add(target_, target_dtype, labels)
121127
else: # update labels if they are different.

src/datumaro/plugins/transforms.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,8 @@ def __init__(
14951495
):
14961496
super().__init__(extractor)
14971497

1498+
self._sep_token = ":"
1499+
14981500
if extractor.media_type() and not issubclass(extractor.media_type(), TableRow):
14991501
raise MediaTypeError(
15001502
"Media type is not table. This transform only support tabular media"
@@ -1523,6 +1525,7 @@ def __init__(
15231525
dst_parent = src_cat.name
15241526
dst_labels = sorted(src_cat.labels)
15251527
for dst_label in dst_labels:
1528+
dst_label = dst_parent + self._sep_token + str(dst_label)
15261529
dst_index = dst_label_cat.add(dst_label, parent=dst_parent, attributes={})
15271530
self._id_mapping[dst_label] = dst_index
15281531
dst_label_cat.add_label_group(src_cat.name, src_cat.labels, group_type=0)
@@ -1533,12 +1536,12 @@ def categories(self):
15331536
return self._categories
15341537

15351538
def transform_item(self, item: DatasetItem):
1536-
annotations = []
1537-
for name, value in item.annotations[0].values.items():
1538-
dtype = self._tabular_cat_types.get(name, None)
1539-
if dtype == CategoricalDtype():
1540-
annotations.append(Label(label=self._id_mapping[value]))
1541-
else:
1542-
annotations.append(Caption(value))
1539+
annotations = [
1540+
Label(label=self._id_mapping[name + self._sep_token + str(value)])
1541+
if self._tabular_cat_types.get(name) == CategoricalDtype() and value is not None
1542+
else Caption(value)
1543+
for name, value in item.annotations[0].values.items()
1544+
if value is not None
1545+
]
15431546

15441547
return self.wrap_item(item, annotations=annotations)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
date,day,period,nswprice,nswdemand,vicprice,vicdemand,transfer,class
2+
0.425556,5,0.340426,0.076108,0.392889,0.003467,0.422915,0.414912,UP
3+
0.425512,4,0.617021,0.060376,0.483041,0.003467,0.422915,0.414912,DOWN
4+
0.013982,4,0.042553,0.061967,0.521125,0.003467,0.422915,0.414912,DOWN
5+
0.907349,3,0.06383,0.080581,0.331003,0.00538,0.47566,0.441228,DOWN
6+
0.889341,0,0.361702,0.027141,0.379649,0.001624,0.248317,0.69386,DOWN
7+
0.433565,3,0.787234,0.082803,0.447337,0.003467,0.422915,0.414912,UP
8+
0.894474,4,0.787234,0.088087,0.840672,0.006012,0.752978,0.455702,UP
9+
0.866997,5,0.446809,0.037739,0.506992,0.002495,0.339202,0.664474,UP
10+
0.460909,3,0.319149,0.054672,0.585689,0.003741,0.448731,0.389912,DOWN
11+
0.031857,4,0.255319,0.055242,0.115739,0.003467,0.422915,0.414912,DOWN
12+
0.876023,5,1.0,0.028822,0.369087,0.001477,0.336872,0.769298,UP
13+
0.030707,6,0.042553,0.047526,0.132104,0.003467,0.422915,0.414912,DOWN
14+
0.500111,6,0.914894,0.06617,0.300952,0.00446,0.287416,0.420175,UP
15+
0.890093,3,0.744681,0.338747,0.960875,0.023332,0.857328,0.325,UP
16+
0.898544,6,0.531915,0.090068,0.476941,0.005544,0.35448,0.716228,UP
17+
0.434406,1,0.340426,0.051039,0.518596,0.003467,0.422915,0.414912,UP
18+
0.881023,0,0.787234,0.029302,0.409104,0.001847,0.418436,0.746053,UP
19+
0.872174,4,0.148936,0.019125,0.142368,0.000841,0.250388,0.875,DOWN
20+
0.469094,0,0.680851,0.029152,0.267034,0.001917,0.23796,0.587281,DOWN
21+
0.871776,2,0.595745,0.0269,0.530348,0.001687,0.682548,0.630702,UP
22+
0.456086,5,0.93617,0.053591,0.57468,0.003671,0.387364,0.565789,
23+
0.486837,3,0.446809,0.07959,0.499851,0.005501,0.493009,0.296491,
24+
0.009513,0,0.170213,0.041341,0.191461,0.003467,0.422915,0.414912,
25+
0.429052,1,0.659574,0.100546,0.512794,0.003467,0.422915,0.414912,

tests/unit/test_tabular_format.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def fxt_electricity(fxt_tabular_root):
2828
yield Dataset.import_from(path, "tabular")
2929

3030

31+
@pytest.fixture()
32+
def fxt_electricity_missing(fxt_tabular_root):
33+
path = osp.join(fxt_tabular_root, "electricity_missing.csv")
34+
yield Dataset.import_from(path, "tabular", target={"input": "nswprice", "output": "class"})
35+
36+
3137
@pytest.fixture()
3238
def fxt_buddy_target():
3339
yield {"input": "length(m)", "output": ["breed_category", "pet_category"]}
@@ -178,3 +184,23 @@ def test_target_dtype(self, fxt_tabular_root, target, expected_included_labels)
178184
)
179185
def test_string_to_dict(self, input_string, expected_result):
180186
assert string_to_dict(input_string) == expected_result
187+
188+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
189+
def test_can_import_tabular_file_with_missing_value(self, fxt_electricity_missing) -> None:
190+
import math
191+
192+
dataset: Type[Dataset] = fxt_electricity_missing
193+
expected_categories_keys = [("class", CategoricalDtype())]
194+
expected_category_labels = {"UP", "DOWN"}
195+
196+
result_categories = dataset.categories()[AnnotationType.tabular].items[0]
197+
assert [(result_categories.name, result_categories.dtype)] == expected_categories_keys
198+
assert len(dataset) == 24
199+
assert result_categories.labels == expected_category_labels
200+
201+
num_nan_annotations = sum(
202+
math.isnan(item.annotations[0].values["class"])
203+
for item in dataset
204+
if isinstance(item.annotations[0].values["class"], float)
205+
)
206+
assert num_nan_annotations == 4

tests/unit/test_transforms.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,37 @@ def setUp(self):
12711271
categories={},
12721272
media_type=TableRow,
12731273
)
1274+
self.table_label_nan = Table.from_list(
1275+
[{"class": "DOWN"}, {"class": "UP"}, {"class": None}]
1276+
)
1277+
self.dataset_label_nan = Dataset.from_iterable(
1278+
[
1279+
DatasetItem(
1280+
id="0",
1281+
subset="train",
1282+
media=TableRow(table=self.table_label_nan, index=0),
1283+
annotations=[Tabular(values={"class": "DOWN"})],
1284+
),
1285+
DatasetItem(
1286+
id="1",
1287+
subset="train",
1288+
media=TableRow(table=self.table_label_nan, index=1),
1289+
annotations=[Tabular(values={"class": "UP"})],
1290+
),
1291+
DatasetItem(
1292+
id="2",
1293+
subset="train",
1294+
media=TableRow(table=self.table_label_nan, index=2),
1295+
annotations=[Tabular(values={"class": None})],
1296+
),
1297+
],
1298+
categories={
1299+
AnnotationType.tabular: TabularCategories.from_iterable(
1300+
[("class", CategoricalDtype(), {"DOWN", "UP"})]
1301+
)
1302+
},
1303+
media_type=TableRow,
1304+
)
12741305

12751306
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
12761307
def test_split_arg_valid(self):
@@ -1325,7 +1356,7 @@ def test_transform_annotation_type_label(self):
13251356
],
13261357
categories={
13271358
AnnotationType.label: LabelCategories.from_iterable(
1328-
[("DOWN", "class"), ("UP", "class")]
1359+
[("class:DOWN", "class"), ("class:UP", "class")]
13291360
)
13301361
},
13311362
media_type=TableRow,
@@ -1369,3 +1400,47 @@ def test_transform_annotation_type_caption(self):
13691400
result = transforms.AstypeAnnotations(dataset)
13701401

13711402
compare_datasets(self, expected, result)
1403+
1404+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
1405+
def test_transform_annotation_type_label_with_nan(self):
1406+
table = self.table_label_nan
1407+
expected = Dataset.from_iterable(
1408+
[
1409+
DatasetItem(
1410+
id="0",
1411+
subset="train",
1412+
media=TableRow(table=table, index=0),
1413+
annotations=[Label(label=0)],
1414+
),
1415+
DatasetItem(
1416+
id="1",
1417+
subset="train",
1418+
media=TableRow(table=table, index=1),
1419+
annotations=[Label(label=1)],
1420+
),
1421+
DatasetItem(
1422+
id="2",
1423+
subset="train",
1424+
media=TableRow(table=table, index=2),
1425+
annotations=[],
1426+
),
1427+
],
1428+
categories={
1429+
AnnotationType.label: LabelCategories.from_iterable(
1430+
[("class:DOWN", "class"), ("class:UP", "class")]
1431+
)
1432+
},
1433+
media_type=TableRow,
1434+
)
1435+
1436+
dataset = self.dataset_label_nan
1437+
result = transforms.AstypeAnnotations(dataset)
1438+
1439+
categories = result._categories.get(AnnotationType.label, None)
1440+
assert categories
1441+
1442+
# Check label_groups of categories
1443+
assert categories.label_groups[0].name == "class"
1444+
assert sorted(categories.label_groups[0].labels) == ["DOWN", "UP"]
1445+
1446+
compare_datasets(self, expected, result)

0 commit comments

Comments
 (0)