Conversation
This may be a temporary solution until #567 is addressed.
This deals with the weights, not the model.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
The The reason to do this is that, it's not recommended to keep For now we could create a branch with bfloat16 weights and the use that for inference. |
|
I removed the automatic conversion but left the other change in order to be able to load the models specifying a |
| latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) | ||
| else: | ||
| if latents.shape != latents_shape: |
There was a problem hiding this comment.
This file doesn't seem to be in main. Should be removed from here before mergin
There was a problem hiding this comment.
The merge target is flax_pipeline for now.
|
Replaced by #600. |
Changes:
dtypeto be specified on model load. This is a temporary solution until Save trainingdtypeas part of the configuration #567 is addressed in a more principled way.bfloat16orfloat16if necessary.float32during inference loop.See comment: #559 (comment)