@@ -61,8 +61,7 @@ struct GesvdVecSign {
6161
6262// (UT, L, V) = gesvd(A) [singular value decomposition]
6363// - V can overwrite A
64- // - Needs workspace (both DType and int), size of which is determined by a
65- // workspace query
64+ // - Needs workspace (DType), size of which is determined by a workspace query
6665struct gesvd {
6766 template <typename xpu, typename DType>
6867 static void op (const Tensor<xpu, 3 , DType>& A,
@@ -126,6 +125,7 @@ MSHADOW_XINLINE double gesvd_back_helper_eps(double* X) {
126125 return 1e-100 ;
127126}
128127
128+ // dA overwritten by L^-1 dA
129129struct GesvdBackHelper_dV {
130130 template <typename DType>
131131 MSHADOW_XINLINE static void Map (int k, int m, int n, DType* L, int ldl,
@@ -144,14 +144,18 @@ struct GesvdBackHelper_dV {
144144 }
145145};
146146
147+ // X (square) overwritten by X L
148+ // Y overwritten by the diagonal of X
147149struct GesvdBackHelper_G1 {
148150 template <typename DType>
149151 MSHADOW_XINLINE static void Map (int k, int m, int n, DType* X, int ldx,
150- DType* L, int ldl) {
152+ DType* L, int ldl, DType* Y, int ldy ) {
151153 const int offl (k * ldl);
154+ const int offy (k * ldy);
152155 const int offx (k * m * ldx);
153156 DType numer (0.0 );
154157 for (int i = 0 ; i < m; ++i) {
158+ Y[offy + i] = X[offx + i * ldx + i];
155159 for (int j = 0 ; j < m; ++j) {
156160 numer = L[offl + j];
157161 X[offx + i * ldx + j] *= numer;
@@ -164,16 +168,15 @@ struct GesvdBackHelper_G2 {
164168 template <typename DType>
165169 MSHADOW_XINLINE static void Map (int k, int m, int n, DType* X, int ldx,
166170 DType* L, int ldl, DType* dL, int lddl,
167- DType* dA , int ldda, DType* V, int ldv ) {
171+ DType* Y , int ldy ) {
168172 const int offx (k * m * ldx);
169173 const int offl (k * ldl);
170174 const int offdl (k * lddl);
171- const int offda (k * m * ldda);
172- const int offv (k * m * ldv);
175+ const int offy (k * ldy);
173176 const DType eps (gesvd_back_helper_eps (X));
174177 DType denom1 (0.0 ), denom2 (0.0 ), elem (0.0 );
175178
176- for (int i = 0 ; i < m - 1 ; ++i) {
179+ for (int i = 0 ; i < m; ++i) {
177180 for (int j = i + 1 ; j < m; ++j) {
178181 denom1 = L[offl + i] - L[offl + j];
179182 denom2 = L[offl + i] + L[offl + j];
@@ -183,14 +186,7 @@ struct GesvdBackHelper_G2 {
183186 X[offx + i * ldx + j] = elem * L[offl + j];
184187 X[offx + j * ldx + i] = elem * L[offl + i];
185188 }
186- }
187- for (int i = 0 ; i < m; ++i) {
188- elem = DType (0.0 );
189- for (int j = 0 ; j < n; ++j) {
190- elem += dA[offda + i * ldda + j] * V[offv + i * ldv + j];
191- }
192- elem = -elem + dL[offdl + i];
193- X[offx + i * ldx + i] = elem;
189+ X[offx + i * ldx + i] = -Y[offy + i] + dL[offdl + i];
194190 }
195191 }
196192};
@@ -204,41 +200,49 @@ struct gesvd_backward {
204200 const Tensor<xpu, 2 , DType>& L,
205201 const Tensor<xpu, 3 , DType>& V,
206202 const Tensor<xpu, 3 , DType>& dA,
207- const Tensor<xpu, 3 , DType>& tempMs ,
208- const Tensor<xpu, 3 , DType>& tempMr ,
203+ const Tensor<xpu, 3 , DType>& tempM ,
204+ const Tensor<xpu, 2 , DType>& tempMd ,
209205 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
210206 // Backward of (UT, L, V) = gesvd(A)
211207 using namespace mxnet_op ;
212208 if (dA.dptr_ != dV.dptr_ ) {
213209 Copy (dA, dV, s);
214210 }
215211 // From here on, we work on dA only
212+ int k = dA.size (0 ), m = dA.size (1 ), n = dA.size (2 );
216213
217214 // Need temporal space, same shape as dUT
218215 // invdV:
219216 Kernel<GesvdBackHelper_dV, xpu>::Launch
220- (s, dA. size ( 0 ), dA. size ( 1 ), dA. size ( 2 ) , L.dptr_ , L.stride_ , dA.dptr_ , dA.stride_ );
217+ (s, k, m, n , L.dptr_ , L.stride_ , dA.dptr_ , dA.stride_ );
221218
222219 // G1:
223- // This copy is just to make sure there are no invalid values (NaN, infinity) in tempM
224- Copy (tempMs, dUT, s) ;
225- Copy (tempMr, dA, s) ;
226- gemm::op (dA, V, tempMs , DType (1.0 ), DType (0.0 ), false , true , s);
220+ // This is just to make sure there are no invalid values (NaN, infinity) in tempM and tempMd
221+ tempM. FlatTo1D () = 0 ;
222+ tempMd. FlatTo1D () = 0 ;
223+ gemm::op (dA, V, tempM , DType (1.0 ), DType (0.0 ), false , true , s);
227224 Kernel<GesvdBackHelper_G1, xpu>::Launch
228- (s, dA. size ( 0 ), dA. size ( 1 ), dA. size ( 2 ), tempMs .dptr_ , tempMs .stride_ ,
229- L.dptr_ , L.stride_ );
230- gemm::op (dUT, UT, tempMs , DType (1.0 ), DType (1.0 ), true , false , s);
225+ (s, k, m, n, tempM .dptr_ , tempM .stride_ ,
226+ L.dptr_ , L.stride_ , tempMd. dptr_ , tempMd. stride_ );
227+ gemm::op (dUT, UT, tempM , DType (1.0 ), DType (1.0 ), true , false , s);
231228
232229 // G2:
233230 Kernel<GesvdBackHelper_G2, xpu>::Launch
234- (s, dA. size ( 0 ), dA. size ( 1 ), dA. size ( 2 ), tempMs .dptr_ , tempMs .stride_ ,
235- L.dptr_ , L.stride_ , dL.dptr_ , dL.stride_ , dA. dptr_ , dA. stride_ ,
236- V .dptr_ , V .stride_ );
231+ (s, k, m, n, tempM .dptr_ , tempM .stride_ ,
232+ L.dptr_ , L.stride_ , dL.dptr_ , dL.stride_ ,
233+ tempMd .dptr_ , tempMd .stride_ );
237234
238235 // G3:
239- gemm::op (tempMs, V, dA, DType (1.0 ), DType (1.0 ), false , false , s);
240- gemm::op (UT, dA, tempMr, DType (1.0 ), DType (0.0 ), false , false , s);
241- Copy (dA, tempMr, s);
236+ gemm::op (tempM, V, dA, DType (1.0 ), DType (1.0 ), false , false , s);
237+ for (int i = 0 ; i < n; i += m) {
238+ int ncols = n - i < m ? n - i : m;
239+ Tensor<xpu, 3 , DType> t = Tensor<xpu, 3 , DType>(dA.dptr_ + i,
240+ Shape3 (k, m, ncols), dA.stride_ , dA.stream_ );
241+ Tensor<xpu, 3 , DType> out = Tensor<xpu, 3 , DType>(tempM.dptr_ ,
242+ Shape3 (k, m, ncols), tempM.stride_ , tempM.stream_ );
243+ gemm::op (UT, t, out, DType (1.0 ), DType (0.0 ), false , false , s);
244+ Copy (t, out, s);
245+ }
242246 }
243247};
244248
@@ -258,23 +262,21 @@ void NumpyLaGesvdBackward(const nnvm::NodeAttrs& attrs,
258262 }
259263 MSHADOW_SGL_DBL_TYPE_SWITCH (outputs[0 ].type_flag_ , OType, {
260264 TBlob tspace (outputs[0 ]);
261- TBlob tempMs, tempMr;
265+ TBlob tempM, tempMd;
266+ int kmn = outputs[0 ].shape_ .Size ();
267+ int kmm = inputs[0 ].shape_ .Size ();
268+ int km = inputs[1 ].shape_ .Size ();
262269 if (req[0 ] == kAddTo ) {
263270 Tensor<xpu, 1 , OType> tempspace = ctx.requested [0 ]
264- .get_space_typed <xpu, 1 , OType>(Shape1 (2 * outputs[0 ].shape_ .Size ()), s);
265- tspace = TBlob (tempspace.Slice (0 , outputs[0 ].shape_ .Size ()))
266- .reshape (outputs[0 ].shape_ );
267- tempMs = TBlob (tempspace.Slice (outputs[0 ].shape_ .Size (),
268- outputs[0 ].shape_ .Size () + inputs[0 ].shape_ .Size ()))
269- .reshape (inputs[0 ].shape_ );
270- tempMr = TBlob (tempspace.Slice (outputs[0 ].shape_ .Size (),
271- 2 * outputs[0 ].shape_ .Size ()))
272- .reshape (outputs[0 ].shape_ );
271+ .get_space_typed <xpu, 1 , OType>(Shape1 (kmn + kmm + km), s);
272+ tspace = TBlob (tempspace.Slice (0 , kmn)).reshape (outputs[0 ].shape_ );
273+ tempM = TBlob (tempspace.Slice (kmn, kmn + kmm)).reshape (inputs[0 ].shape_ );
274+ tempMd = TBlob (tempspace.Slice (kmn + kmm, kmn + kmm + km)).reshape (inputs[1 ].shape_ );
273275 } else {
274276 Tensor<xpu, 1 , OType> tempspace = ctx.requested [0 ]
275- .get_space_typed <xpu, 1 , OType>(Shape1 (outputs[ 0 ]. shape_ . Size () ), s);
276- tempMs = TBlob (tempspace.Slice (0 , inputs[ 0 ]. shape_ . Size () )).reshape (inputs[0 ].shape_ );
277- tempMr = TBlob (tempspace.Slice (0 , outputs[ 0 ]. shape_ . Size ())) .reshape (outputs[ 0 ].shape_ );
277+ .get_space_typed <xpu, 1 , OType>(Shape1 (kmm + km ), s);
278+ tempM = TBlob (tempspace.Slice (0 , kmm )).reshape (inputs[0 ].shape_ );
279+ tempMd = TBlob (tempspace.Slice (kmm, kmm + km)) .reshape (inputs[ 1 ].shape_ );
278280 }
279281 laop::op (inputs[0 ].FlatToKD <xpu, 3 , OType>(s), // dUT
280282 inputs[1 ].FlatToKD <xpu, 2 , OType>(s), // dL
@@ -283,8 +285,8 @@ void NumpyLaGesvdBackward(const nnvm::NodeAttrs& attrs,
283285 inputs[4 ].FlatToKD <xpu, 2 , OType>(s), // L
284286 inputs[5 ].FlatToKD <xpu, 3 , OType>(s), // V
285287 tspace.FlatToKD <xpu, 3 , OType>(s), // dA
286- tempMs .FlatToKD <xpu, 3 , OType>(s), // tempMs
287- tempMr .FlatToKD <xpu, 3 , OType>(s), // tempMr
288+ tempM .FlatToKD <xpu, 3 , OType>(s), // tempM
289+ tempMd .FlatToKD <xpu, 2 , OType>(s), // tempMd
288290 s, attrs);
289291 if (req[0 ] == kAddTo ) {
290292 Tensor<xpu, 1 , OType> out = outputs[0 ].FlatTo1D <xpu, OType>(s);
0 commit comments