@@ -888,7 +888,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
888888 std::string key = " addmm_out_mps_impl" + getTensorsStringKey ({self, other, *bias_}) + " :" +
889889 std::to_string (beta.toDouble ()) + " :" + std::to_string (alpha.toDouble ());
890890 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
891- MPSGraphTensor* biasTensor = mpsGraphRankedPlaceHolder (mpsGraph, *bias_);
891+ auto biasTensor = mpsGraphRankedPlaceHolder (mpsGraph, *bias_);
892+ auto biasTensor_ = bias_->is_conj () ? [mpsGraph conjugateWithTensor: biasTensor name: nil ] : biasTensor;
892893
893894 // TODO: Use alpha and beta here with fill_.Scalar and mul
894895 auto [selfTensor, otherTensor, productTensor] = do_mm (mpsGraph, self, other);
@@ -901,11 +902,11 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
901902 secondaryTensor: alphaTensor
902903 name: @" MM/alpha*(mat1@mat2)" ];
903904 }
904- auto biasTimesBetaTensor = biasTensor ;
905+ auto biasTimesBetaTensor = biasTensor_ ;
905906 if (is_beta_non_zero && beta.toDouble () != 1.0 ) {
906907 auto betaTensor = [mpsGraph constantWithScalar: beta.toDouble ()
907908 dataType: getMPSScalarType ((*bias_).scalar_type ())];
908- biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: biasTensor
909+ biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor: biasTensor_
909910 secondaryTensor: betaTensor
910911 name: @" MM/beta*input" ];
911912 }
@@ -1112,7 +1113,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
11121113 // Call tiled implementation if the number of elements exceeds 2^32
11131114 uint64_t resultSize = batch1.size (0 ) * batch1.size (1 ) * batch2.size (2 );
11141115 if (resultSize > pow (2 , 32 )) {
1115- result = tiled_bmm_out_mps_impl (batch1, batch2, result);
1116+ // Tiled path uses MPSNDArray directly, so resolve conjugate views upfront
1117+ result = tiled_bmm_out_mps_impl (batch1.resolve_conj (), batch2.resolve_conj (), result);
11161118 return result;
11171119 }
11181120
@@ -1130,16 +1132,18 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const
11301132 std::to_string (doTranspose);
11311133
11321134 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
1133- MPSGraphTensor* batch1Tensor = mps::mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (batch1.scalar_type ()));
1134- MPSGraphTensor* batch2Tensor = mps::mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (batch2.scalar_type ()));
1135- MPSGraphTensor* batch2TensorTranspose = batch2Tensor;
1135+ auto batch1Tensor = mps::mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (batch1.scalar_type ()));
1136+ auto batch2Tensor = mps::mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (batch2.scalar_type ()));
1137+
1138+ auto batch1TensorOp = batch1.is_conj () ? [mpsGraph conjugateWithTensor: batch1Tensor name: nil ] : batch1Tensor;
1139+ auto batch2TensorOp = batch2.is_conj () ? [mpsGraph conjugateWithTensor: batch2Tensor name: nil ] : batch2Tensor;
11361140
11371141 if (doTranspose) {
1138- batch2TensorTranspose = [mpsGraph transposeTensor: batch2Tensor dimension: -1 withDimension: -2 name: nil ];
1142+ batch2TensorOp = [mpsGraph transposeTensor: batch2TensorOp dimension: -1 withDimension: -2 name: nil ];
11391143 }
11401144
1141- MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor: batch1Tensor
1142- secondaryTensor: batch2TensorTranspose
1145+ MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor: batch1TensorOp
1146+ secondaryTensor: batch2TensorOp
11431147 name: @" MM/(batch1@batch2)" ];
11441148
11451149 newCachedGraph->batch1Tensor_ = batch1Tensor;
0 commit comments