diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index 56fe109e6b..be8f54e659 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -147,7 +147,7 @@ def get_causal_mask_from_padding_mask( diagonal=0, ).repeat(bsz, 1, 1) mask.narrow(2, 0, seq_len).mul_(padding_mask[:, None, :].expand(-1, seq_len, -1)) - mask.diagonal(dim1=1, dim2=2).copy_(True) + mask.diagonal(dim1=1, dim2=2).copy_(torch.Tensor([True])) return mask