Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Build

on: [push, pull_request, workflow_dispatch]
on:
push:
tags:
- '*'
pull_request:
workflow_dispatch:

jobs:
build_wheels:
Expand All @@ -15,11 +20,11 @@ jobs:
- uses: actions/checkout@v4

- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
uses: pypa/cibuildwheel@v3.4.0
env:
CIBW_TEST_REQUIRES: pytest torch==2.6.0, pysdd==1.0.5
CIBW_TEST_REQUIRES: pytest torch>=2.6.0 pysdd>=1.0.5 pytest-benchmark>=5.2.3
CIBW_TEST_COMMAND: "pytest {project}/tests"
CIBW_SKIP: cp36-* cp37-* cp38-* cp313-* pp* *i686 *ppc64le *s390x *win32* *musllinux*
CIBW_SKIP: cp36-* cp37-* cp38-* pp* *i686 *ppc64le *s390x *win32* *musllinux*
CIBW_TEST_SKIP: cp39-macosx_x86_64 cp310-macosx_x86_64 cp311-macosx_x86_64 cp312-macosx_x86_64

- uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -49,6 +54,7 @@ jobs:

upload_wheels:
name: Upload wheels to PyPI
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
runs-on: ubuntu-latest
needs: [ build_wheels,build_sdist ]
environment:
Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ Key options:

Results are saved as JSON files under `results/`.

### Regression benchmarks

