Skip to content

SOAP Kernel for IAP#246

Merged
bwpriest merged 44 commits intollnl:developfrom
gsallaberry:iap_develop
Sep 26, 2025
Merged

SOAP Kernel for IAP#246
bwpriest merged 44 commits intollnl:developfrom
gsallaberry:iap_develop

Conversation

@gsallaberry
Copy link
Contributor

@gsallaberry gsallaberry commented Jun 17, 2025

Implementation of the SOAP Kernel for MuyGPyS. To have it in working order, we need the following:

  • MetricFn for invoking dot products
  • Implementation of a dot product crosswise and pairwise tensor
  • Implementation of SOAP kernel function
  • Working optimizer with SOAP kernel (?)
  • Working posterior mean with SOAP kernel
  • Docstrings

@bwpriest
Copy link
Member

bwpriest commented Jul 2, 2025

Please apply the black formatter to all of your contributed .py files with the args --line-length 80. If you are using an editor like vscode, you can add a black extension and add --line-length and 80 to the Black-formatter: Args setting. You may also need to specify the path to the black formatter if you are not linking your python environment to vscode.

Copy link
Member

@bwpriest bwpriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some initial comments on the state of the PR.

Comment on lines 11 to 14
"_omega",
"_T1",
"_T2",
"_T3"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually need the API to access these functions? Do you have tests for the correctness of each of these components?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to expose each of these functions, they need to have an implementation in each of the numpy.py, jax.py, torch.py, and mpi.py files or it will cause problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this up. I think I had misunderstood what does and doesn't need to be exposed here.

I'm not sure how to design tests for the individual _* functions outside of shape correctness. At least, tests that don't require implementation of the whole workflow in the reference model.

Comment on lines 11 to 16
def _soap_fn(
diffs,
sensitivity: float
):

return print("Jax backend not yet supported for SOAPKernels")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have this and any other functions in the jax.py, mpi.py, and torch.py files throw a NotImplementedError for now. See this example.

Comment on lines 6 to 12
from MuyGPyS._src.gp.kernels.soap.numpy import (
_omega,
_T1,
_T2,
_T3,
_soap_fn
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Have these functions raise NotImplementedErrors for now.

Comment on lines 72 to 73
Knm = (sensitivity - 1.0) * (omega**(sensitivity - 2.0)) * T2 * T3 + (omega**(sensitivity - 1.0)) * T1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want this to be its own function so that you can import and test it.

Comment on lines 9 to 75
def _omega(
diffs
) -> torch.ndarray:

ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 0

qq_slice = diffs[tuple(slicer)]

return qq_slice


def _T1(
diffs
) -> torch.ndarray:

ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 3

dd_slice = diffs[tuple(slicer)]

return dd_slice


def _T2(
diffs
) -> torch.ndarray:

ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 1

diq_slice = diffs[tuple(slicer)]

return diq_slice


def _T3(
diffs
) -> torch.ndarray:

ndim = diffs.ndim
slicer = [slice(None)] * ndim
slicer[-3] = 2

djq_slice = diffs[tuple(slicer)]

return djq_slice


def _soap_fn(
diffs: torch.ndarray,
sensitivity: float
) -> torch.ndarray:

omega = _omega(diffs)
T1 = _T1(diffs)
T2 = _T2(diffs)
T3 = _T3(diffs)

Knm = (sensitivity - 1.0) * (omega**(sensitivity - 2.0)) * T2 * T3 + (omega**(sensitivity - 1.0)) * T1

Kij = sensitivity * torch.sum(Knm, axis=(-2, -1))

return Kij
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace theses implementations with NotImplementedErrors so that you only have one version while you are working out the bugs. You can reimplement once you are confident that the numpy version works.

@gsallaberry
Copy link
Contributor Author

gsallaberry commented Sep 2, 2025

Need to add working tests for the following:

  • _pairwise_similarity
  • _crosswise_similarity
  • _soap_fn
  • SOAP Kernel posterior mean
  • SOAP Kernel posterior variance

Copy link
Member

@bwpriest bwpriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bwpriest bwpriest marked this pull request as ready for review September 26, 2025 17:24
@bwpriest bwpriest merged commit db23e15 into llnl:develop Sep 26, 2025
25 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants