Skip to content

Commit 3e05c5a

Browse files
pytorchbotmalfet
andauthored
[MPS] Properly handle conjugated tensors in bmm (#178010)
[MPS] Properly handle conjugated tensors in bmm (#177522) Both `bmm` and `addmm` lacked proper handling for conjugated inputs for some of its arguments - Add regression tests - Fixes` test_noncontiguous_samples_linalg_svd_complex64` Fixes #177474 Pull Request resolved: #177522 Approved by: https://github.com/Skylion007, https://github.com/kurtamohler (cherry picked from commit bd1afa6) Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
1 parent db741c7 commit 3e05c5a

4 files changed

Lines changed: 35 additions & 29 deletions

File tree

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
888888
std::string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
889889
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
890890
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
891-
MPSGraphTensor* biasTensor = mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
891+
auto biasTensor = mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
892+
auto biasTensor_ = bias_->is_conj() ? [mpsGraph conjugateWithTensor:biasTensor name:nil] : biasTensor;
892893

893894
// TODO: Use alpha and beta here with fill_.Scalar and mul
894895
auto [selfTensor, otherTensor, productTensor] = do_mm(mpsGraph, self, other);
@@ -901,11 +902,11 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
901902
secondaryTensor:alphaTensor
902903
name:@"MM/alpha*(mat1@mat2)"];
903904
}
904-
auto biasTimesBetaTensor = biasTensor;
905+
auto biasTimesBetaTensor = biasTensor_;
905906
if (is_beta_non_zero && beta.toDouble() != 1.0) {
906907
auto betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
907908
dataType:getMPSScalarType((*bias_).scalar_type())];
908-
biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:biasTensor
909+
biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:biasTensor_
909910
secondaryTensor:betaTensor
910911
name:@"MM/beta*input"];
911912
}
@@ -1112,7 +1113,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
11121113
// Call tiled implementation if the number of elements exceeds 2^32
11131114
uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2);
11141115
if (resultSize > pow(2, 32)) {
1115-
result = tiled_bmm_out_mps_impl(batch1, batch2, result);
1116+
// Tiled path uses MPSNDArray directly, so resolve conjugate views upfront
1117+
result = tiled_bmm_out_mps_impl(batch1.resolve_conj(), batch2.resolve_conj(), result);
11161118
return result;
11171119
}
11181120

@@ -1130,16 +1132,18 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
11301132
std::to_string(doTranspose);
11311133

11321134
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
1133-
MPSGraphTensor* batch1Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch1.scalar_type()));
1134-
MPSGraphTensor* batch2Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch2.scalar_type()));
1135-
MPSGraphTensor* batch2TensorTranspose = batch2Tensor;
1135+
auto batch1Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch1.scalar_type()));
1136+
auto batch2Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch2.scalar_type()));
1137+
1138+
auto batch1TensorOp = batch1.is_conj() ? [mpsGraph conjugateWithTensor:batch1Tensor name:nil] : batch1Tensor;
1139+
auto batch2TensorOp = batch2.is_conj() ? [mpsGraph conjugateWithTensor:batch2Tensor name:nil] : batch2Tensor;
11361140

11371141
if (doTranspose) {
1138-
batch2TensorTranspose = [mpsGraph transposeTensor:batch2Tensor dimension:-1 withDimension:-2 name:nil];
1142+
batch2TensorOp = [mpsGraph transposeTensor:batch2TensorOp dimension:-1 withDimension:-2 name:nil];
11391143
}
11401144

1141-
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor
1142-
secondaryTensor:batch2TensorTranspose
1145+
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1TensorOp
1146+
secondaryTensor:batch2TensorOp
11431147
name:@"MM/(batch1@batch2)"];
11441148

11451149
newCachedGraph->batch1Tensor_ = batch1Tensor;

test/test_mps.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,27 @@ def test_bmm(self):
12071207
self.assertEqual(output_cpu, output_mps)
12081208
self.assertEqual(output_cpu.size(), output_mps.size())
12091209

1210+
def test_bmm_conj(self):
1211+
# bmm must respect the conjugate bit on input tensors.
1212+
# See https://github.com/pytorch/pytorch/issues/177474
1213+
a = torch.randn(4, 3, 5, dtype=torch.complex64, device="mps")
1214+
b = torch.randn(4, 5, 2, dtype=torch.complex64, device="mps")
1215+
result_mps = torch.bmm(a, b.conj())
1216+
result_cpu = torch.bmm(a.cpu(), b.cpu().conj())
1217+
self.assertEqual(result_cpu, result_mps)
1218+
result_mps = torch.bmm(a.conj(), b)
1219+
result_cpu = torch.bmm(a.cpu().conj(), b.cpu())
1220+
self.assertEqual(result_cpu, result_mps)
1221+
1222+
def test_addmm_conj(self):
1223+
# Regression test: addmm must respect the conjugate bit on the bias tensor.
1224+
bias = torch.randn(3, 2, dtype=torch.complex64, device="mps")
1225+
a = torch.randn(3, 5, dtype=torch.complex64, device="mps")
1226+
b = torch.randn(5, 2, dtype=torch.complex64, device="mps")
1227+
result_mps = torch.addmm(bias.conj(), a, b)
1228+
result_cpu = torch.addmm(bias.cpu().conj(), a.cpu(), b.cpu())
1229+
self.assertEqual(result_cpu, result_mps)
1230+
12101231
@xfailIf(MACOS_VERSION < 15.0)
12111232
@parametrize("dtype", [torch.float16, torch.bfloat16])
12121233
def test_large_bmm(self, dtype):

torch/testing/_internal/common_methods_invocations.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19251,14 +19251,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1925119251
device_type='mps', dtypes=[torch.float32]),
1925219252
# The operator 'aten::take' is not currently implemented for the MPS device
1925319253
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'),
19254-
# RuntimeError: svd_backward: The singular vectors in the complex
19255-
# case are specified up to multiplication by e^{i phi}. The
19256-
# specified loss function depends on this phase term, making it
19257-
# ill-defined.
19258-
DecorateInfo(
19259-
unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples',
19260-
device_type='mps', dtypes=(torch.complex64,)
19261-
),
1926219254
)),
1926319255
OpInfo('svd_lowrank',
1926419256
op=lambda *args, **kwargs: wrapper_set_seed(

torch/testing/_internal/opinfo/definitions/linalg.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2728,17 +2728,6 @@ def make_input():
27282728
"test_out_warning",
27292729
device_type="mps",
27302730
),
2731-
# MPS: RuntimeError: svd_backward: The singular vectors in the
2732-
# complex case are specified up to multiplication by e^{i phi}. The
2733-
# specified loss function depends on this phase term, making it
2734-
# ill-defined.
2735-
DecorateInfo(
2736-
unittest.expectedFailure,
2737-
"TestCommon",
2738-
"test_noncontiguous_samples",
2739-
device_type="mps",
2740-
dtypes=(torch.complex64,),
2741-
),
27422731
),
27432732
),
27442733
OpInfo(

0 commit comments

Comments
 (0)