Per-commit regression benchmarks use [pytest-benchmark](https://pytest-benchmark.readthedocs.io/). Run from the project root:

```bash
uv run pytest tests/test_benchmarks.py --benchmark-only --benchmark-save=<name>
```

Compare two saved runs:

```bash
uv run pytest-benchmark compare <old> <new> --columns=mean,stddev
```

Results are saved as JSON files under `.benchmarks/`.

## 📃 Paper

If you use KLay in your research, consider citing [our paper](https://openreview.net/pdf?id=Zes7Wyif8G).
Expand Down
81 changes: 81 additions & 0 deletions bench_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Quick benchmark comparing forward/backward speed across semirings."""
import time
import torch
import klay
from pysdd.sdd import SddManager

def build_circuit(nb_vars=50):
mgr = SddManager(var_count=nb_vars)
variables = list(mgr.vars)
sdd = variables[0] & variables[1]
for v in variables[2:]:
sdd = sdd | (sdd & v)
c = klay.Circuit()
c.add_sdd(sdd)
return c, nb_vars

def bench_semiring(circuit, nb_vars, semiring, n_warmup=5, n_runs=50):
m = circuit.to_torch_module(semiring=semiring)
weights = torch.rand(nb_vars)

# warmup
for _ in range(n_warmup):
m(weights)

# forward
t0 = time.perf_counter()
for _ in range(n_runs):
m(weights)
fwd_time = (time.perf_counter() - t0) / n_runs

# forward + backward
for _ in range(n_warmup):
m(weights.requires_grad_(True)).sum().backward()

t0 = time.perf_counter()
for _ in range(n_runs):
w = weights.detach().requires_grad_(True)
m(w).sum().backward()
bwd_time = (time.perf_counter() - t0) / n_runs

return fwd_time, bwd_time

def bench_probabilistic(circuit, nb_vars, semiring, n_warmup=5, n_runs=50):
m = circuit.to_torch_module(semiring=semiring, probabilistic=True)
if semiring == 'log':
weights = torch.rand(nb_vars).log()
else:
weights = torch.rand(nb_vars)

for _ in range(n_warmup):
m(weights)

t0 = time.perf_counter()
for _ in range(n_runs):
m(weights)
fwd_time = (time.perf_counter() - t0) / n_runs

for _ in range(n_warmup):
m(weights.requires_grad_(True)).sum().backward()

t0 = time.perf_counter()
for _ in range(n_runs):
w = weights.detach().requires_grad_(True)
m(w).sum().backward()
bwd_time = (time.perf_counter() - t0) / n_runs

return fwd_time, bwd_time

if __name__ == "__main__":
circuit, nb_vars = build_circuit(50)
print(f"Circuit with {nb_vars} variables\n")
print(f"{'Semiring':<25} {'Forward (ms)':>12} {'Fwd+Bwd (ms)':>12}")
print("-" * 52)

for sr in ['real', 'log', 'mpe', 'godel']:
fwd, bwd = bench_semiring(circuit, nb_vars, sr)
print(f"{sr:<25} {fwd*1000:>12.3f} {bwd*1000:>12.3f}")

for sr in ['real', 'log']:
fwd, bwd = bench_probabilistic(circuit, nb_vars, sr)
print(f"prob-{sr:<20} {fwd*1000:>12.3f} {bwd*1000:>12.3f}")
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ version = "0.0.4"
description = "Arithmetic circuits on the GPU"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["numpy"]
dependencies = [
"numpy",
"pysdd>=1.0.6",
"pytest>=9.0.2",
"torch>=2.10.0",
]
authors = [
{ name = "Jaron Maene" },
{ name = "Vincent Derkinderen" },
Expand Down Expand Up @@ -57,3 +62,8 @@ MACOSX_DEPLOYMENT_TARGET = "11.0"

[tool.pytest.ini_options]
pythonpath = ["."]

[dependency-groups]
dev = [
"pytest-benchmark>=5.2.3",
]
11 changes: 4 additions & 7 deletions src/klay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path


def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False, eps: float = 0):
def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False):
"""
Convert the circuit into a PyTorch module.

Expand All @@ -18,15 +18,12 @@ def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool =
If enabled, construct a probabilistic circuit instead of an arithmetic circuit.
This means the inputs to a sum node are multiplied by a probability, and
we can interpret sum nodes as latent Categorical variables.
:param eps:
Epsilon used by log semiring for numerical stability.
"""
from .torch.circuit_modules import ProbabilisticCircuitModule
from .torch.circuit_modules import CircuitModule
from .torch import ProbabilisticCircuitModule, CircuitModule
indices = self._get_indices()
if probabilistic:
return ProbabilisticCircuitModule(*indices, semiring=semiring, eps=eps)
return CircuitModule(*indices, semiring=semiring, eps=eps)
return ProbabilisticCircuitModule(*indices, semiring=semiring)
return CircuitModule(*indices, semiring=semiring)


def to_jax_function(self: Circuit, semiring: str = "log"):
Expand Down
3 changes: 2 additions & 1 deletion src/klay/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .circuit_modules import CircuitModule, ProbabilisticCircuitModule
from .circuit_module import CircuitModule
from .probabilistic_circuit_module import ProbabilisticCircuitModule

__all__ = [
"CircuitModule",
Expand Down
70 changes: 70 additions & 0 deletions src/klay/torch/circuit_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from functools import partial

import torch
from torch import nn

from .layers import CircuitLayer, LogSumExpLayer, GatherCircuitLayer
from .utils import unroll_ixs, negate_real, log1mexp

_LAYER_CLASSES = {
"logsumexp": LogSumExpLayer,
}


class CircuitModule(nn.Module):
default_semirings = {
"real": ("sum", "prod", 0, 1, negate_real),
"log": ("logsumexp", "sum", float('-inf'), 0, log1mexp),
"mpe": ("amax", "prod", 0, 1, negate_real),
"godel": ("amax", "amin", 0, 1, negate_real),
}

@staticmethod
def _make_layer(reduce, fill_value):
"""Create a layer factory from a reduce spec.

Strings use CircuitLayer (scatter_reduce) or a known layer class.
Callables use GatherCircuitLayer with fill_value.
"""
if isinstance(reduce, str):
if reduce in _LAYER_CLASSES:
return _LAYER_CLASSES[reduce]
return partial(CircuitLayer, reduce=reduce)
return partial(GatherCircuitLayer, reduce_fn=reduce, fill_value=fill_value)

def __init__(self, ixs_in, ixs_out, semiring: str | tuple = 'real'):
super().__init__()
self.semiring = semiring
if isinstance(semiring, str):
sum_reduce, prod_reduce, self.zero, self.one, self.negate = self.default_semirings[semiring]
else:
sum_reduce, prod_reduce, self.zero, self.one, self.negate = semiring

self.sum_layer = self._make_layer(sum_reduce, self.zero)
self.prod_layer = self._make_layer(prod_reduce, self.one)

layers = []
for i, (ix_in, ix_out) in enumerate(zip(ixs_in, ixs_out)):
ix_in = torch.as_tensor(ix_in, dtype=torch.long)
ix_out = torch.as_tensor(ix_out, dtype=torch.long)
ix_out = unroll_ixs(ix_out)
layer = self.prod_layer if i % 2 == 0 else self.sum_layer
layers.append(layer(ix_in, ix_out))
self.layers = nn.Sequential(*layers)

def forward(self, x_pos, x_neg=None):
x = self.encode_input(x_pos, x_neg)
return self.layers(x)

def encode_input(self, pos, neg):
if neg is None:
neg = self.negate(pos)
x = torch.stack([pos, neg], dim=1).flatten()
units = torch.tensor([self.zero, self.one], dtype=pos.dtype, device=pos.device)
return torch.cat([units, x])

def sparsity(self, nb_vars: int) -> float:
sparse_params = sum(len(layer.ix_out) for layer in self.layers)
layer_widths = [nb_vars] + [layer.out_shape[0] for layer in self.layers]
dense_params = sum(layer_widths[i] * layer_widths[i + 1] for i in range(len(layer_widths) - 1))
return sparse_params / dense_params
87 changes: 0 additions & 87 deletions src/klay/torch/circuit_modules.py

This file was deleted.

Loading
Loading