MobileStyleGAN is a lightweight, compressed version of StyleGAN2 designed for high-fidelity image synthesis on mobile/edge devices. The project uses knowledge distillation to train a smaller student network (MobileStyleGAN) to mimic a larger teacher network (StyleGAN2).
- Knowledge Distillation: Teacher-student training paradigm
- DWT-based Architecture: Uses Discrete Wavelet Transform for efficient upsampling
- Mobile-Optimized: Reduced model size while maintaining quality
- Deployment Ready: Supports ONNX, CoreML, and OpenVINO export
- Mapping Network: Transforms random noise
zto style codesw - Synthesis Network: Generates images from style codes using standard convolutions
- Mapping Network: Same as teacher (shared)
- Mobile Synthesis Network:
- Uses Depthwise Convolutions instead of standard convolutions
- Uses DWT (Discrete Wavelet Transform) for upsampling instead of bilinear interpolation
- Operates in frequency domain for efficiency
Purpose: Main training script for knowledge distillation
What it does:
- Loads configuration from JSON
- Initializes Distiller (teacher + student models)
- Sets up PyTorch Lightning trainer
- Handles model export (ONNX, CoreML)
Key Functions:
build_logger(): Creates logging interface (Neptune, TensorBoard, etc.)main(): Orchestrates training pipeline
Execution Order:
- Parse command-line arguments
- Load configuration file
- Initialize Distiller
- Load checkpoint (if provided)
- Setup PyTorch Lightning trainer
- Start training loop OR export model
Purpose: Interactive demo for visualizing generated images
What it does:
- Loads trained model
- Generates random images in real-time
- Displays images using OpenCV
- Allows switching between student/teacher generators
Execution Order:
- Load configuration
- Initialize Distiller
- Load checkpoint weights
- Loop:
- Generate random noise
- Forward pass through model
- Display image
- Wait for user input (press 'q' to quit)
Purpose: Batch image generation script
What it does:
- Generates multiple images in batches
- Saves images to disk as PNG files
- Supports batch processing for efficiency
Execution Order:
- Load configuration
- Initialize Distiller
- Load checkpoint weights
- Move model to specified device (CPU/CUDA)
- For each batch:
- Generate random noise
- Forward pass
- Save images to disk
Purpose: Side-by-side comparison of teacher vs student outputs
What it does:
- Generates images from both teacher and student models
- Displays them side-by-side
- Interactive visualization
Execution Order:
- Load configuration
- Initialize Distiller
- Load checkpoint weights
- Loop:
- Generate random noise
- Forward pass through both models (
simultaneous_forward) - Concatenate images horizontally
- Display comparison
- Wait for user input
Purpose: Calculate FID (Fréchet Inception Distance) score
What it does:
- Computes FID between real and generated images
- Uses InceptionV3 features
- Provides quantitative quality metric
Execution Order:
- Load InceptionV3 model
- Extract features from reference dataset
- Extract features from generated images
- Calculate FID score (statistical distance)
Purpose: Convert StyleGAN2 checkpoints from rosinality format
What it does:
- Converts StyleGAN2 checkpoints to MobileStyleGAN format
- Separates mapping network and synthesis network
- Creates compatible checkpoint files
Execution Order:
- Load rosinality checkpoint
- Extract mapping network weights
- Extract synthesis network weights
- Save separate checkpoint files
- Generate configuration JSON
Purpose: Main distillation class - orchestrates teacher-student training
Key Components:
Distillerclass (inherits frompl.LightningModule):- Teacher Models: Pre-trained StyleGAN2 mapping and synthesis networks
- Student Model: MobileSynthesisNetwork (trainable)
- Loss Functions: DistillerLoss (L1, L2, perceptual, GAN losses)
- Evaluation: KID metric using InceptionV3
Key Methods:
__init__(): Loads teacher models, initializes studentmake_sample(): Generates teacher outputs for distillationgenerator_step(): Trains student generatordiscriminator_step(): Trains discriminatorforward(): Inference (student or teacher)simultaneous_forward(): Generate from both modelsto_onnx()/to_coreml(): Model export
Execution Flow (Training):
- Generate random noise
z - Map to style codes
wusing teacher mapping network - Generate teacher image using teacher synthesis network
- Generate student image using student synthesis network
- Compute loss (L1, L2, perceptual, GAN)
- Backpropagate and update student weights
Purpose: Utility functions for model loading, configuration, and image processing
Key Functions:
download_ckpt(): Downloads or loads local checkpoints- Checks current directory first
- Falls back to
/tmp/cache - Downloads from Google Drive if needed
load_cfg(): Loads JSON configuration filesload_weights(): Loads model weights with size matchingtensor_to_img(): Converts PyTorch tensors to OpenCV imagesselect_weights(): Filters weights by prefixapply_trace_model_mode(): Sets trace mode for ONNX export
Purpose: Model checkpoint management
What it does:
- Maps model names to download URLs
- Handles checkpoint loading from model zoo or local files
- Integrates with
download_ckpt()utility
Execution Flow:
- Load
model_zoo.jsonconfiguration - Check if name exists in zoo
- If yes: Download/load from zoo
- If no: Load as local file path
Purpose: Dataset for training (noise generation)
What it does:
- Generates random noise vectors for training
- No actual image dataset needed (unsupervised)
Key Class:
NoiseDataset: Generates random noise on-the-fly
Purpose: Combined loss function for knowledge distillation
Components:
DistillerLossclass:- L1 Loss: Pixel-level difference
- L2 Loss: Mean squared error
- Perceptual Loss: VGG16 feature matching
- GAN Loss: Adversarial training
- DWT Loss: Frequency domain matching
Key Methods:
loss_g(): Generator loss (student vs teacher)loss_d(): Discriminator lossreg_d(): Discriminator regularizationimg_to_dwt()/dwt_to_img(): Domain conversion
Loss Computation:
- Compute losses in spatial domain (L1, L2 on RGB images)
- Compute losses in frequency domain (L1, L2 on DWT coefficients)
- Compute perceptual loss (VGG16 features)
- Compute GAN loss (discriminator output)
- Weighted sum of all losses
Purpose: Perceptual loss using VGG16 features
Components:
PerceptualNetwork: Extracts VGG16 featuresPerceptualLoss: Computes L1 loss on feature maps
What it does:
- Extracts features from multiple VGG16 layers
- Compares features between student and teacher outputs
- Provides semantic similarity measure
Purpose: Non-saturating GAN loss implementation
What it does:
- Implements standard GAN training loss
- Supports discriminator regularization
Purpose: Differential augmentation for training stability
Purpose: StyleGAN mapping network (shared by teacher and student)
Architecture:
- Multi-layer MLP
- Transforms
z(latent code) →w(style code)
Purpose: StyleGAN2 synthesis network (teacher model)
Architecture:
- Standard StyleGAN2 synthesis blocks
- Uses regular convolutions
- Bilinear upsampling
Purpose: Mobile-optimized synthesis network (student model)
Key Differences from Teacher:
- Uses DWT upsampling instead of bilinear
- Uses depthwise convolutions for efficiency
- Operates in frequency domain
Architecture:
ConstantInput: Initial constant tensorStyledConv2d: Style-modulated convolutionMobileSynthesisBlock: Main building blocksDWTInverse: Converts frequency domain to RGB
Forward Pass:
- Start with constant input
- Apply initial styled convolution
- Convert to frequency domain (12 channels: 3 low + 9 high)
- For each block:
- Upsample using DWT
- Apply styled convolutions
- Generate frequency representation
- Final DWT inverse to RGB image
Purpose: InceptionV3 for FID/KID evaluation
Components:
InceptionV3: Feature extractorload_inception_v3(): Loads pretrained modelfid_inception_v3(): FID-specific Inception model
Purpose: Building block for mobile synthesis network
Components:
IDWTUpsaplme: DWT-based upsamplingStyledConv2d: Style-modulated convolutions (x2)MultichannelIamge: Converts to frequency domain
Forward Pass:
- Upsample using DWT
- First styled convolution
- Second styled convolution
- Convert to frequency representation
Purpose: DWT-based upsampling module
What it does:
- Uses Inverse Discrete Wavelet Transform for upsampling
- More efficient than bilinear interpolation
- Maintains frequency information
Purpose: Inverse DWT implementation
What it does:
- Reconstructs images from frequency domain
- Converts DWT coefficients back to RGB
Purpose: Style-modulated convolution (standard)
Purpose: Wrapper for style-modulated convolutions
What it does:
- Applies style modulation
- Handles noise injection
- Supports both standard and depthwise convolutions
Purpose: Converts feature maps to frequency domain
What it does:
- Transforms feature channels to DWT representation
- Outputs 12 channels (3 low-frequency + 9 high-frequency)
Purpose: Adds noise to feature maps
Purpose: Initial constant input tensor
Purpose: Custom CUDA operations
Files:
fused_act.py/fused_act_cuda.py: Fused activation functionsupfirdn2d.py/upfirdn2d_cuda.py: Upsampling/filtering operations- CUDA kernels for performance optimization
Purpose: Main training configuration
Sections:
logger: Logging configuration (Neptune)trainset/valset: Dataset parametersteacher: Teacher model checkpoint pathsdistillation_loss: Loss weightstrainer: Training hyperparameters
Purpose: Model checkpoint registry
What it contains:
- Model names → Download URLs
- MD5 checksums for verification
1. Parse arguments
↓
2. Load config (mobile_stylegan_ffhq.json)
↓
3. Initialize Distiller
├─ Load teacher mapping network (from model_zoo)
├─ Load teacher synthesis network (from model_zoo)
├─ Initialize student MobileSynthesisNetwork
├─ Setup DistillerLoss
└─ Setup InceptionV3 for evaluation
↓
4. Load checkpoint (if provided)
↓
5. Setup PyTorch Lightning Trainer
↓
6. Start training loop
├─ For each batch:
│ ├─ Generate random noise z
│ ├─ Map to style w (teacher mapping)
│ ├─ Generate teacher image
│ ├─ Generate student image
│ ├─ Compute losses
│ └─ Update student weights
└─ Validation (compute KID)
1. Load config
↓
2. Initialize Distiller
├─ Load teacher models
└─ Initialize student model
↓
3. Load checkpoint weights
↓
4. Generate images
├─ Random noise z
├─ Map to style w
└─ Forward through generator (student/teacher)
↓
5. Display/Save images
Training Step:
┌─────────────────┐
│ Random Noise z │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Mapping Network │ (Teacher - frozen)
└────────┬────────┘
│
▼
┌────┴────┐
│ │
▼ ▼
┌──────────┐ ┌──────────────────┐
│ Teacher │ │ Student │
│ Synthesis│ │ Mobile Synthesis │
└────┬─────┘ └────────┬─────────┘
│ │
│ │
▼ ▼
┌──────────┐ ┌──────────┐
│ Teacher │ │ Student │
│ Image │ │ Image │
└────┬─────┘ └────┬─────┘
│ │
└────────┬────────┘
│
▼
┌───────────────┐
│ DistillerLoss │
│ - L1/L2 │
│ - Perceptual │
│ - GAN │
│ - DWT │
└───────┬───────┘
│
▼
┌───────────────┐
│ Backprop │
│ Update Student│
└───────────────┘
- Teacher: Pre-trained StyleGAN2 (frozen)
- Student: MobileStyleGAN (trainable)
- Process: Student learns to mimic teacher outputs
- Purpose: Efficient upsampling and frequency domain processing
- Usage:
- Upsampling in mobile synthesis blocks
- Loss computation in frequency domain
- Final image reconstruction
- Purpose: Reduce model parameters
- Usage: Replaces standard convolutions in student model
- Purpose: Control image generation via style codes
- Usage: Applied in all synthesis blocks
- L1/L2: Pixel-level matching
- Perceptual: Feature-level matching (VGG16)
- GAN: Adversarial training
- DWT: Frequency domain matching
Random Noise (z)
│
▼
Mapping Network
│
▼
Style Codes (w)
│
├─────────────────┐
│ │
▼ ▼
Teacher Student
Synthesis Mobile Synthesis
│ │
│ │
│ ┌───────┴───────┐
│ │ │
│ ▼ ▼
│ Frequency Frequency
│ Domain Domain
│ │ │
│ └───────┬───────┘
│ │
▼ ▼
RGB Image RGB Image
(Teacher) (Student)
NoiseDataset
│
▼
Random z vectors
│
▼
Distiller.make_sample()
│
├─→ Teacher Mapping → Teacher Synthesis → Teacher Image
│
└─→ Student Mobile Synthesis → Student Image
│
▼
DistillerLoss
│
├─→ L1/L2 Loss (spatial)
├─→ L1/L2 Loss (frequency)
├─→ Perceptual Loss (VGG16)
└─→ GAN Loss
│
▼
Total Loss
│
▼
Backpropagation
│
▼
Update Student Weights
train.py / demo.py / generate.py
│
├─→ core.distiller
│ ├─→ core.models.mapping_network
│ ├─→ core.models.synthesis_network (teacher)
│ ├─→ core.models.mobile_synthesis_network (student)
│ ├─→ core.loss.distiller_loss
│ └─→ core.models.inception_v3
│
├─→ core.utils
│ └─→ core.model_zoo
│
└─→ configs/*.json
mobile_synthesis_network.py
│
├─→ modules.mobile_synthesis_block
│ ├─→ modules.idwt_upsample
│ ├─→ modules.styled_conv2d
│ └─→ modules.multichannel_image
│
├─→ modules.constant_input
├─→ modules.idwt
└─→ modules.modulated_conv2d
This codebase implements MobileStyleGAN, a compressed version of StyleGAN2 using:
- Knowledge Distillation: Teacher (StyleGAN2) → Student (MobileStyleGAN)
- DWT-based Architecture: Efficient frequency domain processing
- Mobile Optimizations: Depthwise convolutions, reduced parameters
- Multi-scale Training: Spatial + frequency + perceptual losses
The main entry points are:
train.py: Training the student modeldemo.py: Interactive visualizationgenerate.py: Batch image generationcompare.py: Teacher vs student comparison
All scripts follow a similar pattern: load config → initialize Distiller → load weights → generate/display images.