From 2beb14a1a09d6432f1aae20adc50098e27c0427d Mon Sep 17 00:00:00 2001 From: aiihn <101506918+aiihn@users.noreply.github.com> Date: Mon, 26 Aug 2024 03:41:17 +0000 Subject: [PATCH] Enable torch.cuda.amp.GradScaler to automatically adjust the loss scaling --- ct_train.py | 3 +- training/ct_training_loop.py | 56 ++++++++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/ct_train.py b/ct_train.py index 43b3724..d9976a5 100644 --- a/ct_train.py +++ b/ct_train.py @@ -78,6 +78,7 @@ def convert(self, value, param, ctx): @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--tf32', help='Enable tf32 for A100/H100 training speed', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) +@click.option('--enable_gradscaler', help='Enable torch.cuda.amp.GradScaler, NOTE overwritting loss_scale set by --ls', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) @@ -164,7 +165,7 @@ def main(**kwargs): c.ema_halflife_kimg = int(opts.ema * 1000) if opts.ema is not None else opts.ema c.ema_beta = opts.ema_beta c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) - c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench, enable_tf32=opts.tf32) + c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench, enable_tf32=opts.tf32, enable_gradscaler=opts.enable_gradscaler) c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump, ckpt_ticks=opts.ckpt, double_ticks=opts.double) c.update(mid_t=opts.mid_t, metrics=opts.metrics, sample_ticks=opts.sample_every, eval_ticks=opts.eval_every) diff --git a/training/ct_training_loop.py b/training/ct_training_loop.py index a0db9fa..4f5f0b8 100644 --- a/training/ct_training_loop.py +++ b/training/ct_training_loop.py @@ -129,6 +129,7 @@ def training_loop( metrics = None, # Metrics for evaluation. cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? enable_tf32 = False, # Enable tf32 for A100/H100 GPUs? + enable_gradscaler = False, # Enable torch.cuda.amp.GradScaler device = torch.device('cuda'), ): # Initialize. @@ -168,6 +169,14 @@ def training_loop( optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe + dist.print0(f'GradScaler enabled: {enable_gradscaler}') + if enable_gradscaler: + # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-gradscaler + # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation + dist.print0(f'Setting up GradScaler...') + scaler = torch.cuda.amp.GradScaler() + dist.print0(f'Loss scaling is overwritten when GradScaler is enabled') + dist.print0('Setting up DDP...') ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) ema = copy.deepcopy(net).eval().requires_grad_(False) @@ -197,6 +206,13 @@ def training_loop( data = torch.load(resume_state_dump, map_location=torch.device('cpu')) misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) optimizer.load_state_dict(data['optimizer_state']) + if enable_gradscaler: + if 'gradscaler_state' in data: + dist.print0(f'Loading GradScaler state from "{resume_state_dump}"...') + # Although not loading the state_dict of the GradScaler works well, loading it can improve reproducibility. + scaler.load_state_dict(data['gradscaler_state']) + else: + dist.print0(f'GradScaler state is not found in "{resume_state_dump}", using the default state.') del data # conserve memory # Export sample images. @@ -253,16 +269,24 @@ def update_scheduler(loss_fn): loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) training_stats.report('Loss/loss', loss) - # loss.sum().mul(loss_scaling / batch_gpu_total).backward() - loss.mul(loss_scaling).mean().backward() - - # Update weights. - # for g in optimizer.param_groups: - # g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) - for param in net.parameters(): - if param.grad is not None: - torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) - optimizer.step() + if enable_gradscaler: + scaler.scale(loss.mean()).backward() + else: + # loss.sum().mul(loss_scaling / batch_gpu_total).backward() + loss.mul(loss_scaling).mean().backward() + + if enable_gradscaler: + # TODO Is torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) needed when using GradScaler? + scaler.step(optimizer) + scaler.update() + else: + # Update weights. + # for g in optimizer.param_groups: + # g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) + for param in net.parameters(): + if param.grad is not None: + torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) + optimizer.step() # Update EMA. if ema_halflife_kimg is not None: @@ -317,7 +341,11 @@ def update_scheduler(loss_fn): # Save full dump of the training state. if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: - torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_tick:06d}.pt')) + if enable_gradscaler: + data = dict(net=net, optimizer_state=optimizer.state_dict(), gradscaler_state=scaler.state_dict()) + else: + data = dict(net=net, optimizer_state=optimizer.state_dict()) + torch.save(data, os.path.join(run_dir, f'training-state-{cur_tick:06d}.pt')) # Save latest checkpoints if (ckpt_ticks is not None) and (done or cur_tick % ckpt_ticks == 0) and cur_tick != 0: @@ -335,7 +363,11 @@ def update_scheduler(loss_fn): del data # conserve memory if dist.get_rank() == 0: - torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-latest.pt')) + if enable_gradscaler: + data = dict(net=net, optimizer_state=optimizer.state_dict(), gradscaler_state=scaler.state_dict()) + else: + data = dict(net=net, optimizer_state=optimizer.state_dict()) + torch.save(data, os.path.join(run_dir, f'training-state-latest.pt')) # Sample Img if (sample_ticks is not None) and (done or cur_tick % sample_ticks == 0) and dist.get_rank() == 0: