diff --git a/CHANGELOG.md b/CHANGELOG.md index c4fa1a132d..78b5cf36f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- Task-specific Splitter () - `WiderFace` dataset format () - Function to transform annotations to labels () +- Task-specific Splitter (, ) - `VGGFace2` dataset format () ### Changed diff --git a/README.md b/README.md index f8d32d5620..9fd34c3e3e 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,14 @@ CVAT annotations ---> Publication, statistics etc. - polygons to instance masks and vise-versa - apply a custom colormap for mask annotations - rename or remove dataset labels + - Splitting a dataset into multiple subsets like `train`, `val`, and `test`: + - random split + - task-specific splits based on annotations, + which keep initial label and attribute distributions + - for classification task, based on labels + - for detection task, based on bboxes + - for re-identification task, based on labels, + avoiding having same IDs in training and test splits - Dataset quality checking - Simple checking for errors - Comparison with model infernece diff --git a/datumaro/plugins/splitter.py b/datumaro/plugins/splitter.py index 704e8c0966..02a2675124 100644 --- a/datumaro/plugins/splitter.py +++ b/datumaro/plugins/splitter.py @@ -7,14 +7,38 @@ from datumaro.components.extractor import (Transform, AnnotationType, DEFAULT_SUBSET_NAME) +from datumaro.components.cli_plugin import CliPlugin NEAR_ZERO = 1e-7 -class _TaskSpecificSplit(Transform): +class _TaskSpecificSplit(Transform, CliPlugin): + _default_split = [('train', 0.5), ('val', 0.2), ('test', 0.3)] + + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-s', '--subset', action='append', + type=cls._split_arg, dest='splits', + help="Subsets in the form: ':' " + "(repeatable, default: %s)" % dict(cls._default_split)) + parser.add_argument('--seed', type=int, help="Random seed") + return parser + + @staticmethod + def _split_arg(s): + parts = s.split(':') + if len(parts) != 2: + import argparse + raise argparse.ArgumentTypeError() + return (parts[0], float(parts[1])) + def __init__(self, dataset, splits, seed): super().__init__(dataset) + if splits is None: + splits = self._default_split + snames, sratio = self._validate_splits(splits) self._snames = snames @@ -112,8 +136,10 @@ def _group_by_attr(items): return by_attributes def _split_by_attr(self, datasets, snames, ratio, out_splits, - dataset_key="label"): + dataset_key=None): required = self._get_required(ratio) + if dataset_key is None: + dataset_key = "label" for key, items in datasets.items(): np.random.shuffle(items) by_attributes = self._group_by_attr(items) @@ -153,13 +179,18 @@ def __iter__(self): class ClassificationSplit(_TaskSpecificSplit): """ Splits dataset into train/val/test set in class-wise manner. |n + Splits dataset images in the specified ratio, keeping the initial class + distribution.|n |n Notes:|n - - Single label is expected for each DatasetItem.|n - - If there are not enough images in some class or attributes group, + - Each image is expected to have only one Label|n + - If Labels also have attributes, also splits by attribute values.|n + - If there is not enough images in some class or attributes group, the split ratio can't be guaranteed.|n + |n + Example:|n + |s|s%(prog)s --subset train:.5 --subset val:.2 --subset test:.3 """ - def __init__(self, dataset, splits, seed=None): """ Parameters @@ -195,24 +226,51 @@ def _split_dataset(self): self._set_parts(by_splits) -class MatchingReIDSplit(_TaskSpecificSplit): +class ReidentificationSplit(_TaskSpecificSplit): """ - Splits dataset for matching, especially re-id task.|n - First, splits dataset into 'train+val' and 'test' sets by person id.|n - Note that this splitting is not by DatasetItem. |n - Then, tags 'test' into 'gallery'/'query' in class-wise random manner.|n - Then, splits 'train+val' into 'train'/'val' sets in the same way.|n - Therefore, the final subsets would be 'train', 'val', 'test'. |n - And 'gallery', 'query' are tagged using anntoation group.|n - You can get the 'gallery' and 'query' sets using 'get_subset_by_group'.|n + Splits a dataset for re-identification task.|n + Produces a split with a specified ratio of images, avoiding having same + labels in different subsets.|n + |n + In this task, the test set should consist of images of unseen + people or objects during the training phase. |n + This function splits a dataset in the following way:|n + 1. Splits the dataset into 'train + val' and 'test' sets|n + |s|sbased on person or object ID.|n + 2. Splits 'test' set into 'test-gallery' and 'test-query' sets|n + |s|sin class-wise manner.|n + 3. Splits the 'train + val' set into 'train' and 'val' sets|n + |s|sin the same way.|n + The final subsets would be + 'train', 'val', 'test-gallery' and 'test-query'. |n + |n Notes:|n - - Single label is expected for each DatasetItem.|n - - Each label is expected to have attribute representing the person id. |n + - Each image is expected to have a single Label|n + - Object ID can be described by Label, or by attribute (--attr parameter)|n + - The splits of the test set are controlled by '--query' parameter. |n + |s|sGallery ratio would be 1.0 - query.|n + |n + Example: split a dataset in the specified ratio, split the test set|n + |s|s|s|sinto gallery and query in 1:1 ratio|n + |s|s%(prog)s --subset train:.5 --subset val:.2 --subset test:.3 --query .5|n + Example: use 'person_id' attribute for splitting|n + |s|s%(prog)s --attr person_id """ - _group_map = dict() + _default_query_ratio = 0.5 - def __init__(self, dataset, splits, test_splits, pid_name="PID", seed=None): + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('--query', type=float, + help="Query ratio in the test set (default: %.3f)" + % cls._default_query_ratio) + parser.add_argument('--attr', type=str, dest='attr_for_id', + help="Attribute name representing the ID (default: use label)") + return parser + + def __init__(self, dataset, splits, query=None, + attr_for_id=None, seed=None): """ Parameters ---------- @@ -221,51 +279,61 @@ def __init__(self, dataset, splits, test_splits, pid_name="PID", seed=None): A list of (subset(str), ratio(float)) Subset is expected to be one of ["train", "val", "test"]. The sum of ratios is expected to be 1. - test_splits : list - A list of (subset(str), ratio(float)) - Subset is expected to be one of ["gallery", "query"]. - The sum of ratios is expected to be 1. - pid_name: str - attribute name representing the person id. (default: PID) + query : float + The ratio of 'test-query' set. + The ratio of 'test-gallery' set would be 1.0 - query. + attr_for_id: str + attribute name representing the person/object id. + if this is not specified, label would be used. seed : int, optional """ super().__init__(dataset, splits, seed) + if query is None: + query = self._default_query_ratio + + assert 0.0 <= query and query <= 1.0, \ + "Query ratio is expected to be in the range " \ + "[0, 1], but got %f" % query + test_splits = [('test-query', query), ('test-gallery', 1.0 - query)] + + # reset output subset names + self._subsets = {"train", "val", "test-gallery", "test-query"} self._test_splits = test_splits - self._pid_name = pid_name + self._attr_for_id = attr_for_id def _split_dataset(self): np.random.seed(self._seed) id_snames, id_ratio = self._snames, self._sratio - pid_name = self._pid_name + attr_for_id = self._attr_for_id dataset = self._extractor - groups = set() - - # group by PID(pid_name) - by_pid = dict() + # group by ID(attr_for_id) + by_id = dict() annotations = self._get_uniq_annotations(dataset) - for idx, ann in enumerate(annotations): - attributes = dict(ann.attributes.items()) - assert pid_name in attributes, \ - "'%s' is expected as an attribute name" % pid_name - person_id = attributes[pid_name] - if person_id not in by_pid: - by_pid[person_id] = [] - by_pid[person_id].append((idx, ann)) - groups.add(ann.group) - - max_group_id = max(groups) - self._group_map["gallery"] = max_group_id + 1 - self._group_map["query"] = max_group_id + 2 + if attr_for_id is None: # use label + for idx, ann in enumerate(annotations): + ID = getattr(ann, 'label', None) + if ID not in by_id: + by_id[ID] = [] + by_id[ID].append((idx, ann)) + else: # use attr_for_id + for idx, ann in enumerate(annotations): + attributes = dict(ann.attributes.items()) + assert attr_for_id in attributes, \ + "'%s' is expected as an attribute name" % attr_for_id + ID = attributes[attr_for_id] + if ID not in by_id: + by_id[ID] = [] + by_id[ID].append((idx, ann)) required = self._get_required(id_ratio) - if len(by_pid) < required: + if len(by_id) < required: log.warning("There's not enough IDs, which is %s, " "so train/val/test ratio can't be guaranteed." - % len(by_pid) + % len(by_id) ) # 1. split dataset into trval and test @@ -273,12 +341,12 @@ def _split_dataset(self): test = id_ratio[id_snames.index("test")] if "test" in id_snames else 0 if NEAR_ZERO < test: # has testset split_ratio = np.array([test, 1.0 - test]) - person_ids = list(by_pid.keys()) - np.random.shuffle(person_ids) - sections = self._get_sections(len(person_ids), split_ratio) - splits = np.array_split(person_ids, sections) - testset = {pid: by_pid[pid] for pid in splits[0]} - trval = {pid: by_pid[pid] for pid in splits[1]} + IDs = list(by_id.keys()) + np.random.shuffle(IDs) + sections = self._get_sections(len(IDs), split_ratio) + splits = np.array_split(IDs, sections) + testset = {pid: by_id[pid] for pid in splits[0]} + trval = {pid: by_id[pid] for pid in splits[1]} # follow the ratio of datasetitems as possible. # naive heuristic: exchange the best item one by one. @@ -287,32 +355,22 @@ def _split_dataset(self): self._rebalancing(testset, trval, expected_count, testset_total) else: testset = dict() - trval = by_pid + trval = by_id by_splits = dict() for subset in self._subsets: by_splits[subset] = [] - # 2. split 'test' into 'gallery' and 'query' + # 2. split 'test' into 'test-gallery' and 'test-query' if 0 < len(testset): - for person_id, items in testset.items(): - indice = [idx for idx, _ in items] - by_splits["test"].extend(indice) - - valid = ["gallery", "query"] - test_splits = self._test_splits - test_snames, test_ratio = self._validate_splits(test_splits, valid) - by_groups = {s: [] for s in test_snames} - self._split_by_attr(testset, test_snames, test_ratio, by_groups, - dataset_key=pid_name) - - # tag using group - for idx, item in enumerate(self._extractor): - for subset, split in by_groups.items(): - if idx in split: - group_id = self._group_map[subset] - item.annotations[0].group = group_id - break + test_snames = [] + test_ratio = [] + for sname, ratio in self._test_splits: + test_snames.append(sname) + test_ratio.append(float(ratio)) + + self._split_by_attr(testset, test_snames, test_ratio, by_splits, + dataset_key=attr_for_id) # 3. split 'trval' into 'train' and 'val' trval_snames = ["train", "val"] @@ -334,7 +392,7 @@ def _split_dataset(self): else: trval_ratio /= total_ratio # normalize self._split_by_attr(trval, trval_snames, trval_ratio, by_splits, - dataset_key=pid_name) + dataset_key=attr_for_id) self._set_parts(by_splits) @@ -352,6 +410,9 @@ def _rebalancing(test, trval, expected_count, testset_total): diffs[diff] = [(id_test, id_trval)] else: diffs[diff].append((id_test, id_trval)) + if len(diffs) == 0: # nothing would be changed by exchange + return + exchanges = [] while True: target_diff = expected_count - testset_total @@ -362,47 +423,49 @@ def _rebalancing(test, trval, expected_count, testset_total): if abs(target_diff) <= abs(target_diff - nearest): break choice = np.random.choice(range(len(diffs[nearest]))) - pid_test, pid_trval = diffs[nearest][choice] + id_test, id_trval = diffs[nearest][choice] testset_total += nearest new_diffs = dict() - for diff, person_ids in diffs.items(): + for diff, IDs in diffs.items(): new_list = [] - for id1, id2 in person_ids: - if id1 == pid_test or id2 == pid_trval: + for id1, id2 in IDs: + if id1 == id_test or id2 == id_trval: continue new_list.append((id1, id2)) if 0 < len(new_list): new_diffs[diff] = new_list diffs = new_diffs - exchanges.append((pid_test, pid_trval)) - # exchange - for pid_test, pid_trval in exchanges: - test[pid_trval] = trval.pop(pid_trval) - trval[pid_test] = test.pop(pid_test) + exchanges.append((id_test, id_trval)) - def get_subset_by_group(self, group: str): - available = list(self._group_map.keys()) - assert group in self._group_map, \ - "Unknown group '%s', available groups: %s" \ - % (group, available) - group_id = self._group_map[group] - return self.select(lambda item: item.annotations[0].group == group_id) + # exchange + for id_test, id_trval in exchanges: + test[id_trval] = trval.pop(id_trval) + trval[id_test] = test.pop(id_test) class DetectionSplit(_TaskSpecificSplit): """ - Splits dataset into train/val/test set for detection task.|n - For detection dataset, each image can have multiple bbox annotations.|n - Since one DataItem can't be included in multiple subsets at the same time, - the dataset can't be divided according to the bbox annotations.|n - Thus, we split dataset based on DatasetItem - while preserving label distribution as possible.|n + Splits a dataset into train/val/test subsets for detection task, + using object annotations as a basis for splitting.|n + Tries to produce an image split with the specified ratio, keeping the + initial distribution of class objects.|n + |n + In a detection dataset, each image can have multiple object annotations - + instance bounding boxes. Since an image shouldn't be included + in multiple subsets at the same time, and image annotations + shoudln't be split, in general, dataset annotations are unlikely to be split + exactly in the specified ratio. |n + This split tries to split dataset images as close as possible + to the specified ratio, keeping the initial class distribution.|n |n Notes:|n - - Each DatsetItem is expected to have one or more Bbox annotations.|n - - Label annotations are ignored. We only focus on the Bbox annotations.|n + - Each image is expected to have one or more Bbox annotations.|n + - Only Bbox annotations are considered.|n + |n + Example: split dataset so that each object class annotations were split|n + |s|s|s|sin the specified ratio between subsets|n + |s|s%(prog)s --subset train:.5 --subset val:.2 --subset test:.3 """ - def __init__(self, dataset, splits, seed=None): """ Parameters @@ -474,7 +537,6 @@ def _split_dataset(self): (sname, {k: v * ratio for k, v in n_combs.items()}) ) - ## # functions for keep the # of annotations not exceed the expected num def compute_penalty(counts, n_combs): p = 0 @@ -489,8 +551,6 @@ def update_nc(counts, n_combs): n_combs[k] = -1 return n_combs - ## - # 3-2. assign each DatasetItem to a split, one by one for idx, _ in sorted( init_scores.items(), key=lambda item: item[1], reverse=True diff --git a/docs/user_manual.md b/docs/user_manual.md index a38b1560ca..4c624a4332 100644 --- a/docs/user_manual.md +++ b/docs/user_manual.md @@ -346,6 +346,13 @@ datum filter \ -e '/item[image/width < image/height]' ``` +Example: extract a dataset with only images of subset `train`. +``` bash +datum project filter \ + -p test_project \ + -e '/item[subset="train"]' +``` + Example: extract a dataset with only large annotations of class `cat` and any non-`persons` ``` bash @@ -954,6 +961,20 @@ Example: split a dataset randomly to `train` and `test` subsets, ratio is 2:1 datum transform -t random_split -- --subset train:.67 --subset test:.33 ``` +Example: split a dataset in task-specific manner. Supported tasks are +classification, detection, and re-identification. + +``` bash +datum transform -t classification_split -- \ + --subset train:.5 --subset val:.2 --subset test:.3 + +datum transform -t detection_split -- \ + --subset train:.5 --subset val:.2 --subset test:.3 + +datum transform -t reidentification_split -- \ + --subset train:.5 --subset val:.2 --subset test:.3 --query .5 +``` + Example: convert polygons to masks, masks to boxes etc.: ``` bash diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 276ed5f557..ba3cb5a174 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -202,16 +202,16 @@ def test_split_for_classification_multi_label_with_attr(self): self.assertEqual(9, attr_test["attr3"]["distribution"]["1"][0]) self.assertEqual(15, attr_test["attr3"]["distribution"]["2"][0]) - # random seed test - r1 = splitter.ClassificationSplit(source, splits, seed=1234) - r2 = splitter.ClassificationSplit(source, splits, seed=1234) - r3 = splitter.ClassificationSplit(source, splits, seed=4321) - self.assertEqual( - list(r1.get_subset("test")), list(r2.get_subset("test")) - ) - self.assertNotEqual( - list(r1.get_subset("test")), list(r3.get_subset("test")) - ) + with self.subTest("random seed test"): + r1 = splitter.ClassificationSplit(source, splits, seed=1234) + r2 = splitter.ClassificationSplit(source, splits, seed=1234) + r3 = splitter.ClassificationSplit(source, splits, seed=4321) + self.assertEqual( + list(r1.get_subset("test")), list(r2.get_subset("test")) + ) + self.assertNotEqual( + list(r1.get_subset("test")), list(r3.get_subset("test")) + ) def test_split_for_classification_gives_error(self): with self.subTest("no label"): @@ -255,74 +255,127 @@ def test_split_for_classification_gives_error(self): splits = [("train_", 0.5), ("val", 0.2), ("test", 0.3)] splitter.ClassificationSplit(source, splits) - def test_split_for_matching_reid(self): - counts = {i: (i % 3 + 1) * 7 for i in range(10)} - config = {"person": {"attrs": ["PID"], "counts": counts}} + def test_split_for_reidentification(self): + ''' + Test ReidentificationSplit using Dataset with label (ImageNet style) + ''' + def _get_present(stat): + values_present = [] + for label, dist in stat["distribution"].items(): + if dist[0] > 0: + values_present.append(label) + return set(values_present) + + for with_attr in [True, False]: + if with_attr: + counts = {i: (i % 3 + 1) * 7 for i in range(10)} + config = {"person": {"attrs": ["PID"], "counts": counts}} + attr_for_id = "PID" + else: + counts = {} + config = dict() + for i in range(10): + label = "label%d" % i + count = (i % 3 + 1) * 7 + counts[label] = count + config[label] = {"attrs": None, "counts": count} + attr_for_id = None + source = self._generate_dataset(config) + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + query = 0.4 / 0.7 + actual = splitter.ReidentificationSplit(source, + splits, query, attr_for_id) + + stats = dict() + for sname in ["train", "val", "test-query", "test-gallery"]: + subset = actual.get_subset(sname) + stat = compute_ann_statistics(subset)["annotations"]["labels"] + if with_attr: + stat = stat["attributes"]["PID"] + stats[sname] = stat + + # check size of subsets + self.assertEqual(65, stats["train"]["count"]) + self.assertEqual(26, stats["val"]["count"]) + self.assertEqual(18, stats["test-gallery"]["count"]) + self.assertEqual(24, stats["test-query"]["count"]) + + # check ID separation between test set and others + train_ids = _get_present(stats["train"]) + test_ids = _get_present(stats["test-gallery"]) + for pid in train_ids: + assert pid not in test_ids + self.assertEqual(7, len(train_ids)) + self.assertEqual(3, len(test_ids)) + self.assertEqual(train_ids, _get_present(stats["val"])) + self.assertEqual(test_ids, _get_present(stats["test-query"])) + + # check trainval set statistics + trainval = stats["train"]["count"] + stats["val"]["count"] + expected_train_count = int(trainval * 0.5 / 0.7) + expected_val_count = int(trainval * 0.2 / 0.7) + self.assertEqual(expected_train_count, stats["train"]["count"]) + self.assertEqual(expected_val_count, stats["val"]["count"]) + dist_train = stats["train"]["distribution"] + dist_val = stats["val"]["distribution"] + for pid in train_ids: + total = counts[int(pid)] if with_attr else counts[pid] + self.assertEqual(int(total * 0.5 / 0.7), dist_train[pid][0]) + self.assertEqual(int(total * 0.2 / 0.7), dist_val[pid][0]) + + # check teset set statistics + dist_gallery = stats["test-gallery"]["distribution"] + dist_query = stats["test-query"]["distribution"] + for pid in test_ids: + total = counts[int(pid)] if with_attr else counts[pid] + self.assertEqual(int(total * 0.3 / 0.7), dist_gallery[pid][0]) + self.assertEqual(int(total * 0.4 / 0.7), dist_query[pid][0]) + + def test_split_for_reidentification_randomseed(self): + ''' + Test randomseed for reidentification + ''' + counts = {} + config = dict() + for i in range(10): + label = "label%d" % i + count = (i % 3 + 1) * 7 + counts[label] = count + config[label] = {"attrs": None, "counts": count} source = self._generate_dataset(config) - - splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) - - stats = dict() - for sname in ["train", "val", "test"]: - subset = actual.get_subset(sname) - stat_subset = compute_ann_statistics(subset)["annotations"] - stat_attr = stat_subset["labels"]["attributes"]["PID"] - stats[sname] = stat_attr - - for sname in ["gallery", "query"]: - subset = actual.get_subset_by_group(sname) - stat_subset = compute_ann_statistics(subset)["annotations"] - stat_attr = stat_subset["labels"]["attributes"]["PID"] - stats[sname] = stat_attr - - self.assertEqual(65, stats["train"]["count"]) # depends on heuristic - self.assertEqual(26, stats["val"]["count"]) # depends on heuristic - self.assertEqual(42, stats["test"]["count"]) # depends on heuristic - - train_ids = stats["train"]["values present"] - self.assertEqual(7, len(train_ids)) - self.assertEqual(train_ids, stats["val"]["values present"]) - - trainval = stats["train"]["count"] + stats["val"]["count"] - self.assertEqual(int(trainval * 0.5 / 0.7), stats["train"]["count"]) - self.assertEqual(int(trainval * 0.2 / 0.7), stats["val"]["count"]) - - dist_train = stats["train"]["distribution"] - dist_val = stats["val"]["distribution"] - for pid in train_ids: - total = counts[int(pid)] - self.assertEqual(int(total * 0.5 / 0.7), dist_train[pid][0]) - self.assertEqual(int(total * 0.2 / 0.7), dist_val[pid][0]) - - test_ids = stats["test"]["values present"] - self.assertEqual(3, len(test_ids)) - self.assertEqual(test_ids, stats["gallery"]["values present"]) - self.assertEqual(test_ids, stats["query"]["values present"]) - - dist_test = stats["test"]["distribution"] - dist_gallery = stats["gallery"]["distribution"] - dist_query = stats["query"]["distribution"] - for pid in test_ids: - total = counts[int(pid)] - self.assertEqual(total, dist_test[pid][0]) - self.assertEqual(int(total * 0.3 / 0.7), dist_gallery[pid][0]) - self.assertEqual(int(total * 0.4 / 0.7), dist_query[pid][0]) - - # random seed test splits = [("train", 0.5), ("test", 0.5)] - r1 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=1234) - r2 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=1234) - r3 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=4321) + query = 0.4 / 0.7 + r1 = splitter.ReidentificationSplit(source, splits, query, seed=1234) + r2 = splitter.ReidentificationSplit(source, splits, query, seed=1234) + r3 = splitter.ReidentificationSplit(source, splits, query, seed=4321) self.assertEqual( - list(r1.get_subset("test")), list(r2.get_subset("test")) + list(r1.get_subset("train")), list(r2.get_subset("train")) ) self.assertNotEqual( - list(r1.get_subset("test")), list(r3.get_subset("test")) + list(r1.get_subset("train")), list(r3.get_subset("train")) ) - def test_split_for_matching_reid_gives_error(self): + def test_split_for_reidentification_rebalance(self): + ''' + rebalance function shouldn't gives error when there's no exchange + ''' + config = dict() + for i in range(100): + label = "label%03d" % i + config[label] = {"attrs": None, "counts": 7} + source = self._generate_dataset(config) + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + query = 0.4 / 0.7 + actual = splitter.ReidentificationSplit(source, splits, query) + + self.assertEqual(350, len(actual.get_subset("train"))) + self.assertEqual(140, len(actual.get_subset("val"))) + self.assertEqual(90, len(actual.get_subset("test-gallery"))) + self.assertEqual(120, len(actual.get_subset("test-query"))) + + def test_split_for_reidentification_gives_error(self): + query = 0.4 / 0.7 # valid query ratio + with self.subTest("no label"): source = Dataset.from_iterable([ DatasetItem(1, annotations=[]), @@ -331,8 +384,7 @@ def test_split_for_matching_reid_gives_error(self): with self.assertRaisesRegex(Exception, "exactly one is expected"): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) + actual = splitter.ReidentificationSplit(source, splits, query) len(actual.get_subset("train")) with self.subTest(msg="multi label"): @@ -343,8 +395,7 @@ def test_split_for_matching_reid_gives_error(self): with self.assertRaisesRegex(Exception, "exactly one is expected"): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) + actual = splitter.ReidentificationSplit(source, splits, query) len(actual.get_subset("train")) counts = {i: (i % 3 + 1) * 7 for i in range(10)} @@ -353,45 +404,27 @@ def test_split_for_matching_reid_gives_error(self): with self.subTest("wrong ratio"): with self.assertRaisesRegex(Exception, "in the range"): splits = [("train", -0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - splitter.MatchingReIDSplit(source, splits, test_splits) + splitter.ReidentificationSplit(source, splits, query) with self.assertRaisesRegex(Exception, "Sum of ratios"): splits = [("train", 0.6), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - splitter.MatchingReIDSplit(source, splits, test_splits) + splitter.ReidentificationSplit(source, splits, query) with self.assertRaisesRegex(Exception, "in the range"): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", -0.4 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) - len(actual.get_subset_by_group("query")) - - with self.assertRaisesRegex(Exception, "Sum of ratios"): - splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.5 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) - len(actual.get_subset_by_group("query")) + actual = splitter.ReidentificationSplit(source, splits, -query) with self.subTest("wrong subset name"): with self.assertRaisesRegex(Exception, "Subset name"): splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - splitter.MatchingReIDSplit(source, splits, test_splits) - - with self.assertRaisesRegex(Exception, "Subset name"): - splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("_query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) - len(actual.get_subset_by_group("query")) + splitter.ReidentificationSplit(source, splits, query) with self.subTest("wrong attribute name for person id"): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] - actual = splitter.MatchingReIDSplit(source, splits, test_splits) + actual = splitter.ReidentificationSplit(source, splits, query) - with self.assertRaisesRegex(Exception, "Unknown group"): - actual.get_subset_by_group("_gallery") + with self.assertRaisesRegex(Exception, "Unknown subset"): + actual.get_subset("test") def _generate_detection_dataset(self, **kwargs): append_bbox = kwargs.get("append_bbox")