1- import pytest
2- import numpy as np
31from copy import deepcopy
2+
3+ import matplotlib
4+ import numpy as np
5+ import pytest
46from sklearn .datasets import make_classification
57from sklearn .model_selection import train_test_split
68
9+ from alibi .api .defaults import (DEFAULT_DATA_PROTOSELECT ,
10+ DEFAULT_META_PROTOSELECT )
11+ from alibi .api .interfaces import Explanation
712from alibi .prototypes import ProtoSelect
13+ from alibi .prototypes .protoselect import (_batch_preprocessing , _imscatterplot ,
14+ compute_prototype_importances ,
15+ cv_protoselect_euclidean ,
16+ visualize_image_prototypes )
817from alibi .utils .kernel import EuclideanDistance
9- from alibi .prototypes .protoselect import cv_protoselect_euclidean , _batch_preprocessing , compute_prototype_importances
10- from alibi .api .defaults import DEFAULT_META_PROTOSELECT , DEFAULT_DATA_PROTOSELECT
11- from alibi .api .interfaces import Explanation
1218
1319
1420@pytest .mark .parametrize ('n_classes' , [2 , 3 , 5 , 10 ])
@@ -188,16 +194,16 @@ def importance_data():
188194 meta ['params' ] = {
189195 'kernel_distance' : 'EuclideanDistance' ,
190196 'eps' : 0.5 ,
191- 'lambda_penalty' : 0.06666666666666667 ,
197+ 'lambda_penalty' : 0.066 ,
192198 'batch_size' : 10000000000 ,
193199 'verbose' : True
194200 }
195201 data ['prototypes' ] = np .array ([
196- [0.5488135 , 0.71518937 ],
197- [5.79172504 , 4.52889492 ],
198- [- 3.53852064 , 4.78052918 ]
202+ [0.548 , 0.715 ],
203+ [5.791 , 4.528 ],
204+ [- 3.53 , 4.780 ]
199205 ])
200- data ['prototype_indices' ] = np .array ([0 , 5 , 11 ], dtype = np .int32 ),
206+ data ['prototype_indices' ] = np .array ([0 , 5 , 11 ], dtype = np .int32 ),
201207 data ['prototype_labels' ] = np .array ([0 , 1 , 2 ], dtype = np .int32 )
202208 summary = Explanation (meta = meta , data = data )
203209 return trainset , summary
@@ -214,3 +220,69 @@ def test_compute_prototype_importances(importance_data):
214220 importances = compute_prototype_importances (summary = summary , trainset = trainset )
215221 assert np .allclose (expected_importances , importances ['prototype_importances' ])
216222
223+
224+ @pytest .fixture (scope = 'module' )
225+ def plot_data ():
226+ n_samples = 10
227+ x = np .random .uniform (low = - 10 , high = 10 , size = (n_samples , ))
228+ y = np .random .uniform (low = - 10 , high = 10 , size = (n_samples , ))
229+
230+ image_size = (5 , 5 )
231+ images = np .random .rand (n_samples , * image_size , 3 )
232+
233+ zoom_lb , zoom_ub = 3 , 7
234+ zoom = np .random .permutation (np .linspace (1 , 5 , n_samples ))
235+ return {
236+ 'x' : x ,
237+ 'y' : y ,
238+ 'image_size' : image_size ,
239+ 'images' : images ,
240+ 'zoom_lb' : zoom_lb ,
241+ 'zoom_ub' : zoom_ub ,
242+ 'zoom' : zoom
243+ }
244+
245+
246+ @pytest .mark .parametrize ('use_zoom' , [False , True ])
247+ def test__imscatterplot (plot_data , use_zoom ):
248+ """ Test `_imscatterplot` function. """
249+ ax = _imscatterplot (x = plot_data ['x' ],
250+ y = plot_data ['y' ],
251+ images = plot_data ['images' ],
252+ image_size = plot_data ['image_size' ],
253+ zoom = plot_data ['zoom' ] if use_zoom else None ,
254+ zoom_lb = plot_data ['zoom_lb' ],
255+ zoom_ub = plot_data ['zoom_ub' ],
256+ sort_by_zoom = True )
257+
258+ annboxes = [x for x in ax .get_children () if isinstance (x , matplotlib .offsetbox .AnnotationBbox )]
259+ data = np .array ([annbox .offsetbox .get_data () for annbox in annboxes ])
260+ zoom = np .array ([annbox .offsetbox .get_zoom () for annbox in annboxes ])
261+
262+ sorted_idx = np .argsort (plot_data ['zoom' ])[::- 1 ] if use_zoom else None
263+ expected_data = plot_data ['images' ][sorted_idx ]
264+
265+ if not use_zoom :
266+ expected_zoom = np .ones (len (plot_data ['zoom' ]))
267+ else :
268+ expected_zoom = plot_data ['zoom' ][sorted_idx ]
269+ expected_zoom = (expected_zoom - expected_zoom .min ()) / (expected_zoom .max () - expected_zoom .min ())
270+ expected_zoom = expected_zoom * (plot_data ['zoom_ub' ] - plot_data ['zoom_lb' ]) + plot_data ['zoom_lb' ]
271+
272+ assert np .allclose (expected_data , data )
273+ assert np .allclose (expected_zoom , zoom )
274+
275+
276+ def test_visualize_image_prototypes (mocker ):
277+ """ Test the `visualize_image_prototypes` function. """
278+ importances = {
279+ 'prototype_importances' : np .random .randint (2 , 50 , size = (10 , )),
280+ 'X_protos' : np .random .randn (10 , 2 ),
281+ 'X_protos_ft' : None
282+ }
283+
284+ m1 = mocker .patch ('alibi.prototypes.protoselect.compute_prototype_importances' , return_value = importances )
285+ m2 = mocker .patch ('alibi.prototypes.protoselect._imscatterplot' )
286+ visualize_image_prototypes (summary = None , trainset = None , reducer = lambda x : x )
287+ m1 .assert_called_once ()
288+ m2 .assert_called_once ()
0 commit comments