Skip to content

csr matrix multiply in sum layer#19

Open
jjcmoon wants to merge 3 commits intomainfrom
sparse-mm
Open

csr matrix multiply in sum layer#19
jjcmoon wants to merge 3 commits intomainfrom
sparse-mm

Conversation

@jjcmoon
Copy link
Copy Markdown
Member

@jjcmoon jjcmoon commented Mar 5, 2026

  • sum layer can be formulated as sparse matrix multiply.
  • gives small (10% - 20%) speedup.
  • sparse matrix multiply doesn't work on MPS in torch, so there the old scatter reduce is used.

import numpy as np
import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsparse
Copy link
Copy Markdown
Collaborator

@VincentDerk VincentDerk Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation and import indicates it is experimental code, would be good to indicate this when we add a semiring that uses _plain_sum_layer.

Do I understand correctly that right now there is no such semiring in get_semiring, so the sparse code is not used yet?

Looks fine to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants