Skip to content

Commit 02fbdb2

Browse files
Included tests for plotting functionality.
1 parent 8941f32 commit 02fbdb2

File tree

1 file changed

+82
-10
lines changed

1 file changed

+82
-10
lines changed

alibi/prototypes/tests/test_protoselect.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
import pytest
2-
import numpy as np
31
from copy import deepcopy
2+
3+
import matplotlib
4+
import numpy as np
5+
import pytest
46
from sklearn.datasets import make_classification
57
from sklearn.model_selection import train_test_split
68

9+
from alibi.api.defaults import (DEFAULT_DATA_PROTOSELECT,
10+
DEFAULT_META_PROTOSELECT)
11+
from alibi.api.interfaces import Explanation
712
from alibi.prototypes import ProtoSelect
13+
from alibi.prototypes.protoselect import (_batch_preprocessing, _imscatterplot,
14+
compute_prototype_importances,
15+
cv_protoselect_euclidean,
16+
visualize_image_prototypes)
817
from alibi.utils.kernel import EuclideanDistance
9-
from alibi.prototypes.protoselect import cv_protoselect_euclidean, _batch_preprocessing, compute_prototype_importances
10-
from alibi.api.defaults import DEFAULT_META_PROTOSELECT, DEFAULT_DATA_PROTOSELECT
11-
from alibi.api.interfaces import Explanation
1218

1319

1420
@pytest.mark.parametrize('n_classes', [2, 3, 5, 10])
@@ -188,16 +194,16 @@ def importance_data():
188194
meta['params'] = {
189195
'kernel_distance': 'EuclideanDistance',
190196
'eps': 0.5,
191-
'lambda_penalty': 0.06666666666666667,
197+
'lambda_penalty': 0.066,
192198
'batch_size': 10000000000,
193199
'verbose': True
194200
}
195201
data['prototypes'] = np.array([
196-
[0.5488135, 0.71518937],
197-
[5.79172504, 4.52889492],
198-
[-3.53852064, 4.78052918]
202+
[0.548, 0.715],
203+
[5.791, 4.528],
204+
[-3.53, 4.780]
199205
])
200-
data['prototype_indices'] = np.array([0, 5, 11], dtype=np.int32),
206+
data['prototype_indices'] = np.array([0, 5, 11], dtype=np.int32),
201207
data['prototype_labels'] = np.array([0, 1, 2], dtype=np.int32)
202208
summary = Explanation(meta=meta, data=data)
203209
return trainset, summary
@@ -214,3 +220,69 @@ def test_compute_prototype_importances(importance_data):
214220
importances = compute_prototype_importances(summary=summary, trainset=trainset)
215221
assert np.allclose(expected_importances, importances['prototype_importances'])
216222

223+
224+
@pytest.fixture(scope='module')
225+
def plot_data():
226+
n_samples = 10
227+
x = np.random.uniform(low=-10, high=10, size=(n_samples, ))
228+
y = np.random.uniform(low=-10, high=10, size=(n_samples, ))
229+
230+
image_size = (5, 5)
231+
images = np.random.rand(n_samples, *image_size, 3)
232+
233+
zoom_lb, zoom_ub = 3, 7
234+
zoom = np.random.permutation(np.linspace(1, 5, n_samples))
235+
return {
236+
'x': x,
237+
'y': y,
238+
'image_size': image_size,
239+
'images': images,
240+
'zoom_lb': zoom_lb,
241+
'zoom_ub': zoom_ub,
242+
'zoom': zoom
243+
}
244+
245+
246+
@pytest.mark.parametrize('use_zoom', [False, True])
247+
def test__imscatterplot(plot_data, use_zoom):
248+
""" Test `_imscatterplot` function. """
249+
ax = _imscatterplot(x=plot_data['x'],
250+
y=plot_data['y'],
251+
images=plot_data['images'],
252+
image_size=plot_data['image_size'],
253+
zoom=plot_data['zoom'] if use_zoom else None,
254+
zoom_lb=plot_data['zoom_lb'],
255+
zoom_ub=plot_data['zoom_ub'],
256+
sort_by_zoom=True)
257+
258+
annboxes = [x for x in ax.get_children() if isinstance(x, matplotlib.offsetbox.AnnotationBbox)]
259+
data = np.array([annbox.offsetbox.get_data() for annbox in annboxes])
260+
zoom = np.array([annbox.offsetbox.get_zoom() for annbox in annboxes])
261+
262+
sorted_idx = np.argsort(plot_data['zoom'])[::-1] if use_zoom else None
263+
expected_data = plot_data['images'][sorted_idx]
264+
265+
if not use_zoom:
266+
expected_zoom = np.ones(len(plot_data['zoom']))
267+
else:
268+
expected_zoom = plot_data['zoom'][sorted_idx]
269+
expected_zoom = (expected_zoom - expected_zoom.min()) / (expected_zoom.max() - expected_zoom.min())
270+
expected_zoom = expected_zoom * (plot_data['zoom_ub'] - plot_data['zoom_lb']) + plot_data['zoom_lb']
271+
272+
assert np.allclose(expected_data, data)
273+
assert np.allclose(expected_zoom, zoom)
274+
275+
276+
def test_visualize_image_prototypes(mocker):
277+
""" Test the `visualize_image_prototypes` function. """
278+
importances = {
279+
'prototype_importances': np.random.randint(2, 50, size=(10, )),
280+
'X_protos': np.random.randn(10, 2),
281+
'X_protos_ft': None
282+
}
283+
284+
m1 = mocker.patch('alibi.prototypes.protoselect.compute_prototype_importances', return_value=importances)
285+
m2 = mocker.patch('alibi.prototypes.protoselect._imscatterplot')
286+
visualize_image_prototypes(summary=None, trainset=None, reducer=lambda x: x)
287+
m1.assert_called_once()
288+
m2.assert_called_once()

0 commit comments

Comments
 (0)