Releases: google-research/kauldron
v1.4.1
-
kd.ktyping:- [Fix] Fix
PyTree[T]traversing into registered types (e.g.flax.struct.dataclass). - [Fix] Fix
as_np_dtypelookup whentorchapixis imported. - [Fix] Fix
UNKNOWN_DIMformatting in error messages.
- [Fix] Fix
-
kd.konfig:- [Fix] Add hint to import errors if they look like the internal repo prefix is missing.
- [Fix] Update Error message propagation in
module_configdict.
-
kd.random:- [Extended] Support
kd.random.PRNGKey(seed).
- [Extended] Support
-
kd.cli:- [New] Add
run eval_shapecommand. - [Changed] Refactor/beautify CLI help aesthetics.
- [New] Add
-
kd.contrib:- [New] Add a configurable library for
kd.contrib.data.SSTable.
- [New] Add a configurable library for
v1.4.0
Highlights:
-
new
kd.clitool -
py::flag syntax for konfig -
expanded
kd.ktypingwith@typecheckedcontext managers / dataclass /
generator support -
meta-configs with
@dataclasses.dataclass -
new Nnx wrapper API,
-
many quality-of-life improvements across data, evals, metrics, and checkpointing.
-
kd.cli:- [New]
kd.cli: New kauldron CLI tool — a:trainer_clibinary
automatically created bykauldron_binary, using noun-verb command style
(e.g.kd data element_spec), with each command mirrored as a Python
function.
- [New]
-
kd.konfig:- [New]
py::flag value parsing — specify Python objects directly from CLI
flags (e.g.--cfg.xxx="py::my_module.MyObject(x=1)"), with Lark grammar
and alias resolution. - [New] Meta-configs:
@dataclasses.dataclass-based config declaration with
__args__CLI overrides and lazy config building. - [New]
konfig.export(): serialize Python objects (dataclasses, arrays,
dicts) to a dict representation. - [New]
unfreeze()function for unfreezingImmutableDicts. - [Extended] Support (named) tuples as dictionary keys in serialization.
- [Extended]
DEFINE_config_fileaccepts arequiredargument. - [Extended]
konfig.resolvehighlights where the originalConfigDictwas
created in tracebacks. - [Extended] Better error messages for config resolution failures and
FieldReferencewith path tracking. - [Extended] Allow
konfig.restricted()without specifying type. - [Changed] Always use literal evals when parsing flags —
--cfg.xxx=Noneis
nowNonerather than'None'. - [Changed] Two-stage resolution is now the default for
DEFINE_config_file. - [Changed] Deprecated konfig property now raises an error.
- [Fix] Fix JSON args parsing, list arguments from CLI,
unfreezebug,
dynamic resolve trigger,temporary_imports()thread-safety,BaseConfig
hash, and resolve errors.
- [New]
-
kd.ktyping:- [New]
ArraySpec,ElementSpec,PRNGKeyLiketypes. - [New]
kt.isinstancefunction for bool-returning type checks. - [New] Basic
PyTree[T]annotation with runtime checking and path-aware
errors. - [New] PyTree structure specs.
- [New] Per-module config system.
- [New] Warnings when mixing ktyping and jaxtyping.
- [Extended]
@typecheckednow supports: context managers (with typechecked():), nested context managers, dataclasses, generator functions,
methods / class methods / static methods. - [Extended] Improved shape inference for binary operations (e.g.
Array["a+1"]). - [Extended] TensorFlow and XArray type support.
- [Breaking] Rename
get_shape()→shape()andkt.dims→kt.dim. - [Fix] Fix shape checking with TF Tensors, PRNGKey dtype for new-style JAX
keys, Scalar type checking, array type union checking, broadcastable dims,
typeguard 4.5.0 compatibility.
- [New]
-
kd.nn:- [Changed] New Nnx wrapper API — natively compatible with kontext keys,
supports catching intermediates. - [New] Nnx wrapper documentation.
- [Changed] New Nnx wrapper API — natively compatible with kontext keys,
-
kd.data:- [New]
LazyBagDataSourcefor lazy loading of bag data. - [New]
SelectFromDatasetsfor dataset mixtures with user-defined selection. - [New]
shard_by_processto control dataset sharding behavior. - [New]
AddBiastransform. - [New] Random transforms for PyGrain pipelines.
- [New] Padding batches with
batch_drop_remainder='pad'. - [Extended]
CenterCropsupports nD arrays. - [Extended]
Resizesupports min/max size targets. - [Extended]
RepeatFramesworks with both TF and NumPy/JAX arrays. - [Changed] Default
Resizemethod for float inputs is nowbilinearfor
JAX/NumPy (remainsareafor TF). - [Fix] Fix
ElementWiseRandomTransform,grain.shuffleseed range,
unknown-length datasets,Tfds.decoderswithImmutableDict,Resize
device transfer,element_specglobal vs device-local, filter transform
type checking, walrus operator breaking TF autograph.
- [New]
-
kd.kontext:- [Extended]
set_by_pathreturns the list of concrete modified paths when
using glob patterns (**,*). - [Fix] Fix
kontext.imports()errors in docs andCONFIG_IMPORT
placeholder for Colab.
- [Extended]
-
kd.train:- [New]
NoOpTrainStepfor use cases skipping training. - [New]
checkifysupport onTrainStep.initandEvaluator.evaluate. - [New] Expose
KDMetricWriter,Orchestrator,DirectoryBuilderas public
APIs for subclassing. - [New]
konfig_freezeoption to skip immutabledict conversion. - [Extended]
MultiTrainStepsubupdates for better logging. - [Extended] Device-to-host transfer for checkify error checking.
- [Breaking] Rename
ShardingStrategy.ds→ShardingStrategy.batch. - [Removed] Deprecate
CollectingState. - [Fix] Fix sweeps bug with default
config_args,partial_updateswith
integer keys,transfer_guardwithjax_debug_nans,MultiTrainStep
hashability,ml_python+adhocerror,FSDPShardingtype annotation.
- [New]
-
kd.evals:- [New]
CheckpointedEvaluatorfor resumable evaluations. - [New] Skip initial step 0 option.
- [New]
NoopExporter. - [New]
eval_stepadded toEvaluator. - [Extended] Allow skipping checkpointing in
TrainEvaluator. - [Changed]
NoOpCheckpointeris default forSamplingEvaluator. - [Changed]
_ConcatContainerspeeds upconcat_fieldaggregation. - [Fix] Fix non unresponsive with custom dataset in eval, duplicated
job_groupin eval_only.
- [New]
-
kd.metrics:- [New]
finalizemethod for metric states. - [Extended] Support predicted labels (not just logits) in
Accuracy. - [Extended]
min_field,max_fieldforAutoState. - [Extended] Pytree support for
auto_state.sum_field. - [Extended] Better error reporting for merging / finalizing / computing.
- [Fix] Fix one-hot class count in segmentation metrics,
finalize()bugs,
CollectingState.mergeperformance.
- [New]
-
kd.summaries/kd.vizual:- [New] Confusion matrix summary.
- [Extended]
ShowSegmentations:palette,edge, andhardoptions. - [Extended]
ShowImages:cmapoption. - [Extended]
ImageGridconvenience method. - [Fix] Fix
ShowDifferenceImagestype-check and JAX/numpy mismatch,
ShowImagesRGB output with NaN values,bfloat16support, integer arrays
inShowSegmentations.
-
kd.optim:- [New]
ema_weightswrapper for EMA weight tracking. - [Fix] Fix debias logic in
ema_params.
- [New]
-
kd.ckpts:- [New] Custom Orbax preservation policy support.
- [Removed] Remove deprecated
AbstractPartialLoaderalias. - [Fix] Fix EMA params loading for frozen params, checkpoint loading, snapshot
directory race conditions, named tuple compatibility in parameter paths.
-
kd.contrib:- [New]
NpzWriter: metric writer saving array summaries to.npzfiles. - [New]
TreeUnflattenForKeyPyGrain transform. - [New]
GifVideoWriterandShowVideosAsGiffor GIF video summaries. - [New] NNX-to-Linen wrapper
linen_from_nnx(). - [New] Model exporter for JAX export.
- [New] Online Mean+Covariance estimation state,
merge_fieldin auto-state. - [Extended]
concat_fieldworks with pytrees.
- [New]
-
kd.contrib.millstone:- [New] New doc.
- [Extended] Custom Borg runtime, eval dataset support, troubleshooting guide.
- [Removed] Delete deprecated Millstone API.
- [Fix] Fix Pathways server termination during eval.
-
kd.xm:- [New]
jax_log_compilesconfigurable viaxp.debug.jax_log_compiles. - [Extended] Launch configargs support.
- [Fix] Fix
cuda_compressflag for non-GPU builds, duplicatedjob_group.
- [New]
-
kd.random:- [Changed] Move truncation of
as_seed()to uint32 insideas_seed().
- [Changed] Move truncation of
v1.3.0
- Various bug fixes and improvements
v1.2.2
--xp.debug.catch_post_mortemflag now works externally as well- Fixed a problem with
init_transformsthat affectedoptim.decay_to_init - Further removed deprecated summaries protocol
- Fixes regarding
grainworkers and non-thread-safe imports such aseinops - Lifted jit restriction for merging
auto_state.sum_field - Added support for
ImmutableDictinsidekonfig - Added
ShowTextssummary - Make grain an optional dependency on Windows
- Reduced logging noise
- Several other minor bugfixes
v1.2.1
- Minor bug fixes.
v1.2.0
- Fix
kd.sharding.FSDPSharding()to supportsjax.ShapeDtypeStruct kd.data:kd.data.py.PyGrainPipelinesupports direct indexing (ds[0]).kd.data.py.HuggingFacesupports
- Typeguard / typechecking
- +various changes and improvements
v1.1.1
- Restore numpy 1.26 compatibility
v1.1.0
- Add
kd.nn.WrapperModuleto make a inner-module transparent with
respect of Flax modules. - Many other changes...
v1.0.0
-
kd.kontext.Pathnow supports tensor slicing. So for example using keys like
"interm.tensor[..., 0:10, :, -1]"will now work as expected. -
kd.nn.interm_propertynow supports accessing any intermediates from within
the model viaself.interm.get_by_path('path.to.any.module.__call__[0]'). -
Deprecated: Remove
--xp.sweep_info.names=flag. Instead, sweep are unified
under--xp.sweep(see: https://kauldron.rtfd.io/en/latest/intro.html#sweeps) -
Add
kd.data.loader.TFDatafor arbitrarytf.datapipelines -
Add
kd.data.InMemoryPipelinefor small datasets that fit in memory -
Add
kd.knn.convertto convert any Flax module to klinen. -
Add
kontext.path_builder_fromto dynamically generate keys for the config
with auto-complete and static type checking. -
Add
kd.data.BatchSize(XX)util -
Breaking:
Evaluator(run_every=XX)kwarg is removed. To migrate, use
Evaluator(run=kd.evals.RunEvery(XX)) -
Added: Eval can now be launched in separate job:
cfg.evals = { 'eval_train': kd.evals.Evaluator( run=kd.evals.RunEvery(100), # Run along `train` ), 'eval_eval': kd.evals.Evaluator( run=kd.evals.RunXM(), # Run in a separate `eval` job. ), }
-
New XManager launcher
xmanager launch third_party/py/kauldron/xm/launch.py -- \ --cfg=third_party/py/kauldron/examples/mnist_autoencoder.py \ --cfg.train_ds.batch_size=32 \ --xp.sweep \ --xp.platform=a100 \ --xp.debug.catch_post_mortemThis unlock many new features:
-
Based on
konfig(so everything can be deeply configured). -
Customize the work-unit directory name, default to
{xid}/{wid}-{sweep_kwargs}, for better TensorBoard
work-unit names. -
Sweep on XManager architecture:
def sweep(): for platform in ['a100', 'v100']: yield {'cfg.xm_job': kxm.Job(platform=platform)}
-
Possibility to launch eval jobs in a separate job
-
ml_python& xreload support for much faster XM iteration cycles -
New
kd-xmcolab to quickly launch experiments without even having to open
a terminal
-
-
Changed: removed
Checkpointer.partial_initializerand instead added
cfg.init_transformswhich can be used to set multiple transformations for
the params of the model (i.e. instances ofAbstractPartialLoader). -
Changed:
konfig.imports()are not lazy by default anymore (config don't
need to be resolved inwith ecolab.adhoc()anymore!) -
Added:
kd.optim: Optimizer / optax utilskd.eval: Eval moved to their separate namespace
-
Changed: Resolved konfig can now use attribute access for dict:
- Before (still supported):
cfg.train_losses['my_loss'] - After:
cfg.train_losses.my_loss
- Before (still supported):
-
Added:
kd.nn.set_train_propertyto change theself.is_trainingproperty
value inside a model:class MyModule(nn.Module): @nn.compact def __call__(self, x): with kd.nn.set_train_property(False): x = self.pretrained_encoder(x)
-
Added:
kd.nn.ExternalModule(flax_module)to use any external flax modules
inside Kauldron. -
And many, many more changes...