Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 39 additions & 222 deletions notebooks/consistency.ipynb

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions sage/iterated_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def estimate_total(imputer, X, Y, batch_size, loss_fn):
return marginal_loss - mean_loss


def estimate_holdout_importance(imputer, X, Y, batch_size, loss_fn, batches):
def estimate_holdout_importance(imputer, X, Y, batch_size, loss_fn, batches, rng):
'''Estimate the impact of holding out features individually.'''
N, _ = X.shape
num_features = imputer.num_groups
Expand All @@ -38,7 +38,7 @@ def estimate_holdout_importance(imputer, X, Y, batch_size, loss_fn, batches):
# Sample the same batches for all features.
for it in range(batches):
# Sample minibatch.
mb = np.random.choice(N, batch_size)
mb = rng.choice(N, batch_size)
x = X[mb]
y = Y[mb]

Expand Down Expand Up @@ -66,9 +66,12 @@ class IteratedEstimator:
'''
def __init__(self,
imputer,
loss='cross entropy'):
loss='cross entropy',
random_state=None
):
self.imputer = imputer
self.loss_fn = utils.get_loss(loss, reduction='none')
self.rng = np.random.default_rng(seed=random_state)

def __call__(self,
X,
Expand Down Expand Up @@ -139,7 +142,8 @@ def __call__(self,
if verbose:
print('Determining feature ordering...')
holdout_importance = estimate_holdout_importance(
self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches)
self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches, self.rng
)
if verbose:
print('Done')
# Use np.abs in case there are large negative contributors.
Expand Down
13 changes: 9 additions & 4 deletions sage/kernel_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,14 @@ class KernelEstimator:
imputer: model that accommodates held out features.
loss: loss function ('mse', 'cross entropy').
'''
def __init__(self, imputer, loss):
def __init__(self,
imputer,
loss='cross entropy',
random_state=None
):
self.imputer = imputer
self.loss_fn = utils.get_loss(loss, reduction='none')
self.rng = np.random.default_rng(seed=random_state)

