Skip to content

Conversation

@ClassicLarry
Copy link
Collaborator

This PR builds on all recent WR improvements including PR #127. From inspecting trained model weights, it was observed that multiple attention heads were consistently attending to the prior token. However, attention is a computationally inefficient way to attend to the prior token. This functionality is built-in below in a light-weight manner. The first 12 dimensions of the residual stream/embed are used to gate both the smear module and attention. Approximately, the model finds that (token + 0.07prior_token) is a more useful embedding representation than (token).

Note: This improvement is more marginal than the timing change would indicate. The prior WR had a mean loss of 3.2781. If I attempt to control for loss, the impact of this change appears closer to 5 steps based on testing. I think the main value of this PR is not the timing improvement, but the introduction of a new design space for optimizing the architecture: directly modeling close range information passing between tokens outside of attention.

self.smear_gate = CastedLinear(12, 1)
self.smear_gate.weight.detach().zero_()

x = self.embed(input_seq)
# smear token embed forward 1 position
smear_lambda = self.scalars[5 * len(self.blocks)]
smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)]))
x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]])
x = x0 = norm(x[None])

Validation:

import scipy.stats
import torch

accs = [3.2781,3.2792,3.2765,3.2796,3.2803,3.2801,3.2787,3.2798,3.2787,3.2786]

times = [152.771,152.816,152.834,152.755,152.789,152.773,152.815,152.796,152.798,152.754]

print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue)
# p=0.0084
print("acc:", torch.std_mean(torch.tensor(accs)))
# acc: (tensor(0.0011), tensor(3.2790))

print("time:", torch.std_mean(torch.tensor(times)))
# time: (tensor(0.0269), tensor(152.7901))

@varunneal
Copy link
Contributor

Great job. These sparse sigmoid gates seem really powerful. How much did you experiment with the inner dim of the gate? E.g. you've chosen 12 for this as well as the attention gate in #117

@ClassicLarry
Copy link
Collaborator Author

Great job. These sparse sigmoid gates seem really powerful. How much did you experiment with the inner dim of the gate? E.g. you've chosen 12 for this as well as the attention gate in #117

I only tried size 12. I also tried moving it to a different 12 dimensions than the 12 used by attention, but no improvement/worse.

Other things I tried: putting after norm, only applying to x instead of x0, gating on the prior token instead of current token, gating on both the prior token and current token, changing lr_multiple to 0.1 or 2.

@trianxy
Copy link

trianxy commented Sep 19, 2025

Nice job, and thank you @ClassicLarry for adding this note

Note: This improvement is more marginal than the timing change would indicate. The prior WR had a mean loss of 3.2781. If I attempt to control for loss, the impact of this change appears closer to 5 steps based on testing.

Such notes are extremely helpful in developing a better understanding of the architecture and the improvements. Sometimes people (especially in ML papers) leave out such ablation studies (or the results thereof), leaving me as the reader with less understanding (or even a mistaken conclusion) than I would have had otherwise.

@YouJiacheng
Copy link
Contributor

Last year, I tried simple learnable (not data-dependent) 1-token smearing (a.k.a. window=2 causal conv), following RWKV.
But it didn't achieve significant improvement so I dropped it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants