diff --git a/kauldron/kontext/__init__.py b/kauldron/kontext/__init__.py index f599d9f0..0430aef5 100644 --- a/kauldron/kontext/__init__.py +++ b/kauldron/kontext/__init__.py @@ -28,6 +28,7 @@ from kauldron.kontext.annotate import resolve_from_keyed_obj from kauldron.kontext.annotate import resolve_from_keypaths from kauldron.kontext.filter_utils import filter_by_path +from kauldron.kontext.glob_paths import get_by_glob_path from kauldron.kontext.glob_paths import GlobPath from kauldron.kontext.glob_paths import set_by_path from kauldron.kontext.path_builder import path_builder_from diff --git a/kauldron/kontext/glob_paths.py b/kauldron/kontext/glob_paths.py index 206dbe36..95e51e5a 100644 --- a/kauldron/kontext/glob_paths.py +++ b/kauldron/kontext/glob_paths.py @@ -32,8 +32,8 @@ def set_by_path( obj: paths.Context, path: str | tuple[str, ...] | paths.AbstractPath, value: Any, -): - """Mutate the `obj` to set the value.""" +) -> list[str]: + """Mutate the `obj` to set the value. Returns the list of modified paths.""" match path: case str(): path = GlobPath.from_str(path) # Otherwise, try parsing key as path. @@ -47,17 +47,46 @@ def set_by_path( return path.set_in(obj, value) +def get_by_glob_path( + obj: paths.Context, + path: str | tuple[str, ...] | paths.AbstractPath, +) -> dict[str, Any]: + """Get values matching a (potentially glob) path.""" + match path: + case str(): + glob_path = GlobPath.from_str(path) + case tuple() as parts: + glob_path = GlobPath(*parts) + case GlobPath(): + glob_path = path + case paths.AbstractPath(): + glob_path = GlobPath(*path.parts) + case _: + raise TypeError(f"Unknown key/path {path} of type{type(path)}") + + return glob_path.get_from(obj) + + class GlobPath(paths.AbstractPath): """Represents a string path.""" _SUPPORT_GLOB = True - def set_in(self, context: paths.Context, value: Any) -> None: - """Set the object in the path.""" + def set_in(self, context: paths.Context, value: Any) -> list[str]: + """Set the object in the path. Returns the list of modified paths.""" try: - _set_in(context, self.parts, value) + return _set_in(context, self.parts, value) except Exception as e: # pylint: disable=broad-exception-caught epy.reraise(e, prefix=f"Error trying to mutate path {self}: ") + raise # Unreachable, but helps type checkers. + + def get_from(self, context: paths.Context) -> dict[str, Any]: + """Get values matching the glob path.""" + try: + return _get_in(context, self.parts) + except Exception as e: # pylint: disable=broad-exception-caught + epy.reraise(e, prefix=f"Error trying to get path {self}: ") + raise # Unreachable, but helps type checkers. @property def first_non_glob_parent(self) -> paths.Path: @@ -222,8 +251,9 @@ def _set_in( value: Any, *, missing_ok: bool = False, -) -> bool: - """Recursively set the value from the path.""" + _prefix: tuple[paths.Part, ...] = (), +) -> list[str]: + """Recursively set the value from the path. Returns modified paths.""" # Field reference are resolved in the config. if not parts: raise ValueError("Path is empty") @@ -233,7 +263,7 @@ def _set_in( # During glob, the object might contains branch which do not match if missing_ok and part not in wrapper: - return # Leaf not found, do not assign this branch # pytype: disable=bad-return-type + return [] if not rest: # Nothing left to recurse on, assign the leaf. if isinstance(part, path_parser.Wildcard): @@ -241,15 +271,64 @@ def _set_in( # ambiguity too. raise ValueError("Wildcards cannot be located at the end of a path.") wrapper[part] = value + return [str(paths.Path(*(_prefix + (part,))))] elif part == path_parser.Wildcard.DOUBLE_STAR: + modified = [] # Try to assign the rest in the current context - _set_in(context, rest, value, missing_ok=True) + modified.extend( + _set_in(context, rest, value, missing_ok=True, _prefix=_prefix) + ) if isinstance(wrapper, Leaf): # Leaf, do not recurse - return + return modified # Recurse over all elements - for _, new_context in wrapper.get_items(path_parser.Wildcard.STAR): - _set_in(new_context, parts, value) # Propagate the `**` to the leaves + for key, new_context in wrapper.get_items(path_parser.Wildcard.STAR): + modified.extend( + _set_in(new_context, parts, value, _prefix=_prefix + (key,)) + ) + return modified else: # Otherwise, recurse. - for _, new_context in wrapper.get_items(part): - _set_in(new_context, rest, value) # pytype: disable=bad-return-type - # TODO(epot): Reraise with the full path branch in which the error occured + modified = [] + for key, new_context in wrapper.get_items(part): + modified.extend( + _set_in(new_context, rest, value, _prefix=_prefix + (key,)) + ) + return modified + + +def _get_in( + context: paths.Context, + parts: Sequence[paths.Part], + *, + missing_ok: bool = False, + _prefix: tuple[paths.Part, ...] = (), +) -> dict[str, Any]: + """Recursively get values matching the glob path.""" + if not parts: + raise ValueError("Path is empty") + + wrapper = Node.make(context) + part, *rest = parts + + if missing_ok and part not in wrapper: + return {} + + if not rest: + if part == path_parser.Wildcard.DOUBLE_STAR: + raise ValueError("'**' cannot be at the end of a get path.") + elif part == path_parser.Wildcard.STAR: + return {str(paths.Path(*(_prefix + (k,)))): v for k, v in wrapper.items()} + else: + return {str(paths.Path(*(_prefix + (part,)))): wrapper[part]} + elif part == path_parser.Wildcard.DOUBLE_STAR: + result = {} + result.update(_get_in(context, rest, missing_ok=True, _prefix=_prefix)) + if isinstance(wrapper, Leaf): + return result + for key, child in wrapper.get_items(path_parser.Wildcard.STAR): + result.update(_get_in(child, parts, _prefix=_prefix + (key,))) + return result + else: + result = {} + for key, child in wrapper.get_items(part): + result.update(_get_in(child, rest, _prefix=_prefix + (key,))) + return result diff --git a/kauldron/kontext/glob_paths_test.py b/kauldron/kontext/glob_paths_test.py index ec15453e..3eeaeb83 100644 --- a/kauldron/kontext/glob_paths_test.py +++ b/kauldron/kontext/glob_paths_test.py @@ -24,6 +24,7 @@ def _assert_path( ctx, expected_ctx=None, expected_error=None, + expected_paths=None, ) -> None: path = kontext.GlobPath.from_str(path_str) assert str(path) == path_str @@ -32,8 +33,10 @@ def _assert_path( with pytest.raises(type(expected_error), match=expected_error.args[0]): path.set_in(ctx, 'new') else: - path.set_in(ctx, 'new') + modified = path.set_in(ctx, 'new') assert ctx == expected_ctx + if expected_paths is not None: + assert sorted(modified) == sorted(expected_paths) def test_glob_paths(): @@ -54,6 +57,7 @@ def test_glob_paths(): }, 'b2': 5, }, + expected_paths=['a.b[0]', 'a.b2[0]'], ) _assert_path( path_str='**.b', @@ -73,6 +77,7 @@ def test_glob_paths(): 'b2': 5, 'b': 'new', }, + expected_paths=['b', 'a.b'], ) _assert_path( path_str='a.**[0]', @@ -102,6 +107,7 @@ def test_glob_paths(): 'a1': 5, 'a2': [1, 2, 3], }, + expected_paths=['a.b[0]', 'a.b2[0]', 'a.b3.c2[0]'], ) _assert_path( path_str='a.**.b2[0]', @@ -131,6 +137,7 @@ def test_glob_paths(): 'a1': {'b2': [1, 2, 3]}, 'a2': [1, 2, 3], }, + expected_paths=['a.b2[0]', 'a.b3.b2[0]'], ) @@ -152,6 +159,7 @@ def test_config_dict(): konfig.ConfigDict({'b': 'new'}), ], }), + expected_paths=['a[0].b', 'a[1].b', 'a[2].b'], ) @@ -199,6 +207,7 @@ def test_glob_key_error(): 'a2': {'b': 'new', 'b2': 5}, 'a3': {'b': 'new'}, }, + expected_paths=['a.b', 'a2.b', 'a3.b'], ) @@ -216,3 +225,58 @@ def test_glob_first_non_glob_parent(glob_str, parent_str): assert kontext.GlobPath.from_str( glob_str ).first_non_glob_parent == kontext.Path.from_str(parent_str) + + +def test_set_by_path_returns_modified_paths(): + ctx = {'a': {'b': [1, 2], 'c': [3, 4]}, 'z': 5} + modified = kontext.set_by_path(ctx, 'a.*[0]', 'new') + assert sorted(modified) == ['a.b[0]', 'a.c[0]'] + + ctx = {'x': {'y': 'old'}, 'z': 'old2'} + modified = kontext.set_by_path(ctx, 'x.y', 'new') + assert modified == ['x.y'] + + ctx = {'a': {'b': 'old', 'c': {'b': 'old'}}, 'b': 'old'} + modified = kontext.set_by_path(ctx, '**.b', 'new') + assert sorted(modified) == ['a.b', 'a.c.b', 'b'] + + +def test_get_by_glob_path(): + ctx = { + 'a': { + 'b': [1, 2, 3], + 'c': [4, 5, 6], + }, + 'd': 7, + } + assert kontext.get_by_glob_path(ctx, 'a.b') == {'a.b': [1, 2, 3]} + + assert kontext.get_by_glob_path(ctx, 'a.*') == { + 'a.b': [1, 2, 3], + 'a.c': [4, 5, 6], + } + + assert kontext.get_by_glob_path(ctx, 'a.*[0]') == { + 'a.b[0]': 1, + 'a.c[0]': 4, + } + + +def test_get_by_glob_path_double_star(): + ctx = { + 'a': { + 'b': 'v1', + 'c': {'b': 'v2'}, + }, + 'b': 'v3', + } + assert kontext.get_by_glob_path(ctx, '**.b') == { + 'b': 'v3', + 'a.b': 'v1', + 'a.c.b': 'v2', + } + + +def test_get_by_glob_path_double_star_end_error(): + with pytest.raises(ValueError, match='cannot be at the end'): + kontext.get_by_glob_path({'a': 1}, 'a.**') diff --git a/kauldron/kontext/paths.py b/kauldron/kontext/paths.py index 869a2efe..c030a903 100644 --- a/kauldron/kontext/paths.py +++ b/kauldron/kontext/paths.py @@ -56,7 +56,7 @@ class AbstractPath(collections.abc.Sequence): __slots__ = ("parts",) - _SUPPORT_GLOB: ClassVar[bool] + _SUPPORT_GLOB: ClassVar[bool] # pylint: disable=declare-non-slot def __init__(self, *parts: Part): if not _is_valid_part(parts, wildcard_ok=self._SUPPORT_GLOB): @@ -117,7 +117,7 @@ def from_jax_path(cls, jax_path: tuple[JaxKeyEntry, ...]) -> Self: """ return cls(*(_jax_key_entry_to_kd_path_element(p) for p in jax_path)) - def set_in(self, context: Context, value: Any) -> None: + def set_in(self, context: Context, value: Any) -> list[str]: raise NotImplementedError("Abstract method") def relative_to(self, other: AbstractPath) -> Self: @@ -169,7 +169,7 @@ def get_from( return default return result - def set_in(self, context: Context, value: Any) -> None: + def set_in(self, context: Context, value: Any) -> list[str]: """Set the object in the path.""" root = context @@ -183,6 +183,7 @@ def set_in(self, context: Context, value: Any) -> None: ) root[target] = value + return [str(self)] def _jax_key_entry_to_kd_path_element(