Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion kauldron/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""Optimizers etc."""

# pylint: disable=g-importing-member

from kauldron.optim._freeze import partial_updates
from kauldron.optim._masks import exclude
from kauldron.optim._masks import select
from kauldron.optim.combine import named_chain
from kauldron.optim.transform import decay_to_init
# pylint: enable=g-importing-memberfrom
54 changes: 54 additions & 0 deletions kauldron/optim/_freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Freeze utils."""

from collections.abc import Callable
import functools
from typing import Any

import jax
import optax

_PyTree = Any


def partial_updates(
optimizer: optax.GradientTransformation,
mask: _PyTree | Callable[[_PyTree], _PyTree],
) -> optax.GradientTransformation:
"""Applies the optimizer to a subset of the parameters.

Args:
optimizer: The optimizer to use.
mask: A tree or callable returning a tree of bools to apply the optimizer
to.

Returns:
The wrapped optimizer.
"""

return optax.multi_transform(
{
'train': optimizer,
'freeze': optax.set_to_zero(),
},
functools.partial(_make_labels, mask=mask),
)


def _make_labels(tree, mask):
if callable(mask):
mask = mask(tree)
return jax.tree.map(lambda x: 'train' if x else 'freeze', mask)
58 changes: 58 additions & 0 deletions kauldron/optim/_freeze_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from kauldron import kd
import optax


def test_partial_updates():
optimizer = kd.optim.partial_updates(
optax.adam(learning_rate=1e-3),
mask=kd.optim.select('lora'),
)

params = {
'a': {
'lora': {
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
}
},
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
}

assert kd.optim._freeze._make_labels(params, kd.optim.select('lora')) == {
'a': {
'lora': {
'x': 'train',
'y': 'train',
}
},
'x': 'freeze',
'y': 'freeze',
}

# TODO(epot): Could check the state params is empty for frozen params.
optimizer.init({
'a': {
'lora': {
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
}
},
'x': jnp.zeros((2,)),
'y': jnp.zeros((2,)),
})
149 changes: 149 additions & 0 deletions kauldron/optim/_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Masks utils."""

from collections.abc import Callable, Sequence
import re
from typing import Any

import jax

_PyTree = Any


# Improvements:
# * Could add `exclude=` kwargs, similar to `glob()`.


def select(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]:
r"""Create a mask which selects only the sub-pytree matching the pattern.

* `xx` will match all `{'xx': ...}` dict anywhere inside the tree. Note that
the match is strict, so `xx` will NOT match `{'xxyy': }`
* `xx.yy` will match `{'xx': {'yy': ...}}` dict
* Regex are supported, when using regex, make sure to escape `.` (e.g.
`xx\.yy[0-9]+`)

Example:

```python
mask_fn = kg.optim.select("lora")

mask_fn({
'layer0': {
'lora': {
'a': jnp.zeros(),
'b': jnp.zeros(),
},
'weights': jnp.zeros(),
'bias': jnp.zeros(),
}
}) == {
'layer0': {
'lora': {
'a': True,
'b': True,
},
'weights': False,
'bias': False,
}
}
```

Args:
pattern: The pattern to include. Everything else will be `False`.

Returns:
The optax mask factory.
"""

# Convert the pattern to a regex.
if isinstance(pattern, str):
pattern = [pattern]

pattern_regexes = [_make_regex(p) for p in pattern]

def _path_match_pattern(path: jax.tree_util.KeyPath) -> bool:
path_str = ".".join(_jax_key_entry_to_str(p) for p in path)
return any(bool(p.search(path_str)) for p in pattern_regexes)

def _make_mask(tree: _PyTree) -> _PyTree:
# TODO(epot): Replace by `jax.tree.flatten_with_path` once Colab is updated
leaves_with_path, treedef = jax.tree_util.tree_flatten_with_path(tree)

# Parse each leaves
leaves = []
for path, _ in leaves_with_path:
leaves.append(_path_match_pattern(path))

# Restore the tree structure.
return jax.tree.unflatten(treedef, leaves)

return _make_mask


def exclude(pattern: str | Sequence[str]) -> Callable[[_PyTree], _PyTree]:
"""Create a mask which selects all nodes except the ones matching the pattern.