def __call__(self,
X,
Expand Down Expand Up @@ -168,16 +173,16 @@ def __call__(self,
# Sample subsets.
for it in range(n_loops):
# Sample data.
mb = np.random.choice(N, batch_size)
mb = self.rng.choice(N, batch_size)
x = X[mb]
y = Y[mb]

# Sample subsets.
S = np.zeros((batch_size, num_features), dtype=bool)
num_included = np.random.choice(num_features - 1, size=batch_size,
num_included = self.rng.choice(num_features - 1, size=batch_size,
p=weights) + 1
for row, num in zip(S, num_included):
inds = np.random.choice(num_features, size=num, replace=False)
inds = self.rng.choice(num_features, size=num, replace=False)
row[inds] = 1

# Calculate loss.
Expand Down
122 changes: 66 additions & 56 deletions sage/permutation_estimator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import joblib
import numpy as np
from sage import utils, core
from tqdm.auto import tqdm
Expand All @@ -13,9 +14,14 @@ class PermutationEstimator:
'''
def __init__(self,
imputer,
loss='cross entropy'):
loss='cross entropy',
n_jobs=1,
random_state=None
):
self.imputer = imputer
self.loss_fn = utils.get_loss(loss, reduction='none')
self.n_jobs = joblib.effective_n_jobs(n_jobs)
self.rng = np.random.default_rng(seed=random_state)

def __call__(self,
X,
Expand Down Expand Up @@ -71,12 +77,8 @@ def __call__(self,
assert min_coalition >= 0
assert max_coalition <= num_features
assert min_coalition < max_coalition
if min_coalition > 0 or max_coalition < num_features:
relaxed = True
explanation_type = 'Relaxed ' + explanation_type
else:
relaxed = False
sample_counts = None

explanation_type = 'Relaxed ' + explanation_type

# Possibly force convergence detection.
if n_permutations is None:
Expand All @@ -98,62 +100,29 @@ def __call__(self,
bar = tqdm(total=n_loops * batch_size * num_features)

# Setup.
arange = np.arange(batch_size)
scores = np.zeros((batch_size, num_features))
S = np.zeros((batch_size, num_features), dtype=bool)
permutations = np.tile(np.arange(num_features), (batch_size, 1))
tracker = utils.ImportanceTracker()

# Permutation sampling.
for it in range(n_loops):
# Sample data.
mb = np.random.choice(N, batch_size)
x = X[mb]
y = Y[mb]

# Sample permutations.
S[:] = 0
for i in range(batch_size):
np.random.shuffle(permutations[i])

# Calculate sample counts.
if relaxed:
scores[:] = 0
sample_counts = np.zeros(num_features, dtype=int)
for i in range(batch_size):
sample_counts[permutations[i, min_coalition:max_coalition]] = (
sample_counts[permutations[i, min_coalition:max_coalition]] + 1)

# Add necessary features to minimum coalition.
for i in range(min_coalition):
# Add next feature.
inds = permutations[:, i]
S[arange, inds] = 1

# Make prediction with minimum coalition.
y_hat = self.imputer(x, S)
prev_loss = self.loss_fn(y_hat, y)
# Performed iterations counter.
it = 0

# Add all remaining features.
for i in range(min_coalition, max_coalition):
# Add next feature.
inds = permutations[:, i]
S[arange, inds] = 1
while it < n_loops:
# Make sure we don't perform more iterations than n_loops.
num_batches = min(self.n_jobs, n_loops - it)

# Make prediction with missing features.
y_hat = self.imputer(x, S)
loss = self.loss_fn(y_hat, y)
batches = []
for _ in range(num_batches):
idxs = self.rng.choice(N, batch_size)
batches.append((X[idxs], Y[idxs]))

# Calculate delta sample.
scores[arange, inds] = prev_loss - loss
prev_loss = loss
results = joblib.Parallel(n_jobs=self.n_jobs)(
joblib.delayed(self._process_sample)(x, y, batch_size, num_features, min_coalition, max_coalition)
for x, y in batches
)

# Update bar (if not detecting convergence).
if bar and (not detect_convergence):
bar.update(batch_size)
it += self.n_jobs

# Update tracker.
tracker.update(scores, sample_counts)
for scores, sample_counts in results:
tracker.update(scores, sample_counts)

# Calculate progress.
std = np.max(tracker.std)
Expand Down Expand Up @@ -190,3 +159,44 @@ def __call__(self,
bar.close()

return core.Explanation(tracker.values, tracker.std, explanation_type)

def _process_sample(self, x, y, batch_size, num_features, min_coalition, max_coalition):
arange = np.arange(batch_size)
scores = np.zeros((batch_size, num_features))
S = np.zeros((batch_size, num_features), dtype=bool)
permutations = np.tile(np.arange(num_features), (batch_size, 1))

# Sample permutations.
for i in range(batch_size):
self.rng.shuffle(permutations[i])

# Calculate sample counts.
sample_counts = np.zeros(num_features, dtype=int)
for i in range(batch_size):
sample_counts[permutations[i, min_coalition:max_coalition]] += 1

# Add necessary features to minimum coalition.
for i in range(min_coalition):
# Add next feature.
inds = permutations[:, i]
S[arange, inds] = 1

# Make prediction with minimum coalition.
y_hat = self.imputer(x, S)
prev_loss = self.loss_fn(y_hat, y)

# Add all remaining features.
for i in range(min_coalition, max_coalition):
# Add next feature.
inds = permutations[:, i]
S[arange, inds] = 1

# Make prediction with missing features.
y_hat = self.imputer(x, S)
loss = self.loss_fn(y_hat, y)

# Calculate delta sample.
scores[arange, inds] = prev_loss - loss
prev_loss = loss

return scores, sample_counts
14 changes: 9 additions & 5 deletions sage/sign_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def estimate_total(imputer, X, Y, batch_size, loss_fn):
return marginal_loss - mean_loss


def estimate_holdout_importance(imputer, X, Y, batch_size, loss_fn, batches):
def estimate_holdout_importance(imputer, X, Y, batch_size, loss_fn, batches, rng):
'''Estimate the impact of holding out features individually.'''
N, _ = X.shape
num_features = imputer.num_groups
Expand All @@ -39,7 +39,7 @@ def estimate_holdout_importance(imputer, X, Y, batch_size, loss_fn, batches):
# Sample the same batches for all features.
for it in range(batches):
# Sample minibatch.
mb = np.random.choice(N, batch_size)
mb = rng.choice(N, batch_size)
x = X[mb]
y = Y[mb]

Expand Down Expand Up @@ -68,9 +68,12 @@ class SignEstimator:
'''
def __init__(self,
imputer,
loss='cross entropy'):
loss='cross entropy',
random_state=None
):
self.imputer = imputer
self.loss_fn = utils.get_loss(loss, reduction='none')
self.rng = np.random.default_rng(seed=random_state)

def __call__(self,
X,
Expand Down Expand Up @@ -133,7 +136,8 @@ def __call__(self,
if verbose:
print('Determining feature ordering...')
holdout_importance = estimate_holdout_importance(
self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches)
self.imputer, X, Y, batch_size, self.loss_fn, ordering_batches, self.rng
)
if verbose:
print('Done')
# Use np.abs in case there are large negative contributors.
Expand All @@ -153,7 +157,7 @@ def __call__(self,
converged = False
while not converged:
# Sample data.
mb = np.random.choice(N, batch_size)
mb = self.rng.choice(N, batch_size)
x = X[mb]
y = Y[mb]

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
'numpy',
'scipy',
'matplotlib',
'tqdm'
'tqdm',
'joblib'
],
classifiers=[
"Programming Language :: Python :: 3",
Expand Down