Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
dd8d8c4
Add file for SOAP kernel. Add new Metric
gsallaberry Jun 16, 2025
fab4e97
Add numpy implementation of _pairwise_similarity and _crosswise_simil…
gsallaberry Jun 16, 2025
09cc1b3
Add files to mirror shear_kernel backend
gsallaberry Jun 16, 2025
be1fa61
placeholder functions for soap kernel
gsallaberry Jun 16, 2025
6256f38
Base setup for SOAP Kernel
gsallaberry Jun 25, 2025
0293743
Restructured SOAP kernel implementation for consistency
gsallaberry Jun 26, 2025
504cbfb
Fix style so CI stops yelling at me
gsallaberry Jun 27, 2025
8a826d8
correct Kout shape
gsallaberry Jul 1, 2025
e5a372e
Tensor verification test framework
gsallaberry Jul 2, 2025
6c77862
add np.pad to _src.math.numpy.py
gsallaberry Jul 2, 2025
86fd189
Initial tests implemented but failing
gsallaberry Jul 2, 2025
7570dc5
Format contributed files with black
gsallaberry Jul 2, 2025
2e5ea68
Similarity tensor tests passing
gsallaberry Jul 2, 2025
32e1183
add raise NotImplementedError to unsupported backends
gsallaberry Jul 2, 2025
709301e
Add kernel tests
gsallaberry Jul 7, 2025
b79c864
Add numpy functions needed for SOAP tests
gsallaberry Jul 9, 2025
7467664
Add Jared's implementation to tests
gsallaberry Jul 9, 2025
0bca546
Change similarity tensor math to broadcasting. Updated test
gsallaberry Jul 22, 2025
9228db8
Update similarity tensor calculation to produce correct organization …
gsallaberry Jul 25, 2025
53c48f1
updated contributions to black 80-width standards
bwpriest Jul 28, 2025
a727674
fixed flake8 issues by commenting out some unused variables.
bwpriest Jul 28, 2025
2c8df46
Merge branch 'develop' into iap_develop
bwpriest Jul 28, 2025
b723cca
Add baseline functions to _test
gsallaberry Aug 7, 2025
1dba211
Merge branch 'iap_develop' of github.com:gsallaberry/MuyGPyS into iap…
gsallaberry Aug 7, 2025
40b3a8a
More shape experimentation
gsallaberry Aug 11, 2025
ab885e2
Separate Knm from _soap_fn
gsallaberry Sep 2, 2025
de4c958
Change similarity tensor creation to einsum
gsallaberry Sep 2, 2025
33cc6f9
add ndenumerate to numpy imports for tests
gsallaberry Sep 11, 2025
5716ab1
update tests
gsallaberry Sep 11, 2025
f64b108
Tutorial notebook
gsallaberry Sep 11, 2025
5ed7a90
Merge branch 'develop' into iap_develop
gsallaberry Sep 11, 2025
a196599
add an _out_similatiry for nontrivial Kout
gsallaberry Sep 12, 2025
09bbe58
Add infrastructure for Kout
gsallaberry Sep 12, 2025
dba5674
Syntax and style fixes
Sep 23, 2025
d1a8631
Modify Kout to match new workflow
Sep 23, 2025
2c9ba9e
Add imports for test functions
Sep 23, 2025
50b2dd9
import for Kout
Sep 23, 2025
2b83d1a
Add jax support for SOAP related fns
Sep 24, 2025
f446929
Add NotImplemented errors to kernel backends
Sep 24, 2025
d95b594
Raise correct errors for non-implemented backends
Sep 25, 2025
5bd26d0
fix typos
Sep 25, 2025
0906972
Example notebook
Sep 25, 2025
a5a04c5
Test infrastructure
Sep 26, 2025
37f50f5
Typo corrections
Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
586 changes: 586 additions & 0 deletions experimental/SOAP_kernel.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions experimental/shear_2x3_offset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "gpshear_env",
"language": "python",
"name": "python3"
},
Expand All @@ -1019,7 +1019,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions experimental/shear_kernel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1881,7 +1881,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "gp4iap",
"language": "python",
"name": "python3"
},
Expand All @@ -1895,7 +1895,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
8 changes: 8 additions & 0 deletions src/MuyGPyS/_src/gp/kernels/soap/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2021-2024 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

from MuyGPyS._src.util import _collect_implementation

(_soap_fn,) = _collect_implementation("MuyGPyS._src.gp.kernels.soap", "_soap_fn")
28 changes: 28 additions & 0 deletions src/MuyGPyS/_src/gp/kernels/soap/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2021-2024 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT


def _omega(diffs, sensitivity):
raise NotImplementedError("JAX backend not yet supported for SOAPKernel")


def _T1(diffs, sensitivity):
raise NotImplementedError("JAX backend not yet supported for SOAPKernel")


def _T2(diffs, sensitivity):
raise NotImplementedError("JAX backend not yet supported for SOAPKernel")


def _T3(diffs, sensitivity):
raise NotImplementedError("JAX backend not yet supported for SOAPKernel")


def _Knm(omega, T1, T2, T3, sensitivity):
raise NotImplementedError("JAX backend not yet supported for SOAPKernel")


def _soap_fn(diffs, sensitivity):
raise NotImplementedError("JAX backend not yet supported for SOAPKernel")
28 changes: 28 additions & 0 deletions src/MuyGPyS/_src/gp/kernels/soap/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2021-2024 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT


def _omega(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _T1(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _T2(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _T3(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _Knm(omega, T1, T2, T3, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _soap_fn(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")
66 changes: 66 additions & 0 deletions src/MuyGPyS/_src/gp/kernels/soap/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2021-2024 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

import MuyGPyS._src.math.numpy as mm


def _omega(diffs) -> mm.ndarray:
ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 0

qq_slice = diffs[tuple(slicer)]

return qq_slice


def _T1(diffs) -> mm.ndarray:
ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 3

dd_slice = diffs[tuple(slicer)]

return dd_slice


def _T2(diffs) -> mm.ndarray:
ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 1

diq_slice = diffs[tuple(slicer)]

return diq_slice


def _T3(diffs) -> mm.ndarray:
ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 2

djq_slice = diffs[tuple(slicer)]

return djq_slice


def _Knm(omega, T1, T2, T3, sensitivity) -> mm.ndarray:
Knm = (sensitivity - 1.0) * (omega ** (sensitivity - 2.0)) * (T2 * T3) + (
omega ** (sensitivity - 1.0)
) * T1
return Knm


def _soap_fn(diffs: mm.ndarray, sensitivity: float, **kwargs) -> mm.ndarray:
omega = _omega(diffs)
T1 = _T1(diffs)
T2 = _T2(diffs)
T3 = _T3(diffs)

Knm = _Knm(omega, T1, T2, T3, sensitivity)

Kij = sensitivity * mm.sum(Knm, axis=(-2, -1))

return Kij
28 changes: 28 additions & 0 deletions src/MuyGPyS/_src/gp/kernels/soap/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2021-2024 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT


def _omega(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _T1(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _T2(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _T3(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _Knm(omega, T1, T2, T3, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")


def _soap_fn(diffs, sensitivity):
raise NotImplementedError("MPI backend not yet supported for SOAPKernel")
6 changes: 6 additions & 0 deletions src/MuyGPyS/_src/gp/tensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
_make_fast_predict_tensors,
_batch_features_tensor,
_crosswise_differences,
_crosswise_similarity,
_out_similarity,
_crosswise_tensor,
_pairwise_differences,
_pairwise_similarity,
_pairwise_tensor,
_fast_nn_update,
_make_heteroscedastic_tensor,
Expand All @@ -21,8 +24,11 @@
"_make_fast_predict_tensors",
"_batch_features_tensor",
"_crosswise_differences",
"_crosswise_similarity",
"_out_similarity",
"_crosswise_tensor",
"_pairwise_differences",
"_pairwise_similarity",
"_pairwise_tensor",
"_fast_nn_update",
"_make_heteroscedastic_tensor",
Expand Down
26 changes: 26 additions & 0 deletions src/MuyGPyS/_src/gp/tensors/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ def _pairwise_tensor(
return _pairwise_differences(points)


@jit
def _crosswise_similarity(
data: jnp.ndarray,
nn_data: jnp.ndarray,
data_indices: jnp.ndarray,
nn_indices: jnp.ndarray,
):
raise NotImplementedError("JAX backend not yet supported for similarity tensors")


@jit
def _pairwise_similarity(
data: jnp.ndarray,
nn_indices: jnp.ndarray,
):
raise NotImplementedError("JAX backend not yet supported for similarity tensors")


@jit
def _out_similarity(
data: jnp.ndarray,
data_indices: jnp.ndarray
):
raise NotImplementedError("JAX backend not yet supported for similarity tensors")


@jit
def _F2(diffs: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(diffs**2, axis=-1)
Expand Down
23 changes: 23 additions & 0 deletions src/MuyGPyS/_src/gp/tensors/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,29 @@ def _pairwise_differences(points: np.ndarray) -> np.ndarray:
)


def _crosswise_similarity(
data: np.ndarray,
nn_data: np.ndarray,
data_indices: np.ndarray,
nn_indices: np.ndarray,
):
raise NotImplementedError("MPI backend not yet supported for similarity tensors")


def _pairwise_similarity(
data: np.ndarray,
nn_indices: np.ndarray,
):
raise NotImplementedError("MPI backend not yet supported for similarity tensors")


def _out_similarity(
data: np.ndarray,
data_indices: np.ndarray
):
raise NotImplementedError("MPI backend not yet supported for similarity tensors")


def _fast_nn_update(
train_nn_indices: np.ndarray,
) -> np.ndarray:
Expand Down
60 changes: 60 additions & 0 deletions src/MuyGPyS/_src/gp/tensors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,66 @@ def _pairwise_differences(points: np.ndarray) -> np.ndarray:
raise ValueError(f"points shape {points.shape} is not supported.")


def _crosswise_similarity(
data: np.ndarray,
nn_data: np.ndarray,
data_indices: np.ndarray,
nn_indices: np.ndarray,
) -> np.ndarray:
locations = data[data_indices]
points = nn_data[nn_indices].swapaxes(2, 1)

# working implementation without einsum
# dot = np.sum(
# locations[:, None, :, None, :, None, :, None, :]
# * points[:, :, None, :, None, :, None, :, :],
# axis=-1,
# ).swapaxes(1, 2)
# shape = dot.shape

# locations.shape = (i, x, d, a, q)
# points.shape = (i, y, k, e, b, q)
dot = np.einsum("ixdaq, iykebq -> iykxdeab", locations, points)

crosswise_similarity = dot.reshape(*dot.shape[:-4], -1, *dot.shape[-2:])

return crosswise_similarity


def _pairwise_similarity(
data: np.ndarray,
nn_indices: np.ndarray,
) -> np.ndarray:
points = data[nn_indices].swapaxes(2, 1)

# working implementation without einsum
# dot = np.sum(
# points[:, :, :, None, None, :, None, :, None, :]
# * points[:, None, None, :, :, None, :, None, :, :],
# axis=-1,
# )

# points.shape=(i, x, k, d, a, q) / (i, y, l, e, b, q)
dot = np.einsum("ixkdaq,iylebq->ixkyldeab", points, points)

pairwise_similarity = dot.reshape(*dot.shape[:5], -1, *dot.shape[-2:])

return pairwise_similarity


def _out_similarity(
data: np.ndarray,
data_indices: np.ndarray
) -> np.ndarray:
points = data[data_indices]

dot = np.einsum("ixdaq, iyebq -> ixydeab", points, points)

out_similarity = dot.reshape(*dot.shape[:3], -1, *dot.shape[-2:])

return out_similarity


def _F2(diffs: np.ndarray) -> np.ndarray:
return np.sum(diffs**2, axis=-1)

Expand Down
23 changes: 23 additions & 0 deletions src/MuyGPyS/_src/gp/tensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@ def _pairwise_differences(points: torch.ndarray) -> torch.ndarray:
raise ValueError(f"points shape {points.shape} is not supported.")


def _crosswise_similarity(
data: torch.ndarray,
nn_data: torch.ndarray,
data_indices: torch.ndarray,
nn_indices: torch.ndarray,
):
raise NotImplementedError("Torch backend not yet supported for similarity tensors")


def _pairwise_similarity(
data: torch.ndarray,
nn_indices: torch.ndarray,
):
raise NotImplementedError("Torch backend not yet supported for similarity tensors")


def _out_similarity(
data: torch.ndarray,
data_indices: torch.ndarray
):
raise NotImplementedError("Torch backend not yet supported for similarity tensors")


def _F2(diffs: torch.ndarray) -> torch.ndarray:
return torch.sum(diffs**2, axis=-1)

Expand Down
Loading