-
Notifications
You must be signed in to change notification settings - Fork 514
New WR 152.7s: Smear token embeddings 1 position forward, -15 steps #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Merge PR 118
…aining, improve skip connection gating, and enhance bfloat16 usage
|
Great job. These sparse sigmoid gates seem really powerful. How much did you experiment with the inner dim of the gate? E.g. you've chosen |
I only tried size 12. I also tried moving it to a different 12 dimensions than the 12 used by attention, but no improvement/worse. Other things I tried: putting after norm, only applying to x instead of x0, gating on the prior token instead of current token, gating on both the prior token and current token, changing lr_multiple to 0.1 or 2. |
|
Nice job, and thank you @ClassicLarry for adding this note
Such notes are extremely helpful in developing a better understanding of the architecture and the improvements. Sometimes people (especially in ML papers) leave out such ablation studies (or the results thereof), leaving me as the reader with less understanding (or even a mistaken conclusion) than I would have had otherwise. |
|
Last year, I tried simple learnable (not data-dependent) 1-token smearing (a.k.a. window=2 causal conv), following RWKV. |
This PR builds on all recent WR improvements including PR #127. From inspecting trained model weights, it was observed that multiple attention heads were consistently attending to the prior token. However, attention is a computationally inefficient way to attend to the prior token. This functionality is built-in below in a light-weight manner. The first 12 dimensions of the residual stream/embed are used to gate both the smear module and attention. Approximately, the model finds that (token + 0.07prior_token) is a more useful embedding representation than (token).
Note: This improvement is more marginal than the timing change would indicate. The prior WR had a mean loss of 3.2781. If I attempt to control for loss, the impact of this change appears closer to 5 steps based on testing. I think the main value of this PR is not the timing improvement, but the introduction of a new design space for optimizing the architecture: directly modeling close range information passing between tokens outside of attention.
Validation: