Learned kernel MMD with KeOps backend#602
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
@@ Coverage Diff @@
## master #602 +/- ##
==========================================
- Coverage 83.58% 82.14% -1.44%
==========================================
Files 207 209 +2
Lines 13838 14159 +321
==========================================
+ Hits 11566 11631 +65
- Misses 2272 2528 +256
|
| n_features = [5] | ||
| n_instances = [(100, 100), (100, 75)] | ||
| kernel_a = ['GaussianRBF', 'MyKernel'] | ||
| kernel_b = ['GaussianRBF', 'MyKernel', None] | ||
| eps = [0.5, 'trainable'] | ||
| tests_dk = list(product(n_features, n_instances, kernel_a, kernel_b, eps)) | ||
| n_tests_dk = len(tests_dk) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def deep_kernel_params(request): | ||
| return tests_dk[request.param] | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_keops, reason='Skipping since pykeops is not installed.') | ||
| @pytest.mark.parametrize('deep_kernel_params', list(range(n_tests_dk)), indirect=True) | ||
| def test_deep_kernel(deep_kernel_params): |
There was a problem hiding this comment.
Not a huge fan of using this pattern to parametrizing tests, is there a reason not to parametrize each parameter directly? @ascillitoe
There was a problem hiding this comment.
Yeh the more conventional way to do it would be to parametrize each parameter separately. e.g. see test_save_model in test_saving.py:
@parametrize('model', [encoder_model])
@parametrize('layer', [None, -1])
def test_save_model(data, model, layer, backend, tmp_path):This approach has the advantage of giving much more descriptive test names which is useful when things go wrong.
We keep writing tests with the current pattern for consistency with existing tests. But, unless we are going to refactor existing tests very soon maybe we should prioritise adopting/exploring a new pattern so that we have less refactoring to do later...
There was a problem hiding this comment.
Ok I was not aware here tbh and followed existing patterns.
| if has_keops: | ||
| class MyKernel(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, x: LazyTensor, y: LazyTensor) -> LazyTensor: | ||
| return (- ((x - y) ** 2).sum(-1)).exp() |
There was a problem hiding this comment.
Is typing the only reason why this needs to be inside if has_keops block? If so we could just use forward-references to 'LazyTensor'? @ascillitoe
There was a problem hiding this comment.
Good point. Looks like that might be the case...
There was a problem hiding this comment.
Did the foward-ref not work?
There was a problem hiding this comment.
@jklaise and I were chatting earlier and decided (correct me if I'm wrong) that it's not necessarily a better solution, just a bit different, so I kept it as is.
There was a problem hiding this comment.
OK fair enough!
(I don't have a strong opinion either way)
jklaise
left a comment
There was a problem hiding this comment.
Overall LGTM. Ok with proposed departure for the DeepKernel API for keops assuming DeepKernel is never really meant to be used by user directly?
No, its intended usage is always within a learned detector. I will wait then until @ascillitoe leaves comments before making possibly some final changes and merging. |
|
@arnaudvl the proposed api for the keops |
|
|
||
| tests_lkdrift = ['tensorflow', 'pytorch', 'PyToRcH', 'mxnet'] | ||
| if has_keops: | ||
| class MyKernelKeops(nn.Module): |
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "metadata": { | ||
| "pycharm": { |
There was a problem hiding this comment.
Super nitpicky: It is preferable to strip out this unnecessary metadata...
Following #548 , also extending the learned kernel MMD detector with the KeOps backend to further scale and speed up the detector.
LearnedKernelDriftis required given decreased coverage.kernel_aand optionallykernel_bneeds to be set.utils.keops.kernels.DeepKernel. The main issue here is thatDeepKernel.projis not used within the DeepKernel's forward pass, but explicitly called by the learned kernel drift detector. The reason is that we first need to do regular torch tensor computations using the projection, and then with those projected features compute the kernel matrix. KeOps is only used for the latter, not when computing the projection itself. So technically you could pass a separate projection (i.e.DeepKernel.proj) to the drift detector and apply a weighted sum kernel later on the original and projected data. But this would break the consistency of the API's/input kwargs with the other backends and realistically make the detector harder to understand. As a result, I chose to keep the same DeepKernel format as the PyTorch and TensorFlow backends and deal with this difference directly in the drift detector. This also means there are explicit checks in place (self.has_projandself.has_kernel_b) to check if the DeepKernel format is used and we can do the projection separately.num_workersfor both KeOps and PyTorch backends since it can make a significant difference for the dataloader speed for higher number of instances. Addnum_workersto PyTorch/KeOps backend where relevant #611batch_size_predictkwarg which I would also like to add to the PyTorch backend. The reason is that the optimal batch size for training can be wildly different than that for prediction (where we just care about being as fast as possible within our compute budget). So if we e.g. pickbatch_size=32for training we might want to change this tobatch_size_predict=1000for prediction using the trained kernel. There is also another reason why they can be very different: during training of the detector the whole training batch (all tensors incl. the projection etc) needs to fit on the GPU. But for predictions across all permutations at once we can first compute all the projections separately, and then lazily evaluate both the projections and original instances for all permutations. This means we can likely get away with much higher batch sizes for the projection predictions. Addbatch_size_predictas kwarg to PyTorch backend for learned detectors. #612The smaller potential PyTorch changes (
num_workersandbatch_size_predict) can be done in a quick follow up PR.