Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions asteroid/dsp/beamforming.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def stable_solve(b, a):
input_dtype = _common_dtype(b, a)
solve_dtype = input_dtype
if input_dtype not in [torch.float64, torch.complex128]:
solve_dtype = _to_double_map[input_dtype]
solve_dtype = _precision_mapping()[input_dtype]
return _stable_solve(b.to(solve_dtype), a.to(solve_dtype)).to(input_dtype)


Expand All @@ -351,7 +351,7 @@ def stable_cholesky(input, upper=False, out=None, eps=1e-6):
input_dtype = input.dtype
solve_dtype = input_dtype
if input_dtype not in [torch.float64, torch.complex128]:
solve_dtype = _to_double_map[input_dtype]
solve_dtype = _precision_mapping()[input_dtype]
return _stable_cholesky(input.to(solve_dtype), upper=upper, out=out, eps=eps).to(input_dtype)


Expand All @@ -371,7 +371,7 @@ def generalized_eigenvalue_decomposition(a, b):
input_dtype = _common_dtype(a, b)
solve_dtype = input_dtype
if input_dtype not in [torch.float64, torch.complex128]:
solve_dtype = _to_double_map[input_dtype]
solve_dtype = _precision_mapping()[input_dtype]
e_val, e_vec = _generalized_eigenvalue_decomposition(a.to(solve_dtype), b.to(solve_dtype))
return e_val.to(input_dtype), e_vec.to(input_dtype)

Expand Down Expand Up @@ -403,6 +403,36 @@ def _common_dtype(*args):
return all_dtypes[0]


USE_DOUBLE = True


def force_float_linalg():
global USE_DOUBLE
USE_DOUBLE = False


def force_double_linalg():
global USE_DOUBLE
USE_DOUBLE = True


def _precision_mapping():
if USE_DOUBLE:
return {
torch.float16: torch.float64,
torch.float32: torch.float64,
torch.complex32: torch.complex128,
torch.complex64: torch.complex128,
}
else:
return {
torch.float16: torch.float16,
torch.float32: torch.float32,
torch.complex32: torch.complex32,
torch.complex64: torch.complex64,
}


# Legacy
BeamFormer = Beamformer
SdwMwfBeamformer = SDWMWFBeamformer
Expand Down