-
Notifications
You must be signed in to change notification settings - Fork 446
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
🐛 Bug
Function dsp.beamforming.generalized_eigenvalue_decomposition() returns complex eigen values whereas _generalized_eigenvalue_decomposition() returns real eigen values (along with complex eigen vectors) when passing complex tensor as input.
To Reproduce
import torch
from asteroid.dsp.beamforming import (
generalized_eigenvalue_decomposition,
compute_scm,
_generalized_eigenvalue_decomposition,
)
shape = (2, 2, 3, 4)
a = torch.randn(shape, dtype=torch.complex64)
b = torch.randn(shape, dtype=torch.complex64)
scm_a = compute_scm(a).permute(0, 3, 1, 2) # bmmf -> bfmm
scm_b = compute_scm(b).permute(0, 3, 1, 2)
e_values, e_vectors = generalized_eigenvalue_decomposition(scm_a, scm_b)
print("e_values type: ", e_values.dtype)
print("e_vectors type: ", e_vectors.dtype)
e_values, e_vectors = _generalized_eigenvalue_decomposition(scm_a, scm_b)
print("e_values type: ", e_values.dtype)
print("e_vectors type: ", e_vectors.dtype)
Outputs
e_values type: torch.complex64
e_vectors type: torch.complex64
e_values type: torch.float32
e_vectors type: torch.complex64
Expected behavior
generalized_eigenvalue_decomposition() should return real eigen values with dtype = torch.float32.
Solutions
- Add another type mapping that handles complex to real type conversion:
return {
torch.complex32: torch.float16,
torch.complex64: torch.float32,
torch.complex128: torch.float64
}
- An other solution is to drop out conversion type for
generalized_eigenvalue_decomposition().
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed