Skip to content

[Question] Concerns about gradient propagation through .cuda() operations in CPU→GPU transfers #8

@lejelly

Description

@lejelly

Hello!
Thank you for sharing your excellent research and implementation.
Your paper and code have been incredibly informative, and I'm learning a great deal about this novel approach to model merging using task vectors.


Context

Looking at the repository code, I notice that in the AdaMerging class, the weight coefficients (like alpha) are being composed with task difference parameters on CPU, then moved to GPU using params = tuple(p.cuda(0) for p in params) before the forward pass.

def forward(self, inp, dataset_name):
alph = self.lambdas()
params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
params = tuple(p.cuda(0) for p in params)
load_weights(self.model, self.names, params)
feature = self.model(inp)

Typically, in PyTorch's autograd mechanism, operations like .cpu() or .cuda() break the computation graph at the device transfer point, preventing gradient propagation.
The current flow appears to be:

  1. Compose pretrained_model_parameter + Σ alpha * Taskvector_k on CPU
  2. Move parameters back to GPU using .cuda()
  3. Perform forward/backward passes
  4. repeat 1~3

This sequence suggests that during backpropagation (.backward()), gradients might notflow back to alpha (or lambdas_raw).


Questions

In this implementation, it appears that gradients to the learnable alpha parameter might be blocked by the .cpu() → .cuda() operations. Am I missing something in my understanding?

If there's a specific mechanism or technique you've implemented to ensure gradient flow, I would greatly appreciate learning about it.


Thank you again for sharing your excellent research.
I would be very grateful for any insights you can provide when you have the time.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions