Skip to content

Semiring rework#21

Open
rmanhaeve wants to merge 10 commits intomainfrom
semiring-rework
Open

Semiring rework#21
rmanhaeve wants to merge 10 commits intomainfrom
semiring-rework

Conversation

@rmanhaeve
Copy link
Copy Markdown
Contributor

Refactored the torch layer hierarchy. The main goals were making it easier to define custom semirings and cleaning up the deep method dispatch chains.

Layer hierarchy (before → after):

  • CircuitLayer was a god-class with _scatter_forward, _scatter_backward, _safe_exp, _scatter_logsumexp_forward all baked in. Now it's a thin wrapper around torch.scatter_reduce that just takes a reduce: str.
  • LogSumExpLayer subclasses CircuitLayer and overrides forward — instead of the old runtime if reduce == "logsumexp" branching.
  • ProbabilisticCircuitLayer was also tangled into CircuitLayer. Now it's a separate branch under AbstractCircuitLayer, with ProbabilisticSumLayer and ProbabilisticLogSumLayer as clean subclasses (no more if log_space branching or reduce_fn callables).
  • New GatherCircuitLayer for custom reductions: pads groups into a 2D tensor so users just provide a standard torch reduction (e.g. torch.nanmean) + a fill value. No scatter knowledge needed.

Custom semirings:

  • Semirings are now (sum_reduce, prod_reduce, zero, one, negate) where reduces are strings (for scatter_reduce-backed ops) or callables (for custom gather-based ops). No more passing layer constructors.
  • e.g. ("amin", "sum", float('inf'), 0.0, tropical_negate) for tropical semiring.

Utilities:

  • Extracted scatter_logsumexp and gather_indices into utils.py.

Performance:

  • No regressions, actually ~6-20% faster across the board. The speedup comes from flattening the method dispatch chains (the old code went 5 calls deep for logsumexp) and removing per-forward string comparisons/nan_to_num_ calls.

Tests & benchmarks:

  • 25 tests, 100% coverage on klay.torch.
  • Added tests/test_benchmarks.py using pytest-benchmark for per-commit regression tracking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant