Conversation
|
Please apply the black formatter to all of your contributed |
bwpriest
left a comment
There was a problem hiding this comment.
Here are some initial comments on the state of the PR.
| "_omega", | ||
| "_T1", | ||
| "_T2", | ||
| "_T3" |
There was a problem hiding this comment.
Do you actually need the API to access these functions? Do you have tests for the correctness of each of these components?
There was a problem hiding this comment.
If you want to expose each of these functions, they need to have an implementation in each of the numpy.py, jax.py, torch.py, and mpi.py files or it will cause problems.
There was a problem hiding this comment.
Thanks for bringing this up. I think I had misunderstood what does and doesn't need to be exposed here.
I'm not sure how to design tests for the individual _* functions outside of shape correctness. At least, tests that don't require implementation of the whole workflow in the reference model.
| def _soap_fn( | ||
| diffs, | ||
| sensitivity: float | ||
| ): | ||
|
|
||
| return print("Jax backend not yet supported for SOAPKernels") |
There was a problem hiding this comment.
Please have this and any other functions in the jax.py, mpi.py, and torch.py files throw a NotImplementedError for now. See this example.
| from MuyGPyS._src.gp.kernels.soap.numpy import ( | ||
| _omega, | ||
| _T1, | ||
| _T2, | ||
| _T3, | ||
| _soap_fn | ||
| ) |
There was a problem hiding this comment.
Same as above. Have these functions raise NotImplementedErrors for now.
| Knm = (sensitivity - 1.0) * (omega**(sensitivity - 2.0)) * T2 * T3 + (omega**(sensitivity - 1.0)) * T1 | ||
|
|
There was a problem hiding this comment.
You might want this to be its own function so that you can import and test it.
| def _omega( | ||
| diffs | ||
| ) -> torch.ndarray: | ||
|
|
||
| ndim = diffs.ndim | ||
| slicer = [slice(None)] * ndim | ||
| slicer[-3] = 0 | ||
|
|
||
| qq_slice = diffs[tuple(slicer)] | ||
|
|
||
| return qq_slice | ||
|
|
||
|
|
||
| def _T1( | ||
| diffs | ||
| ) -> torch.ndarray: | ||
|
|
||
| ndim = diffs.ndim | ||
| slicer = [slice(None)] * ndim | ||
| slicer[-3] = 3 | ||
|
|
||
| dd_slice = diffs[tuple(slicer)] | ||
|
|
||
| return dd_slice | ||
|
|
||
|
|
||
| def _T2( | ||
| diffs | ||
| ) -> torch.ndarray: | ||
|
|
||
| ndim = diffs.ndim | ||
| slicer = [slice(None)] * ndim | ||
| slicer[-3] = 1 | ||
|
|
||
| diq_slice = diffs[tuple(slicer)] | ||
|
|
||
| return diq_slice | ||
|
|
||
|
|
||
| def _T3( | ||
| diffs | ||
| ) -> torch.ndarray: | ||
|
|
||
| ndim = diffs.ndim | ||
| slicer = [slice(None)] * ndim | ||
| slicer[-3] = 2 | ||
|
|
||
| djq_slice = diffs[tuple(slicer)] | ||
|
|
||
| return djq_slice | ||
|
|
||
|
|
||
| def _soap_fn( | ||
| diffs: torch.ndarray, | ||
| sensitivity: float | ||
| ) -> torch.ndarray: | ||
|
|
||
| omega = _omega(diffs) | ||
| T1 = _T1(diffs) | ||
| T2 = _T2(diffs) | ||
| T3 = _T3(diffs) | ||
|
|
||
| Knm = (sensitivity - 1.0) * (omega**(sensitivity - 2.0)) * T2 * T3 + (omega**(sensitivity - 1.0)) * T1 | ||
|
|
||
| Kij = sensitivity * torch.sum(Knm, axis=(-2, -1)) | ||
|
|
||
| return Kij |
There was a problem hiding this comment.
Please replace theses implementations with NotImplementedErrors so that you only have one version while you are working out the bugs. You can reimplement once you are confident that the numpy version works.
|
Need to add working tests for the following:
|
Implementation of the SOAP Kernel for MuyGPyS. To have it in working order, we need the following: