Skip to content

Commit 8ec0a5c

Browse files
Ando233wangyuqikashifsayakpauldg845
authored
feat: implement rae autoencoder. (#13046)
* feat: implement three RAE encoders(dinov2, siglip2, mae) * feat: finish first version of autoencoder_rae * fix formatting * make fix-copies * initial doc * fix latent_mean / latent_var init types to accept config-friendly inputs * use mean and std convention * cleanup * add rae to diffusers script * use imports * use attention * remove unneeded class * example traiing script * input and ground truth sizes have to be the same * fix argument * move loss to training script * cleanup * simplify mixins * fix training script * fix entrypoint for instantiating the AutoencoderRAE * added encoder_image_size config * undo last change * fixes from pretrained weights * cleanups * address reviews * fix train script to use pretrained * fix conversion script review * latebt normalization buffers are now always registered with no-op defaults * Update examples/research_projects/autoencoder_rae/README.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * use image url * Encoder is frozen * fix slow test * remove config * use ModelTesterMixin and AutoencoderTesterMixin * make quality * strip final layernorm when converting * _strip_final_layernorm_affine for training script * fix test * add dispatch forward and update conversion script * update training script * error out as soon as possible and add comments * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * use buffer * inline * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * remove optional * _noising takes a generator * Update src/diffusers/models/autoencoders/autoencoder_rae.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix api * rename * remove unittest * use randn_tensor * fix device map on multigpu * check if the key is missing in the original state dict and only then add to the allow_missing set * remove initialize_weights --------- Co-authored-by: wangyuqi <wangyuqi@MBP-FJDQNJTWYN-0208.local> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 29b9109 commit 8ec0a5c

File tree

11 files changed

+1977
-0
lines changed

11 files changed

+1977
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@
460460
title: AutoencoderKLQwenImage
461461
- local: api/models/autoencoder_kl_wan
462462
title: AutoencoderKLWan
463+
- local: api/models/autoencoder_rae
464+
title: AutoencoderRAE
463465
- local: api/models/consistency_decoder_vae
464466
title: ConsistencyDecoderVAE
465467
- local: api/models/autoencoder_oobleck
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# AutoencoderRAE
14+
15+
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
16+
17+
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
18+
19+
The following RAE models are released and supported in Diffusers:
20+
21+
| Model | Encoder | Latent shape (224px input) |
22+
|:------|:--------|:---------------------------|
23+
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
24+
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
25+
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
26+
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
27+
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
28+
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
29+
30+
## Loading a pretrained model
31+
32+
```python
33+
from diffusers import AutoencoderRAE
34+
35+
model = AutoencoderRAE.from_pretrained(
36+
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
37+
).to("cuda").eval()
38+
```
39+
40+
## Encoding and decoding a real image
41+
42+
```python
43+
import torch
44+
from diffusers import AutoencoderRAE
45+
from diffusers.utils import load_image
46+
from torchvision.transforms.functional import to_tensor, to_pil_image
47+
48+
model = AutoencoderRAE.from_pretrained(
49+
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
50+
).to("cuda").eval()
51+
52+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
53+
image = image.convert("RGB").resize((224, 224))
54+
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
55+
56+
with torch.no_grad():
57+
latents = model.encode(x).latent # (1, 768, 16, 16)
58+
recon = model.decode(latents).sample # (1, 3, 256, 256)
59+
60+
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
61+
recon_image.save("recon.png")
62+
```
63+
64+
## Latent normalization
65+
66+
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
67+
68+
```python
69+
model = AutoencoderRAE.from_pretrained(
70+
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
71+
).to("cuda").eval()
72+
73+
# Latent normalization is handled automatically inside encode/decode
74+
# when the checkpoint config includes latents_mean/latents_std.
75+
with torch.no_grad():
76+
latents = model.encode(x).latent # normalized latents
77+
recon = model.decode(latents).sample
78+
```
79+
80+
## AutoencoderRAE
81+
82+
[[autodoc]] AutoencoderRAE
83+
- encode
84+
- decode
85+
- all
86+
87+
## DecoderOutput
88+
89+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Training AutoencoderRAE
2+
3+
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
4+
5+
It follows the same high-level training recipe as the official RAE stage-1 setup:
6+
- frozen encoder
7+
- train decoder
8+
- pixel reconstruction loss
9+
- optional encoder feature consistency loss
10+
11+
## Quickstart
12+
13+
### Resume or finetune from pretrained weights
14+
15+
```bash
16+
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
17+
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
18+
--train_data_dir /path/to/imagenet_like_folder \
19+
--output_dir /tmp/autoencoder-rae \
20+
--resolution 256 \
21+
--train_batch_size 8 \
22+
--learning_rate 1e-4 \
23+
--num_train_epochs 10 \
24+
--report_to wandb \
25+
--reconstruction_loss_type l1 \
26+
--use_encoder_loss \
27+
--encoder_loss_weight 0.1
28+
```
29+
30+
### Train from scratch with a pretrained encoder
31+
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
32+
33+
```bash
34+
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
35+
--train_data_dir /path/to/imagenet_like_folder \
36+
--output_dir /tmp/autoencoder-rae \
37+
--resolution 256 \
38+
--encoder_type dinov2 \
39+
--encoder_name_or_path facebook/dinov2-with-registers-base \
40+
--encoder_input_size 224 \
41+
--patch_size 16 \
42+
--image_size 256 \
43+
--decoder_hidden_size 1152 \
44+
--decoder_num_hidden_layers 28 \
45+
--decoder_num_attention_heads 16 \
46+
--decoder_intermediate_size 4096 \
47+
--train_batch_size 8 \
48+
--learning_rate 1e-4 \
49+
--num_train_epochs 10 \
50+
--report_to wandb \
51+
--reconstruction_loss_type l1 \
52+
--use_encoder_loss \
53+
--encoder_loss_weight 0.1
54+
```
55+
56+
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
57+
58+
Dataset format is expected to be `ImageFolder`-compatible:
59+
60+
```text
61+
train_data_dir/
62+
class_a/
63+
img_0001.jpg
64+
class_b/
65+
img_0002.jpg
66+
```

0 commit comments

Comments
 (0)