-
Notifications
You must be signed in to change notification settings - Fork 244
Learned kernel MMD with KeOps backend #602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
072a1d7
first commit adding learned kernel mmd with keops backend
arnaudvl 31e9d04
update method docs learned kernel
arnaudvl f11bf52
update preprocessing and types
arnaudvl 352f666
add test and update output type deep kernel
arnaudvl d006137
update example
arnaudvl 841ea66
test equivalence learned kernel mmd2 keops with torch implementation
arnaudvl a2cb8ac
add deep kernel keops test and add skipif for all keops tests
arnaudvl cdce757
remove print statement
arnaudvl ee8938d
fix flake8
arnaudvl b70c21a
handle optional keops dependency
arnaudvl 869e01f
clarify bandwidth setting
arnaudvl 6926cfb
handle keops optional dependency in test
arnaudvl a4ba39e
update pydantic model schema
arnaudvl baf5f18
add DeepKernel to keops dependency management test
arnaudvl d7c4e4d
add keops to top level learned kernel test
arnaudvl 1f6d5e0
update test learned kernel
arnaudvl d64cb29
clarify test variable and make proj type explicit
arnaudvl 36cb4aa
remove unnecessary metadata
arnaudvl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
130 changes: 130 additions & 0 deletions
130
alibi_detect/cd/keops/tests/test_learned_kernel_keops.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| from itertools import product | ||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| import torch.nn as nn | ||
| from typing import Callable, Optional, Union | ||
| from alibi_detect.utils.frameworks import has_keops | ||
| from alibi_detect.utils.pytorch import GaussianRBF as GaussianRBFTorch | ||
| from alibi_detect.utils.pytorch import mmd2_from_kernel_matrix | ||
| if has_keops: | ||
| from alibi_detect.cd.keops.learned_kernel import LearnedKernelDriftKeops | ||
| from alibi_detect.utils.keops import GaussianRBF | ||
| from pykeops.torch import LazyTensor | ||
|
|
||
| n = 50 # number of instances used for the reference and test data samples in the tests | ||
|
|
||
|
|
||
| if has_keops: | ||
| class MyKernel(nn.Module): | ||
| def __init__(self, n_features: int, proj: bool): | ||
| super().__init__() | ||
| sigma = .1 | ||
| self.kernel = GaussianRBF(trainable=True, sigma=torch.Tensor([sigma])) | ||
| self.has_proj = proj | ||
| if proj: | ||
| self.proj = nn.Linear(n_features, 2) | ||
| self.kernel_b = GaussianRBF(trainable=True, sigma=torch.Tensor([sigma])) | ||
|
|
||
| def forward(self, x_proj: LazyTensor, y_proj: LazyTensor, x: Optional[LazyTensor] = None, | ||
| y: Optional[LazyTensor] = None) -> LazyTensor: | ||
| similarity = self.kernel(x_proj, y_proj) | ||
| if self.has_proj: | ||
| similarity = similarity + self.kernel_b(x, y) | ||
| return similarity | ||
|
|
||
|
|
||
| # test List[Any] inputs to the detector | ||
| def identity_fn(x: Union[torch.Tensor, list]) -> torch.Tensor: | ||
| if isinstance(x, list): | ||
| return torch.from_numpy(np.array(x)) | ||
| else: | ||
| return x | ||
|
|
||
|
|
||
| p_val = [.05] | ||
| n_features = [4] | ||
| preprocess_at_init = [True, False] | ||
| update_x_ref = [None, {'reservoir_sampling': 1000}] | ||
| preprocess_fn = [None, identity_fn] | ||
| n_permutations = [10] | ||
| batch_size_permutations = [5, 1000000] | ||
| train_size = [.5] | ||
| retrain_from_scratch = [True] | ||
| batch_size_predict = [1000000] | ||
| preprocess_batch = [None, identity_fn] | ||
| has_proj = [True, False] | ||
| tests_lkdrift = list(product(p_val, n_features, preprocess_at_init, update_x_ref, preprocess_fn, | ||
| n_permutations, batch_size_permutations, train_size, retrain_from_scratch, | ||
| batch_size_predict, preprocess_batch, has_proj)) | ||
| n_tests = len(tests_lkdrift) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def lkdrift_params(request): | ||
| return tests_lkdrift[request.param] | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_keops, reason='Skipping since pykeops is not installed.') | ||
| @pytest.mark.parametrize('lkdrift_params', list(range(n_tests)), indirect=True) | ||
| def test_lkdrift(lkdrift_params): | ||
| p_val, n_features, preprocess_at_init, update_x_ref, preprocess_fn, \ | ||
| n_permutations, batch_size_permutations, train_size, retrain_from_scratch, \ | ||
| batch_size_predict, preprocess_batch, has_proj = lkdrift_params | ||
|
|
||
| np.random.seed(0) | ||
| torch.manual_seed(0) | ||
|
|
||
| kernel = MyKernel(n_features, has_proj) | ||
| x_ref = np.random.randn(*(n, n_features)).astype(np.float32) | ||
| x_test1 = np.ones_like(x_ref) | ||
| to_list = False | ||
| if preprocess_batch is not None and preprocess_fn is None: | ||
| to_list = True | ||
| x_ref = [_ for _ in x_ref] | ||
| update_x_ref = None | ||
|
|
||
| cd = LearnedKernelDriftKeops( | ||
| x_ref=x_ref, | ||
| kernel=kernel, | ||
| p_val=p_val, | ||
| preprocess_at_init=preprocess_at_init, | ||
| update_x_ref=update_x_ref, | ||
| preprocess_fn=preprocess_fn, | ||
| n_permutations=n_permutations, | ||
| batch_size_permutations=batch_size_permutations, | ||
| train_size=train_size, | ||
| retrain_from_scratch=retrain_from_scratch, | ||
| batch_size_predict=batch_size_predict, | ||
| preprocess_batch_fn=preprocess_batch, | ||
| batch_size=32, | ||
| epochs=1 | ||
| ) | ||
|
|
||
| x_test0 = x_ref.copy() | ||
| preds_0 = cd.predict(x_test0) | ||
| assert cd.n == len(x_test0) + len(x_ref) | ||
| assert preds_0['data']['is_drift'] == 0 | ||
|
|
||
| if to_list: | ||
| x_test1 = [_ for _ in x_test1] | ||
| preds_1 = cd.predict(x_test1) | ||
| assert cd.n == len(x_test1) + len(x_test0) + len(x_ref) | ||
| assert preds_1['data']['is_drift'] == 1 | ||
| assert preds_0['data']['distance'] < preds_1['data']['distance'] | ||
|
|
||
| # ensure the keops MMD^2 estimate matches the pytorch implementation for the same kernel | ||
| if not isinstance(x_ref, list) and update_x_ref is None and not has_proj: | ||
| if isinstance(preprocess_fn, Callable): | ||
| x_ref, x_test1 = cd.preprocess(x_test1) | ||
| n_ref, n_test = x_ref.shape[0], x_test1.shape[0] | ||
| x_all = torch.from_numpy(np.concatenate([x_ref, x_test1], axis=0)).float() | ||
| perms = [torch.randperm(n_ref + n_test) for _ in range(n_permutations)] | ||
| mmd2 = cd._mmd2(x_all, perms, n_ref, n_test)[0] | ||
|
|
||
| if isinstance(preprocess_batch, Callable): | ||
| x_all = preprocess_batch(x_all) | ||
| kernel = GaussianRBFTorch(sigma=cd.kernel.kernel.sigma) | ||
| kernel_mat = kernel(x_all, x_all) | ||
| mmd2_torch = mmd2_from_kernel_matrix(kernel_mat, n_test) | ||
| np.testing.assert_almost_equal(mmd2, mmd2_torch, decimal=6) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,10 @@ | |
| from alibi_detect.cd import LearnedKernelDrift | ||
| from alibi_detect.cd.pytorch.learned_kernel import LearnedKernelDriftTorch | ||
| from alibi_detect.cd.tensorflow.learned_kernel import LearnedKernelDriftTF | ||
| from alibi_detect.utils.frameworks import has_keops | ||
| if has_keops: | ||
| from alibi_detect.cd.keops.learned_kernel import LearnedKernelDriftKeops | ||
| from pykeops.torch import LazyTensor | ||
|
|
||
| n, n_features = 100, 5 | ||
|
|
||
|
|
@@ -37,7 +41,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return torch.einsum('ji,ki->jk', self.dense(x), self.dense(y)) | ||
|
|
||
|
|
||
| tests_lkdrift = ['tensorflow', 'pytorch', 'PyToRcH', 'mxnet'] | ||
| if has_keops: | ||
| class MyKernelKeops(nn.Module): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as this one? |
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, x: LazyTensor, y: LazyTensor) -> LazyTensor: | ||
| return (- ((x - y) ** 2).sum(-1)).exp() | ||
|
|
||
|
|
||
| tests_lkdrift = ['tensorflow', 'pytorch', 'keops', 'PyToRcH', 'mxnet'] | ||
| n_tests = len(tests_lkdrift) | ||
|
|
||
|
|
||
|
|
@@ -53,6 +66,8 @@ def test_lkdrift(lkdrift_params): | |
| kernel = MyKernelTorch(n_features) | ||
| elif backend.lower() == 'tensorflow': | ||
| kernel = MyKernelTF(n_features) | ||
| elif has_keops and backend.lower() == 'keops': | ||
| kernel = MyKernelKeops() | ||
| else: | ||
| kernel = None | ||
| x_ref = np.random.randn(*(n, n_features)) | ||
|
|
@@ -61,10 +76,15 @@ def test_lkdrift(lkdrift_params): | |
| cd = LearnedKernelDrift(x_ref=x_ref, kernel=kernel, backend=backend) | ||
| except NotImplementedError: | ||
| cd = None | ||
| except ImportError: | ||
| assert not has_keops | ||
| cd = None | ||
|
|
||
| if backend.lower() == 'pytorch': | ||
| assert isinstance(cd._detector, LearnedKernelDriftTorch) | ||
| elif backend.lower() == 'tensorflow': | ||
| assert isinstance(cd._detector, LearnedKernelDriftTF) | ||
| elif has_keops and backend.lower() == 'keops': | ||
| assert isinstance(cd._detector, LearnedKernelDriftKeops) | ||
| else: | ||
| assert cd is None | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,12 @@ | ||
| from alibi_detect.utils.missing_optional_dependency import import_optional | ||
|
|
||
|
|
||
| GaussianRBF = import_optional('alibi_detect.utils.keops.kernels', names=['GaussianRBF']) | ||
| GaussianRBF, DeepKernel = import_optional( | ||
| 'alibi_detect.utils.keops.kernels', | ||
| names=['GaussianRBF', 'DeepKernel'] | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "GaussianRBF" | ||
| "GaussianRBF", | ||
| "DeepKernel" | ||
| ] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.