11import pytest
22import numpy as np
3- from sklearn .datasets import load_iris , load_boston
3+ from sklearn .datasets import load_iris , load_diabetes
44from sklearn .linear_model import LogisticRegression , LinearRegression
55from sklearn .svm import SVR
66from alibi .confidence .model_linearity import linearity_measure , LinearityMeasure
@@ -31,24 +31,22 @@ def test_linear_superposition(input_shape, nb_instances):
3131@pytest .mark .parametrize ('nb_instances' , (1 , 5 ))
3232@pytest .mark .parametrize ('nb_samples' , (2 , 10 ))
3333def test_sample_knn (nb_instances , nb_samples ):
34-
3534 iris = load_iris ()
3635 X_train = iris .data
3736 input_shape = X_train .shape [1 :]
38- x = np .ones ((nb_instances , ) + input_shape )
37+ x = np .ones ((nb_instances ,) + input_shape )
3938
4039 X_samples = _sample_knn (x = x , X_train = X_train , nb_samples = nb_samples )
4140
4241 assert X_samples .shape [0 ] == nb_instances
4342 assert X_samples .shape [1 ] == nb_samples
4443
4544
46- @pytest .mark .parametrize ('nb_instances' , (5 , ))
47- @pytest .mark .parametrize ('nb_samples' , (3 , ))
45+ @pytest .mark .parametrize ('nb_instances' , (5 ,))
46+ @pytest .mark .parametrize ('nb_samples' , (3 ,))
4847@pytest .mark .parametrize ('input_shape' , ((3 ,), (4 , 4 , 1 )))
4948def test_sample_grid (nb_instances , nb_samples , input_shape ):
50-
51- x = np .ones ((nb_instances , ) + input_shape )
49+ x = np .ones ((nb_instances ,) + input_shape )
5250 nb_features = x .reshape (x .shape [0 ], - 1 ).shape [1 ]
5351 feature_range = np .array ([[0 , 1 ] for _ in range (nb_features )])
5452
@@ -64,7 +62,6 @@ def test_sample_grid(nb_instances, nb_samples, input_shape):
6462@pytest .mark .parametrize ('nb_instances' , (1 , 10 ))
6563@pytest .mark .parametrize ('agg' , ('global' , 'pairwise' ))
6664def test_linearity_measure_class (method , epsilon , res , nb_instances , agg ):
67-
6865 iris = load_iris ()
6966 X_train = iris .data
7067 y_train = iris .target
@@ -94,9 +91,8 @@ def predict_fn(x):
9491@pytest .mark .parametrize ('nb_instances' , (1 , 10 ))
9592@pytest .mark .parametrize ('agg' , ('global' , 'pairwise' ))
9693def test_linearity_measure_reg (method , epsilon , res , nb_instances , agg ):
97-
98- boston = load_boston ()
99- X_train , y_train = boston .data , boston .target
94+ diabetes = load_diabetes ()
95+ X_train , y_train = diabetes .data , diabetes .target
10096 x = X_train [0 : nb_instances ].reshape (nb_instances , - 1 )
10197
10298 lg = LinearRegression ()
@@ -155,7 +151,6 @@ def predict_fn_multi(x):
155151@pytest .mark .parametrize ('nb_instances' , (1 , 10 ))
156152@pytest .mark .parametrize ('agg' , ('global' , 'pairwise' ))
157153def test_LinearityMeasure_class (method , epsilon , res , nb_instances , agg ):
158-
159154 iris = load_iris ()
160155 X_train = iris .data
161156 y_train = iris .target
@@ -180,9 +175,8 @@ def predict_fn(x):
180175@pytest .mark .parametrize ('nb_instances' , (1 , 10 ))
181176@pytest .mark .parametrize ('agg' , ('global' , 'pairwise' ))
182177def test_LinearityMeasure_reg (method , epsilon , res , nb_instances , agg ):
183-
184- boston = load_boston ()
185- X_train , y_train = boston .data , boston .target
178+ diabetes = load_diabetes ()
179+ X_train , y_train = diabetes .data , diabetes .target
186180 x = X_train [0 : nb_instances ].reshape (nb_instances , - 1 )
187181
188182 lg = LinearRegression ()
0 commit comments