Skip to content
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `VGGFace2` dataset format (<https://github.com/openvinotoolkit/datumaro/pull/69>)

### Changed
-
- `Dataset` class extended with new operations: `save`, `load`, `export`, `import_from`, `detect`, `run_model` (<https://github.com/openvinotoolkit/datumaro/pull/71>)
- `Dataset` operations return `Dataset` instances, allowing to chain operations (<https://github.com/openvinotoolkit/datumaro/pull/71>)
- Allowed importing `Extractor`-only defined formats (in `Project.import_from`, `dataset.import_from` and CLI/`project import`)

### Deprecated
-
Expand Down
42 changes: 9 additions & 33 deletions datumaro/cli/commands/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os.path as osp

from datumaro.components.project import Environment
from datumaro.components.dataset import Dataset

from ..contexts.project import FilterModes
from ..util import CliException, MultilineFormatter, make_file_name
Expand Down Expand Up @@ -68,46 +69,24 @@ def convert_command(args):
raise CliException("Converter for format '%s' is not found" % \
args.output_format)
extra_args = converter.from_cmdline(args.extra_args)
def converter_proxy(extractor, save_dir):
return converter.convert(extractor, save_dir, **extra_args)

filter_args = FilterModes.make_filter_args(args.filter_mode)

fmt = args.input_format
if not args.input_format:
matches = []
for format_name in env.importers.items:
log.debug("Checking '%s' format...", format_name)
importer = env.make_importer(format_name)
try:
match = importer.detect(args.source)
if match:
log.debug("format matched")
matches.append((format_name, importer))
except NotImplementedError:
log.debug("Format '%s' does not support auto detection.",
format_name)

matches = env.detect_dataset(args.source)
if len(matches) == 0:
log.error("Failed to detect dataset format. "
"Try to specify format with '-if/--input-format' parameter.")
return 1
elif len(matches) != 1:
log.error("Multiple formats match the dataset: %s. "
"Try to specify format with '-if/--input-format' parameter.",
', '.join(m[0] for m in matches))
', '.join(matches))
return 2

format_name, importer = matches[0]
args.input_format = format_name
fmt = matches[0]
log.info("Source dataset format detected as '%s'", args.input_format)
else:
try:
importer = env.make_importer(args.input_format)
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline()
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.input_format)

source = osp.abspath(args.source)

Expand All @@ -121,15 +100,12 @@ def converter_proxy(extractor, save_dir):
(osp.basename(source), make_file_name(args.output_format)))
dst_dir = osp.abspath(dst_dir)

project = importer(source)
dataset = project.make_dataset()
dataset = Dataset.import_from(source, fmt)

log.info("Exporting the dataset")
dataset.export_project(
save_dir=dst_dir,
converter=converter_proxy,
filter_expr=args.filter,
**filter_args)
if args.filter:
dataset = dataset.filter(args.filter, **filter_args)
dataset.export(args.output_format, save_dir=dst_dir, **extra_args)

log.info("Dataset exported to '%s' as '%s'" % \
(dst_dir, args.output_format))
Expand Down
2 changes: 1 addition & 1 deletion datumaro/cli/contexts/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os.path as osp
import re

from datumaro.components.config import DEFAULT_FORMAT
from datumaro.components.dataset import DEFAULT_FORMAT
from datumaro.components.project import Environment

from ...util import CliException, MultilineFormatter, add_subparser
Expand Down
51 changes: 22 additions & 29 deletions datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,50 +172,43 @@ def import_command(args):
log.info("Importing project from '%s'" % args.source)

extra_args = {}
fmt = args.format
if not args.format:
if args.extra_args:
raise CliException("Extra args can not be used without format")

log.info("Trying to detect dataset format...")

matches = []
for format_name in env.importers.items:
log.debug("Checking '%s' format...", format_name)
importer = env.make_importer(format_name)
try:
match = importer.detect(args.source)
if match:
log.debug("format matched")
matches.append((format_name, importer))
except NotImplementedError:
log.debug("Format '%s' does not support auto detection.",
format_name)

matches = env.detect_dataset(args.source)
if len(matches) == 0:
log.error("Failed to detect dataset format automatically. "
"Try to specify format with '-f/--format' parameter.")
return 1
elif len(matches) != 1:
log.error("Multiple formats match the dataset: %s. "
"Try to specify format with '-f/--format' parameter.",
', '.join(m[0] for m in matches))
', '.join(matches))
return 2

format_name, importer = matches[0]
args.format = format_name
else:
try:
importer = env.make_importer(args.format)
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.format)

log.info("Importing project as '%s'" % args.format)

source = osp.abspath(args.source)
project = importer(source, **extra_args)
fmt = matches[0]
elif args.extra_args:
if fmt in env.importers:
arg_parser = env.importers[fmt]
elif fmt in env.extractors:
arg_parser = env.extractors[fmt]
else:
raise CliException("Unknown format '%s'. A format can be added"
"by providing an Extractor and Importer plugins" % fmt)

if hasattr(arg_parser, 'from_cmdline'):
extra_args = arg_parser.from_cmdline(args.extra_args)
else:
raise CliException("Format '%s' does not accept "
"extra parameters" % fmt)

log.info("Importing project as '%s'" % fmt)

project = Project.import_from(osp.abspath(args.source), fmt, **extra_args)
project.config.project_name = project_name
project.config.project_dir = project_dir

Expand Down
3 changes: 0 additions & 3 deletions datumaro/components/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,3 @@ def set(self, key, value):
return super().set(key, value)
else:
return super().set(key, value)


