diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 109dcd8..632f340 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,6 +1,11 @@ name: Build -on: [push, pull_request, workflow_dispatch] +on: + push: + tags: + - '*' + pull_request: + workflow_dispatch: jobs: build_wheels: @@ -15,11 +20,11 @@ jobs: - uses: actions/checkout@v4 - name: Build wheels - uses: pypa/cibuildwheel@v2.21.3 + uses: pypa/cibuildwheel@v3.4.0 env: - CIBW_TEST_REQUIRES: pytest torch==2.6.0, pysdd==1.0.5 + CIBW_TEST_REQUIRES: pytest torch>=2.6.0 pysdd>=1.0.5 pytest-benchmark>=5.2.3 CIBW_TEST_COMMAND: "pytest {project}/tests" - CIBW_SKIP: cp36-* cp37-* cp38-* cp313-* pp* *i686 *ppc64le *s390x *win32* *musllinux* + CIBW_SKIP: cp36-* cp37-* cp38-* pp* *i686 *ppc64le *s390x *win32* *musllinux* CIBW_TEST_SKIP: cp39-macosx_x86_64 cp310-macosx_x86_64 cp311-macosx_x86_64 cp312-macosx_x86_64 - uses: actions/upload-artifact@v4 @@ -49,6 +54,7 @@ jobs: upload_wheels: name: Upload wheels to PyPI + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') runs-on: ubuntu-latest needs: [ build_wheels,build_sdist ] environment: diff --git a/README.md b/README.md index 81ace2a..fa780eb 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,22 @@ Key options: Results are saved as JSON files under `results/`. +### Regression benchmarks + +Per-commit regression benchmarks use [pytest-benchmark](https://pytest-benchmark.readthedocs.io/). Run from the project root: + +```bash +uv run pytest tests/test_benchmarks.py --benchmark-only --benchmark-save= +``` + +Compare two saved runs: + +```bash +uv run pytest-benchmark compare --columns=mean,stddev +``` + +Results are saved as JSON files under `.benchmarks/`. + ## 📃 Paper If you use KLay in your research, consider citing [our paper](https://openreview.net/pdf?id=Zes7Wyif8G). diff --git a/bench_compare.py b/bench_compare.py new file mode 100644 index 0000000..b655c8b --- /dev/null +++ b/bench_compare.py @@ -0,0 +1,81 @@ +"""Quick benchmark comparing forward/backward speed across semirings.""" +import time +import torch +import klay +from pysdd.sdd import SddManager + +def build_circuit(nb_vars=50): + mgr = SddManager(var_count=nb_vars) + variables = list(mgr.vars) + sdd = variables[0] & variables[1] + for v in variables[2:]: + sdd = sdd | (sdd & v) + c = klay.Circuit() + c.add_sdd(sdd) + return c, nb_vars + +def bench_semiring(circuit, nb_vars, semiring, n_warmup=5, n_runs=50): + m = circuit.to_torch_module(semiring=semiring) + weights = torch.rand(nb_vars) + + # warmup + for _ in range(n_warmup): + m(weights) + + # forward + t0 = time.perf_counter() + for _ in range(n_runs): + m(weights) + fwd_time = (time.perf_counter() - t0) / n_runs + + # forward + backward + for _ in range(n_warmup): + m(weights.requires_grad_(True)).sum().backward() + + t0 = time.perf_counter() + for _ in range(n_runs): + w = weights.detach().requires_grad_(True) + m(w).sum().backward() + bwd_time = (time.perf_counter() - t0) / n_runs + + return fwd_time, bwd_time + +def bench_probabilistic(circuit, nb_vars, semiring, n_warmup=5, n_runs=50): + m = circuit.to_torch_module(semiring=semiring, probabilistic=True) + if semiring == 'log': + weights = torch.rand(nb_vars).log() + else: + weights = torch.rand(nb_vars) + + for _ in range(n_warmup): + m(weights) + + t0 = time.perf_counter() + for _ in range(n_runs): + m(weights) + fwd_time = (time.perf_counter() - t0) / n_runs + + for _ in range(n_warmup): + m(weights.requires_grad_(True)).sum().backward() + + t0 = time.perf_counter() + for _ in range(n_runs): + w = weights.detach().requires_grad_(True) + m(w).sum().backward() + bwd_time = (time.perf_counter() - t0) / n_runs + + return fwd_time, bwd_time + +if __name__ == "__main__": + circuit, nb_vars = build_circuit(50) + print(f"Circuit with {nb_vars} variables\n") + print(f"{'Semiring':<25} {'Forward (ms)':>12} {'Fwd+Bwd (ms)':>12}") + print("-" * 52) + + for sr in ['real', 'log', 'mpe', 'godel']: + fwd, bwd = bench_semiring(circuit, nb_vars, sr) + print(f"{sr:<25} {fwd*1000:>12.3f} {bwd*1000:>12.3f}") + + for sr in ['real', 'log']: + fwd, bwd = bench_probabilistic(circuit, nb_vars, sr) + print(f"prob-{sr:<20} {fwd*1000:>12.3f} {bwd*1000:>12.3f}") diff --git a/pyproject.toml b/pyproject.toml index c560753..b42e1c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,12 @@ version = "0.0.4" description = "Arithmetic circuits on the GPU" readme = "README.md" requires-python = ">=3.10" -dependencies = ["numpy"] +dependencies = [ + "numpy", + "pysdd>=1.0.6", + "pytest>=9.0.2", + "torch>=2.10.0", +] authors = [ { name = "Jaron Maene" }, { name = "Vincent Derkinderen" }, @@ -57,3 +62,8 @@ MACOSX_DEPLOYMENT_TARGET = "11.0" [tool.pytest.ini_options] pythonpath = ["."] + +[dependency-groups] +dev = [ + "pytest-benchmark>=5.2.3", +] diff --git a/src/klay/__init__.py b/src/klay/__init__.py index a584f51..51207b8 100644 --- a/src/klay/__init__.py +++ b/src/klay/__init__.py @@ -8,7 +8,7 @@ from pathlib import Path -def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False, eps: float = 0): +def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = False): """ Convert the circuit into a PyTorch module. @@ -18,15 +18,12 @@ def to_torch_module(self: Circuit, semiring: str = "log", probabilistic: bool = If enabled, construct a probabilistic circuit instead of an arithmetic circuit. This means the inputs to a sum node are multiplied by a probability, and we can interpret sum nodes as latent Categorical variables. - :param eps: - Epsilon used by log semiring for numerical stability. """ - from .torch.circuit_modules import ProbabilisticCircuitModule - from .torch.circuit_modules import CircuitModule + from .torch import ProbabilisticCircuitModule, CircuitModule indices = self._get_indices() if probabilistic: - return ProbabilisticCircuitModule(*indices, semiring=semiring, eps=eps) - return CircuitModule(*indices, semiring=semiring, eps=eps) + return ProbabilisticCircuitModule(*indices, semiring=semiring) + return CircuitModule(*indices, semiring=semiring) def to_jax_function(self: Circuit, semiring: str = "log"): diff --git a/src/klay/torch/__init__.py b/src/klay/torch/__init__.py index af4c058..00e1b58 100644 --- a/src/klay/torch/__init__.py +++ b/src/klay/torch/__init__.py @@ -1,4 +1,5 @@ -from .circuit_modules import CircuitModule, ProbabilisticCircuitModule +from .circuit_module import CircuitModule +from .probabilistic_circuit_module import ProbabilisticCircuitModule __all__ = [ "CircuitModule", diff --git a/src/klay/torch/circuit_module.py b/src/klay/torch/circuit_module.py new file mode 100644 index 0000000..78d1ac8 --- /dev/null +++ b/src/klay/torch/circuit_module.py @@ -0,0 +1,70 @@ +from functools import partial + +import torch +from torch import nn + +from .layers import CircuitLayer, LogSumExpLayer, GatherCircuitLayer +from .utils import unroll_ixs, negate_real, log1mexp + +_LAYER_CLASSES = { + "logsumexp": LogSumExpLayer, +} + + +class CircuitModule(nn.Module): + default_semirings = { + "real": ("sum", "prod", 0, 1, negate_real), + "log": ("logsumexp", "sum", float('-inf'), 0, log1mexp), + "mpe": ("amax", "prod", 0, 1, negate_real), + "godel": ("amax", "amin", 0, 1, negate_real), + } + + @staticmethod + def _make_layer(reduce, fill_value): + """Create a layer factory from a reduce spec. + + Strings use CircuitLayer (scatter_reduce) or a known layer class. + Callables use GatherCircuitLayer with fill_value. + """ + if isinstance(reduce, str): + if reduce in _LAYER_CLASSES: + return _LAYER_CLASSES[reduce] + return partial(CircuitLayer, reduce=reduce) + return partial(GatherCircuitLayer, reduce_fn=reduce, fill_value=fill_value) + + def __init__(self, ixs_in, ixs_out, semiring: str | tuple = 'real'): + super().__init__() + self.semiring = semiring + if isinstance(semiring, str): + sum_reduce, prod_reduce, self.zero, self.one, self.negate = self.default_semirings[semiring] + else: + sum_reduce, prod_reduce, self.zero, self.one, self.negate = semiring + + self.sum_layer = self._make_layer(sum_reduce, self.zero) + self.prod_layer = self._make_layer(prod_reduce, self.one) + + layers = [] + for i, (ix_in, ix_out) in enumerate(zip(ixs_in, ixs_out)): + ix_in = torch.as_tensor(ix_in, dtype=torch.long) + ix_out = torch.as_tensor(ix_out, dtype=torch.long) + ix_out = unroll_ixs(ix_out) + layer = self.prod_layer if i % 2 == 0 else self.sum_layer + layers.append(layer(ix_in, ix_out)) + self.layers = nn.Sequential(*layers) + + def forward(self, x_pos, x_neg=None): + x = self.encode_input(x_pos, x_neg) + return self.layers(x) + + def encode_input(self, pos, neg): + if neg is None: + neg = self.negate(pos) + x = torch.stack([pos, neg], dim=1).flatten() + units = torch.tensor([self.zero, self.one], dtype=pos.dtype, device=pos.device) + return torch.cat([units, x]) + + def sparsity(self, nb_vars: int) -> float: + sparse_params = sum(len(layer.ix_out) for layer in self.layers) + layer_widths = [nb_vars] + [layer.out_shape[0] for layer in self.layers] + dense_params = sum(layer_widths[i] * layer_widths[i + 1] for i in range(len(layer_widths) - 1)) + return sparse_params / dense_params diff --git a/src/klay/torch/circuit_modules.py b/src/klay/torch/circuit_modules.py deleted file mode 100644 index c17dec2..0000000 --- a/src/klay/torch/circuit_modules.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -from torch import nn - -from .layers import get_semiring, ProbabilisticCircuitLayer -from .utils import unroll_ixs - - -class CircuitModule(nn.Module): - def __init__(self, ixs_in, ixs_out, semiring: str = 'real', eps: float = 0): - super(CircuitModule, self).__init__() - self.semiring = semiring - self._eps = 0 - - self.sum_layer, self.prod_layer, self.zero, self.one, self.negate = \ - get_semiring(semiring, self.is_probabilistic()) - - layers = [] - for i, (ix_in, ix_out) in enumerate(zip(ixs_in, ixs_out)): - ix_in = torch.as_tensor(ix_in, dtype=torch.long) - ix_out = torch.as_tensor(ix_out, dtype=torch.long) - ix_out = unroll_ixs(ix_out) - layer = self.prod_layer if i % 2 == 0 else self.sum_layer - layers.append(layer(ix_in, ix_out, eps)) - self.layers = nn.Sequential(*layers) - - - def forward(self, x_pos, x_neg=None): - x = self.encode_input(x_pos, x_neg) - return self.layers(x) - - def encode_input(self, pos, neg): - if neg is None: - neg = self.negate(pos, self._eps) - x = torch.stack([pos, neg], dim=1).flatten() - units = torch.tensor([self.zero, self.one], dtype=pos.dtype, device=pos.device) - return torch.cat([units, x]) - - def sparsity(self, nb_vars: int) -> float: - sparse_params = sum(len(layer.ix_out) for layer in self.layers) - layer_widths = [nb_vars] + [layer.out_shape[0] for layer in self.layers] - dense_params = sum(layer_widths[i] * layer_widths[i + 1] for i in range(len(layer_widths) - 1)) - return sparse_params / dense_params - - def to_pc(self, x_pos, x_neg=None): - """ Converts the circuit into a probabilistic circuit.""" - assert self.semiring == "log" or self.semiring == "real" - pc = ProbabilisticCircuitModule([], [], self.semiring) - layers = [] - - x = self.encode_input(x_pos, x_neg) - for i, layer in enumerate(self.layers): - if isinstance(layer, self.sum_layer): - new_layer = pc.sum_layer(layer.ix_in, layer.ix_out, layer._eps) - weights = x.log() if self.semiring == "real" else x - new_layer.weights.data = weights[new_layer.ix_in] - else: - new_layer = layer - x = layer(x) - layers.append(new_layer) - - pc.layers = nn.Sequential(*layers) - return pc - - def is_probabilistic(self) -> bool: - """ Checks whether this circuit is probabilistic. """ - return False - - -class ProbabilisticCircuitModule(CircuitModule): - def sample(self): - """ Samples from the probabilistic circuit distribution. """ - y = torch.tensor([1]) - for layer in reversed(self.layers): - y = layer.sample(y) - return y[2::2] - - def condition(self, x_pos, x_neg): - x = self.encode_input(x_pos, x_neg) - for layer in self.layers: - x = layer.condition(x) \ - if isinstance(layer, ProbabilisticCircuitLayer) \ - else layer(x) - return x - - def is_probabilistic(self) -> bool: - """ Checks whether this circuit is probabilistic. """ - return True diff --git a/src/klay/torch/layers.py b/src/klay/torch/layers.py index 2af0ccf..215dd57 100644 --- a/src/klay/torch/layers.py +++ b/src/klay/torch/layers.py @@ -1,140 +1,67 @@ +from functools import partial +from typing import Callable + import torch from torch import nn -from .utils import negate_real, log1mexp +from .utils import gather_indices, scatter_logsumexp -class CircuitLayer(nn.Module): - def __init__(self, ix_in, ix_out, eps): +class AbstractCircuitLayer(nn.Module): + def __init__(self, ix_in: torch.Tensor, ix_out: torch.Tensor): super().__init__() self.register_buffer('ix_in', ix_in) self.register_buffer('ix_out', ix_out) self.out_shape = (self.ix_out[-1].item() + 1,) self.in_shape = (self.ix_in.max().item() + 1,) - self._eps = eps - def _scatter_forward(self, x: torch.Tensor, reduce: str): - if reduce == "logsumexp": - return self._scatter_logsumexp_forward(x) - output = torch.empty(self.out_shape, dtype=x.dtype, device=x.device) - output = torch.scatter_reduce(output, 0, index=self.ix_out, src=x, reduce=reduce, include_self=False) - return output + def sample(self, y: torch.Tensor) -> torch.Tensor: + y = y[self.ix_out] + output = torch.zeros(self.in_shape, dtype=y.dtype, device=y.device) + return torch.scatter_reduce(output, 0, index=self.ix_in, src=y, reduce="amax", include_self=False) + + +class CircuitLayer(AbstractCircuitLayer): + + def __init__(self, ix_in: torch.Tensor, ix_out: torch.Tensor, reduce: str): + super().__init__(ix_in, ix_out) + self.reduce = reduce + + def forward(self, x: torch.Tensor) -> torch.Tensor: + src = x[self.ix_in] + out = torch.empty(self.out_shape, dtype=src.dtype, device=src.device) + return torch.scatter_reduce(out, 0, index=self.ix_out, src=src, reduce=self.reduce, include_self=False) + + +class LogSumExpLayer(CircuitLayer): + + def __init__(self, ix_in: torch.Tensor, ix_out: torch.Tensor): + super().__init__(ix_in, ix_out, reduce="amax") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return scatter_logsumexp(x[self.ix_in], self.ix_out, self.out_shape) + + +class GatherCircuitLayer(AbstractCircuitLayer): + + def __init__(self, ix_in: torch.Tensor, ix_out: torch.Tensor, reduce_fn: Callable, fill_value: float = 0): + super().__init__(ix_in, ix_out) + self.reduce_fn = reduce_fn + self.fill_value = fill_value + order, sorted_index, positions, max_len = gather_indices(ix_out, self.out_shape[0]) + self.register_buffer('order', order) + self.register_buffer('sorted_index', sorted_index) + self.register_buffer('positions', positions) + self.max_len = max_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + sorted_src = x[self.ix_in][self.order] + padded = x.new_full((self.out_shape[0], self.max_len), self.fill_value) + padded[self.sorted_index, self.positions] = sorted_src + return self.reduce_fn(padded, dim=1) - def _scatter_backward(self, x: torch.Tensor, reduce: str): - output = torch.zeros(self.in_shape, dtype=x.dtype, device=x.device) - output = torch.scatter_reduce(output, 0, index=self.ix_in, src=x, reduce=reduce, include_self=False) - return output - def _safe_exp(self, x: torch.Tensor): - max_output = self._scatter_forward(x.detach(), "amax") - x = x - max_output[self.ix_out] - x.nan_to_num_(nan=0., posinf=float('inf'), neginf=float('-inf')) - return torch.exp(x), max_output - - def _scatter_logsumexp_forward(self, x: torch.Tensor): - x, max_output = self._safe_exp(x) - output = torch.full(self.out_shape, self._eps, dtype=x.dtype, device=x.device) - output = torch.scatter_add(output, 0, index=self.ix_out, src=x) - output = torch.log(output) + max_output - return output - - def sample(self, y): - return self._scatter_backward(y[self.ix_out], "amax") - - -class SumLayer(CircuitLayer): - def forward(self, x): - return self._scatter_forward(x[self.ix_in], "sum") - - -class ProdLayer(CircuitLayer): - def forward(self, x): - return self._scatter_forward(x[self.ix_in], "prod") - - -class MinLayer(CircuitLayer): - def forward(self, x): - return self._scatter_forward(x[self.ix_in], "amin") - - -class MaxLayer(CircuitLayer): - def forward(self, x): - return self._scatter_forward(x[self.ix_in], "amax") - - -class LogSumLayer(CircuitLayer): - def forward(self, x): - return self._scatter_forward(x[self.ix_in], "logsumexp") - - -class ProbabilisticCircuitLayer(CircuitLayer): - def __init__(self, ix_in, ix_out, eps): - super().__init__(ix_in, ix_out, eps) - self.weights = nn.Parameter(torch.randn_like(ix_in, dtype=torch.float32)) - - def get_edge_weights(self): - exp_weights, _ = self._safe_exp(self.weights) - norm = self._scatter_forward(exp_weights, "sum") - return exp_weights / norm[self.ix_out] - - def renorm_weights(self, x): - with torch.no_grad(): - self.weights.data = self.get_log_edge_weights() + x - - def get_log_edge_weights(self): - norm = self._scatter_logsumexp_forward(self.weights) - return self.weights - norm[self.ix_out] - - def sample(self, y): - weights = self.get_log_edge_weights() - noise = -(-torch.log(torch.rand_like(weights) + self._eps) + self._eps).log() - gumbels = weights + noise - samples = self._scatter_forward(gumbels, "amax") - samples = samples[self.ix_out] == gumbels - samples &= y[self.ix_out].to(torch.bool) - return self._scatter_backward(samples, "sum") > 0 - - -class ProbabilisticSumLayer(ProbabilisticCircuitLayer): - def forward(self, x): - x = self.get_edge_weights() * x[self.ix_in] - return self._scatter_forward(x, "sum") - - def condition(self, x): - x2 = self.forward(x) - self.renorm_weights(x[self.ix_in].log()) - return x2 - - -class ProbabilisticLogSumLayer(ProbabilisticCircuitLayer): - def forward(self, x): - x = self.get_log_edge_weights() + x[self.ix_in] - return self._scatter_logsumexp_forward(x) - - def condition(self, x): - y = self.forward(x) - self.renorm_weights(x[self.ix_in]) - return y - - -def get_semiring(name: str, probabilistic: bool): - """ - For a given semiring, returns the sum and product layer, - the zero and one elements, and a negation function. - """ - if probabilistic: - if name == "real": - return ProbabilisticSumLayer, ProdLayer, 0, 1, negate_real - if name == "log": - return ProbabilisticLogSumLayer, SumLayer, float('-inf'), 0, log1mexp - raise ValueError(f"Unknown probabilistic semiring {name}") - else: - if name == "real": - return SumLayer, ProdLayer, 0, 1, negate_real - elif name == "log": - return LogSumLayer, SumLayer, float('-inf'), 0, log1mexp - elif name == "mpe": - return MaxLayer, ProdLayer, 0, 1, negate_real - elif name == "godel": - return MaxLayer, MinLayer, 0, 1, negate_real - raise ValueError(f"Unknown semiring {name}") +SumLayer = partial(CircuitLayer, reduce="sum") +ProdLayer = partial(CircuitLayer, reduce="prod") +MinLayer = partial(CircuitLayer, reduce="amin") +MaxLayer = partial(CircuitLayer, reduce="amax") diff --git a/src/klay/torch/probabilistic_circuit_module.py b/src/klay/torch/probabilistic_circuit_module.py new file mode 100644 index 0000000..94fa904 --- /dev/null +++ b/src/klay/torch/probabilistic_circuit_module.py @@ -0,0 +1,63 @@ +import torch + +from .circuit_module import CircuitModule +from .probabilistic_layers import ProbabilisticCircuitLayer, ProbabilisticSumLayer, ProbabilisticLogSumLayer + +_PROB_LAYER_CLASSES = { + "sum": ProbabilisticSumLayer, + "logsumexp": ProbabilisticLogSumLayer, +} + + +class ProbabilisticCircuitModule(CircuitModule): + + def __init__(self, ixs_in, ixs_out, semiring: str = 'real'): + if not isinstance(semiring, str): + raise ValueError(f"ProbabilisticCircuitModule only supports named semirings {list(_PROB_LAYER_CLASSES)}, got {semiring!r}") + super().__init__(ixs_in, ixs_out, semiring) + sum_reduce = self.default_semirings[semiring][0] + self.sum_layer = _PROB_LAYER_CLASSES[sum_reduce] + # Rebuild sum layers as probabilistic + layers = [] + for i, layer in enumerate(self.layers): + if i % 2 == 1: + layers.append(self.sum_layer(layer.ix_in, layer.ix_out)) + else: + layers.append(layer) + self.layers = torch.nn.Sequential(*layers) + + def sample(self): + """ Samples from the probabilistic circuit distribution. """ + y = torch.tensor([1]) + for layer in reversed(self.layers): + y = layer.sample(y) + return y[2::2] + + def condition(self, x_pos, x_neg): + x = self.encode_input(x_pos, x_neg) + for layer in self.layers: + x = layer.condition(x) \ + if isinstance(layer, ProbabilisticCircuitLayer) \ + else layer(x) + return x + + @staticmethod + def from_circuit(circuit: CircuitModule, x_pos, x_neg=None): + """ Converts the circuit into a probabilistic circuit.""" + assert circuit.semiring == "log" or circuit.semiring == "real" + pc = ProbabilisticCircuitModule([], [], circuit.semiring) + layers = [] + + x = circuit.encode_input(x_pos, x_neg) + for i, layer in enumerate(circuit.layers): + if i % 2 == 1: # sum layers are at odd indices + new_layer = pc.sum_layer(layer.ix_in, layer.ix_out) + weights = x.log() if circuit.semiring == "real" else x + new_layer.weights.data = weights[new_layer.ix_in] + else: + new_layer = layer + x = layer(x) + layers.append(new_layer) + + pc.layers = torch.nn.Sequential(*layers) + return pc diff --git a/src/klay/torch/probabilistic_layers.py b/src/klay/torch/probabilistic_layers.py new file mode 100644 index 0000000..3d82e25 --- /dev/null +++ b/src/klay/torch/probabilistic_layers.py @@ -0,0 +1,56 @@ +import torch +from torch import nn + +from .layers import AbstractCircuitLayer +from .utils import scatter_logsumexp + + +class ProbabilisticCircuitLayer(AbstractCircuitLayer): + def __init__(self, ix_in, ix_out): + super().__init__(ix_in, ix_out) + self.weights = nn.Parameter(torch.randn_like(ix_in, dtype=torch.float32)) + + def get_edge_weights(self): + return self.get_log_edge_weights().exp() + + def get_log_edge_weights(self): + norm = scatter_logsumexp(self.weights, self.ix_out, self.out_shape) + return self.weights - norm[self.ix_out] + + def renorm_weights(self, x): + with torch.no_grad(): + self.weights.data = self.get_log_edge_weights() + x + + def sample(self, y): + weights = self.get_log_edge_weights() + noise = -(-torch.log(torch.rand_like(weights))).log() + gumbels = weights + noise + max_vals = torch.full(self.out_shape, float('-inf'), dtype=gumbels.dtype, device=gumbels.device) + max_vals = torch.scatter_reduce(max_vals, 0, index=self.ix_out, src=gumbels.detach(), reduce="amax", include_self=False) + samples = max_vals[self.ix_out] == gumbels + samples &= y[self.ix_out].to(torch.bool) + result = torch.zeros(self.in_shape, dtype=torch.long, device=y.device) + result = torch.scatter_reduce(result, 0, index=self.ix_in, src=samples.long(), reduce="sum", include_self=False) + return result > 0 + + +class ProbabilisticSumLayer(ProbabilisticCircuitLayer): + def forward(self, x): + weighted = self.get_edge_weights() * x[self.ix_in] + out = torch.zeros(self.out_shape, dtype=weighted.dtype, device=x.device) + return torch.scatter_add(out, 0, index=self.ix_out, src=weighted) + + def condition(self, x): + y = self.forward(x) + self.renorm_weights(x[self.ix_in].log()) + return y + + +class ProbabilisticLogSumLayer(ProbabilisticCircuitLayer): + def forward(self, x): + return scatter_logsumexp(self.get_log_edge_weights() + x[self.ix_in], self.ix_out, self.out_shape) + + def condition(self, x): + y = self.forward(x) + self.renorm_weights(x[self.ix_in]) + return y diff --git a/src/klay/torch/utils.py b/src/klay/torch/utils.py index cffbc75..909ff25 100644 --- a/src/klay/torch/utils.py +++ b/src/klay/torch/utils.py @@ -5,13 +5,14 @@ CUTOFF = -math.log(2) -def log1mexp(x, eps=1e-12): +def log1mexp(x): """ Numerically accurate evaluation of log(1 - exp(x)) for x < 0. See [Maechler2012accurate]_ for details. https://github.com/pytorch/pytorch/issues/39242 """ mask = CUTOFF < x # x < 0 + eps = torch.finfo(x.dtype).eps return torch.where( mask, (-x.clamp(min=CUTOFF).expm1() + eps).log(), @@ -19,10 +20,30 @@ def log1mexp(x, eps=1e-12): ) -def negate_real(x, eps): +def negate_real(x): return 1 - x +def gather_indices(index: torch.Tensor, num_groups: int): + """Compute indices for gathering sparse values into a dense 2D (num_groups, max_group_size) tensor.""" + counts = torch.bincount(index, minlength=num_groups) + max_len = counts.max().item() + order = index.argsort(stable=True) + sorted_index = index[order] + offsets = counts.cumsum(0).roll(1) + offsets[0] = 0 + positions = torch.arange(len(index), dtype=torch.long) - offsets[sorted_index] + return order, sorted_index, positions, max_len + + +def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, out_shape: tuple) -> torch.Tensor: + max_vals = torch.full(out_shape, float('-inf'), dtype=src.dtype, device=src.device) + max_vals = torch.scatter_reduce(max_vals, 0, index=index, src=src.detach(), reduce="amax", include_self=True) + exp_sum = torch.zeros(out_shape, dtype=src.dtype, device=src.device) + exp_sum = torch.scatter_add(exp_sum, 0, index=index, src=(src - max_vals[index]).exp()) + return torch.log(exp_sum) + max_vals + + def unroll_ixs(ixs): deltas = torch.diff(ixs) ixs = torch.arange(len(deltas), dtype=torch.long, device=ixs.device) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py new file mode 100644 index 0000000..922e4cd --- /dev/null +++ b/tests/test_benchmarks.py @@ -0,0 +1,114 @@ +"""Benchmarks for circuit evaluation. Run with: uv run pytest tests/test_benchmarks.py --benchmark-only""" +import pytest + +pytest.importorskip("torch") +pytest.importorskip("pysdd") + +import torch +import klay +from pysdd.sdd import SddManager + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +def small_circuit(): + """Small circuit: 10 variables.""" + mgr = SddManager(var_count=10) + vs = list(mgr.vars) + sdd = vs[0] & vs[1] + for v in vs[2:]: + sdd = sdd | (sdd & v) + c = klay.Circuit() + c.add_sdd(sdd) + return c, 10 + + +@pytest.fixture(scope="module") +def medium_circuit(): + """Medium circuit: 50 variables.""" + mgr = SddManager(var_count=50) + vs = list(mgr.vars) + sdd = vs[0] & vs[1] + for v in vs[2:]: + sdd = sdd | (sdd & v) + c = klay.Circuit() + c.add_sdd(sdd) + return c, 50 + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _forward(module, weights): + module(weights) + + +def _forward_backward(module, weights): + w = weights.detach().requires_grad_(True) + module(w).sum().backward() + + +# ── Non-probabilistic benchmarks ───────────────────────────────────────────── + +@pytest.mark.parametrize("semiring", ["real", "log", "mpe", "godel"]) +def test_forward_small(benchmark, small_circuit, semiring): + circuit, nb_vars = small_circuit + m = circuit.to_torch_module(semiring=semiring) + weights = torch.rand(nb_vars) + benchmark(_forward, m, weights) + + +@pytest.mark.parametrize("semiring", ["real", "log", "mpe", "godel"]) +def test_forward_backward_small(benchmark, small_circuit, semiring): + circuit, nb_vars = small_circuit + m = circuit.to_torch_module(semiring=semiring) + weights = torch.rand(nb_vars) + benchmark(_forward_backward, m, weights) + + +@pytest.mark.parametrize("semiring", ["real", "log", "mpe", "godel"]) +def test_forward_medium(benchmark, medium_circuit, semiring): + circuit, nb_vars = medium_circuit + m = circuit.to_torch_module(semiring=semiring) + weights = torch.rand(nb_vars) + benchmark(_forward, m, weights) + + +@pytest.mark.parametrize("semiring", ["real", "log", "mpe", "godel"]) +def test_forward_backward_medium(benchmark, medium_circuit, semiring): + circuit, nb_vars = medium_circuit + m = circuit.to_torch_module(semiring=semiring) + weights = torch.rand(nb_vars) + benchmark(_forward_backward, m, weights) + + +# ── Probabilistic benchmarks ───────────────────────────────────────────────── + +@pytest.mark.parametrize("semiring", ["real", "log"]) +def test_prob_forward_medium(benchmark, medium_circuit, semiring): + circuit, nb_vars = medium_circuit + m = circuit.to_torch_module(semiring=semiring, probabilistic=True) + weights = torch.rand(nb_vars) + if semiring == "log": + weights = weights.log() + benchmark(_forward, m, weights) + + +@pytest.mark.parametrize("semiring", ["real", "log"]) +def test_prob_forward_backward_medium(benchmark, medium_circuit, semiring): + circuit, nb_vars = medium_circuit + m = circuit.to_torch_module(semiring=semiring, probabilistic=True) + weights = torch.rand(nb_vars) + if semiring == "log": + weights = weights.log() + benchmark(_forward_backward, m, weights) + + +# ── Sampling benchmark ─────────────────────────────────────────────────────── + +def test_prob_sample_medium(benchmark, medium_circuit): + circuit, nb_vars = medium_circuit + m = circuit.to_torch_module(semiring='real', probabilistic=True) + weights = torch.rand(nb_vars) + m(weights) # initialize + benchmark(m.sample) diff --git a/tests/test_manual.py b/tests/test_manual.py index fc3058d..b028e2c 100644 --- a/tests/test_manual.py +++ b/tests/test_manual.py @@ -1,5 +1,7 @@ import pytest +from klay.torch import ProbabilisticCircuitModule + pytest.importorskip("torch") pytest.importorskip("pysdd") @@ -48,7 +50,7 @@ def test_create_pc(): c.set_root(and_node) m = c.to_torch_module(semiring='real') - m = m.to_pc(torch.tensor([0.4, 0.8, 0.5])) + m = ProbabilisticCircuitModule.from_circuit(m, torch.tensor([0.4, 0.8, 0.5])) edge_weights = m.layers[1].get_edge_weights() expected_weights = torch.tensor([2/3, 1/3, 2/7, 5/7]) assert torch.allclose(edge_weights, expected_weights) @@ -163,6 +165,102 @@ def test_sdd_literal(): assert torch.allclose(m(weights), expected) +def test_custom_semiring_tropical(): + """Test CircuitModule with a manually defined tropical (min-plus) semiring. + + Tropical semiring: ⊕ = min, ⊗ = +, zero = +∞, one = 0. + For circuit AND(OR(l1, l2), OR(l2, l3)) with costs [1, 2, 3]: + OR(l1, l2) = min(1, 2) = 1 + OR(l2, l3) = min(2, 3) = 2 + AND(...) = 1 + 2 = 3 + """ + def tropical_negate(x): + return -x + + tropical_semiring = ("amin", "sum", float('inf'), 0.0, tropical_negate) + + c = klay.Circuit() + l1, l2, l3 = c.literal_node(1), c.literal_node(2), c.literal_node(3) + or1 = c.or_node([l1, l2]) + or2 = c.or_node([l2, l3]) + c.set_root(c.and_node([or1, or2])) + + m = c.to_torch_module(semiring=tropical_semiring) + costs = torch.tensor([1.0, 2.0, 3.0]) + result = m(costs) + + expected = torch.tensor([3.0]) + assert torch.allclose(result, expected), f"Expected {expected}, got {result}" + + +def test_custom_callable_layer(): + """GatherCircuitLayer accepts a standard reduction + fill_value.""" + from klay.torch.layers import GatherCircuitLayer + + # Three groups of unequal size: [0,1,2] -> 0, [3,4] -> 1, [5] -> 2 + ix_in = torch.tensor([0, 1, 2, 3, 4, 5]) + ix_out = torch.tensor([0, 0, 0, 1, 1, 2]) + layer = GatherCircuitLayer(ix_in, ix_out, reduce_fn=torch.nanmean, fill_value=float('nan')) + + x = torch.tensor([1.0, 3.0, 8.0, 2.0, 6.0, 5.0]) + result = layer(x) + expected = torch.tensor([4.0, 4.0, 5.0]) # mean([1,3,8])=4, mean([2,6])=4, mean([5])=5 + assert torch.allclose(result, expected) + + +def test_custom_callable_semiring(): + """CircuitModule accepts a semiring with a custom callable reduction.""" + from klay.torch.utils import log1mexp + + c = klay.Circuit() + l1, l2, l3 = c.literal_node(1), c.literal_node(-2), c.literal_node(3) + or1 = c.or_node([l1, l2]) + or2 = c.or_node([l2, l3]) + c.set_root(c.and_node([or1, or2])) + + # Custom callable (torch.logsumexp) for sum, string for prod + semiring = (torch.logsumexp, "sum", float('-inf'), 0.0, log1mexp) + m = c.to_torch_module(semiring=semiring) + + weights = torch.tensor([0.4, 0.8, 0.5]) + expected = c.to_torch_module(semiring='log')(weights.log()).exp() + assert torch.allclose(m(weights.log()).exp(), expected) + + +def test_probabilistic_rejects_custom_semiring(): + c = klay.Circuit() + l1, l2 = c.literal_node(1), c.literal_node(-2) + c.set_root(c.or_node([l1, l2])) + with pytest.raises(ValueError, match="only supports named semirings"): + c.to_torch_module(semiring=("sum", "prod", 0, 1, lambda x: 1 - x), probabilistic=True) + + +def test_sparsity(): + c = klay.Circuit() + l1, l2, l3 = c.literal_node(1), c.literal_node(2), c.literal_node(3) + or1 = c.or_node([l1, l2]) + or2 = c.or_node([l2, l3]) + c.set_root(c.and_node([or1, or2])) + + m = c.to_torch_module(semiring='real') + assert 0 < m.sparsity(3) <= 1 + + +def test_log_pc_conditioning(): + c = klay.Circuit() + p1, p2 = c.literal_node(1), c.literal_node(2) + n1, n2 = c.literal_node(-1), c.literal_node(-2) + and_node1 = c.and_node([p1, p2]) + and_node2 = c.and_node([n1, n2]) + or_node = c.or_node([and_node1, and_node2]) + c.set_root(or_node) + + m = c.to_torch_module(semiring='log', probabilistic=True) + m.condition(torch.tensor([0.0, 0.0]), torch.tensor([0.0, float('-inf')])) + for _ in range(20): + assert torch.allclose(m.sample(), torch.tensor([True, True])) + + def test_sdd_multiroot(): sdd_mgr = SddManager(var_count=2) a, b = sdd_mgr.vars