diff --git a/CHANGELOG.md b/CHANGELOG.md index b6a96cfb9a..cc9b6bbdbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactor Datumaro format code and test code () +### Fixed +- Fix image filenames and anomaly mask appearance in MVTec exporter + () + ## 24/02/2023 - Release v1.0.0 ### Added - Add Data Explorer diff --git a/datumaro/plugins/data_formats/mvtec/exporter.py b/datumaro/plugins/data_formats/mvtec/exporter.py index ed704a76b3..269ee1017a 100644 --- a/datumaro/plugins/data_formats/mvtec/exporter.py +++ b/datumaro/plugins/data_formats/mvtec/exporter.py @@ -77,7 +77,7 @@ def apply(self): labels.append(self.get_label(ann.label)) if self._save_media: - self._save_image(item, subdir=osp.join(subset_name, labels[0])) + self._save_image(item, subdir=subset_name) bboxes = [a for a in item.annotations if a.type == AnnotationType.bbox] if bboxes and MvtecTask.detection in self._tasks: @@ -92,7 +92,7 @@ def apply(self): if not osp.exists(osp.join(self._save_dir, osp.dirname(mask_path))): os.mkdir(osp.join(self._save_dir, osp.dirname(mask_path))) - cv2.imwrite(osp.join(self._save_dir, mask_path), mask) + cv2.imwrite(osp.join(self._save_dir, mask_path), mask * 255) masks = [a for a in item.annotations if a.type == AnnotationType.mask] if masks and MvtecTask.segmentation in self._tasks: @@ -103,7 +103,7 @@ def apply(self): if not osp.exists(osp.join(self._save_dir, osp.dirname(mask_path))): os.mkdir(osp.join(self._save_dir, osp.dirname(mask_path))) cv2.imwrite( - osp.join(self._save_dir, mask_path), masks[0].image.astype(np.uint8) + osp.join(self._save_dir, mask_path), masks[0].image.astype(np.uint8) * 255 ) def get_label(self, label_id): diff --git a/tests/unit/test_mvtec_format.py b/tests/unit/test_mvtec_format.py index 2129ac159f..7944ba65d1 100644 --- a/tests/unit/test_mvtec_format.py +++ b/tests/unit/test_mvtec_format.py @@ -1,5 +1,6 @@ from unittest import TestCase +import cv2 import numpy as np from datumaro.components.annotation import AnnotationType, Bbox, Label, LabelCategories, Mask @@ -226,3 +227,47 @@ def test_can_detect_mvtec(self): with self.subTest(path=path, task=subtask): detected_formats = env.detect_dataset(path) self.assertIn(subtask.NAME, detected_formats) + + +class MVTecExporterTest(TestCase): + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_segmentation_masks_saved_as_binary_image(self): + source_dataset = Dataset.from_iterable( + [ + DatasetItem( + id="label_1/000", + media=Image(data=np.ones((8, 8, 3))), + annotations=[Mask(image=np.ones((8, 8), dtype=np.uint8), label=1)], + ), + ], + categories={ + AnnotationType.label: LabelCategories.from_iterable( + "label_" + str(label) for label in range(2) + ), + }, + ) + + with TestDir() as test_dir: + MvtecExporter.convert(source_dataset, test_dir, save_media=True) + assert cv2.imread(test_dir + "/ground_truth/label_1/000_mask.png").max() == 255 + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_detection_masks_saved_as_binary_image(self): + source_dataset = Dataset.from_iterable( + [ + DatasetItem( + id="label_1/000", + media=Image(data=np.ones((8, 8, 3))), + annotations=[Bbox(0, 0, 8, 8, label=1)], + ), + ], + categories={ + AnnotationType.label: LabelCategories.from_iterable( + "label_" + str(label) for label in range(2) + ), + }, + ) + + with TestDir() as test_dir: + MvtecExporter.convert(source_dataset, test_dir, save_media=True) + assert cv2.imread(test_dir + "/ground_truth/label_1/000_mask.png").max() == 255