diff --git a/src/datumaro/__init__.py b/src/datumaro/__init__.py index 263ca82ac6..c448057c58 100644 --- a/src/datumaro/__init__.py +++ b/src/datumaro/__init__.py @@ -6,7 +6,6 @@ from . import errors as errors from . import ops as ops -from . import project as project from .components.algorithms import LossDynamicsAnalyzer from .components.annotation import ( NO_GROUP, diff --git a/src/datumaro/components/project.py b/src/datumaro/components/project.py deleted file mode 100644 index 730cf25273..0000000000 --- a/src/datumaro/components/project.py +++ /dev/null @@ -1,2710 +0,0 @@ -# Copyright (C) 2019-2023 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from __future__ import annotations - -import logging as log -import os -import os.path as osp -import re -import shutil -import tempfile -import unittest.mock -from contextlib import ExitStack, suppress -from enum import Enum, auto -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generic, - Iterable, - Iterator, - List, - NewType, - Optional, - Tuple, - TypeVar, - Union, -) - -from datumaro.components.config import Config -from datumaro.components.config_model import ( - BuildStage, - BuildTarget, - Model, - PipelineConfig, - ProjectConfig, - ProjectLayout, - Source, - TreeConfig, - TreeLayout, -) -from datumaro.components.dataset import DEFAULT_FORMAT, Dataset, IDataset -from datumaro.components.environment import Environment -from datumaro.components.errors import ( - DatasetMergeError, - EmptyCommitError, - EmptyPipelineError, - ForeignChangesError, - InvalidStageError, - MigrationError, - MismatchingObjectError, - MissingObjectError, - MissingPipelineHeadError, - MissingSourceHashError, - MultiplePipelineHeadsError, - OldProjectError, - PathOutsideSourceError, - ProjectAlreadyExists, - ProjectNotFoundError, - ReadonlyDatasetError, - ReadonlyProjectError, - SourceExistsError, - SourceUrlInsideProjectError, - UnexpectedUrlError, - UnknownRefError, - UnknownSourceError, - UnknownStageError, - UnknownTargetError, - UnsavedChangesError, - VcsAlreadyExists, - VcsError, -) -from datumaro.components.launcher import Launcher -from datumaro.util import find, parse_json_file, parse_str_enum_value -from datumaro.util.deprecation import deprecated -from datumaro.util.log_utils import catch_logs, logging_disabled -from datumaro.util.os_util import ( - copytree, - generate_next_name, - is_subpath, - make_file_name, - rmfile, - rmtree, -) -from datumaro.util.scope import on_error_do, scope_add, scoped - -if TYPE_CHECKING: - import networkx as nx - -else: - from datumaro.util.import_util import lazy_import - - nx = lazy_import("networkx") - - -@deprecated(deprecated_version="1.11", removed_version="1.12") -class ProjectSourceDataset(IDataset): - def __init__(self, path: str, tree: Tree, source: str, readonly: bool = False): - config = tree.sources[source] - - rpath = path - if config.path: - rpath = osp.join(path, config.path) - if "path" in config.options: - rpath = osp.join(path, config.options.pop("path")) - - dataset = Dataset.import_from(rpath, env=tree.env, format=config.format, **config.options) - - # Using rpath won't allow to save directly with .save() when a file - # path is specified. Dataset doesn't know the root location and if - # it exists at all, but in a project, we do. - dataset.bind(path, format=dataset.format, options=dataset.options) - - self.__dict__["_dataset"] = dataset - - self.__dict__["_config"] = config - self.__dict__["_readonly"] = readonly - self.__dict__["name"] = source - - def save(self, save_dir=None, **kwargs): - if self.readonly and ( - save_dir is None or osp.abspath(save_dir) == osp.abspath(self.data_path) - ): - raise ReadonlyDatasetError() - self._dataset.save(save_dir, **kwargs) - - @property - def readonly(self): - return self._readonly or not self.is_bound - - @property - def config(self): - return self._config - - def __getattr__(self, name): - return getattr(self._dataset, name) - - def __setattr__(self, name, value): - return setattr(self._dataset, name, value) - - def __iter__(self): - yield from self._dataset - - def __len__(self): - return len(self._dataset) - - def subsets(self): - return self._dataset.subsets() - - def get_subset(self, name): - return self._dataset.get_subset(name) - - def infos(self): - return self._dataset.infos() - - def categories(self): - return self._dataset.categories() - - def get(self, id, subset=None): - return self._dataset.get(id, subset) - - def media_type(self): - return self._dataset.media_type() - - def ann_types(self): - return self._dataset.ann_types() - - -class IgnoreMode(Enum): - rewrite = auto() - append = auto() - remove = auto() - - -def _update_ignore_file( - paths: Union[str, List[str]], - repo_root: str, - filepath: str, - mode: Union[None, str, IgnoreMode] = None, -): - def _make_ignored_path(path): - path = osp.join(repo_root, osp.normpath(path)) - assert is_subpath(path, base=repo_root) - - # Prepend the '/' to match only direct childs. - # Otherwise the rule can be in any path part. - return "/" + osp.relpath(path, repo_root).replace("\\", "/") - - header = "# The file is autogenerated by Datumaro" - - mode = parse_str_enum_value(mode, IgnoreMode, IgnoreMode.append) - - if isinstance(paths, str): - paths = [paths] - paths = {osp.join(repo_root, osp.normpath(p)): _make_ignored_path(p) for p in paths} - - openmode = "r+" - if not osp.isfile(filepath): - openmode = "w+" # r+ cannot create, w truncates - with open(filepath, openmode) as f: - lines = [] - if mode in {IgnoreMode.append, IgnoreMode.remove}: - for line in f: - lines.append(line.strip()) - f.seek(0) - - new_lines = [] - for line in lines: - if not line or line.startswith("#"): - new_lines.append(line) - continue - - line_path = osp.join( - repo_root, - osp.normpath(line.split("#", maxsplit=1)[0]).replace("\\", "/").lstrip("/"), - ) - - if mode == IgnoreMode.append: - if line_path in paths: - paths.pop(line_path) - new_lines.append(line) - elif mode == IgnoreMode.remove: - if line_path not in paths: - new_lines.append(line) - - if mode in {IgnoreMode.rewrite, IgnoreMode.append}: - new_lines.extend(paths.values()) - - if not new_lines or new_lines[0] != header: - print(header, file=f) - for line in new_lines: - print(line, file=f) - f.truncate() - - -CrudEntry = TypeVar("CrudEntry") -T = TypeVar("T") - - -class CrudProxy(Generic[CrudEntry]): - @property - def _data(self) -> Dict[str, CrudEntry]: - raise NotImplementedError() - - def __len__(self): - return len(self._data) - - def __getitem__(self, name: str) -> CrudEntry: - return self._data[name] - - def get( - self, name: str, default: Union[None, T, CrudEntry] = None - ) -> Union[None, T, CrudEntry]: - return self._data.get(name, default) - - def __iter__(self) -> Iterator[CrudEntry]: - return iter(self._data.keys()) - - def items(self) -> Iterable[Tuple[str, CrudEntry]]: - return iter(self._data.items()) - - def __contains__(self, name: str): - return name in self._data - - -class _DataSourceBase(CrudProxy[Source]): - def __init__(self, tree: Tree, config_field: str): - self._tree = tree - self._field = config_field - - @property - def _data(self) -> Dict[str, Source]: - return self._tree.config[self._field] - - def add(self, name: str, value: Union[Dict, Config, Source]) -> Source: - if name in self: - raise SourceExistsError(name) - - return self._data.set(name, value) - - def remove(self, name: str): - self._data.remove(name) - - -@deprecated(deprecated_version="1.11", removed_version="1.12") -class ProjectSources(_DataSourceBase): - def __init__(self, tree: Tree): - super().__init__(tree, "sources") - - def __getitem__(self, name): - try: - return super().__getitem__(name) - except KeyError as e: - raise KeyError("Unknown source '%s'" % name) from e - - -class BuildStageType(Enum): - source = auto() - project = auto() - transform = auto() - filter = auto() - convert = auto() - inference = auto() - - -class Pipeline: - @staticmethod - def _create_graph(config: PipelineConfig): - graph = nx.DiGraph() - for entry in config: - target_name = entry["name"] - parents = entry["parents"] - target = BuildStage(entry["config"]) - - graph.add_node(target_name, config=target) - for prev_stage in parents: - graph.add_edge(prev_stage, target_name) - - return graph - - def __init__(self, config: PipelineConfig = None): - self._head = None - - if config is not None: - self._graph = self._create_craph(config) - if not self.head: - raise MissingPipelineHeadError() - else: - self._graph = nx.DiGraph() - - def __getattr__(self, key): - return getattr(self._graph, key) - - @staticmethod - def _find_head_node(graph) -> Optional[str]: - head = None - for node in graph.nodes: - if graph.out_degree(node) == 0: - if head is not None: - raise MultiplePipelineHeadsError( - "A pipeline can have only one " - "main target, but it has at least 2: %s, %s" % (head, node) - ) - head = node - return head - - @property - def head(self) -> str: - if self._head is None: - self._head = self._find_head_node(self._graph) - return self._head - - @property - def head_node(self): - return self._graph.nodes[self.head] - - @staticmethod - def _serialize(graph) -> PipelineConfig: - serialized = PipelineConfig() - for node_name, node in graph.nodes.items(): - serialized.nodes.append( - { - "name": node_name, - "parents": list(graph.predecessors(node_name)), - "config": dict(node["config"]), - } - ) - return serialized - - @staticmethod - def _get_subgraph(graph, target): - """ - Returns a subgraph with all the target dependencies and - the target itself. - """ - return graph.subgraph(nx.ancestors(graph, target) | {target}) - - def get_slice(self, target) -> Pipeline: - pipeline = Pipeline() - pipeline._graph = self._get_subgraph(self._graph, target).copy() - return pipeline - - -@deprecated(deprecated_version="1.11", removed_version="1.12") -class ProjectBuilder: - def __init__(self, project: Project, tree: Tree): - self._project = project - self._tree = tree - - def make_dataset(self, pipeline: Pipeline) -> IDataset: - dataset = self._get_resulting_dataset(pipeline) - - # TODO: May be need to save and load, because it can modify dataset, - # unless we work with the internal format. For example, it can - # add format-specific attributes. It should be needed as soon - # format converting stages (export, convert, load) are allowed. - # - # TODO: If the target was rebuilt from sources, it may require saving - # and hashing, so the resulting hash could be compared with the saved - # one in the pipeline. This is needed to make sure the reproduced - # version of the dataset is correct. Currently we only rely on the - # initial source version check, which can be not enough if stages - # produce different result (because of the library changes etc). - # - # save_in_cache(project, pipeline) # update and check hash in config! - # dataset = load_dataset(project, pipeline) - - return dataset - - def _run_pipeline(self, pipeline: Pipeline): - self._validate_pipeline(pipeline) - - missing_sources, wd_hashes = self._find_missing_sources(pipeline) - for source_name in missing_sources: - source = self._tree.sources[source_name] - - if wd_hashes.get(source_name): - raise ForeignChangesError( - "Local source '%s' data does not " - "match any previous source revision. Probably, the source " - "was modified outside Datumaro. You can restore the " - "latest source revision with 'checkout' command." % source_name - ) - - if self._project.readonly: - # Source re-downloading is prohibited in readonly projects - # because it can seriously hurt free storage space. It must - # be run manually, so that the user could know about this. - log.info( - "Skipping re-downloading missing source '%s', " - "because the project is read-only. Automatic downloading " - "is disabled in read-only projects.", - source_name, - ) - continue - - if not source.hash: - raise MissingSourceHashError( - "Unable to re-download source " - "'%s': the source was added with no hash information. " % source_name - ) - - with self._project._make_tmp_dir() as tmp_dir: - obj_hash, _, _ = self._project._download_source(source.url, tmp_dir) - - if source.hash and source.hash != obj_hash: - raise MismatchingObjectError( - "Downloaded source '%s' data is different " - "from what is saved in the build pipeline: " - "'%s' vs '%s'" % (source_name, obj_hash, source.hash) - ) - - return self._init_pipeline(pipeline, working_dir_hashes=wd_hashes) - - def _get_resulting_dataset(self, pipeline): - graph, head = self._run_pipeline(pipeline) - return graph.nodes[head]["dataset"] - - def _init_pipeline(self, pipeline: Pipeline, working_dir_hashes=None): - """ - Initializes datasets in the pipeline nodes. Currently, only the head - node will have a dataset on exit, so no extra memory is wasted - for the intermediate nodes. - """ - - def _join_parent_datasets(force=False): - parents = {p: graph.nodes[p] for p in graph.predecessors(stage_name)} - - if 1 < len(parents) or force: - try: - dataset = Dataset.from_extractors( - *(p["dataset"] for p in parents.values()), env=self._tree.env - ) - except DatasetMergeError as e: - e.sources = set(parents) - raise e - else: - dataset = next(iter(parents.values()))["dataset"] - - # clear fully utilized datasets to release memory - for p_name, p in parents.items(): - p["_use_count"] = p.get("_use_count", 0) + 1 - - if p_name != head and p["_use_count"] == graph.out_degree(p_name): - p.pop("dataset") - - return dataset - - if working_dir_hashes is None: - working_dir_hashes = {} - - def _try_load_from_disk(stage_name: str, stage_config: BuildStage) -> Dataset: - # Check if we can restore this stage from the cache or - # from the working directory. - # - # If we have a hash, we have executed this stage already - # and can have a cache entry or, - # if this is the last stage of a target in the working tree, - # we can use data from the working directory. - stage_hash = stage_config.hash - - data_dir = None - cached = False - - source_name, source_stage_name = ProjectBuildTargets.split_target_name(stage_name) - if self._tree.is_working_tree and source_name in self._tree.sources: - target = self._tree.build_targets[source_name] - data_dir = self._project.source_data_dir(source_name) - wd_hash = working_dir_hashes.get(source_name) - - if not stage_hash: - if source_stage_name == target.head.name and osp.isdir(data_dir): - pass - else: - log.debug( - "Build: skipping loading stage '%s' from " - "working dir '%s', because the stage has no hash " - "and is not the head stage", - stage_name, - data_dir, - ) - data_dir = None - elif not wd_hash: - if osp.isdir(data_dir): - wd_hash = self._project.compute_source_hash(data_dir) - working_dir_hashes[source_name] = wd_hash - else: - log.debug( - "Build: skipping checking working dir '%s', " - "because it does not exist", - data_dir, - ) - data_dir = None - - if stage_hash and stage_hash != wd_hash: - log.debug( - "Build: skipping loading stage '%s' from " - "working dir '%s', because hashes do not match", - stage_name, - data_dir, - ) - data_dir = None - - if not data_dir and stage_hash: - if self._project._is_cached(stage_hash): - data_dir = self._project.cache_path(stage_hash) - cached = True - elif self._project._can_retrieve_from_vcs_cache(stage_hash): - data_dir = self._project._materialize_obj(stage_hash) - cached = True - - if not data_dir or not osp.isdir(data_dir): - log.debug( - "Build: skipping loading stage '%s' from " - "cache obj '%s', because it is not available", - stage_name, - stage_hash, - ) - return None - - if data_dir: - assert osp.isdir(data_dir), data_dir - log.debug("Build: loading stage '%s' from '%s'", stage_name, data_dir) - return ProjectSourceDataset( - data_dir, self._tree, source_name, readonly=cached or self._project.readonly - ) - - return None - - # Pipeline is assumed to be validated already - graph = pipeline._graph - head = pipeline.head - - # traverse the graph and initialize nodes from sources to the head - to_visit = [head] - while to_visit: - stage_name = to_visit.pop() - stage = graph.nodes[stage_name] - stage_config = stage["config"] - stage_type = BuildStageType[stage_config.type] - stage_hash = stage_config.hash - - assert stage.get("dataset") is None - - dataset = _try_load_from_disk(stage_name, stage_config) - if dataset is not None: - stage["dataset"] = dataset - continue - - uninitialized_parents = [] - for p_name in graph.predecessors(stage_name): - parent = graph.nodes[p_name] - if parent.get("dataset") is None: - uninitialized_parents.append(p_name) - - if uninitialized_parents: - to_visit.append(stage_name) - to_visit.extend(uninitialized_parents) - continue - - if stage_type == BuildStageType.transform: - kind = stage_config.kind - try: - transform = self._tree.env.transforms[kind] - except KeyError as e: - raise UnknownStageError("Unknown transform '%s'" % kind) from e - - dataset = _join_parent_datasets() - dataset = dataset.transform(transform, **stage_config.params) - - elif stage_type == BuildStageType.filter: - dataset = _join_parent_datasets() - dataset = dataset.filter(**stage_config.params) - - elif stage_type == BuildStageType.inference: - kind = stage_config.kind - model = self._project.make_model(kind) - - dataset = _join_parent_datasets() - dataset = dataset.run_model(model) - - elif stage_type == BuildStageType.source: - # Stages of type "Source" cannot have inputs, - # they are build tree inputs themselves - assert graph.in_degree(stage_name) == 0, stage_name - - # The only valid situation we get here is that it is a - # generated source: - # - No cache entry - # - No local dir data - source_name = ProjectBuildTargets.strip_target_name(stage_name) - source = self._tree.sources[source_name] - if not source.is_generated: - # Source is missing in the cache and the working tree, - # and cannot be retrieved from the VCS cache. - # It is assumed that all the missing sources were - # downloaded earlier. - raise MissingObjectError( - "Failed to initialize stage '%s': " - "object '%s' was not found in cache" % (stage_name, stage_hash) - ) - - # Generated sources do not require a data directory, - # but they still can be bound to a directory - if self._tree.is_working_tree: - source_dir = self._project.source_data_dir(source_name) - else: - source_dir = None - dataset = ProjectSourceDataset( - source_dir, - self._tree, - source_name, - readonly=not source_dir or self._project.readonly, - ) - - elif stage_type == BuildStageType.project: - dataset = _join_parent_datasets(force=True) - - elif stage_type == BuildStageType.convert: - dataset = _join_parent_datasets() - - else: - raise UnknownStageError("Unexpected stage type '%s'" % stage_type) - - stage["dataset"] = dataset - - return graph, head - - @staticmethod - def _validate_pipeline(pipeline: Pipeline): - graph = pipeline._graph - if ( - len(graph) == 0 - or len(graph) == 1 - and next(iter(graph.nodes)) - == ProjectBuildTargets.make_target_name( - ProjectBuildTargets.MAIN_TARGET, ProjectBuildTargets.BASE_STAGE - ) - ): - raise EmptyPipelineError() - - head = pipeline.head - if not head: - raise MissingPipelineHeadError() - - for stage_name, stage in graph.nodes.items(): - stage_type = BuildStageType[stage["config"].type] - - if graph.in_degree(stage_name) == 0: - if stage_type != BuildStageType.source: - raise InvalidStageError( - "Stage '%s' of type '%s' must have inputs" % (stage_name, stage_type.name) - ) - else: - if stage_type == BuildStageType.source: - raise InvalidStageError( - "Stage '%s' of type '%s' can't have inputs" % (stage_name, stage_type.name) - ) - - if graph.out_degree(stage_name) == 0: - if stage_name != head: - raise InvalidStageError( - "Stage '%s' of type '%s' has no outputs, " - "but is not the head stage" % (stage_name, stage_type.name) - ) - - def _find_missing_sources(self, pipeline: Pipeline): - work_dir_hashes = {} - - def _can_retrieve(stage_name: str, stage_config: BuildStage): - stage_hash = stage_config.hash - - source_name, source_stage_name = ProjectBuildTargets.split_target_name(stage_name) - if self._tree.is_working_tree and source_name in self._tree.sources: - target = self._tree.build_targets[source_name] - data_dir = self._project.source_data_dir(source_name) - - if not stage_hash: - return source_stage_name == target.head.name and osp.isdir(data_dir) - - wd_hash = work_dir_hashes.get(source_name) - if not wd_hash and osp.isdir(data_dir): - wd_hash = self._project.compute_source_hash( - self._project.source_data_dir(source_name) - ) - work_dir_hashes[source_name] = wd_hash - - if stage_hash and stage_hash == wd_hash: - return True - - if stage_hash and self._project.is_obj_cached(stage_hash): - return True - - return False - - missing_sources = set() - checked_deps = set() - unchecked_deps = [pipeline.head] - while unchecked_deps: - stage_name = unchecked_deps.pop() - if stage_name in checked_deps: - continue - - stage_config = pipeline._graph.nodes[stage_name]["config"] - - if not _can_retrieve(stage_name, stage_config): - if pipeline._graph.in_degree(stage_name) == 0: - assert stage_config.type == "source", stage_config.type - source_name = self._tree.build_targets.strip_target_name(stage_name) - source = self._tree.sources[source_name] - if not source.is_generated: - missing_sources.add(source_name) - else: - for p in pipeline._graph.predecessors(stage_name): - if p not in checked_deps: - unchecked_deps.append(p) - continue - - checked_deps.add(stage_name) - return missing_sources, work_dir_hashes - - -@deprecated(deprecated_version="1.11", removed_version="1.12") -class ProjectBuildTargets(CrudProxy[BuildTarget]): - MAIN_TARGET = "project" - BASE_STAGE = "root" - - def __init__(self, tree: Tree): - self._tree = tree - - @property - def _data(self): - data = self._tree.config.build_targets - - if self.MAIN_TARGET not in data: - data[self.MAIN_TARGET] = { - "stages": [ - BuildStage( - { - "name": self.BASE_STAGE, - "type": BuildStageType.project.name, - } - ), - ] - } - - for source in self._tree.sources: - if source not in data: - data[source] = { - "stages": [ - BuildStage( - { - "name": self.BASE_STAGE, - "type": BuildStageType.source.name, - } - ), - ] - } - - return data - - def __contains__(self, key): - if "." in key: - target, stage = self.split_target_name(key) - return target in self._data and self._data[target].find_stage(stage) is not None - return key in self._data - - def add_target(self, name) -> BuildTarget: - return self._data.set( - name, - { - "stages": [ - BuildStage( - { - "name": self.BASE_STAGE, - "type": BuildStageType.source.name, - } - ), - ] - }, - ) - - def add_stage(self, target, value, prev=None, name=None) -> str: - target_name = target - target_stage_name = None - if "." in target: - target_name, target_stage_name = self.split_target_name(target) - - if prev is None: - prev = target_stage_name - - target = self._data[target_name] - - if prev: - prev_stage = find(enumerate(target.stages), lambda e: e[1].name == prev) - if prev_stage is None: - raise KeyError("Can't find stage '%s'" % prev) - prev_stage = prev_stage[0] - else: - prev_stage = len(target.stages) - 1 - - name = value.get("name") or name - if not name: - name = generate_next_name( - (s.name for s in target.stages), "stage", sep="-", default="1" - ) - else: - if target.find_stage(name): - raise VcsError("Stage '%s' already exists" % name) - value["name"] = name - - value = BuildStage(value) - assert value.type in BuildStageType.__members__ - target.stages.insert(prev_stage + 1, value) - - return self.make_target_name(target_name, name) - - def remove_target(self, name: str): - assert name != self.MAIN_TARGET, "Can't remove the main target" - self._data.remove(name) - - def remove_stage(self, target: str, name: str): - assert name not in {self.BASE_STAGE}, "Can't remove a default stage" - - target = self._data[target] - idx = find(enumerate(target.stages), lambda e: e[1].name == name) - if idx is None: - raise KeyError("Can't find stage '%s'" % name) - target.stages.remove(idx) - - def add_transform_stage( - self, target: str, transform: str, params: Optional[Dict] = None, name: Optional[str] = None - ): - if transform not in self._tree.env.transforms: - raise KeyError("Unknown transform '%s'" % transform) - - return self.add_stage( - target, - { - "type": BuildStageType.transform.name, - "kind": transform, - "params": params or {}, - }, - name=name, - ) - - def add_inference_stage( - self, target: str, model: str, params: Optional[Dict] = None, name: Optional[str] = None - ): - if model not in self._tree._project.models: - raise KeyError("Unknown model '%s'" % model) - - return self.add_stage( - target, - { - "type": BuildStageType.inference.name, - "kind": model, - "params": params or {}, - }, - name=name, - ) - - def add_filter_stage( - self, target: str, expr: str, params: Optional[Dict] = None, name: Optional[str] = None - ): - params = params or {} - params["expr_or_filter_func"] = expr - return self.add_stage( - target, - { - "type": BuildStageType.filter.name, - "params": params, - }, - name=name, - ) - - def add_convert_stage( - self, target: str, format: str, params: Optional[Dict] = None, name: Optional[str] = None - ): - if not self._tree.env.is_format_known(format): - raise KeyError("Unknown format '%s'" % format) - - return self.add_stage( - target, - { - "type": BuildStageType.convert.name, - "kind": format, - "params": params or {}, - }, - name=name, - ) - - @staticmethod - def make_target_name(target: str, stage: Optional[str] = None) -> str: - if stage: - return "%s.%s" % (target, stage) - return target - - @classmethod - def split_target_name(cls, name: str) -> Tuple[str, str]: - if "." in name: - target, stage = name.split(".", maxsplit=1) - if not target: - raise ValueError("Wrong build target name '%s': " "a name can't be empty" % name) - if not stage: - raise ValueError( - "Wrong build target name '%s': " - "expected stage name after the separator" % name - ) - else: - target = name - stage = cls.BASE_STAGE - return target, stage - - @classmethod - def strip_target_name(cls, name: str) -> str: - return cls.split_target_name(name)[0] - - def _make_full_pipeline(self) -> Pipeline: - pipeline = Pipeline() - graph = pipeline._graph - - for target_name, target in self.items(): - if target_name == self.MAIN_TARGET: - # main target combines all the others - prev_stages = [ - self.make_target_name(n, t.head.name) - for n, t in self.items() - if n != self.MAIN_TARGET - ] - else: - prev_stages = [self.make_target_name(t, self[t].head.name) for t in target.parents] - - for stage in target.stages: - stage_name = self.make_target_name(target_name, stage["name"]) - - graph.add_node(stage_name, config=stage) - - for prev_stage in prev_stages: - graph.add_edge(prev_stage, stage_name) - prev_stages = [stage_name] - - return pipeline - - def make_pipeline(self, target: str) -> Pipeline: - if target not in self: - raise UnknownTargetError(target) - - # a subgraph with all the target dependencies - if "." not in target: - target = self.make_target_name(target, self[target].head.name) - - return self._make_full_pipeline().get_slice(target) - - -class GitWrapper: - @staticmethod - def module(): - try: - import git - - return git - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Can't import the 'git' package. " - "Make sure GitPython is installed, or install it with " - "'pip install datumaro[default]'." - ) from e - - def _git_dir(self): - return osp.join(self._project_dir, ".git") - - def __init__(self, project_dir, repo=None): - self._project_dir = project_dir - self.repo = repo - - if repo is None and osp.isdir(project_dir) and osp.isdir(self._git_dir()): - self.repo = self.module().Repo(project_dir) - - @property - def initialized(self): - return self.repo is not None - - def init(self): - if self.initialized: - return - - repo = self.module().Repo.init(path=self._project_dir) - repo.config_writer().set_value("user", "name", "User").set_value( - "user", "email", "<>" - ).release() - - # GitPython's init produces an incomplete repo, which becomes normal - # only after a first commit. Unless the commit is done, some - # GitPython's functions will throw useless errors. - # Call "git init" directly to have the desired behaviour. - repo.git.init() - - self.repo = repo - - def close(self): - if self.repo: - self.repo.close() - self.repo = None - - def __del__(self): - with suppress(Exception): - self.close() - - def checkout(self, ref: str, dst_dir=None, clean=False, force=False): - # If user wants to navigate to a head, we need to supply its object - # insted of just a string. Otherwise, we'll get a detached head. - try: - ref_obj = self.repo.heads[ref] - except IndexError: - ref_obj = ref - - commit = self.repo.commit(ref) - tree = commit.tree - - if not dst_dir: - dst_dir = self._project_dir - - repo_dir = osp.abspath(self._project_dir) - dst_dir = osp.abspath(dst_dir) - assert is_subpath(dst_dir, base=repo_dir) - - if not force: - statuses = self.status(tree, base_dir=dst_dir) - - # Only modified files produce conflicts in checkout - dst_rpath = osp.relpath(dst_dir, repo_dir) - conflicts = [osp.join(dst_rpath, p) for p, s in statuses.items() if s == "M"] - if conflicts: - raise UnsavedChangesError(conflicts) - - self.repo.head.ref = ref_obj - self.repo.head.reset(working_tree=False) - - if clean: - rmtree(dst_dir) - - self.write_tree(tree, dst_dir) - - def add(self, paths, base=None): - """ - Adds paths to index. - Paths can be truncated relatively to base. - """ - - path_rewriter = None - if base: - base = osp.abspath(base) - repo_root = osp.abspath(self._project_dir) - assert is_subpath(base, base=repo_root), "Base path should be inside of the repo" - base = osp.relpath(base, repo_root) - path_rewriter = lambda entry: osp.relpath(entry.path, base).replace("\\", "/") - - if isinstance(paths, str): - paths = [paths] - - # A workaround for path_rewriter incompatibility - # with directory paths expansion - paths_to_add = [] - for path in paths: - if not osp.isdir(path): - paths_to_add.append(path) - continue - - for d, _, filenames in os.walk(path): - for fn in filenames: - paths_to_add.append(osp.join(d, fn)) - - self.repo.index.add(paths_to_add, path_rewriter=path_rewriter) - - def commit(self, message) -> str: - """ - Creates a new revision from index. - Returns: new revision hash. - """ - return self.repo.index.commit(message).hexsha - - GitTree = NewType("GitTree", object) - GitStatus = NewType("GitStatus", str) - - def status( - self, paths: Union[str, GitTree, Iterable[str]] = None, base_dir: str = None - ) -> Dict[str, GitStatus]: - """ - Compares working directory and index. - - Parameters: - paths: an iterable of paths to compare, a git.Tree, or None. - When None, uses all the paths from HEAD. - base_dir: a base path for paths. Paths will be prepended by this. - When None or '', uses repo root. Can be useful, if index contains - displaced paths, which needs to be mapped on real paths. - - The statuses are: - - "A" for added paths - - "D" for deleted paths - - "R" for renamed paths - - "M" for paths with modified data - - "T" for changed in the type paths - - Returns: { abspath(base_dir + path): status } - """ - - if paths is None or isinstance(paths, self.module().objects.tree.Tree): - if paths is None: - tree = self.repo.head.commit.tree - else: - tree = paths - paths = (obj.path for obj in tree.traverse() if obj.type == "blob") - elif isinstance(paths, str): - paths = [paths] - - if not base_dir: - base_dir = self._project_dir - - repo_dir = osp.abspath(self._project_dir) - base_dir = osp.abspath(base_dir) - assert is_subpath(base_dir, base=repo_dir) - - statuses = {} - for obj_path in paths: - file_path = osp.join(base_dir, obj_path) - - index_entry = self.repo.index.entries.get((obj_path, 0), None) - file_exists = osp.isfile(file_path) - if not file_exists and index_entry: - status = "D" - elif file_exists and not index_entry: - status = "A" - elif file_exists and index_entry: - # '--ignore-cr-at-eol' doesn't affect '--name-status' - # so we can't really obtain 'T' - status = self.repo.git.diff("--ignore-cr-at-eol", index_entry.hexsha, file_path) - if status: - status = "M" - assert status in {"", "M", "T"}, status - else: - status = "" # ignore missing paths - - if status: - statuses[obj_path] = status - - return statuses - - def is_ref(self, rev): - try: - self.repo.commit(rev) - return True - except (ValueError, self.module().exc.BadName): - return False - - def has_commits(self): - return self.is_ref("HEAD") - - def get_tree(self, ref): - return self.repo.tree(ref) - - def write_tree(self, tree, base_path: str, include_files: Optional[List[str]] = None): - os.makedirs(base_path, exist_ok=True) - - for obj in tree.traverse(visit_once=True): - if include_files and obj.path not in include_files: - continue - - path = osp.join(base_path, obj.path) - os.makedirs(osp.dirname(path), exist_ok=True) - if obj.type == "blob": - with open(path, "wb") as f: - obj.stream_data(f) - elif obj.type == "tree": - pass - else: - raise ValueError( - "Unexpected object type in a " "git tree: %s (%s)" % (obj.type, obj.hexsha) - ) - - @property - def head(self) -> str: - return self.repo.head.commit.hexsha - - @property - def branch(self) -> str: - if self.repo.head.is_detached: - return None - return self.repo.active_branch - - def rev_parse(self, ref: str) -> Tuple[str, str]: - """ - Expands named refs and tags. - - Returns: object type, object hash - """ - obj = self.repo.rev_parse(ref) - return obj.type, obj.hexsha - - def ignore( - self, - paths: Union[str, List[str]], - mode: Union[None, str, IgnoreMode] = None, - gitignore: Optional[str] = None, - ): - if not gitignore: - gitignore = ".gitignore" - repo_root = self._project_dir - gitignore = osp.abspath(osp.join(repo_root, gitignore)) - assert is_subpath(gitignore, base=repo_root), gitignore - - _update_ignore_file(paths, repo_root=repo_root, mode=mode, filepath=gitignore) - - HASH_LEN = 40 - - @classmethod - def is_hash(cls, s: str) -> bool: - return len(s) == cls.HASH_LEN - - def log(self, depth=10) -> List[Tuple[Any, int]]: - """ - Returns: a list of (commit, index) pairs - """ - - commits = [] - - if not self.has_commits(): - return commits - - for commit in zip(self.repo.iter_commits(rev="HEAD"), range(depth)): - commits.append(commit) - return commits - - -class DvcWrapper: - @staticmethod - def module(): - try: - import dvc - import dvc.cli - import dvc.env - import dvc.repo - - return dvc - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Can't import the 'dvc' package. " - "Make sure DVC is installed, or install it with " - "'pip install datumaro[default]'." - ) from e - - def _dvc_dir(self): - return osp.join(self._project_dir, ".dvc") - - class DvcError(Exception): - pass - - def __init__(self, project_dir): - self._project_dir = project_dir - self.repo = None - - if osp.isdir(project_dir) and osp.isdir(self._dvc_dir()): - with logging_disabled(): - self.repo = self.module().repo.Repo(project_dir) - - @property - def initialized(self): - return self.repo is not None - - def init(self): - if self.initialized: - return - - with logging_disabled(): - self.repo = self.module().repo.Repo.init(self._project_dir) - - repo_dir = osp.join(self._project_dir, ".dvc") - _update_ignore_file( - [osp.join(repo_dir, "plots")], - filepath=osp.join(repo_dir, ".gitignore"), - repo_root=repo_dir, - ) - - def close(self): - if self.repo: - self.repo.close() - self.repo = None - - def __del__(self): - with suppress(Exception): - self.close() - - def checkout(self, targets=None): - args = ["checkout"] - if targets: - if isinstance(targets, str): - args.append(targets) - else: - args.extend(targets) - self._exec(args) - - def add(self, paths, no_commit=False): - args = ["add"] - if no_commit: - args.append("--no-commit") - if paths: - if isinstance(paths, str): - args.append(paths) - else: - args.extend(paths) - self._exec(args) - - def _exec(self, args, hide_output=True, answer_on_input="y"): - args = ["--cd", self._project_dir] + args - - # Avoid calling an extra process. Improves call performance and - # removes an extra console window on Windows. - os.environ[self.module().env.DVC_NO_ANALYTICS] = "1" - - with ExitStack() as es: - es.callback(os.chdir, os.getcwd()) # restore cd after DVC - - if answer_on_input is not None: - - def _input(*args): - return answer_on_input - - es.enter_context(unittest.mock.patch("dvc.prompt.input", new=_input)) - - log.debug("Calling DVC main with args: %s", args) - - logs = es.enter_context(catch_logs("dvc")) - retcode = self.module().cli.main(args) - - logs = logs.getvalue() - if retcode != 0: - raise self.DvcError(logs) - if not hide_output: - print(logs) - return logs - - def is_cached(self, obj_hash): - path = self.obj_path(obj_hash) - if not osp.isfile(path): - return False - - if obj_hash.endswith(self.DIR_HASH_SUFFIX): - objects = parse_json_file(path) - for entry in objects: - if not osp.isfile(self.obj_path(entry["md5"])): - return False - - return True - - def obj_path(self, obj_hash, root=None): - assert self.is_hash(obj_hash), obj_hash - if not root: - root = osp.join(self._project_dir, ".dvc", "cache", "files", "md5") - return osp.join(root, obj_hash[:2], obj_hash[2:]) - - def ignore( - self, - paths: Union[str, List[str]], - mode: Union[None, str, IgnoreMode] = None, - dvcignore: Optional[str] = None, - ): - if not dvcignore: - dvcignore = ".dvcignore" - repo_root = self._project_dir - dvcignore = osp.abspath(osp.join(repo_root, dvcignore)) - assert is_subpath(dvcignore, base=repo_root), dvcignore - - _update_ignore_file(paths, repo_root=repo_root, mode=mode, filepath=dvcignore) - - # This ruamel parser is needed to preserve comments, - # order and form (if multiple forms allowed by the standard) - # of the entries in the file. It can be reused. - import ruamel.yaml as yaml - - yaml_parser = yaml.YAML(typ="rt") - - @classmethod - def get_hash_from_dvcfile(cls, path) -> str: - with open(path) as f: - contents = cls.yaml_parser.load(f) - return contents["outs"][0]["md5"] - - FILE_HASH_LEN = 32 - DIR_HASH_SUFFIX = ".dir" - DIR_HASH_LEN = FILE_HASH_LEN + len(DIR_HASH_SUFFIX) - - @classmethod - def is_file_hash(cls, s: str) -> bool: - return len(s) == cls.FILE_HASH_LEN - - @classmethod - def is_dir_hash(cls, s: str) -> bool: - return len(s) == cls.DIR_HASH_LEN and s.endswith(cls.DIR_HASH_SUFFIX) - - @classmethod - def is_hash(cls, s: str) -> bool: - return cls.is_file_hash(s) or cls.is_dir_hash(s) - - def write_obj(self, obj_hash, dst_dir, allow_links=True): - def _copy_obj(src, dst, link=False): - os.makedirs(osp.dirname(dst), exist_ok=True) - if link: - os.link(src, dst) - else: - shutil.copy(src, dst, follow_symlinks=True) - - src = self.obj_path(obj_hash) - if osp.isfile(src): - _copy_obj(src, dst_dir, link=allow_links) - return - - src += self.DIR_HASH_SUFFIX - if not osp.isfile(src): - raise UnknownRefError(obj_hash) - - src_meta = parse_json_file(src) - for entry in src_meta: - _copy_obj( - self.obj_path(entry["md5"]), osp.join(dst_dir, entry["relpath"]), link=allow_links - ) - - def remove_cache_obj(self, obj_hash: str): - src = self.obj_path(obj_hash) - if osp.isfile(src): - rmfile(src) - return - - src += self.DIR_HASH_SUFFIX - if not osp.isfile(src): - raise UnknownRefError(obj_hash) - - src_meta = parse_json_file(src) - for entry in src_meta: - entry_path = self.obj_path(entry["md5"]) - if osp.isfile(entry_path): - rmfile(entry_path) - - rmfile(src) - - -class Tree: - # can be: - # - attached to the work dir - # - attached to a revision - - def __init__( - self, - project: Project, - config: Union[None, Dict, Config, TreeConfig] = None, - rev: Union[None, Revision] = None, - ): - assert isinstance(project, Project) - assert not rev or project.is_ref(rev), rev - - if not isinstance(config, TreeConfig): - config = TreeConfig(config) - if config.format_version != 2: - raise ValueError( - "Unexpected tree config version '%s', expected 2" % config.format_version - ) - self._config = config - - self._project = project - self._rev = rev - - self._sources = ProjectSources(self) - self._targets = ProjectBuildTargets(self) - - def save(self): - self.dump(self._config.config_path) - - def dump(self, path): - os.makedirs(osp.dirname(path), exist_ok=True) - self._config.dump(path) - - def clone(self) -> Tree: - return Tree(self._project, TreeConfig(self.config), self._rev) - - @property - def sources(self) -> ProjectSources: - return self._sources - - @property - def build_targets(self) -> ProjectBuildTargets: - return self._targets - - @property - def config(self) -> Config: - return self._config - - @property - def env(self) -> Environment: - return self._project.env - - @property - def rev(self) -> Union[None, Revision]: - return self._rev - - def make_pipeline(self, target: Optional[str] = None) -> Pipeline: - if not target: - target = "project" - - return self.build_targets.make_pipeline(target) - - def make_dataset(self, target: Union[None, str, Pipeline] = None) -> Dataset: - if not target or isinstance(target, str): - pipeline = self.make_pipeline(target) - elif isinstance(target, Pipeline): - pipeline = target - else: - raise TypeError(f"Unexpected target type {type(target)}") - - return ProjectBuilder(self._project, self).make_dataset(pipeline) - - @property - def is_working_tree(self) -> bool: - return not self._rev - - def source_data_dir(self, source) -> str: - if self.is_working_tree: - return self._project.source_data_dir(source) - - obj_hash = self.build_targets[source].head.hash - return self._project.cache_path(obj_hash) - - -class DiffStatus(Enum): - added = auto() - modified = auto() - removed = auto() - missing = auto() - foreign_modified = auto() - - -Revision = NewType("Revision", str) # a commit hash or a named reference -ObjectId = NewType("ObjectId", str) # a commit or an object hash - - -@deprecated(deprecated_version="1.11", removed_version="1.12") -class Project: - @staticmethod - def find_project_dir(path: str) -> Optional[str]: - path = osp.abspath(path) - - if osp.basename(path) != ProjectLayout.aux_dir: - path = osp.join(path, ProjectLayout.aux_dir) - - if osp.isdir(path): - return path - - return None - - @staticmethod - @scoped - def migrate_from_v1_to_v2(src_dir: str, dst_dir: str, skip_import_errors=False): - if not osp.isdir(src_dir): - raise FileNotFoundError("Source project is not found") - - if osp.exists(dst_dir): - raise FileExistsError("Output path already exists") - - src_dir = osp.abspath(src_dir) - dst_dir = osp.abspath(dst_dir) - if src_dir == dst_dir: - raise MigrationError( - "Source and destination paths are the same. " - "Project migration cannot be done inplace." - ) - - old_aux_dir = osp.join(src_dir, ".datumaro") - old_config = Config.parse(osp.join(old_aux_dir, "config.yaml")) - if old_config.format_version != 1: - raise MigrationError( - "Failed to migrate project: " - "unexpected old version '%s'" % old_config.format_version - ) - - on_error_do(rmtree, dst_dir, ignore_errors=True) - new_project = scope_add(Project.init(dst_dir)) - - new_wtree_dir = osp.join(new_project._aux_dir, ProjectLayout.working_tree_dir) - os.makedirs(new_wtree_dir, exist_ok=True) - - old_plugins_dir = osp.join(old_aux_dir, "plugins") - if osp.isdir(old_plugins_dir): - copytree(old_plugins_dir, osp.join(new_project._aux_dir, ProjectLayout.plugins_dir)) - - old_models_dir = osp.join(old_aux_dir, "models") - if osp.isdir(old_models_dir): - copytree(old_models_dir, osp.join(new_project._aux_dir, ProjectLayout.models_dir)) - - new_project.env.load_plugins(osp.join(new_project._aux_dir, ProjectLayout.plugins_dir)) - - new_tree_config = new_project.working_tree.config - new_local_config = new_project.config - - if "models" in old_config: - for name, old_model in old_config.models.items(): - new_local_config.models[name] = Model( - {"launcher": old_model["launcher"], "options": old_model["options"]} - ) - - if "sources" in old_config: - for name, old_source in old_config.sources.items(): - is_local = False - source_dir = osp.join(src_dir, "sources", name) - url = osp.abspath(osp.join(source_dir, old_source["url"])) - rpath = None - if osp.exists(url): - if is_subpath(url, source_dir): - if url != source_dir: - rpath = osp.relpath(url, source_dir) - url = source_dir - is_local = True - elif osp.isfile(url): - url, rpath = osp.split(url) - elif not old_source["url"]: - url = "" - - try: - source = new_project.import_source( - name, - url=url, - rpath=rpath, - format=old_source["format"], - options=old_source["options"], - ) - if is_local: - source.url = "" - - new_project.working_tree.make_dataset(name) - except Exception as e: - if not skip_import_errors: - raise MigrationError(f"Failed to migrate the source '{name}'") from e - else: - log.warning( - f"Failed to migrate the source '{name}'. " - "Try to add this source manually with " - "'datum project import', once migration is finished. The " - "reason is: %s", - e, - ) - new_project.remove_source(name, force=True, keep_data=False) - - old_dataset_dir = osp.join(src_dir, "dataset") - if osp.isdir(old_dataset_dir): - # Such source cannot be represented in v2 directly. - # However, it can be considered a generated source with - # working tree data. - name = generate_next_name( - list(new_tree_config.sources), "local_dataset", sep="-", default="1" - ) - source = new_project.import_source(name, url=old_dataset_dir, format=DEFAULT_FORMAT) - - # Make the source generated. It can only have local data. - source.url = "" - - new_project.save() - new_project.close() - - def __init__(self, path: Optional[str] = None, readonly=False): - if not path: - path = osp.curdir - found_path = self.find_project_dir(path) - if not found_path: - raise ProjectNotFoundError(path) - - old_config_path = osp.join(found_path, "config.yaml") - if osp.isfile(old_config_path): - if Config.parse(old_config_path).format_version != 2: - raise OldProjectError() - - self._aux_dir = found_path - self._root_dir = osp.dirname(found_path) - - self._readonly = readonly - - # Force import errors on missing dependencies. - # - # TODO: maybe allow class use in some cases, which not require - # Git or DVC - GitWrapper.module() - DvcWrapper.module() - - self._git = GitWrapper(self._root_dir) - self._dvc = DvcWrapper(self._root_dir) - - self._working_tree = None - self._head_tree = None - - local_config = osp.join(self._aux_dir, ProjectLayout.conf_file) - if osp.isfile(local_config): - self._config = ProjectConfig.parse(local_config) - else: - self._config = ProjectConfig() - - self._env = Environment() - - plugins_dir = osp.join(self._aux_dir, ProjectLayout.plugins_dir) - if osp.isdir(plugins_dir): - self._env.load_plugins(plugins_dir) - - def _init_vcs(self): - # DVC requires Git to be initialized - if not self._git.initialized: - self._git.init() - self._git.ignore( - [ - ProjectLayout.cache_dir, - ], - gitignore=osp.join(self._aux_dir, ".gitignore"), - ) - self._git.ignore([]) # create the file - if not self._dvc.initialized: - self._dvc.init() - self._dvc.ignore( - [ - osp.join(self._aux_dir, ProjectLayout.cache_dir), - osp.join(self._aux_dir, ProjectLayout.working_tree_dir), - ] - ) - self._git.repo.index.remove( - osp.join(self._root_dir, ".dvc", "plots"), r=True, ignore_unmatch=True - ) - self.commit("Initial commit", allow_empty=True) - - @classmethod - @scoped - def init(cls, path) -> Project: - existing_project = cls.find_project_dir(path) - if existing_project: - raise ProjectAlreadyExists(path) - - path = osp.abspath(path) - if osp.basename(path) != ProjectLayout.aux_dir: - path = osp.join(path, ProjectLayout.aux_dir) - - project_dir = osp.dirname(path) - if not osp.isdir(project_dir): - on_error_do(rmtree, project_dir, ignore_errors=True) - - os.makedirs(path, exist_ok=True) - - on_error_do(rmtree, osp.join(project_dir, ProjectLayout.cache_dir), ignore_errors=True) - on_error_do(rmtree, osp.join(project_dir, ProjectLayout.tmp_dir), ignore_errors=True) - os.makedirs(osp.join(path, ProjectLayout.cache_dir)) - os.makedirs(osp.join(path, ProjectLayout.tmp_dir)) - - git_dir, dvc_dir = osp.join(project_dir, ".git"), osp.join(project_dir, ".dvc") - - if osp.exists(git_dir): - raise VcsAlreadyExists(git_dir) - if osp.exists(dvc_dir): - raise VcsAlreadyExists(dvc_dir) - - on_error_do(rmtree, git_dir, ignore_errors=True) - on_error_do(rmtree, dvc_dir, ignore_errors=True) - - project = Project(path) - project._init_vcs() - - return project - - def close(self): - if self._dvc: - self._dvc.close() - self._dvc = None - - if self._git: - self._git.close() - self._git = None - - def __del__(self): - with suppress(Exception): - self.close() - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - self.close() - - def save(self): - self._config.dump(osp.join(self._aux_dir, ProjectLayout.conf_file)) - - if self._working_tree: - self._working_tree.save() - - @property - def readonly(self) -> bool: - return self._readonly - - @property - def working_tree(self) -> Tree: - if self._working_tree is None: - self._working_tree = self.get_rev(None) - return self._working_tree - - @property - def head(self) -> Tree: - if self._head_tree is None: - self._head_tree = self.get_rev("HEAD") - return self._head_tree - - @property - def head_rev(self) -> Revision: - return self._git.head - - @property - def branch(self) -> str: - return self._git.branch - - @property - def config(self) -> Config: - return self._config - - @property - def env(self) -> Environment: - return self._env - - @property - def models(self) -> Dict[str, Model]: - return dict(self._config.models) - - def get_rev(self, rev: Union[None, Revision]) -> Tree: - """ - Reference conventions: - - None or "" - working dir - - "<40 symbols>" - revision hash - """ - - obj_type, obj_hash = self._parse_ref(rev) - assert obj_type == self._ObjectIdKind.tree, obj_type - - if self._is_working_tree_ref(obj_hash): - config_path = osp.join( - self._aux_dir, ProjectLayout.working_tree_dir, TreeLayout.conf_file - ) - if osp.isfile(config_path): - tree_config = TreeConfig.parse(config_path) - else: - tree_config = TreeConfig() - os.makedirs(osp.dirname(config_path), exist_ok=True) - tree_config.dump(config_path) - tree_config.config_path = config_path - tree_config.base_dir = osp.dirname(config_path) - tree = Tree(config=tree_config, project=self, rev=obj_hash) - else: - if not self.is_rev_cached(obj_hash): - self._materialize_rev(obj_hash) - - rev_dir = self.cache_path(obj_hash) - tree_config = TreeConfig.parse(osp.join(rev_dir, TreeLayout.conf_file)) - tree_config.base_dir = rev_dir - tree = Tree(config=tree_config, project=self, rev=obj_hash) - return tree - - def is_rev_cached(self, rev: Revision) -> bool: - obj_type, obj_hash = self._parse_ref(rev) - assert obj_type == self._ObjectIdKind.tree, obj_type - return self._is_cached(obj_hash) - - def is_obj_cached(self, obj_hash: ObjectId) -> bool: - return self._is_cached(obj_hash) or self._can_retrieve_from_vcs_cache(obj_hash) - - @staticmethod - def _is_working_tree_ref(ref: Union[None, Revision, ObjectId]) -> bool: - return not ref - - class _ObjectIdKind(Enum): - # Project revision data. Currently, a Git commit hash. - tree = auto() - - # Source revision data. DVC directories and files. - blob = auto() - - def _parse_ref(self, ref: Union[None, Revision, ObjectId]) -> Tuple[_ObjectIdKind, ObjectId]: - """ - Resolves the reference to an object hash. - """ - - if self._is_working_tree_ref(ref): - return self._ObjectIdKind.tree, ref - - try: - obj_type, obj_hash = self._git.rev_parse(ref) - except Exception: # nosec try_except_pass - pass # Ignore git errors - else: - if obj_type != "commit": - raise UnknownRefError(obj_hash) - - return self._ObjectIdKind.tree, obj_hash - - try: - assert self._dvc.is_hash(ref), ref - return self._ObjectIdKind.blob, ref - except Exception as e: - raise UnknownRefError(ref) from e - - def _materialize_rev(self, rev: Revision) -> str: - """ - Restores the revision tree data in the project cache from Git. - - Returns: cache object path - """ - # TODO: maybe avoid this operation by providing a virtual filesystem - # object - - # Allowed to be run when readonly, because it doesn't modify project - # data and doesn't hurt disk space. - - obj_dir = self.cache_path(rev) - if osp.isdir(obj_dir): - return obj_dir - - tree = self._git.get_tree(rev) - self._git.write_tree(tree, obj_dir) - return obj_dir - - def _is_cached(self, obj_hash: ObjectId): - return osp.isdir(self.cache_path(obj_hash)) - - def cache_path(self, obj_hash: ObjectId) -> str: - assert self._git.is_hash(obj_hash) or self._dvc.is_hash(obj_hash), obj_hash - if self._dvc.is_dir_hash(obj_hash): - obj_hash = obj_hash[: self._dvc.FILE_HASH_LEN] - - return osp.join(self._aux_dir, ProjectLayout.cache_dir, obj_hash[:2], obj_hash[2:]) - - def _can_retrieve_from_vcs_cache(self, obj_hash: ObjectId): - if not self._dvc.is_dir_hash(obj_hash): - dir_check = self._dvc.is_cached(obj_hash + self._dvc.DIR_HASH_SUFFIX) - else: - dir_check = False - return dir_check or self._dvc.is_cached(obj_hash) - - def source_data_dir(self, name: str) -> str: - return osp.join(self._root_dir, name) - - def _source_dvcfile_path(self, name: str, root: Optional[str] = None) -> str: - """ - root - Path to the tree root directory. If not set, - the working tree is used. - """ - - if not root: - root = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) - return osp.join(root, TreeLayout.sources_dir, name, "source.dvc") - - def _make_tmp_dir(self, suffix: Optional[str] = None): - project_tmp_dir = osp.join(self._aux_dir, ProjectLayout.tmp_dir) - os.makedirs(project_tmp_dir, exist_ok=True) - if suffix: - suffix = "_" + suffix - - return tempfile.TemporaryDirectory(suffix=suffix, dir=project_tmp_dir) - - def remove_cache_obj(self, ref: Union[Revision, ObjectId]): - if self.readonly: - raise ReadonlyProjectError() - - obj_type, obj_hash = self._parse_ref(ref) - - if self._is_cached(obj_hash): - rmtree(self.cache_path(obj_hash)) - - if obj_type == self._ObjectIdKind.tree: - # Revision metadata is cheap enough and needed to materialize - # the revision, so we keep it in the Git cache. - pass - elif obj_type == self._ObjectIdKind.blob: - self._dvc.remove_cache_obj(obj_hash) - else: - raise ValueError("Unexpected object type '%s'" % obj_type) - - def validate_source_name(self, name: str): - if not name: - raise ValueError("Source name cannot be empty") - - disallowed_symbols = r"[^\\ \.\~\-\w]" - found_wrong_symbols = re.findall(disallowed_symbols, name) - if found_wrong_symbols: - raise ValueError("Source name contains invalid symbols: %s" % found_wrong_symbols) - - valid_filename = make_file_name(name) - if valid_filename != name: - raise ValueError( - "Source name contains " "invalid symbols: %s" % (set(name) - set(valid_filename)) - ) - - if name.startswith("."): - raise ValueError("Source name can't start with '.'") - - reserved_names = {"dataset", "build", "project"} - if name.lower() in reserved_names: - raise ValueError("Source name is reserved for internal use") - - @scoped - def _download_source( - self, url: str, dst_dir: str, *, no_cache: bool = False, no_hash: bool = False - ) -> Tuple[str, str, str]: - assert url - assert dst_dir - - dvcfile = osp.join(dst_dir, "source.dvc") - data_dir = osp.join(dst_dir, "data") - - log.debug(f"Copying from '{url}' to '{data_dir}'") - - if osp.isdir(url): - copytree(url, data_dir) - elif osp.isfile(url): - os.makedirs(data_dir, exist_ok=True) - shutil.copy(url, data_dir) - else: - raise UnexpectedUrlError(url) - on_error_do(rmtree, data_dir, ignore_errors=True) - - log.debug("Done") - - if not no_hash: - obj_hash = self.compute_source_hash(data_dir, dvcfile=dvcfile, no_cache=no_cache) - if not no_cache: - log.debug("Data is added to DVC cache") - log.debug("Data hash: '%s'", obj_hash) - else: - obj_hash = "" - - return obj_hash, dvcfile, data_dir - - @staticmethod - def _get_source_hash(dvcfile): - obj_hash = DvcWrapper.get_hash_from_dvcfile(dvcfile) - if obj_hash.endswith(DvcWrapper.DIR_HASH_SUFFIX): - obj_hash = obj_hash[: -len(DvcWrapper.DIR_HASH_SUFFIX)] - return obj_hash - - @scoped - def compute_source_hash( - self, - data_dir: str, - dvcfile: Optional[str] = None, - no_cache: bool = True, - ) -> ObjectId: - if not dvcfile: - tmp_dir = scope_add(self._make_tmp_dir()) - dvcfile = osp.join(tmp_dir, "source.dvc") - - self._dvc.add(data_dir, no_commit=no_cache) - - gen_dvcfile = osp.join(self._root_dir, data_dir + ".dvc") - if os.path.isfile(gen_dvcfile): - shutil.move(gen_dvcfile, dvcfile) - - obj_hash = self._get_source_hash(dvcfile) - return obj_hash - - def refresh_source_hash(self, source: str, no_cache: bool = True) -> ObjectId: - """ - Computes and updates the source hash in the working directory. - - Returns: hash - """ - - if self.readonly: - raise ReadonlyProjectError() - - build_target = self.working_tree.build_targets[source] - source_dir = self.source_data_dir(source) - - if not osp.isdir(source_dir): - return None - - dvcfile = self._source_dvcfile_path(source) - os.makedirs(osp.dirname(dvcfile), exist_ok=True) - obj_hash = self.compute_source_hash(source_dir, dvcfile=dvcfile, no_cache=no_cache) - - build_target.head.hash = obj_hash - if not build_target.has_stages: - self.working_tree.sources[source].hash = obj_hash - - return obj_hash - - def _materialize_obj(self, obj_hash: ObjectId) -> str: - """ - Restores the object data in the project cache from DVC. - - Returns: cache object path - """ - # TODO: maybe avoid this operation by providing a virtual filesystem - # object - - # Allowed to be run when readonly, because it shouldn't hurt disk - # space, if object is materialized with symlinks. - - if not self._can_retrieve_from_vcs_cache(obj_hash): - raise MissingObjectError(obj_hash) - - dst_dir = self.cache_path(obj_hash) - if osp.isdir(dst_dir): - return dst_dir - - self._dvc.write_obj(obj_hash, dst_dir, allow_links=True) - return dst_dir - - @scoped - def import_source( - self, - name: str, - url: Optional[str], - format: str, - options: Optional[Dict] = None, - *, - no_cache: bool = True, - no_hash: bool = True, - rpath: Optional[str] = None, - ) -> Source: - """ - Adds a new source (dataset) to the working directory of the project. - - When 'rpath' is specified, will copy all the data from URL, but read - only the specified file. Required to support subtasks and subsets - in datasets. - - Parameters: - name (str): Name of the new source - url (str): URL of the new source. A path to a file or directory - format (str): Dataset format - options (dict): Options for the format Extractor - no_cache (bool): Don't put a copy of files into the project cache. - Can be used to reduce project cache size. - no_hash (bool): Don't compute source data hash. Implies "no_cache". - Useful to reduce import time at the cost of disabled data - integrity checks. - rpath (str): Used to specify a relative path to the dataset - inside of the directory pointed by URL. - - Returns: the new source config - """ - - if self.readonly: - raise ReadonlyProjectError() - - self.validate_source_name(name) - - if name in self.working_tree.sources: - raise SourceExistsError(name) - - data_dir = self.source_data_dir(name) - if osp.exists(data_dir): - if os.listdir(data_dir): - raise FileExistsError("Source directory '%s' already " "exists" % data_dir) - os.rmdir(data_dir) - - if url: - url = osp.abspath(url) - if not osp.exists(url): - raise FileNotFoundError(url) - - if is_subpath(url, base=self._root_dir): - raise SourceUrlInsideProjectError() - - if rpath: - rpath = osp.normpath(osp.join(url, rpath)) - - if not osp.exists(rpath): - raise FileNotFoundError(rpath) - - if not is_subpath(rpath, base=url): - raise PathOutsideSourceError( - "Source data path is outside of the directory, " - "specified by source URL: '%s', '%s'" % (rpath, url) - ) - - rpath = osp.relpath(rpath, url) - elif osp.isfile(url): - rpath = osp.basename(url) - else: - rpath = None - - if no_hash: - no_cache = True - - config = Source( - { - "url": (url or "").replace("\\", "/"), - "path": (rpath or "").replace("\\", "/"), - "format": format, - "options": options or {}, - } - ) - - if not config.is_generated: - dvcfile = self._source_dvcfile_path(name) - os.makedirs(osp.dirname(dvcfile), exist_ok=True) - - with self._make_tmp_dir() as tmp_dir: - obj_hash, tmp_dvcfile, tmp_data_dir = self._download_source( - url, tmp_dir, no_cache=no_cache, no_hash=no_hash - ) - - shutil.move(tmp_data_dir, data_dir) - on_error_do(rmtree, data_dir) - - if not no_hash: - os.replace(tmp_dvcfile, dvcfile) - config["hash"] = obj_hash - - self._git.ignore([data_dir]) - - config = self.working_tree.sources.add(name, config) - target = self.working_tree.build_targets.add_target(name) - target.root.hash = config.hash - - self.working_tree.save() - - return config - - @scoped - def add_source( - self, path: str, format: str, options: Optional[Dict] = None, *, rpath: Optional[str] = None - ) -> Tuple[str, Source]: - """ - Adds a new source (dataset) from the working directory of the project. - - Only directories from the project root can be added. This command is - useful after a source was removed and you need to re-add it, or when - the dataset was copied or downloaded manually. - - When 'rpath' is specified, will copy all the data from URL, but read - only the specified file. Required to support subtasks and subsets - in datasets. - - Parameters: - url (str): URL of the new source. A path to a directory - format (str): Dataset format - options (dict): Options for the format Extractor - rpath (str): Used to specify a relative path to the dataset - inside of the directory pointed by URL. - - Returns: the name and the config of the new source - """ - - if self.readonly: - raise ReadonlyProjectError() - - if not path: - raise ValueError("Source path cannot be empty") - - path = osp.abspath(path) - - name = osp.basename(path) - self.validate_source_name(name) - - if name in self.working_tree.sources: - raise SourceExistsError(name) - - if not osp.isdir(path): - raise FileNotFoundError("Source directory '%s' is not found" % path) - - if not (is_subpath(path, base=self._root_dir) and osp.dirname(path) == self._root_dir): - raise UnexpectedUrlError( - "The source path is expected to be " "a directory in the project root" - ) - - if rpath: - rpath = osp.normpath(osp.join(path, rpath)) - - if not osp.exists(rpath): - raise FileNotFoundError(rpath) - - if not is_subpath(rpath, base=path): - raise PathOutsideSourceError( - "Source data path is outside of the directory, " - "specified by source URL: '%s', '%s'" % (rpath, path) - ) - - rpath = osp.relpath(rpath, path) - else: - rpath = None - - self._git.ignore([path]) - - config = self.working_tree.sources.add( - name, - { - "url": (path or "").replace("\\", "/"), - "path": (rpath or "").replace("\\", "/"), - "format": format, - "options": options or {}, - }, - ) - self.working_tree.build_targets.add_target(name) - - self.working_tree.save() - - return name, config - - def remove_source(self, name: str, *, force: bool = False, keep_data: bool = True): - """ - Options: - - force (bool) - ignores errors and tries to wipe remaining data - - keep_data (bool) - leaves source data untouched - """ - - if self.readonly: - raise ReadonlyProjectError() - - if name not in self.working_tree.sources and not force: - raise UnknownSourceError(name) - - self.working_tree.sources.remove(name) - - data_dir = self.source_data_dir(name) - if not keep_data: - if osp.isdir(data_dir): - rmtree(data_dir) - - dvcfile = self._source_dvcfile_path(name) - if osp.isfile(dvcfile): - try: - rmfile(dvcfile) - except Exception: - if not force: - raise - - self.working_tree.build_targets.remove_target(name) - - self.working_tree.save() - - self._git.ignore([data_dir], mode="remove") - - def commit( - self, - message: str, - *, - no_cache: bool = False, - allow_empty: bool = False, - allow_foreign: bool = False, - ) -> Revision: - """ - Copies tree and objects from the working dir to the cache. - Creates a new commit. Moves the HEAD pointer to the new commit. - - Options: - - - no_cache (bool) - don't put added dataset data into cache, - store only metainfo. Can be used to reduce storage size. - - allow_empty (bool) - allow commits with no changes. - - allow_foreign (bool) - allow commits with changes made not by Datumaro. - - Returns: the new commit hash - """ - - if self.readonly: - raise ReadonlyProjectError() - - statuses = self.status() - - if not allow_empty and not statuses: - raise EmptyCommitError() - - for t, s in statuses.items(): - if s == DiffStatus.foreign_modified: - # TODO: compute a patch and a new stage, remove allow_foreign - if allow_foreign: - log.warning( - "The source '%s' has been changed " - "without Datumaro. It will be saved, but it will " - "only be available for reproduction from the cache.", - t, - ) - else: - raise ForeignChangesError( - "The source '%s' is changed outside Datumaro. You can " - "restore the latest source revision with 'checkout' " - "command." % t - ) - - for s in self.working_tree.sources: - self.refresh_source_hash(s, no_cache=no_cache) - - wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) - self.working_tree.save() - self._git.add(wtree_dir, base=wtree_dir) - - extra_files = [ - osp.join(self._root_dir, ".dvc", ".gitignore"), - osp.join(self._root_dir, ".dvc", "config"), - osp.join(self._root_dir, ".dvcignore"), - osp.join(self._root_dir, ".gitignore"), - osp.join(self._aux_dir, ".gitignore"), - ] - self._git.add(extra_files, base=self._root_dir) - - head = self._git.commit(message) - - rev_dir = self.cache_path(head) - copytree(wtree_dir, rev_dir) - for p in extra_files: - if osp.isfile(p): - dst_path = osp.join(rev_dir, osp.relpath(p, self._root_dir)) - os.makedirs(osp.dirname(dst_path), exist_ok=True) - shutil.copyfile(p, dst_path) - - self._head_tree = None - - return head - - @staticmethod - def _move_dvc_dir(src_dir, dst_dir): - for name in {"config", ".gitignore"}: - os.replace(osp.join(src_dir, name), osp.join(dst_dir, name)) - - def checkout( - self, - rev: Union[None, Revision] = None, - sources: Union[None, str, Iterable[str]] = None, - *, - force: bool = False, - ): - """ - Copies tree and objects from the cache to the working tree. - - Sets HEAD to the specified revision, unless sources specified. - When sources specified, only copies objects from the cache to - the working tree. When no revision and no sources is specified, - restores the sources from the current revision. - - By default, uses the current (HEAD) revision. - - Options: - - force (bool) - ignore unsaved changes. By default, an error is raised - """ - - if self.readonly: - raise ReadonlyProjectError() - - if isinstance(sources, str): - sources = {sources} - elif sources is None: - sources = {} - else: - sources = set(sources) - - rev = rev or "HEAD" - - if sources: - rev_tree = self.get_rev(rev) - - # Check targets - for s in sources: - if s not in rev_tree.sources: - raise UnknownSourceError(s) - - rev_dir = rev_tree.config.base_dir - with self._make_tmp_dir() as tmp_dir: - dvcfiles = [] - - for s in sources: - dvcfile = self._source_dvcfile_path(s, root=rev_dir) - - tmp_dvcfile = osp.join(tmp_dir, s + ".dvc") - with open(dvcfile) as f: - conf = self._dvc.yaml_parser.load(f) - - conf["wdir"] = self._root_dir - - with open(tmp_dvcfile, "w") as f: - self._dvc.yaml_parser.dump(conf, f) - - dvcfiles.append(tmp_dvcfile) - - self._dvc.checkout(dvcfiles) - - self._git.ignore(sources) - - for s in sources: - self.working_tree.config.sources[s] = rev_tree.config.sources[s] - self.working_tree.config.build_targets[s] = rev_tree.config.build_targets[s] - - self.working_tree.save() - else: - # Check working tree for unsaved changes, - # set HEAD to the revision - # write revision tree to working tree - wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) - self._git.checkout(rev, dst_dir=wtree_dir, clean=True, force=force) - self._move_dvc_dir(osp.join(wtree_dir, ".dvc"), osp.join(self._root_dir, ".dvc")) - - self._working_tree = None - - # Restore sources from the commit. - # Work with the working tree instead of cache, to - # avoid extra memory use from materializing - # the head commit sources in the cache - rev_tree = self.working_tree - with self._make_tmp_dir() as tmp_dir: - dvcfiles = [] - - for s in rev_tree.sources: - dvcfile = self._source_dvcfile_path(s) - - tmp_dvcfile = osp.join(tmp_dir, s + ".dvc") - with open(dvcfile) as f: - conf = self._dvc.yaml_parser.load(f) - - conf["wdir"] = self._root_dir - - with open(tmp_dvcfile, "w") as f: - self._dvc.yaml_parser.dump(conf, f) - - dvcfiles.append(tmp_dvcfile) - - self._dvc.checkout(dvcfiles) - - os.replace(osp.join(wtree_dir, ".gitignore"), osp.join(self._root_dir, ".gitignore")) - os.replace(osp.join(wtree_dir, ".dvcignore"), osp.join(self._root_dir, ".dvcignore")) - - self._working_tree = None - - def is_ref(self, ref: Union[None, str]) -> bool: - if self._is_working_tree_ref(ref): - return True - return self._git.is_ref(ref) - - def has_commits(self) -> bool: - return self._git.has_commits() - - def status(self) -> Dict[str, DiffStatus]: - wd = self.working_tree - - if not self.has_commits(): - return {s: DiffStatus.added for s in wd.sources} - - head = self.head - - changed_targets = {} - - for t_name, wd_target in wd.build_targets.items(): - if t_name == ProjectBuildTargets.MAIN_TARGET: - continue - - if osp.isdir(self.source_data_dir(t_name)): - old_hash = wd_target.head.hash - new_hash = self.compute_source_hash(t_name, no_cache=True) - - if old_hash and old_hash != new_hash: - changed_targets[t_name] = DiffStatus.foreign_modified - - for t_name in set(head.build_targets) | set(wd.build_targets): - if t_name == ProjectBuildTargets.MAIN_TARGET: - continue - if t_name in changed_targets: - continue - - head_target = head.build_targets.get(t_name) - wd_target = wd.build_targets.get(t_name) - - status = None - - if head_target is None: - status = DiffStatus.added - elif wd_target is None: - status = DiffStatus.removed - else: - if head_target != wd_target: - status = DiffStatus.modified - elif not osp.isdir(self.source_data_dir(t_name)): - status = DiffStatus.missing - - if status: - changed_targets[t_name] = status - - return changed_targets - - def history(self, max_count=10) -> List[Tuple[Revision, str]]: - return [(c.hexsha, c.message) for c, _ in self._git.log(max_count)] - - def diff( - self, rev_a: Union[Tree, Revision], rev_b: Union[Tree, Revision] - ) -> Dict[str, DiffStatus]: - """ - Compares 2 revision trees. - - Returns: { target_name: status } for changed targets - """ - - if rev_a == rev_b: - return {} - - if isinstance(rev_a, str): - tree_a = self.get_rev(rev_a) - else: - tree_a = rev_a - - if isinstance(rev_b, str): - tree_b = self.get_rev(rev_b) - else: - tree_b = rev_b - - changed_targets = {} - - for t_name in set(tree_a.build_targets) | set(tree_b.build_targets): - if t_name == ProjectBuildTargets.MAIN_TARGET: - continue - - head_target = tree_a.build_targets.get(t_name) - wd_target = tree_b.build_targets.get(t_name) - - status = None - - if head_target is None: - status = DiffStatus.added - elif wd_target is None: - status = DiffStatus.removed - else: - if head_target != wd_target: - status = DiffStatus.modified - - if status: - changed_targets[t_name] = status - - return changed_targets - - def model_data_dir(self, name: str) -> str: - return osp.join(self._aux_dir, ProjectLayout.models_dir, name) - - def make_model(self, name: str) -> Launcher: - model = self._config.models[name] - model_dir = self.model_data_dir(name) - if not osp.isdir(model_dir): - model_dir = None - return self._env.make_launcher(model.launcher, **model.options, model_dir=model_dir) - - def add_model(self, name: str, launcher: str, options: Dict[str, Any] = None) -> Model: - if self.readonly: - raise ReadonlyProjectError() - - if launcher not in self.env.launchers: - raise KeyError("Unknown launcher '%s'" % launcher) - - if not name: - raise ValueError("Model name can't be empty") - - if name in self.models: - raise KeyError("Model '%s' already exists" % name) - - return self._config.models.set(name, {"launcher": launcher, "options": options or {}}) - - def remove_model(self, name: str): - if self.readonly: - raise ReadonlyProjectError() - - if name not in self.models: - raise KeyError("Unknown model '%s'" % name) - - self._config.models.remove(name) - - data_dir = self.model_data_dir(name) - if osp.isdir(data_dir): - rmtree(data_dir) diff --git a/src/datumaro/project.py b/src/datumaro/project.py index 07de8fc40d..06f83171c7 100644 --- a/src/datumaro/project.py +++ b/src/datumaro/project.py @@ -2,20 +2,8 @@ # # SPDX-License-Identifier: MIT -# ruff: noqa: F401 +# Project functionality has been removed as of version 1.12 +# This module is kept for backward compatibility but all exports have been removed -# This module is a usability proxy for components.project - -from .components.project import ( - BuildStageType, - DiffStatus, - IgnoreMode, - ObjectId, - Pipeline, - Project, - ProjectBuilder, - ProjectBuildTargets, - ProjectSourceDataset, - Revision, - Tree, -) +# All project-related functionality has been deprecated and removed. +# For dataset operations, use datumaro.components.dataset.Dataset directly. diff --git a/tests/integration/cli/test_compare.py b/tests/integration/cli/test_compare.py index dfb8cf3c99..d3417c29a5 100644 --- a/tests/integration/cli/test_compare.py +++ b/tests/integration/cli/test_compare.py @@ -21,9 +21,9 @@ PolyLine, ) from datumaro.components.comparator import DistanceComparator +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from ...requirements import Requirements, mark_requirement diff --git a/tests/integration/cli/test_merge.py b/tests/integration/cli/test_merge.py index 1ac28cc300..302d035d0b 100644 --- a/tests/integration/cli/test_merge.py +++ b/tests/integration/cli/test_merge.py @@ -10,9 +10,9 @@ import datumaro.plugins.data_formats.voc.format as VOC from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories, MaskCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from ...requirements import Requirements, mark_requirement diff --git a/tests/integration/cli/test_patch.py b/tests/integration/cli/test_patch.py index a55219c9e0..fed480ce1e 100644 --- a/tests/integration/cli/test_patch.py +++ b/tests/integration/cli/test_patch.py @@ -4,9 +4,9 @@ import numpy as np from datumaro.components.annotation import Bbox +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from ...requirements import Requirements, mark_requirement diff --git a/tests/unit/data_formats/arrow/conftest.py b/tests/unit/data_formats/arrow/conftest.py index 22f99ae149..e2cdec418c 100644 --- a/tests/unit/data_formats/arrow/conftest.py +++ b/tests/unit/data_formats/arrow/conftest.py @@ -9,9 +9,9 @@ import pytest from datumaro.components.annotation import Cuboid3d, Label +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image, PointCloud -from datumaro.components.project import Dataset from datumaro.util.image import encode_image from ..datumaro.conftest import ( diff --git a/tests/unit/data_formats/arrow/test_arrow_format.py b/tests/unit/data_formats/arrow/test_arrow_format.py index e5a0744080..42f7a9da3e 100644 --- a/tests/unit/data_formats/arrow/test_arrow_format.py +++ b/tests/unit/data_formats/arrow/test_arrow_format.py @@ -8,10 +8,10 @@ import numpy as np import pytest +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment from datumaro.components.media import FromFileMixin, Image -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.arrow import ArrowExporter, ArrowImporter from datumaro.plugins.transforms import Sort diff --git a/tests/unit/data_formats/datumaro/conftest.py b/tests/unit/data_formats/datumaro/conftest.py index 5af727cc56..2336340ab8 100644 --- a/tests/unit/data_formats/datumaro/conftest.py +++ b/tests/unit/data_formats/datumaro/conftest.py @@ -28,9 +28,9 @@ PolyLine, RleMask, ) +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image, MediaElement, PointCloud, Video, VideoFrame -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.datumaro.format import DatumaroPath from datumaro.util.mask_tools import generate_colormap diff --git a/tests/unit/data_formats/datumaro/test_datumaro_format.py b/tests/unit/data_formats/datumaro/test_datumaro_format.py index 0575e20f72..523a21de2f 100644 --- a/tests/unit/data_formats/datumaro/test_datumaro_format.py +++ b/tests/unit/data_formats/datumaro/test_datumaro_format.py @@ -12,12 +12,12 @@ import numpy as np import pytest +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment from datumaro.components.errors import PathSeparatorInSubsetNameError from datumaro.components.importer import DatasetImportError from datumaro.components.media import Image -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.datumaro.exporter import DatumaroExporter from datumaro.plugins.data_formats.datumaro.format import DatumaroPath from datumaro.plugins.data_formats.datumaro.importer import DatumaroImporter diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index b4b262238f..6200c9e30e 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -18,10 +18,9 @@ def test_can_import_core(self): def test_can_reach_module_alias_symbols_from_base(self): import datumaro as dm - assert hasattr(dm.project, "Project") assert hasattr(dm.errors, "DatumaroError") @mark_requirement(Requirements.DATUM_API) def test_can_import_from_module_aliases(self): + from datumaro.components.dataset import Dataset from datumaro.errors import DatumaroError - from datumaro.project import Project diff --git a/tests/unit/test_compare.py b/tests/unit/test_compare.py index 112d986cee..b78ecc4d7e 100644 --- a/tests/unit/test_compare.py +++ b/tests/unit/test_compare.py @@ -12,9 +12,9 @@ from datumaro.cli.util.compare import DistanceCompareVisualizer from datumaro.components.annotation import Bbox, Caption, Label, Mask, Points from datumaro.components.comparator import DistanceComparator, EqualityComparator, TableComparator +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from ..requirements import Requirements, mark_requirement diff --git a/tests/unit/test_icdar_format.py b/tests/unit/test_icdar_format.py index da6a54bbcd..22059ca0f7 100644 --- a/tests/unit/test_icdar_format.py +++ b/tests/unit/test_icdar_format.py @@ -5,10 +5,10 @@ import numpy as np from datumaro.components.annotation import Bbox, Caption, Mask, Polygon +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment from datumaro.components.media import Image -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.icdar.base import ( IcdarTextLocalizationImporter, IcdarTextSegmentationImporter, diff --git a/tests/unit/test_image_dir_format.py b/tests/unit/test_image_dir_format.py index b3d72766c8..864fe1078c 100644 --- a/tests/unit/test_image_dir_format.py +++ b/tests/unit/test_image_dir_format.py @@ -4,7 +4,6 @@ from datumaro.components.dataset import Dataset, StreamDataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.image_dir import ImageDirExporter from ..requirements import Requirements, mark_requirement diff --git a/tests/unit/test_image_zip_format.py b/tests/unit/test_image_zip_format.py index 09cd7d8b49..aa7018e7ed 100644 --- a/tests/unit/test_image_zip_format.py +++ b/tests/unit/test_image_zip_format.py @@ -3,9 +3,9 @@ import numpy as np +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image, save_image -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.image_zip import ImageZipExporter, ImageZipPath from ..requirements import Requirements, mark_requirement diff --git a/tests/unit/test_kitti_3d_format.py b/tests/unit/test_kitti_3d_format.py index 3dbadb2507..5148e60d64 100644 --- a/tests/unit/test_kitti_3d_format.py +++ b/tests/unit/test_kitti_3d_format.py @@ -4,10 +4,10 @@ import numpy as np from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment from datumaro.components.media import Image, PointCloud -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.kitti_3d.importer import Kitti3dImporter from tests.requirements import Requirements, mark_requirement diff --git a/tests/unit/test_kitti_raw_format.py b/tests/unit/test_kitti_raw_format.py index 498e99b20f..7510332e9e 100644 --- a/tests/unit/test_kitti_raw_format.py +++ b/tests/unit/test_kitti_raw_format.py @@ -6,10 +6,10 @@ import numpy as np from datumaro.components.annotation import AnnotationType, Cuboid3d, LabelCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment from datumaro.components.media import Image, PointCloud -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.kitti_raw.base import KittiRawImporter from datumaro.plugins.data_formats.kitti_raw.exporter import KittiRawExporter diff --git a/tests/unit/test_labeling.py b/tests/unit/test_labeling.py index fab87e70a9..a7cf3e643f 100644 --- a/tests/unit/test_labeling.py +++ b/tests/unit/test_labeling.py @@ -7,9 +7,9 @@ import numpy as np from datumaro.components.annotation import AnnotationType, Label, LabelCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from ..requirements import Requirements, mark_requirement diff --git a/tests/unit/test_ndr.py b/tests/unit/test_ndr.py index 2d13fe3231..4b8a64c420 100644 --- a/tests/unit/test_ndr.py +++ b/tests/unit/test_ndr.py @@ -4,10 +4,10 @@ import datumaro.plugins.ndr as ndr from datumaro.components.annotation import AnnotationType, Label, LabelCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.errors import MediaShapeError from datumaro.components.media import Image -from datumaro.components.project import Dataset from ..requirements import Requirements, mark_requirement diff --git a/tests/unit/test_project.py b/tests/unit/test_project.py deleted file mode 100644 index 3074faa0e5..0000000000 --- a/tests/unit/test_project.py +++ /dev/null @@ -1,1404 +0,0 @@ -# Copyright (C) 2019-2023 Intel Corporation -# -# SPDX-License-Identifier: MIT - -import os -import os.path as osp -import shutil -import textwrap -from typing import List, Sequence -from unittest import TestCase - -import numpy as np -import pytest - -from datumaro.components.annotation import Annotation, Bbox, Label -from datumaro.components.config_model import Model, Source -from datumaro.components.dataset import DEFAULT_FORMAT, Dataset -from datumaro.components.dataset_base import DatasetBase, DatasetItem -from datumaro.components.errors import ( - DatasetMergeError, - EmptyCommitError, - EmptyPipelineError, - ForeignChangesError, - MismatchingObjectError, - MissingObjectError, - MissingSourceHashError, - OldProjectError, - PathOutsideSourceError, - ReadonlyProjectError, - SourceExistsError, - SourceUrlInsideProjectError, - UnexpectedUrlError, - UnknownTargetError, - VcsAlreadyExists, -) -from datumaro.components.launcher import Launcher -from datumaro.components.media import Image -from datumaro.components.project import DiffStatus, Project -from datumaro.components.transformer import ItemTransform -from datumaro.util.os_util import find_files -from datumaro.util.scope import scope_add, scoped - -from ..requirements import Requirements, mark_requirement - -from tests.utils.assets import get_test_asset_path -from tests.utils.test_utils import TestDir, compare_datasets, compare_dirs - - -class ProjectNewTest: - @pytest.fixture(params=[".git", ".dvc"]) - def fxt_vcs_exist_dir(self, test_dir, request): - vcs_dir = osp.join(test_dir, request.param) - os.makedirs(vcs_dir) - with open(osp.join(vcs_dir, "dummy.file"), "w") as fp: - fp.write("dummy") - yield test_dir - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_init_failed_by_vcs_already_exist(self, fxt_vcs_exist_dir): - with pytest.raises(VcsAlreadyExists): - Project.init(fxt_vcs_exist_dir) - - # Assert Project.init() do not the existing vcs directory - assert len(list(find_files(fxt_vcs_exist_dir, ".file", recursive=True))) > 0 - - -class ProjectTest(TestCase): - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_init_and_load(self): - test_dir = scope_add(TestDir()) - - scope_add(Project.init(test_dir)).close() - scope_add(Project(test_dir)) - - self.assertTrue(".datumaro" in os.listdir(test_dir)) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_find_project_in_project_dir(self): - test_dir = scope_add(TestDir()) - - scope_add(Project.init(test_dir)) - - self.assertEqual(osp.join(test_dir, ".datumaro"), Project.find_project_dir(test_dir)) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_find_project_when_no_project(self): - test_dir = scope_add(TestDir()) - - self.assertEqual(None, Project.find_project_dir(test_dir)) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_add_local_model(self): - class TestLauncher(Launcher): - pass - - source_name = "source" - config = Model({"launcher": "test", "options": {"a": 5, "b": "hello"}}) - - test_dir = scope_add(TestDir()) - project = scope_add(Project.init(test_dir)) - project.env.launchers.register("test", TestLauncher) - - project.add_model(source_name, launcher=config.launcher, options=config.options) - - added = project.models[source_name] - self.assertEqual(added.launcher, config.launcher) - self.assertEqual(added.options, config.options) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_run_inference(self): - class TestLauncher(Launcher): - def launch(self, batch: Sequence[DatasetItem]) -> List[List[Annotation]]: - return [[Label(inp.media.data[0, 0, 0])] for inp in batch] - - expected = Dataset.from_iterable( - [ - DatasetItem( - 0, media=Image.from_numpy(data=np.zeros([2, 2, 3])), annotations=[Label(0)] - ), - DatasetItem( - 1, media=Image.from_numpy(data=np.ones([2, 2, 3])), annotations=[Label(1)] - ), - ], - categories=["a", "b"], - ) - - launcher_name = "custom_launcher" - model_name = "model" - - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem(0, media=Image.from_numpy(data=np.zeros([2, 2, 3]) * 0)), - DatasetItem(1, media=Image.from_numpy(data=np.ones([2, 2, 3]) * 1)), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.env.launchers.register(launcher_name, TestLauncher) - project.add_model(model_name, launcher=launcher_name) - project.import_source("source", source_url, format=DEFAULT_FORMAT) - - dataset = project.working_tree.make_dataset() - model = project.make_model(model_name) - - inference = dataset.run_model(model) - - compare_datasets(self, expected, inference) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_import_local_dir_source(self): - test_dir = scope_add(TestDir()) - source_base_url = osp.join(test_dir, "test_repo") - source_file_path = osp.join(source_base_url, "x", "y.txt") - os.makedirs(osp.dirname(source_file_path), exist_ok=True) - with open(source_file_path, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_base_url, format="fmt") - - source = project.working_tree.sources["s1"] - self.assertEqual("fmt", source.format) - compare_dirs(self, source_base_url, project.source_data_dir("s1")) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertTrue("/s1" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_import_local_file_source(self): - # In this variant, we copy and read just the file specified - - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "f.txt") - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format="fmt") - - source = project.working_tree.sources["s1"] - self.assertEqual("fmt", source.format) - self.assertEqual("f.txt", source.path) - - self.assertEqual({"f.txt"}, set(os.listdir(project.source_data_dir("s1")))) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertTrue("/s1" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_import_local_source_with_relpath(self): - # This form must copy all the data in URL, but read only - # specified files. Required to support subtasks and subsets. - - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - subset="a", - media=Image.from_numpy(data=np.zeros([2, 2, 3])), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="b", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - expected_dataset = Dataset.from_iterable( - [ - DatasetItem( - 1, - subset="b", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - - project.import_source( - "s1", url=source_url, format=DEFAULT_FORMAT, rpath=osp.join("annotations", "b.json") - ) - - source = project.working_tree.sources["s1"] - self.assertEqual(DEFAULT_FORMAT, source.format) - - compare_dirs(self, source_url, project.source_data_dir("s1")) - read_dataset = project.working_tree.make_dataset("s1") - compare_datasets(self, expected_dataset, read_dataset, require_media=True) - - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertTrue("/s1" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_import_local_source_with_relpath_outside(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - os.makedirs(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - - with self.assertRaises(PathOutsideSourceError): - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT, rpath="..") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_import_local_source_with_url_inside_project(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "qq") - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(test_dir)) - - with self.assertRaises(SourceUrlInsideProjectError): - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_report_incompatible_sources(self): - test_dir = scope_add(TestDir()) - source1_url = osp.join(test_dir, "dataset1") - dataset1 = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - ], - categories=["a", "b"], - ) - dataset1.save(source1_url) - - source2_url = osp.join(test_dir, "dataset2") - dataset2 = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - ], - categories=["c", "d"], - ) - dataset2.save(source2_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source1_url, format=DEFAULT_FORMAT) - project.import_source("s2", url=source2_url, format=DEFAULT_FORMAT) - - with self.assertRaises(DatasetMergeError) as cm: - project.working_tree.make_dataset() - - self.assertEqual({"s1.root", "s2.root"}, cm.exception.sources) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_import_sources_with_same_names(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - with self.assertRaises(SourceExistsError): - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_import_generated_source(self): - test_dir = scope_add(TestDir()) - source_name = "source" - origin = Source( - { - # no url - "format": "fmt", - "options": {"c": 5, "d": "hello"}, - } - ) - project = scope_add(Project.init(test_dir)) - - project.import_source(source_name, url="", format=origin.format, options=origin.options) - - added = project.working_tree.sources[source_name] - self.assertEqual(added.format, origin.format) - self.assertEqual(added.options, origin.options) - with open(osp.join(test_dir, ".gitignore")) as f: - self.assertTrue("/" + source_name in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_import_source_with_wrong_name(self): - test_dir = scope_add(TestDir()) - project = scope_add(Project.init(test_dir)) - - for name in {"dataset", "project", "build", ".any"}: - with self.subTest(name=name), self.assertRaisesRegex(ValueError, "Source name"): - project.import_source(name, url="", format="fmt") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_add_project_local_source(self): - test_dir = scope_add(TestDir()) - proj_dir = osp.join(test_dir, "proj") - - project = scope_add(Project.init(proj_dir)) - - source_base_url = osp.join(proj_dir, "x") - source_file_path = osp.join(source_base_url, "y.txt") - os.makedirs(osp.dirname(source_file_path)) - with open(source_file_path, "w") as f: - f.write("hello") - - name, source = project.add_source(source_base_url, format="fmt") - - self.assertEqual("x", name) - self.assertEqual(project.working_tree.sources[name], source) - self.assertEqual("fmt", source.format) - - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertTrue("/x" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_add_source_deep_in_the_project(self): - test_dir = scope_add(TestDir()) - proj_dir = osp.join(test_dir, "proj") - - project = scope_add(Project.init(proj_dir)) - - source_base_url = osp.join(proj_dir, "x", "y") - source_file_path = osp.join(source_base_url, "y.txt") - os.makedirs(osp.dirname(source_file_path)) - with open(source_file_path, "w") as f: - f.write("hello") - - with self.assertRaises(UnexpectedUrlError): - project.add_source(source_base_url, format="fmt") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_add_source_outside_project(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "x") - os.makedirs(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - - with self.assertRaises(UnexpectedUrlError): - project.add_source(source_url, format="fmt") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_remove_source_and_keep_data(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_source.txt") - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - project.remove_source("s1", keep_data=True) - - self.assertFalse("s1" in project.working_tree.sources) - compare_dirs(self, source_url, project.source_data_dir("s1")) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertFalse("/s1" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_remove_source_and_wipe_data(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_source.txt") - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - project.remove_source("s1", keep_data=False) - - self.assertFalse("s1" in project.working_tree.sources) - self.assertFalse(osp.exists(project.source_data_dir("s1"))) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertFalse("/s1" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_redownload_source_rev_noncached(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - media=Image.from_numpy(data=np.ones((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="s", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") - - # remove local source data - project.remove_cache_obj(project.working_tree.build_targets["s1"].head.hash) - shutil.rmtree(project.source_data_dir("s1")) - - read_dataset = project.working_tree.make_dataset("s1") - - compare_datasets(self, source_dataset, read_dataset) - compare_dirs( - self, source_url, project.cache_path(project.working_tree.build_targets["s1"].root.hash) - ) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_redownload_source_and_check_data_hash(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - media=Image.from_numpy(data=np.zeros((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="s", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") - - # remove local source data - project.remove_cache_obj(project.working_tree.build_targets["s1"].head.hash) - shutil.rmtree(project.source_data_dir("s1")) - - # modify the source repo - with open(osp.join(source_url, "extra_file.txt"), "w") as f: - f.write("text\n") - - with self.assertRaises(MismatchingObjectError): - project.working_tree.make_dataset("s1") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_use_source_from_cache_with_working_copy(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - media=Image.from_numpy(data=np.zeros((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="s", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") - - shutil.rmtree(project.source_data_dir("s1")) - - read_dataset = project.working_tree.make_dataset("s1") - - compare_datasets(self, source_dataset, read_dataset) - self.assertFalse(osp.isdir(project.source_data_dir("s1"))) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_raises_an_error_if_local_data_unknown(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - media=Image.from_numpy(data=np.zeros((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="s", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") - - # remove the cached object so that it couldn't be matched - project.remove_cache_obj(project.working_tree.build_targets["s1"].root.hash) - - # modify local source data - with open(osp.join(project.source_data_dir("s1"), "extra.txt"), "w") as f: - f.write("text\n") - - with self.assertRaises(ForeignChangesError): - project.working_tree.make_dataset("s1") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_read_working_copy_of_source(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - media=Image.from_numpy(data=np.zeros((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="s", - media=Image.from_numpy(data=np.zeros((1, 2, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - read_dataset = project.working_tree.make_dataset("s1") - - compare_datasets(self, source_dataset, read_dataset) - compare_dirs(self, source_url, project.source_data_dir("s1")) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_read_current_revision_of_source(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - media=Image.from_numpy(data=np.zeros((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="s", - media=Image.from_numpy(data=np.zeros((1, 2, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.commit("A commit") - - shutil.rmtree(project.source_data_dir("s1")) - - read_dataset = project.head.make_dataset("s1") - - compare_datasets(self, source_dataset, read_dataset) - self.assertFalse(osp.isdir(project.source_data_dir("s1"))) - compare_dirs(self, source_url, project.head.source_data_dir("s1")) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_make_dataset_from_project(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - source_dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - read_dataset = project.working_tree.make_dataset() - - compare_datasets(self, source_dataset, read_dataset) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_make_dataset_from_source(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - built_dataset = project.working_tree.make_dataset("s1") - - compare_datasets(self, dataset, built_dataset) - self.assertEqual(DEFAULT_FORMAT, built_dataset.format) - self.assertEqual(project.source_data_dir("s1"), built_dataset.data_path) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_make_dataset_from_empty_project(self): - test_dir = scope_add(TestDir()) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - - with self.assertRaises(EmptyPipelineError): - project.working_tree.make_dataset() - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_make_dataset_from_unknown_target(self): - test_dir = scope_add(TestDir()) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - - with self.assertRaises(UnknownTargetError): - project.working_tree.make_dataset("s1") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_add_filter_stage(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - new_tree = project.working_tree.clone() - stage = new_tree.build_targets.add_filter_stage("s1", '/item/annotation[label="b"]') - - self.assertTrue(stage in new_tree.build_targets) - self.assertTrue(stage not in project.working_tree.build_targets) - - resulting_dataset = project.working_tree.make_dataset(new_tree.make_pipeline("s1")) - compare_datasets( - self, - Dataset.from_iterable( - [ - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ), - resulting_dataset, - ) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_add_convert_stage(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - stage = project.working_tree.build_targets.add_convert_stage("s1", DEFAULT_FORMAT) - - self.assertTrue(stage in project.working_tree.build_targets) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_add_transform_stage(self): - class TestTransform(ItemTransform): - def __init__(self, extractor, p1=None, p2=None): - super().__init__(extractor) - self.p1 = p1 - self.p2 = p2 - - def transform_item(self, item): - return self.wrap_item(item, attributes={"p1": self.p1, "p2": self.p2}) - - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.working_tree.env.transforms.register("tr", TestTransform) - - # TODO: simplify adding stages and making datasets from them - new_tree = project.working_tree.clone() - stage = new_tree.build_targets.add_transform_stage( - "s1", "tr", params={"p1": 5, "p2": ["1", 2, 3.5]} - ) - - self.assertTrue(stage in new_tree.build_targets) - self.assertTrue(stage not in project.working_tree.build_targets) - - resulting_dataset = project.working_tree.make_dataset(new_tree.make_pipeline("s1")) - compare_datasets( - self, - Dataset.from_iterable( - [ - DatasetItem( - 1, annotations=[Label(0)], attributes={"p1": 5, "p2": ["1", 2, 3.5]} - ), - DatasetItem( - 2, annotations=[Label(1)], attributes={"p1": 5, "p2": ["1", 2, 3.5]} - ), - ], - categories=["a", "b"], - ), - resulting_dataset, - ) - - project.working_tree.config.update(new_tree.config) - self.assertTrue(stage in project.working_tree.build_targets) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_make_dataset_from_stage(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - - built_dataset = project.working_tree.make_dataset("s1.root") - - compare_datasets(self, dataset, built_dataset) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_commit(self): - test_dir = scope_add(TestDir()) - project = scope_add(Project.init(test_dir)) - - commit_hash = project.commit("First commit", allow_empty=True) - - self.assertTrue(project.is_ref(commit_hash)) - self.assertEqual(len(project.history()), 2) - self.assertEqual(project.history()[0], (commit_hash, "First commit")) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_commit_empty(self): - test_dir = scope_add(TestDir()) - project = scope_add(Project.init(test_dir)) - - with self.assertRaises(EmptyCommitError): - project.commit("First commit") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_commit_patch(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_source.txt") - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", source_url, format=DEFAULT_FORMAT) - project.commit("First commit") - - source_path = osp.join(project.source_data_dir("s1"), osp.basename(source_url)) - with open(source_path, "w") as f: - f.write("world") - - commit_hash = project.commit("Second commit", allow_foreign=True) - - self.assertTrue(project.is_ref(commit_hash)) - self.assertNotEqual( - project.get_rev("HEAD~1").build_targets["s1"].head.hash, - project.working_tree.build_targets["s1"].head.hash, - ) - self.assertTrue(project.is_obj_cached(project.working_tree.build_targets["s1"].head.hash)) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_commit_foreign_changes(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_source.txt") - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", source_url, format=DEFAULT_FORMAT) - project.commit("First commit") - - source_path = osp.join(project.source_data_dir("s1"), osp.basename(source_url)) - with open(source_path, "w") as f: - f.write("world") - - with self.assertRaises(ForeignChangesError): - project.commit("Second commit") - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_checkout_revision(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_source.txt") - os.makedirs(osp.dirname(source_url), exist_ok=True) - with open(source_url, "w") as f: - f.write("hello") - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", source_url, format=DEFAULT_FORMAT) - project.commit("First commit") - - source_path = osp.join(project.source_data_dir("s1"), osp.basename(source_url)) - with open(source_path, "w") as f: - f.write("world") - project.commit("Second commit", allow_foreign=True) - - project.checkout("HEAD~1") - - compare_dirs(self, source_url, project.source_data_dir("s1")) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - self.assertTrue("/s1" in [line.strip() for line in f]) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_checkout_sources(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s2", url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") - project.remove_source("s1", keep_data=False) # remove s1 from tree - shutil.rmtree(project.source_data_dir("s2")) # modify s2 "manually" - - project.checkout(sources=["s1", "s2"]) - - compare_dirs(self, source_url, project.source_data_dir("s1")) - compare_dirs(self, source_url, project.source_data_dir("s2")) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - lines = [line.strip() for line in f] - self.assertTrue("/s1" in lines) - self.assertTrue("/s2" in lines) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_checkout_with_force(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s2", url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") - project.remove_source("s1", keep_data=False) # remove s1 from tree - shutil.rmtree(project.source_data_dir("s2")) # modify s2 "manually" - - project.checkout(force=True) - - compare_dirs(self, source_url, project.source_data_dir("s1")) - compare_dirs(self, source_url, project.source_data_dir("s2")) - with open(osp.join(test_dir, "proj", ".gitignore")) as f: - lines = [line.strip() for line in f] - self.assertTrue("/s1" in lines) - self.assertTrue("/s2" in lines) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_checkout_sources_from_revision(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") - project.remove_source("s1", keep_data=False) - project.commit("Commit 2") - - project.checkout(rev="HEAD~1", sources=["s1"]) - - compare_dirs(self, source_url, project.source_data_dir("s1")) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_check_status(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s2", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s3", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s4", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s5", url=source_url, format=DEFAULT_FORMAT) - project.commit("Commit 1") - - project.remove_source("s2") - project.import_source("s6", url=source_url, format=DEFAULT_FORMAT) - - shutil.rmtree(project.source_data_dir("s3")) - - project.working_tree.build_targets.add_transform_stage("s4", "reindex") - project.working_tree.make_dataset("s4").save() - project.refresh_source_hash("s4") - - s5_dir = osp.join(project.source_data_dir("s5")) - with open(osp.join(s5_dir, "annotations", "t.txt"), "w") as f: - f.write("hello") - - status = project.status() - self.assertEqual( - { - "s2": DiffStatus.removed, - "s3": DiffStatus.missing, - "s4": DiffStatus.modified, - "s5": DiffStatus.foreign_modified, - "s6": DiffStatus.added, - }, - status, - ) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_compare_revisions(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - project.import_source("s2", url=source_url, format=DEFAULT_FORMAT) - rev1 = project.commit("Commit 1") - - project.remove_source("s2") - project.import_source("s3", url=source_url, format=DEFAULT_FORMAT) - rev2 = project.commit("Commit 2") - - diff = project.diff(rev1, rev2) - self.assertEqual(diff, {"s2": DiffStatus.removed, "s3": DiffStatus.added}) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_restore_revision(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "test_repo") - dataset = Dataset.from_iterable( - [ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], - categories=["a", "b"], - ) - dataset.save(source_url) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source("s1", url=source_url, format=DEFAULT_FORMAT) - rev1 = project.commit("Commit 1") - - project.remove_cache_obj(rev1) - - self.assertFalse(project.is_rev_cached(rev1)) - - head_dataset = project.head.make_dataset() - - self.assertTrue(project.is_rev_cached(rev1)) - compare_datasets(self, dataset, head_dataset) - - @mark_requirement(Requirements.DATUM_BUG_404) - @scoped - def test_can_add_plugin(self): - test_dir = scope_add(TestDir()) - scope_add(Project.init(test_dir)).close() - - plugin_dir = osp.join(test_dir, ".datumaro", "plugins") - os.makedirs(plugin_dir) - with open(osp.join(plugin_dir, "__init__.py"), "w") as f: - f.write( - textwrap.dedent( - """ - from datumaro.components.dataset_base import (SubsetBase, - DatasetItem) - class MyBase(SubsetBase): - def __init__(self, *args, **kwargs): - super().__init__() - def __iter__(self): - yield from [ - DatasetItem('1'), - DatasetItem('2'), - ] - def ann_types(self): - return {} - """ - ) - ) - - project = scope_add(Project(test_dir)) - project.import_source("src", url="", format="my") - - expected = Dataset.from_iterable([DatasetItem("1"), DatasetItem("2")]) - actual = project.working_tree.make_dataset() - compare_datasets(self, expected, actual) - - @mark_requirement(Requirements.DATUM_BUG_402) - @scoped - def test_can_transform_by_name(self): - class CustomExtractor(DatasetBase): - def __init__(self, *args, **kwargs): - super().__init__() - - def __iter__(self): - return iter( - [ - DatasetItem("a"), - DatasetItem("b"), - ] - ) - - test_dir = scope_add(TestDir()) - extractor_name = "ext1" - project = scope_add(Project.init(test_dir)) - project.env.extractors.register(extractor_name, CustomExtractor) - project.import_source("src1", url="", format=extractor_name) - dataset = project.working_tree.make_dataset() - - dataset = dataset.transform("reindex") - - expected = Dataset.from_iterable( - [ - DatasetItem(1), - DatasetItem(2), - ] - ) - compare_datasets(self, expected, dataset) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_modify_readonly(self): - test_dir = scope_add(TestDir()) - dataset_url = osp.join(test_dir, "dataset") - Dataset.from_iterable( - [ - DatasetItem("a"), - DatasetItem("b"), - ] - ).save(dataset_url) - - proj_dir = osp.join(test_dir, "proj") - with Project.init(proj_dir) as project: - project.import_source("source1", url=dataset_url, format=DEFAULT_FORMAT) - project.commit("first commit") - project.remove_source("source1") - commit2 = project.commit("second commit") - project.checkout("HEAD~1") - project.remove_cache_obj(commit2) - project.remove_cache_obj(project.working_tree.sources["source1"].hash) - - project = scope_add(Project(proj_dir, readonly=True)) - - self.assertTrue(project.readonly) - - with self.subTest("add source"), self.assertRaises(ReadonlyProjectError): - project.import_source("src1", url="", format=DEFAULT_FORMAT) - - with self.subTest("remove source"), self.assertRaises(ReadonlyProjectError): - project.remove_source("src1") - - with self.subTest("add model"), self.assertRaises(ReadonlyProjectError): - project.add_model("m1", launcher="x") - - with self.subTest("remove model"), self.assertRaises(ReadonlyProjectError): - project.remove_model("m1") - - with self.subTest("checkout"), self.assertRaises(ReadonlyProjectError): - project.checkout("HEAD") - - with self.subTest("commit"), self.assertRaises(ReadonlyProjectError): - project.commit("third commit", allow_empty=True) - - # Can't re-download the source in a readonly project - with self.subTest("make_dataset"), self.assertRaises(MissingObjectError): - project.get_rev("HEAD").make_dataset() - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_import_without_hashing(self): - test_dir = scope_add(TestDir()) - dataset_url = osp.join(test_dir, "dataset") - dataset = Dataset.from_iterable( - [ - DatasetItem("a"), - DatasetItem("b"), - ] - ) - dataset.save(dataset_url) - - proj_dir = osp.join(test_dir, "proj") - project = scope_add(Project.init(proj_dir)) - project.import_source("source1", url=dataset_url, format=DEFAULT_FORMAT, no_hash=True) - - self.assertEqual("", project.working_tree.sources["source1"].hash) - compare_dirs(self, dataset_url, project.source_data_dir("source1")) - compare_datasets(self, dataset, project.working_tree.make_dataset()) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_check_status_of_unhashed(self): - test_dir = scope_add(TestDir()) - dataset_url = osp.join(test_dir, "dataset") - Dataset.from_iterable( - [ - DatasetItem("a"), - DatasetItem("b"), - ] - ).save(dataset_url) - - proj_dir = osp.join(test_dir, "proj") - project = scope_add(Project.init(proj_dir)) - project.import_source("source1", url=dataset_url, format=DEFAULT_FORMAT, no_hash=True) - project.import_source("source2", url=dataset_url, format=DEFAULT_FORMAT, no_hash=True) - project.working_tree.build_targets.add_transform_stage("source2", "reindex") - - status = project.status() - self.assertEqual(status["source1"], DiffStatus.added) - self.assertEqual(status["source2"], DiffStatus.added) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_commit_unhashed(self): - test_dir = scope_add(TestDir()) - dataset_url = osp.join(test_dir, "dataset") - Dataset.from_iterable( - [ - DatasetItem("a"), - DatasetItem("b"), - ] - ).save(dataset_url) - - proj_dir = osp.join(test_dir, "proj") - project = scope_add(Project.init(proj_dir)) - project.import_source("source1", url=dataset_url, format=DEFAULT_FORMAT, no_hash=True) - project.commit("a commit") - - self.assertNotEqual("", project.working_tree.sources["source1"].hash) - self.assertNotEqual("", project.working_tree.build_targets["source1"].head.hash) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_redownload_unhashed(self): - test_dir = scope_add(TestDir()) - dataset_url = osp.join(test_dir, "dataset") - Dataset.from_iterable( - [ - DatasetItem("a"), - DatasetItem("b"), - ] - ).save(dataset_url) - - proj_dir = osp.join(test_dir, "proj") - project = scope_add(Project.init(proj_dir)) - project.import_source("source1", url=dataset_url, format=DEFAULT_FORMAT, no_hash=True) - project.working_tree.build_targets.add_transform_stage("source1", "reindex") - project.commit("a commit") - - with self.assertRaises(MissingSourceHashError): - project.working_tree.make_dataset("source1.root") - - @mark_requirement(Requirements.DATUM_BUG_602) - @scoped - def test_can_save_local_source_with_relpath(self): - test_dir = scope_add(TestDir()) - source_url = osp.join(test_dir, "source") - source_dataset = Dataset.from_iterable( - [ - DatasetItem( - 0, - subset="a", - media=Image.from_numpy(data=np.ones((2, 3, 3))), - annotations=[Bbox(1, 2, 3, 4, label=0)], - ), - DatasetItem( - 1, - subset="b", - media=Image.from_numpy(data=np.zeros((10, 20, 3))), - annotations=[Bbox(1, 2, 3, 4, label=1)], - ), - ], - categories=["a", "b"], - ) - source_dataset.save(source_url, save_media=True) - - project = scope_add(Project.init(osp.join(test_dir, "proj"))) - project.import_source( - "s1", url=source_url, format=DEFAULT_FORMAT, rpath=osp.join("annotations", "b.json") - ) - - read_dataset = project.working_tree.make_dataset("s1") - self.assertEqual(read_dataset.data_path, project.source_data_dir("s1")) - - read_dataset.save() - - -class BackwardCompatibilityTests_v0_1(TestCase): - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_migrate_old_project(self): - expected_dataset = Dataset.from_iterable( - [ - DatasetItem(0, subset="train", annotations=[Label(0)]), - DatasetItem(1, subset="test", annotations=[Label(1)]), - DatasetItem(2, subset="train", annotations=[Label(0)]), - ], - categories=["a", "b"], - ) - - test_dir = scope_add(TestDir()) - old_proj_dir = osp.join(test_dir, "old_proj") - new_proj_dir = osp.join(test_dir, "new_proj") - shutil.copytree(get_test_asset_path("compat", "v0.1", "project"), old_proj_dir) - - with self.assertLogs(None) as logs: - Project.migrate_from_v1_to_v2(old_proj_dir, new_proj_dir, skip_import_errors=True) - - self.assertIn("Failed to migrate the source 'source3'", "\n".join(logs.output)) - - project = scope_add(Project(new_proj_dir)) - loaded_dataset = project.working_tree.make_dataset() - compare_datasets(self, expected_dataset, loaded_dataset) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_cant_load_old_project(self): - test_dir = scope_add(TestDir()) - proj_dir = osp.join(test_dir, "old_proj") - shutil.copytree(get_test_asset_path("compat", "v0.1", "project"), proj_dir) - - with self.assertRaises(OldProjectError): - scope_add(Project(proj_dir)) diff --git a/tests/unit/test_sampler.py b/tests/unit/test_sampler.py index 40159e7014..6f0da3be9f 100644 --- a/tests/unit/test_sampler.py +++ b/tests/unit/test_sampler.py @@ -7,9 +7,9 @@ from unittest import TestCase, skipIf from datumaro.components.annotation import AnnotationType, Label, LabelCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.components.project import Dataset from datumaro.plugins.sampler.random_sampler import LabelRandomSampler, RandomSampler from tests.utils.test_utils import compare_datasets, compare_datasets_strict diff --git a/tests/unit/test_sly_pointcloud_format.py b/tests/unit/test_sly_pointcloud_format.py index 49f534cc8c..4d2968beb9 100644 --- a/tests/unit/test_sly_pointcloud_format.py +++ b/tests/unit/test_sly_pointcloud_format.py @@ -6,10 +6,10 @@ import numpy as np from datumaro.components.annotation import AnnotationType, Cuboid3d, LabelCategories +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment from datumaro.components.media import Image, PointCloud -from datumaro.components.project import Dataset from datumaro.plugins.data_formats.sly_pointcloud.base import SuperviselyPointCloudImporter from datumaro.plugins.data_formats.sly_pointcloud.exporter import SuperviselyPointCloudExporter diff --git a/tests/unit/test_splitter.py b/tests/unit/test_splitter.py index 5321ffe678..7dc972c3fc 100644 --- a/tests/unit/test_splitter.py +++ b/tests/unit/test_splitter.py @@ -11,10 +11,10 @@ Mask, Polygon, ) +from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image from datumaro.components.operations import compute_ann_statistics -from datumaro.components.project import Dataset from ..requirements import Requirements, mark_requirement diff --git a/tests/unit/test_video.py b/tests/unit/test_video.py index 8802e51e67..58e75d9ae0 100644 --- a/tests/unit/test_video.py +++ b/tests/unit/test_video.py @@ -15,7 +15,6 @@ from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image, Video, VideoFrame -from datumaro.components.project import Project from datumaro.util.scope import Scope, on_exit_do, scope_add, scoped from ..requirements import Requirements, mark_requirement @@ -225,89 +224,6 @@ def test_can_split_and_load(self, test_dir): compare_datasets(TestCase(), expected, actual) -class ProjectTest: - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_release_resources_on_exit(self, test_dir): - video_path = _make_sample_video(test_dir) - with Scope() as scope: - project_dir = scope.add(TestDir()) - - project = scope.add(Project.init(project_dir)) - - project.import_source( - "src", - osp.dirname(video_path), - "video_frames", - rpath=osp.basename(video_path), - ) - - assert len(project.working_tree.make_dataset()) == 4 - assert not osp.exists(project_dir) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_release_resources_on_remove(self, test_dir): - video_path = _make_sample_video(test_dir) - project_dir = scope_add(TestDir()) - - project = scope_add(Project.init(project_dir)) - - project.import_source( - "src", - osp.dirname(video_path), - "video_frames", - rpath=osp.basename(video_path), - ) - project.commit("commit 1") - - assert len(project.working_tree.make_dataset()) == 4 - assert osp.isdir(osp.join(project_dir, "src")) - - project.remove_source("src", keep_data=False) - - assert not osp.exists(osp.join(project_dir, "src")) - - @pytest.mark.xfail( - sys.platform == "win32", - reason="failing due to a file because it is being used by another process", - ) - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @scoped - def test_can_release_resources_on_checkout(self, test_dir): - video_path = _make_sample_video(test_dir) - project_dir = scope_add(TestDir()) - - project = scope_add(Project.init(project_dir)) - - src_url = osp.join(project_dir, "src") - src = Dataset.from_iterable( - [ - DatasetItem(1), - ], - categories=["a"], - ) - src.save(src_url) - project.add_source(src_url, "datumaro") - project.commit("commit 1") - - project.remove_source("src", keep_data=False) - - project.import_source( - "src", - osp.dirname(video_path), - "video_frames", - rpath=osp.basename(video_path), - ) - project.commit("commit 2") - - assert len(project.working_tree.make_dataset()) == 4 - assert osp.isdir(osp.join(project_dir, "src")) - - project.checkout("HEAD~1") - - assert len(project.working_tree.make_dataset()) == 1 - - @pytest.mark.new class VideoAnnotationTest: @mark_requirement(Requirements.DATUM_GENERAL_REQ)