11import logging
2- import numpy as np
3- import matplotlib .pyplot as plt
4- from matplotlib .offsetbox import OffsetImage , AnnotationBbox
5-
6- from tqdm import tqdm
72from copy import deepcopy
8- from typing import Callable , Optional , Dict , List , Union , Tuple
3+ from typing import Callable , Dict , List , Optional , Tuple , Union
4+
5+ import matplotlib .pyplot as plt
6+ import numpy as np
7+ from matplotlib .offsetbox import AnnotationBbox , OffsetImage
8+ from skimage .transform import resize
99from sklearn .model_selection import KFold
1010from sklearn .neighbors import KNeighborsClassifier
11- from skimage . transform import resize
11+ from tqdm import tqdm
1212
13+ from alibi .api .defaults import (DEFAULT_DATA_PROTOSELECT ,
14+ DEFAULT_META_PROTOSELECT )
15+ from alibi .api .interfaces import Explanation , FitMixin , Summariser
1316from alibi .utils .distance import batch_compute_kernel_matrix
14- from alibi .api .interfaces import Summariser , Explanation , FitMixin
15- from alibi .api .defaults import DEFAULT_META_PROTOSELECT , DEFAULT_DATA_PROTOSELECT
1617from alibi .utils .kernel import EuclideanDistance
1718
1819logger = logging .getLogger (__name__ )
@@ -226,10 +227,10 @@ def _build_summary(self, protos: Dict[int, List[int]]) -> Explanation:
226227 Helper method to build the summary as an `Explanation` object.
227228 """
228229 data = deepcopy (DEFAULT_DATA_PROTOSELECT )
229- data ['prototypes_indices ' ] = np .concatenate (list (protos .values ())).astype (np .int32 )
230- data ['prototypes_labels ' ] = np .concatenate ([[self .label_inv_mapping [l ]] * len (protos [l ])
231- for l in protos ]).astype (np .int32 ) # noqa: E741
232- data ['prototypes' ] = self .Z [data ['prototypes_indices ' ]]
230+ data ['prototype_indices ' ] = np .concatenate (list (protos .values ())).astype (np .int32 )
231+ data ['prototype_labels ' ] = np .concatenate ([[self .label_inv_mapping [l ]] * len (protos [l ])
232+ for l in protos ]).astype (np .int32 ) # noqa: E741
233+ data ['prototypes' ] = self .Z [data ['prototype_indices ' ]]
233234 return Explanation (meta = self .meta , data = data )
234235
235236
@@ -262,7 +263,7 @@ def _helper_protoselect_euclidean_1knn(summariser: ProtoSelect,
262263 summary = summariser .summarise (num_prototypes = num_prototypes )
263264
264265 # train 1-knn classifier
265- X_protos , y_protos = summary .data ['prototypes' ], summary .data ['prototypes_labels ' ]
266+ X_protos , y_protos = summary .data ['prototypes' ], summary .data ['prototype_labels ' ]
266267 if len (X_protos ) == 0 :
267268 return None
268269
@@ -546,6 +547,79 @@ def _imscatterplot(x: np.ndarray,
546547 return ax
547548
548549
550+ def compute_prototype_importances (summary : 'Explanation' ,
551+ trainset : Tuple [np .ndarray , np .ndarray ],
552+ preprocess_fn : Optional [Callable [[np .ndarray ], np .ndarray ]] = None ,
553+ knn_kw : Optional [dict ] = None ) -> Dict [str , Optional [np .ndarray ]]:
554+
555+ """
556+ Computes the importance of each prototype. The importance of a prototype is the number of assigned
557+ training instances correctly classified according to the 1-KNN classifier
558+ (Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
559+
560+ Parameters
561+ ----------
562+ summary
563+ An `Explanation` object produced by a call to the
564+ :py:meth:`alibi.prototypes.protoselect.ProtoSelect.summarise` method.
565+ trainset
566+ Tuple, `(X_train, y_train)`, consisting of the training data instances with the corresponding labels.
567+ preprocess_fn
568+ Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied.
569+ knn_kw
570+ Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
571+ set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
572+ If the `metric` is not specified, it will be set by default to ``'euclidean'``.
573+ See parameters description:
574+ https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
575+
576+ Returns
577+ -------
578+ A dictionary containing:
579+
580+ - ``'prototype_indices'`` - an array of the prototype indices.
581+
582+ - ``'prototype_importances'`` - an array of prototype importances.
583+
584+ - ``'X_protos'`` - an array of raw prototypes.
585+
586+ - ``'X_protos_ft'`` - an optional array of preprocessed prototypes. If the ``preprocess_fn=None``, \
587+ no preprocessing is applied and ``None`` is returned instead.
588+ """
589+ if knn_kw is None :
590+ knn_kw = {}
591+
592+ if knn_kw .get ('metric' ) is None :
593+ knn_kw .update ({'metric' : 'euclidean' })
594+ logger .warning ("KNN metric was not specified. Automatically setting `metric='euclidean'`." )
595+
596+ X_train , y_train = trainset
597+ X_protos = summary .data ['prototypes' ]
598+ y_protos = summary .data ['prototype_labels' ]
599+
600+ # preprocess the dataset
601+ X_train_ft = _batch_preprocessing (X = X_train , preprocess_fn = preprocess_fn ) \
602+ if (preprocess_fn is not None ) else X_train
603+ X_protos_ft = _batch_preprocessing (X = X_protos , preprocess_fn = preprocess_fn ) \
604+ if (preprocess_fn is not None ) else X_protos
605+
606+ # train knn classifier
607+ knn = KNeighborsClassifier (n_neighbors = 1 , ** knn_kw )
608+ knn = knn .fit (X = X_protos_ft , y = y_protos )
609+
610+ # get neighbors indices for each training instance
611+ neigh_idx = knn .kneighbors (X = X_train_ft , n_neighbors = 1 , return_distance = False ).reshape (- 1 )
612+
613+ # compute how many correct labeled instances each prototype covers
614+ idx , counts = np .unique (neigh_idx [y_protos [neigh_idx ] == y_train ], return_counts = True )
615+ return {
616+ 'prototype_indices' : idx ,
617+ 'prototype_importances' : counts ,
618+ 'X_protos' : X_protos [idx ],
619+ 'X_protos_ft' : None if (preprocess_fn is None ) else X_protos_ft [idx ]
620+ }
621+
622+
549623def visualize_image_prototypes (summary : 'Explanation' ,
550624 trainset : Tuple [np .ndarray , np .ndarray ],
551625 reducer : Callable [[np .ndarray ], np .ndarray ],
@@ -560,7 +634,6 @@ def visualize_image_prototypes(summary: 'Explanation',
560634 Plot the images of the prototypes at the location given by the `reducer` representation.
561635 The size of each prototype is proportional to the logarithm of the number of assigned training instances correctly
562636 classified according to the 1-KNN classifier (Bien and Tibshirani (2012): https://arxiv.org/abs/1202.5933).
563-
564637 Parameters
565638 ----------
566639 summary
@@ -573,7 +646,7 @@ def visualize_image_prototypes(summary: 'Explanation',
573646 input instances if ``preprocess_fn=None``. If the `preprocess_fn` is specified, the reducer will be called
574647 on the feature representation obtained after passing the input instances through the `preprocess_fn`.
575648 preprocess_fn
576- Preprocessor function.
649+ Optional preprocessor function. If ``preprocess_fn=None``, no preprocessing is applied .
577650 knn_kw
578651 Keyword arguments passed to `sklearn.neighbors.KNeighborsClassifier`. The `n_neighbors` will be
579652 set automatically to 1, but the `metric` has to be specified according to the kernel distance used.
@@ -592,37 +665,27 @@ def visualize_image_prototypes(summary: 'Explanation',
592665 zoom_ub
593666 Zoom upper bound. The zoom will be scaled linearly between `[zoom_lb, zoom_ub]`.
594667 """
595- if knn_kw is None :
596- knn_kw = {}
597- if knn_kw .get ('metric' ) is None :
598- knn_kw .update ({'metric' : 'euclidean' })
599- logger .warning ("KNN metric was not specified. Automatically setting `metric='euclidean'`." )
600-
601- X_train , y_train = trainset
602- X_protos = summary .data ['prototypes' ]
603- y_protos = summary .data ['prototypes_labels' ]
604-
605- # preprocess the dataset
606- X_train_ft = _batch_preprocessing (X = X_train , preprocess_fn = preprocess_fn ) \
607- if (preprocess_fn is not None ) else X_train
608- X_protos_ft = _batch_preprocessing (X = X_protos , preprocess_fn = preprocess_fn ) \
609- if (preprocess_fn is not None ) else X_protos
610-
611- # train knn classifier
612- knn = KNeighborsClassifier (n_neighbors = 1 , ** knn_kw )
613- knn = knn .fit (X = X_protos_ft , y = y_protos )
668+ # compute how many correct labeled instances each prototype covers
669+ protos_importance = compute_prototype_importances (summary = summary ,
670+ trainset = trainset ,
671+ preprocess_fn = preprocess_fn ,
672+ knn_kw = knn_kw )
614673
615- # get neighbors indices for each training instance
616- neigh_idx = knn .kneighbors (X = X_train_ft , n_neighbors = 1 , return_distance = False ).reshape (- 1 )
674+ # unpack values
675+ counts = protos_importance ['prototype_importances' ]
676+ X_protos = protos_importance ['X_protos' ]
677+ X_protos_ft = protos_importance ['X_protos_ft' ] if (protos_importance ['X_protos_ft' ] is not None ) else X_protos
617678
618- # compute how many correct labeled instances each prototype covers
619- idx , counts = np .unique (neigh_idx [y_protos [neigh_idx ] == y_train ], return_counts = True )
620- zoom = np .log (counts )
679+ # compute image zoom
680+ zoom = np .log (counts ) # type: ignore[arg-type]
621681
622682 # compute 2D embedding
623- protos_2d = reducer (X_protos_ft [ idx ])
683+ protos_2d = reducer (X_protos_ft ) # type: ignore[arg-type]
624684 x , y = protos_2d [:, 0 ], protos_2d [:, 1 ]
625685
626686 # plot images
627- return _imscatterplot (x = x , y = y , images = X_protos , ax = ax , fig_kw = fig_kw , image_size = image_size ,
687+ return _imscatterplot (x = x , y = y ,
688+ images = X_protos , # type: ignore[arg-type]
689+ ax = ax , fig_kw = fig_kw ,
690+ image_size = image_size ,
628691 zoom = zoom , zoom_lb = zoom_lb , zoom_ub = zoom_ub )
0 commit comments