DEFAULT_FORMAT = 'datumaro'
117 changes: 110 additions & 7 deletions datumaro/components/dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
# Copyright (C) 2020 Intel Corporation
# Copyright (C) 2020-2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

from collections import OrderedDict, defaultdict
from typing import Iterable, Union, Dict, List
import logging as log
import os
import os.path as osp
import shutil

from datumaro.components.extractor import (Extractor, LabelCategories,
AnnotationType, DatasetItem, DEFAULT_SUBSET_NAME)
from datumaro.components.dataset_filter import \
XPathDatasetFilter, XPathAnnotationsFilter
from datumaro.components.environment import Environment
from datumaro.util import error_rollback
from datumaro.util.log_utils import logging_disabled


DEFAULT_FORMAT = 'datumaro'

class Dataset(Extractor):
class Subset(Extractor):
def __init__(self, parent):
Expand All @@ -28,7 +37,8 @@ def categories(self):

@classmethod
def from_iterable(cls, iterable: Iterable[DatasetItem],
categories: Union[Dict, List[str]] = None):
categories: Union[Dict, List[str]] = None,
env: Environment = None):
if isinstance(categories, list):
categories = { AnnotationType.label:
LabelCategories.from_iterable(categories)
Expand All @@ -44,12 +54,12 @@ def __iter__(self):
def categories(self):
return categories

return cls.from_extractors(_extractor())
return cls.from_extractors(_extractor(), env=env)

@classmethod
def from_extractors(cls, *sources):
def from_extractors(cls, *sources, env=None):
categories = cls._merge_categories(s.categories() for s in sources)
dataset = Dataset(categories=categories)
dataset = Dataset(categories=categories, env=env)

# merge items
subsets = defaultdict(lambda: cls.Subset(dataset))
Expand All @@ -67,9 +77,12 @@ def from_extractors(cls, *sources):
dataset._subsets = dict(subsets)
return dataset

def __init__(self, categories=None):
def __init__(self, categories=None, env=None):
super().__init__()

assert env is None or isinstance(env, Environment), env
self._env = env

self._subsets = {}

if not categories:
Expand Down Expand Up @@ -183,4 +196,94 @@ def _merge_anno(a, b):
def _merge_categories(sources):
# TODO: implement properly with merging and annotations remapping
from .operations import merge_categories
return merge_categories(sources)
return merge_categories(sources)

@error_rollback('on_error', implicit=True)
def export(self, converter, save_dir, **kwargs):
if isinstance(converter, str):
converter = self.env.make_converter(converter)

save_dir = osp.abspath(save_dir)
if not osp.exists(save_dir):
on_error.do(shutil.rmtree, save_dir, ignore_errors=True)
os.makedirs(save_dir, exist_ok=True)
converter(self, save_dir=save_dir, **kwargs)

def transform(self, method, *args, **kwargs):
if isinstance(method, str):
method = self.env.make_transform(method)

result = super().transform(method, *args, **kwargs)
return Dataset.from_extractors(result, env=self._env)

def run_model(self, model, batch_size=1):
from datumaro.components.launcher import Launcher, ModelTransform
if isinstance(model, Launcher):
return self.transform(ModelTransform, launcher=model,
batch_size=batch_size)
elif isinstance(model, ModelTransform):
return self.transform(model, batch_size=batch_size)
else:
raise TypeError('Unexpected model argument type: %s' % type(model))

@property
def env(self):
if not self._env:
self._env = Environment()
return self._env

def save(self, save_dir, **kwargs):
self.export(DEFAULT_FORMAT, save_dir=save_dir, **kwargs)

@classmethod
def load(cls, path, **kwargs):
return cls.import_from(path, format=DEFAULT_FORMAT, **kwargs)

@classmethod
def import_from(cls, path, format=None, env=None, **kwargs): #pylint: disable=redefined-builtin
from datumaro.components.config_model import Source

if env is None:
env = Environment()

# TODO: remove importers, put this logic into extractors
if not format:
format = cls.detect(path, env)
if format in env.importers:
importer = env.make_importer(format)
with logging_disabled(log.INFO):
project = importer(path, **kwargs)
detected_sources = list(project.config.sources.values())
elif format in env.extractors:
detected_sources = [{
'url': path, 'format': format, 'options': kwargs
}]
else:
raise Exception("Unknown source format '%s'. To make it "
"available, add the corresponding Extractor implementation "
"to the environment" % format)

extractors = []
for src_conf in detected_sources:
if not isinstance(src_conf, Source):
src_conf = Source(src_conf)
extractors.append(env.make_extractor(
src_conf.format, src_conf.url, **src_conf.options
))

return cls.from_extractors(*extractors)

@staticmethod
def detect(path, env=None):
if env is None:
env = Environment()

matches = env.detect_dataset(path)
if not matches:
raise Exception("Failed to detect dataset format automatically: "
"no matching formats found")
if 1 < len(matches):
raise Exception("Failed to detect dataset format automatically:"
" data matches more than one format: %s" % \
', '.join(matches))
return matches[0]
19 changes: 19 additions & 0 deletions datumaro/components/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,22 @@ def register_model(self, name, model):

def unregister_model(self, name):
self.models.unregister(name)

def is_format_known(self, name):
return name in self.importers or name in self.extractors

def detect_dataset(self, path):
matches = []

for format_name, importer in self.importers.items.items():
log.debug("Checking '%s' format...", format_name)
try:
match = importer.detect(path)
if match:
log.debug("format matched")
matches.append(format_name)
except NotImplementedError:
log.debug("Format '%s' does not support auto detection.",
format_name)

return matches
Loading