Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ repos:
rev: 20.8b1
hooks:
- id: black
args:
- --config=pyproject.toml
types: [python]
123 changes: 72 additions & 51 deletions asteroid/dsp/beamforming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import nn
import warnings
from functools import wraps


class SCM(nn.Module):
Expand Down Expand Up @@ -65,7 +66,6 @@ def from_atf_vect(
noise_scm_t = noise_scm.permute(0, 3, 1, 2) # -> bfmm
atf_vec_t = atf_vec.transpose(-1, -2).unsqueeze(-1) # -> bfm1

# numerator, _ = torch.solve(atf_vec_t, noise_scm_t) # -> bfm1
numerator = stable_solve(atf_vec_t, noise_scm_t) # -> bfm1

denominator = torch.matmul(atf_vec_t.conj().transpose(-1, -2), numerator) # -> bf11
Expand Down Expand Up @@ -99,7 +99,7 @@ def forward(
target_scm_t = target_scm.permute(0, 3, 1, 2) # -> bfmm

denominator = target_scm_t + self.mu * noise_scm_t
bf_vect, _ = torch.solve(target_scm_t, denominator)
bf_vect = stable_solve(target_scm_t, denominator)
bf_vect = bf_vect[..., ref_mic].transpose(-1, -2) # -> bfm1 -> bmf
output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft
return output
Expand Down Expand Up @@ -162,38 +162,83 @@ def compute_scm(x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = Tr
return scm


def stable_solve(b, a):
"""Return torch.solve in matrix `a` is non-singular, else regularize `a` and return torch.solve."""
try:
return torch.solve(b, a)[0]
except RuntimeError:
a = condition_scm(a, 1e-6)
return torch.solve(b, a)[0]


def condition_scm(x, gamma=1e-6, dim1=-2, dim2=-1):
"""Condition input SCM with (x + gamma tr(x) I) / (1 + gamma) along `dim1` and `dim2`.
def condition_scm(x, eps=1e-6, dim1=-2, dim2=-1):
"""Condition input SCM with (x + eps tr(x) I) / (1 + eps) along `dim1` and `dim2`.

See https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3).
"""
# Assume 4d with ...mm
if dim1 != -2 or dim2 != -1:
raise NotImplementedError
scale = gamma * batch_trace(x, dim1=dim1, dim2=dim2)[..., None, None] / x.shape[dim1]
scale = eps * batch_trace(x, dim1=dim1, dim2=dim2)[..., None, None] / x.shape[dim1]
scaled_eye = torch.eye(x.shape[dim1])[None, None] * scale
return (x + scaled_eye) / (1 + gamma)
return (x + scaled_eye) / (1 + eps)


def batch_trace(x, dim1=-2, dim2=-1):
"""Compute the trace along `dim1` and `dim2` for a any matrix `ndim>=2`."""
return torch.diagonal(x, dim1=dim1, dim2=dim2).sum(-1)


def stable_solve(b, a):
"""Return torch.solve if `a` is non-singular, else regularize `a` and return torch.solve."""
# Only run it in double
input_dtype = _common_dtype(b, a)
solve_dtype = input_dtype
if input_dtype not in [torch.float64, torch.complex128]:
solve_dtype = _to_double_map[input_dtype]
return _stable_solve(b.to(solve_dtype), a.to(solve_dtype)).to(input_dtype)


def _stable_solve(b, a, eps=1e-6):
try:
return torch.solve(b, a)[0]
except RuntimeError:
a = condition_scm(a, eps)
return torch.solve(b, a)[0]


def stable_cholesky(input, upper=False, out=None, eps=1e-6):
"""Compute the Cholesky decomposition of ``input``.
If ``input`` is only p.s.d, add a small jitter to the diagonal.

Args:
input (Tensor): The tensor to compute the Cholesky decomposition of
upper (bool, optional): See torch.cholesky
out (Tensor, optional): See torch.cholesky
eps (int): small jitter added to the diagonal if PD.
"""
# Only run it in double
input_dtype = input.dtype
solve_dtype = input_dtype
if input_dtype not in [torch.float64, torch.complex128]:
solve_dtype = _to_double_map[input_dtype]
return _stable_cholesky(input.to(solve_dtype), upper=upper, out=out, eps=eps).to(input_dtype)


def _stable_cholesky(input, upper=False, out=None, eps=1e-6):
try:
return torch.cholesky(input, upper=upper, out=out)
except RuntimeError:
input = condition_scm(input, eps)
return torch.cholesky(input, upper=upper, out=out)


def generalized_eigenvalue_decomposition(a, b):
"""Solves the generalized eigenvalue decomposition through Cholesky decomposition.
Returns eigen values and eigen vectors (ascending order).
"""
cholesky = stable_cholesky(b, max_tries=2)
# Only run it in double
input_dtype = _common_dtype(a, b)
solve_dtype = input_dtype
if input_dtype not in [torch.float64, torch.complex128]:
solve_dtype = _to_double_map[input_dtype]
e_val, e_vec = _generalized_eigenvalue_decomposition(a.to(solve_dtype), b.to(solve_dtype))
return e_val.to(input_dtype), e_vec.to(input_dtype)


def _generalized_eigenvalue_decomposition(a, b):
cholesky = stable_cholesky(b)
inv_cholesky = torch.inverse(cholesky)
# Compute C matrix L⁻1 A L^-T
cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2)
Expand All @@ -204,40 +249,16 @@ def generalized_eigenvalue_decomposition(a, b):
return e_val, e_vec


def stable_cholesky(input, upper=False, out=None, jitter=1e-6, max_tries=2, verbose=False):
"""Compute the Cholesky decomposition of A.
If A is only p.s.d, add a small jitter to the diagonal.
_to_double_map = {
torch.float16: torch.float64,
torch.float32: torch.float64,
torch.complex32: torch.complex128,
torch.complex64: torch.complex128,
}

Args:
input (Tensor): The tensor to compute the Cholesky decomposition of
upper (bool, optional): See torch.cholesky
out (Tensor, optional): See torch.cholesky
jitter (float): The jitter to add to the diagonal of A in case A is only p.s.d.
max_tries (int, optional): Number of attempts (with increasing jitter) before raising an error.
verbose (bool): Whether to raise a warning if the jitter had to be added.

Adapted from GPytorch https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/utils/cholesky.py#L12
"""
try:
return torch.cholesky(input, upper=upper, out=out)
except RuntimeError as e:
clone = input.clone()
jitter_prev = 0
for i in range(max_tries):
jitter_new = jitter * (10 ** i)
clone.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
jitter_prev = jitter_new
try:
out = torch.cholesky(clone, upper=upper, out=out)
if verbose is True:
warnings.warn(
f"Had to add a jitter of {jitter_new:.1e} to compute the cholesky decomposition.",
RuntimeWarning,
)
return out
except RuntimeError:
continue
raise RuntimeError(
f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. "
f"Original error on first attempt: {e}"
)
def _common_dtype(*args):
all_dtypes = [a.dtype for a in args]
if len(set(all_dtypes)) > 1:
raise RuntimeError(f"Expected inputs from the same dtype. Received {all_dtypes}.")
return all_dtypes[0]
8 changes: 3 additions & 5 deletions tests/dsp/beamforming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def test_mwf(n_mics, mu):


def test_stable_cholesky():
stable_cholesky(torch.zeros(2, 2))
with pytest.warns(RuntimeWarning):
stable_cholesky(torch.zeros(2, 2), verbose=True)
with pytest.raises(RuntimeError):
stable_cholesky(torch.zeros(2, 2), jitter=0.0)
a = torch.randn(3, 3)
a = torch.mm(a, a.t()) # make symmetric positive-definite
stable_cholesky(a)