This is the inverse of `select()`.

Example:

```python
optax.masked(
optax.set_to_zero(),
kd.optim.exclude("lora"), # Only `lora` weights are trained.
)
```

Args:
pattern: The pattern to exclude. See `select()` for more details.

Returns:
The optax mask factory.
"""
make_select_mask = select(pattern)

def _make_mask(tree: _PyTree) -> _PyTree:
# Invert the select mask.
tree = make_select_mask(tree)
return jax.tree.map(lambda x: not x, tree)

return _make_mask


_REGEX_SPECIAL_CHARS = set("()[]?+*^$|\\")


def _make_regex(pattern: str) -> re.Pattern[str]:
# Auto-detect regex and forward them as-is.
if any(c in _REGEX_SPECIAL_CHARS for c in pattern):
pass
else: # Otherwise, escape special characters (`.`).
pattern = re.escape(pattern)

pattern = rf"(?:^|\.){pattern}(?:$|\.)"
return re.compile(pattern)


def _jax_key_entry_to_str(
jax_key_entry: jax.tree_util.KeyEntry,
) -> str:
"""Convert a JaxKeyEntry into a valid `kontext.Path` element."""
match jax_key_entry:
case jax.tree_util.DictKey(key):
return key
case _:
raise TypeError(f"Unknown key entry type {type(jax_key_entry)}")
119 changes: 119 additions & 0 deletions kauldron/optim/_masks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kauldron import kd
from kauldron.optim import _masks


def test_select():
# Check the regex is restricted to the exact path.
assert kd.optim.select("lora")({
"lora": 0,
"notlora": 0,
"lora.more": 0,
"loranot.more": 0,
"notlora.more": 0,
"more.lora": 0,
"more.notlora": 0,
"more.lora.more": 0,
"more.notlora.more": 0,
}) == {
"lora": True,
"notlora": False,
"lora.more": True,
"loranot.more": False,
"notlora.more": False,
"more.lora": True,
"more.notlora": False,
"more.lora.more": True,
"more.notlora.more": False,
}

# Exclude returns the opossite mask.
assert kd.optim.exclude("lora")({
"lora": 0,
"notlora": 0,
"lora.more": 0,
"loranot.more": 0,
"notlora.more": 0,
"more.lora": 0,
"more.notlora": 0,
"more.lora.more": 0,
"more.notlora.more": 0,
}) == {
"lora": False,
"notlora": True,
"lora.more": False,
"loranot.more": True,
"notlora.more": True,
"more.lora": False,
"more.notlora": True,
"more.lora.more": False,
"more.notlora.more": True,
}

# Test that a `.` in the path is properly escaped.
assert kd.optim.select("lora.more")({
"lora": 0,
"loraxmore": 0,
"lora.more": 0,
"more.loraxmore.more": 0,
"more.lora.more.more": 0,
}) == {
"lora": False,
"loraxmore": False,
"lora.more": True,
"more.loraxmore.more": False,
"more.lora.more.more": True,
}

# Test that the select works on nested tree
assert kd.optim.select("lora.more")({
"lora": {
"more": {
"x": 0,
"y": 0,
},
"notmore": 0,
},
"y": {"lora": {"more": 0}},
"z": 0,
}) == {
"lora": {
"more": {
"x": True,
"y": True,
},
"notmore": False,
},
"y": {"lora": {"more": True}},
"z": False,
}

# Tests that regex are properly escaped
assert kd.optim.select("lora[0-9]+")({
"lora00": 0,
"lora1": 0,
"lora1x": 0,
"lora1": 0,
"xx.lora": 0,
"xx.lora3.aa": 0,
}) == {
"lora00": True,
"lora1": True,
"lora1x": False,
"lora1": True,
"xx.lora": False,
"xx.lora3.aa": True,
}