Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kauldron/kontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from kauldron.kontext.annotate import resolve_from_keyed_obj
from kauldron.kontext.annotate import resolve_from_keypaths
from kauldron.kontext.filter_utils import filter_by_path
from kauldron.kontext.glob_paths import get_by_glob_path
from kauldron.kontext.glob_paths import GlobPath
from kauldron.kontext.glob_paths import set_by_path
from kauldron.kontext.path_builder import path_builder_from
Expand Down
109 changes: 94 additions & 15 deletions kauldron/kontext/glob_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def set_by_path(
obj: paths.Context,
path: str | tuple[str, ...] | paths.AbstractPath,
value: Any,
):
"""Mutate the `obj` to set the value."""
) -> list[str]:
"""Mutate the `obj` to set the value. Returns the list of modified paths."""
match path:
case str():
path = GlobPath.from_str(path) # Otherwise, try parsing key as path.
Expand All @@ -47,17 +47,46 @@ def set_by_path(
return path.set_in(obj, value)


def get_by_glob_path(
obj: paths.Context,
path: str | tuple[str, ...] | paths.AbstractPath,
) -> dict[str, Any]:
"""Get values matching a (potentially glob) path."""
match path:
case str():
glob_path = GlobPath.from_str(path)
case tuple() as parts:
glob_path = GlobPath(*parts)
case GlobPath():
glob_path = path
case paths.AbstractPath():
glob_path = GlobPath(*path.parts)
case _:
raise TypeError(f"Unknown key/path {path} of type{type(path)}")

return glob_path.get_from(obj)


class GlobPath(paths.AbstractPath):
"""Represents a string path."""

_SUPPORT_GLOB = True

def set_in(self, context: paths.Context, value: Any) -> None:
"""Set the object in the path."""
def set_in(self, context: paths.Context, value: Any) -> list[str]:
"""Set the object in the path. Returns the list of modified paths."""
try:
_set_in(context, self.parts, value)
return _set_in(context, self.parts, value)
except Exception as e: # pylint: disable=broad-exception-caught
epy.reraise(e, prefix=f"Error trying to mutate path {self}: ")
raise # Unreachable, but helps type checkers.

def get_from(self, context: paths.Context) -> dict[str, Any]:
"""Get values matching the glob path."""
try:
return _get_in(context, self.parts)
except Exception as e: # pylint: disable=broad-exception-caught
epy.reraise(e, prefix=f"Error trying to get path {self}: ")
raise # Unreachable, but helps type checkers.

@property
def first_non_glob_parent(self) -> paths.Path:
Expand Down Expand Up @@ -222,8 +251,9 @@ def _set_in(
value: Any,
*,
missing_ok: bool = False,
) -> bool:
"""Recursively set the value from the path."""
_prefix: tuple[paths.Part, ...] = (),
) -> list[str]:
"""Recursively set the value from the path. Returns modified paths."""
# Field reference are resolved in the config.
if not parts:
raise ValueError("Path is empty")
Expand All @@ -233,23 +263,72 @@ def _set_in(

# During glob, the object might contains branch which do not match
if missing_ok and part not in wrapper:
return # Leaf not found, do not assign this branch # pytype: disable=bad-return-type
return []

if not rest: # Nothing left to recurse on, assign the leaf.
if isinstance(part, path_parser.Wildcard):
# I don't think there's a use-case for this. For glob, this would create
# ambiguity too.
raise ValueError("Wildcards cannot be located at the end of a path.")
wrapper[part] = value
return [str(paths.Path(*(_prefix + (part,))))]
elif part == path_parser.Wildcard.DOUBLE_STAR:
modified = []
# Try to assign the rest in the current context
_set_in(context, rest, value, missing_ok=True)
modified.extend(
_set_in(context, rest, value, missing_ok=True, _prefix=_prefix)
)
if isinstance(wrapper, Leaf): # Leaf, do not recurse
return
return modified
# Recurse over all elements
for _, new_context in wrapper.get_items(path_parser.Wildcard.STAR):
_set_in(new_context, parts, value) # Propagate the `**` to the leaves
for key, new_context in wrapper.get_items(path_parser.Wildcard.STAR):
modified.extend(
_set_in(new_context, parts, value, _prefix=_prefix + (key,))
)
return modified
else: # Otherwise, recurse.
for _, new_context in wrapper.get_items(part):
_set_in(new_context, rest, value) # pytype: disable=bad-return-type
# TODO(epot): Reraise with the full path branch in which the error occured
modified = []
for key, new_context in wrapper.get_items(part):
modified.extend(
_set_in(new_context, rest, value, _prefix=_prefix + (key,))
)
return modified


def _get_in(
context: paths.Context,
parts: Sequence[paths.Part],
*,
missing_ok: bool = False,
_prefix: tuple[paths.Part, ...] = (),
) -> dict[str, Any]:
"""Recursively get values matching the glob path."""
if not parts:
raise ValueError("Path is empty")

wrapper = Node.make(context)
part, *rest = parts

if missing_ok and part not in wrapper:
return {}

if not rest:
if part == path_parser.Wildcard.DOUBLE_STAR:
raise ValueError("'**' cannot be at the end of a get path.")
elif part == path_parser.Wildcard.STAR:
return {str(paths.Path(*(_prefix + (k,)))): v for k, v in wrapper.items()}
else:
return {str(paths.Path(*(_prefix + (part,)))): wrapper[part]}
elif part == path_parser.Wildcard.DOUBLE_STAR:
result = {}
result.update(_get_in(context, rest, missing_ok=True, _prefix=_prefix))
if isinstance(wrapper, Leaf):
return result
for key, child in wrapper.get_items(path_parser.Wildcard.STAR):
result.update(_get_in(child, parts, _prefix=_prefix + (key,)))
return result
else:
result = {}
for key, child in wrapper.get_items(part):
result.update(_get_in(child, rest, _prefix=_prefix + (key,)))
return result
66 changes: 65 additions & 1 deletion kauldron/kontext/glob_paths_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _assert_path(
ctx,
expected_ctx=None,
expected_error=None,
expected_paths=None,
) -> None:
path = kontext.GlobPath.from_str(path_str)
assert str(path) == path_str
Expand All @@ -32,8 +33,10 @@ def _assert_path(
with pytest.raises(type(expected_error), match=expected_error.args[0]):
path.set_in(ctx, 'new')
else:
path.set_in(ctx, 'new')
modified = path.set_in(ctx, 'new')
assert ctx == expected_ctx
if expected_paths is not None:
assert sorted(modified) == sorted(expected_paths)


def test_glob_paths():
Expand All @@ -54,6 +57,7 @@ def test_glob_paths():
},
'b2': 5,
},
expected_paths=['a.b[0]', 'a.b2[0]'],
)
_assert_path(
path_str='**.b',
Expand All @@ -73,6 +77,7 @@ def test_glob_paths():
'b2': 5,
'b': 'new',
},
expected_paths=['b', 'a.b'],
)
_assert_path(
path_str='a.**[0]',
Expand Down Expand Up @@ -102,6 +107,7 @@ def test_glob_paths():
'a1': 5,
'a2': [1, 2, 3],
},
expected_paths=['a.b[0]', 'a.b2[0]', 'a.b3.c2[0]'],
)
_assert_path(
path_str='a.**.b2[0]',
Expand Down Expand Up @@ -131,6 +137,7 @@ def test_glob_paths():
'a1': {'b2': [1, 2, 3]},
'a2': [1, 2, 3],
},
expected_paths=['a.b2[0]', 'a.b3.b2[0]'],
)


Expand All @@ -152,6 +159,7 @@ def test_config_dict():
konfig.ConfigDict({'b': 'new'}),
],
}),
expected_paths=['a[0].b', 'a[1].b', 'a[2].b'],
)


Expand Down Expand Up @@ -199,6 +207,7 @@ def test_glob_key_error():
'a2': {'b': 'new', 'b2': 5},
'a3': {'b': 'new'},
},
expected_paths=['a.b', 'a2.b', 'a3.b'],
)


Expand All @@ -216,3 +225,58 @@ def test_glob_first_non_glob_parent(glob_str, parent_str):
assert kontext.GlobPath.from_str(
glob_str
).first_non_glob_parent == kontext.Path.from_str(parent_str)


def test_set_by_path_returns_modified_paths():
ctx = {'a': {'b': [1, 2], 'c': [3, 4]}, 'z': 5}
modified = kontext.set_by_path(ctx, 'a.*[0]', 'new')
assert sorted(modified) == ['a.b[0]', 'a.c[0]']

ctx = {'x': {'y': 'old'}, 'z': 'old2'}
modified = kontext.set_by_path(ctx, 'x.y', 'new')
assert modified == ['x.y']

ctx = {'a': {'b': 'old', 'c': {'b': 'old'}}, 'b': 'old'}
modified = kontext.set_by_path(ctx, '**.b', 'new')
assert sorted(modified) == ['a.b', 'a.c.b', 'b']


def test_get_by_glob_path():
ctx = {
'a': {
'b': [1, 2, 3],
'c': [4, 5, 6],
},
'd': 7,
}
assert kontext.get_by_glob_path(ctx, 'a.b') == {'a.b': [1, 2, 3]}

assert kontext.get_by_glob_path(ctx, 'a.*') == {
'a.b': [1, 2, 3],
'a.c': [4, 5, 6],
}

assert kontext.get_by_glob_path(ctx, 'a.*[0]') == {
'a.b[0]': 1,
'a.c[0]': 4,
}


def test_get_by_glob_path_double_star():
ctx = {
'a': {
'b': 'v1',
'c': {'b': 'v2'},
},
'b': 'v3',
}
assert kontext.get_by_glob_path(ctx, '**.b') == {
'b': 'v3',
'a.b': 'v1',
'a.c.b': 'v2',
}


def test_get_by_glob_path_double_star_end_error():
with pytest.raises(ValueError, match='cannot be at the end'):
kontext.get_by_glob_path({'a': 1}, 'a.**')
7 changes: 4 additions & 3 deletions kauldron/kontext/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class AbstractPath(collections.abc.Sequence):

__slots__ = ("parts",)

_SUPPORT_GLOB: ClassVar[bool]
_SUPPORT_GLOB: ClassVar[bool] # pylint: disable=declare-non-slot

def __init__(self, *parts: Part):
if not _is_valid_part(parts, wildcard_ok=self._SUPPORT_GLOB):
Expand Down Expand Up @@ -117,7 +117,7 @@ def from_jax_path(cls, jax_path: tuple[JaxKeyEntry, ...]) -> Self:
"""
return cls(*(_jax_key_entry_to_kd_path_element(p) for p in jax_path))

def set_in(self, context: Context, value: Any) -> None:
def set_in(self, context: Context, value: Any) -> list[str]:
raise NotImplementedError("Abstract method")

def relative_to(self, other: AbstractPath) -> Self:
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_from(
return default
return result

def set_in(self, context: Context, value: Any) -> None:
def set_in(self, context: Context, value: Any) -> list[str]:
"""Set the object in the path."""
root = context

Expand All @@ -183,6 +183,7 @@ def set_in(self, context: Context, value: Any) -> None:
)

root[target] = value
return [str(self)]


def _jax_key_entry_to_kd_path_element(
Expand Down
Loading