Skip to content

Type conversion issue in generalized_eigenvalue_decomposition() #518

@ldelebec

Description

@ldelebec

🐛 Bug

Function dsp.beamforming.generalized_eigenvalue_decomposition() returns complex eigen values whereas _generalized_eigenvalue_decomposition() returns real eigen values (along with complex eigen vectors) when passing complex tensor as input.

To Reproduce

import torch
from asteroid.dsp.beamforming import (
    generalized_eigenvalue_decomposition,
    compute_scm,
    _generalized_eigenvalue_decomposition,
)

shape = (2, 2, 3, 4)

a = torch.randn(shape, dtype=torch.complex64)
b = torch.randn(shape, dtype=torch.complex64)

scm_a = compute_scm(a).permute(0, 3, 1, 2)  # bmmf -> bfmm
scm_b = compute_scm(b).permute(0, 3, 1, 2)

e_values, e_vectors = generalized_eigenvalue_decomposition(scm_a, scm_b)

print("e_values type: ", e_values.dtype)
print("e_vectors type: ", e_vectors.dtype)

e_values, e_vectors = _generalized_eigenvalue_decomposition(scm_a, scm_b)

print("e_values type: ", e_values.dtype)
print("e_vectors type: ", e_vectors.dtype)

Outputs

e_values type:  torch.complex64
e_vectors type:  torch.complex64
e_values type:  torch.float32
e_vectors type:  torch.complex64

Expected behavior

generalized_eigenvalue_decomposition() should return real eigen values with dtype = torch.float32.

Solutions

  1. Add another type mapping that handles complex to real type conversion:
        return {
            torch.complex32: torch.float16,
            torch.complex64: torch.float32,
            torch.complex128: torch.float64
        }
  1. An other solution is to drop out conversion type for generalized_eigenvalue_decomposition().

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions