Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit da6d472

Browse files
author
Fan
committed
use (m, m) temp space
1 parent 279e1ce commit da6d472

File tree

6 files changed

+59
-62
lines changed

6 files changed

+59
-62
lines changed

3rdparty/tvm

Submodule tvm updated from afd4b3e to 8bd9d4d

src/operator/c_lapack_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
249249
#define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
250250
#define MXNET_LAPACK_dgetrf LAPACKE_dgetrf
251251

252+
// Internally A is factorized as U * L * VT, and (according to the tech report)
253+
// we want to factorize it as UT * L * V, so we pass ut as u and v as vt.
254+
// We also have to allocate at least m - 1 DType elements as workspace as the internal
255+
// LAPACKE function needs it to store `superb`. (see MKL documentation)
252256
#define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \
253257
inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, int m, int n, dtype* ut, \
254258
int ldut, dtype* s, dtype* v, int ldv, \
@@ -382,6 +386,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
382386
MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
383387
MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)
384388

389+
// Note: Supports row-major format only. Internally, column-major is used, so all
390+
// inputs/outputs are flipped and transposed. m and n are flipped as well.
385391
#define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \
386392
inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \
387393
int ldut, dtype* s, dtype* v, int ldv, \

src/operator/linalg.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
195195

196196
// CPU/GPU-versions of LAPACK function "gesvd". Please refer to the
197197
// LAPACK documentation for further details.
198-
// Note:
199-
// - V is input and output parameter (overwritten by A)
198+
// Note: V is input and output parameter (it overwrites A)
200199

201200
template<typename xpu, typename DType>
202201
void linalg_gesvd(const Tensor<xpu, 2, DType>& UT,

src/operator/linalg_impl.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,25 +1262,16 @@ void linalg_gesvd<cpu, DType>(const Tensor<cpu, 2, DType>& UT, \
12621262
const Tensor<cpu, 1, DType>& work, \
12631263
Stream<cpu> *s) { \
12641264
check_gesvd(UT, L, V); \
1265-
DType lwork(0); \
1266-
MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, V.size(0), V.size(1), \
1267-
UT.dptr_, UT.stride_, L.dptr_, V.dptr_, V.stride_, \
1268-
&lwork, -1); \
1265+
int lwork(work.size(0)); \
12691266
int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, V.size(0), V.size(1), \
12701267
UT.dptr_, UT.stride_, L.dptr_, V.dptr_, V.stride_, \
1271-
work.dptr_, static_cast<int>(lwork))); \
1268+
work.dptr_, lwork)); \
12721269
CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
12731270
}
12741271

12751272
LINALG_CPU_GESVD(sgesvd, float)
12761273
LINALG_CPU_GESVD(dgesvd, double)
12771274

1278-
// Mangle temp storage requirements for DType and int into a single
1279-
// request as we can only allocate one temp space per operator. We
1280-
// partition this temp space into two chunks again when calling sseyvd.
1281-
// Returned is the number of elements of type DType that the temp space
1282-
// needs to accomodate. This also makes this function signature equivalent
1283-
// to the work space query on GPU.
12841275
#define LINALG_CPU_GESVD_WORKSPACE_QUERY(func, DType) \
12851276
template<> inline \
12861277
int linalg_gesvd_workspace_query<cpu, DType>(const Tensor<cpu, 2, DType>& UT, \

src/operator/numpy/linalg/np_gesvd-inl.h

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ struct GesvdVecSign {
6161

6262
// (UT, L, V) = gesvd(A) [singular value decomposition]
6363
// - V can overwrite A
64-
// - Needs workspace (both DType and int), size of which is determined by a
65-
// workspace query
64+
// - Needs workspace (DType), size of which is determined by a workspace query
6665
struct gesvd {
6766
template<typename xpu, typename DType>
6867
static void op(const Tensor<xpu, 3, DType>& A,
@@ -126,6 +125,7 @@ MSHADOW_XINLINE double gesvd_back_helper_eps(double* X) {
126125
return 1e-100;
127126
}
128127

128+
// dA overwritten by L^-1 dA
129129
struct GesvdBackHelper_dV {
130130
template<typename DType>
131131
MSHADOW_XINLINE static void Map(int k, int m, int n, DType* L, int ldl,
@@ -144,14 +144,18 @@ struct GesvdBackHelper_dV {
144144
}
145145
};
146146

147+
// X (square) overwritten by X L
148+
// Y overwritten by the diagonal of X
147149
struct GesvdBackHelper_G1 {
148150
template<typename DType>
149151
MSHADOW_XINLINE static void Map(int k, int m, int n, DType* X, int ldx,
150-
DType* L, int ldl) {
152+
DType* L, int ldl, DType* Y, int ldy) {
151153
const int offl(k * ldl);
154+
const int offy(k * ldy);
152155
const int offx(k * m * ldx);
153156
DType numer(0.0);
154157
for (int i = 0; i < m; ++i) {
158+
Y[offy + i] = X[offx + i * ldx + i];
155159
for (int j = 0; j < m; ++j) {
156160
numer = L[offl + j];
157161
X[offx + i * ldx + j] *= numer;
@@ -164,16 +168,15 @@ struct GesvdBackHelper_G2 {
164168
template<typename DType>
165169
MSHADOW_XINLINE static void Map(int k, int m, int n, DType* X, int ldx,
166170
DType* L, int ldl, DType* dL, int lddl,
167-
DType* dA, int ldda, DType* V, int ldv) {
171+
DType* Y, int ldy) {
168172
const int offx(k * m * ldx);
169173
const int offl(k * ldl);
170174
const int offdl(k * lddl);
171-
const int offda(k * m * ldda);
172-
const int offv(k * m * ldv);
175+
const int offy(k * ldy);
173176
const DType eps(gesvd_back_helper_eps(X));
174177
DType denom1(0.0), denom2(0.0), elem(0.0);
175178

176-
for (int i = 0; i < m - 1; ++i) {
179+
for (int i = 0; i < m; ++i) {
177180
for (int j = i + 1; j < m; ++j) {
178181
denom1 = L[offl + i] - L[offl + j];
179182
denom2 = L[offl + i] + L[offl + j];
@@ -183,14 +186,7 @@ struct GesvdBackHelper_G2 {
183186
X[offx + i * ldx + j] = elem * L[offl + j];
184187
X[offx + j * ldx + i] = elem * L[offl + i];
185188
}
186-
}
187-
for (int i = 0; i < m; ++i) {
188-
elem = DType(0.0);
189-
for (int j = 0; j < n; ++j) {
190-
elem += dA[offda + i * ldda + j] * V[offv + i * ldv + j];
191-
}
192-
elem = -elem + dL[offdl + i];
193-
X[offx + i * ldx + i] = elem;
189+
X[offx + i * ldx + i] = -Y[offy + i] + dL[offdl + i];
194190
}
195191
}
196192
};
@@ -204,41 +200,49 @@ struct gesvd_backward {
204200
const Tensor<xpu, 2, DType>& L,
205201
const Tensor<xpu, 3, DType>& V,
206202
const Tensor<xpu, 3, DType>& dA,
207-
const Tensor<xpu, 3, DType>& tempMs,
208-
const Tensor<xpu, 3, DType>& tempMr,
203+
const Tensor<xpu, 3, DType>& tempM,
204+
const Tensor<xpu, 2, DType>& tempMd,
209205
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
210206
// Backward of (UT, L, V) = gesvd(A)
211207
using namespace mxnet_op;
212208
if (dA.dptr_ != dV.dptr_) {
213209
Copy(dA, dV, s);
214210
}
215211
// From here on, we work on dA only
212+
int k = dA.size(0), m = dA.size(1), n = dA.size(2);
216213

217214
// Need temporal space, same shape as dUT
218215
// invdV:
219216
Kernel<GesvdBackHelper_dV, xpu>::Launch
220-
(s, dA.size(0), dA.size(1), dA.size(2), L.dptr_, L.stride_, dA.dptr_, dA.stride_);
217+
(s, k, m, n, L.dptr_, L.stride_, dA.dptr_, dA.stride_);
221218

222219
// G1:
223-
// This copy is just to make sure there are no invalid values (NaN, infinity) in tempM
224-
Copy(tempMs, dUT, s);
225-
Copy(tempMr, dA, s);
226-
gemm::op(dA, V, tempMs, DType(1.0), DType(0.0), false, true, s);
220+
// This is just to make sure there are no invalid values (NaN, infinity) in tempM and tempMd
221+
tempM.FlatTo1D() = 0;
222+
tempMd.FlatTo1D() = 0;
223+
gemm::op(dA, V, tempM, DType(1.0), DType(0.0), false, true, s);
227224
Kernel<GesvdBackHelper_G1, xpu>::Launch
228-
(s, dA.size(0), dA.size(1), dA.size(2), tempMs.dptr_, tempMs.stride_,
229-
L.dptr_, L.stride_);
230-
gemm::op(dUT, UT, tempMs, DType(1.0), DType(1.0), true, false, s);
225+
(s, k, m, n, tempM.dptr_, tempM.stride_,
226+
L.dptr_, L.stride_, tempMd.dptr_, tempMd.stride_);
227+
gemm::op(dUT, UT, tempM, DType(1.0), DType(1.0), true, false, s);
231228

232229
// G2:
233230
Kernel<GesvdBackHelper_G2, xpu>::Launch
234-
(s, dA.size(0), dA.size(1), dA.size(2), tempMs.dptr_, tempMs.stride_,
235-
L.dptr_, L.stride_, dL.dptr_, dL.stride_, dA.dptr_, dA.stride_,
236-
V.dptr_, V.stride_);
231+
(s, k, m, n, tempM.dptr_, tempM.stride_,
232+
L.dptr_, L.stride_, dL.dptr_, dL.stride_,
233+
tempMd.dptr_, tempMd.stride_);
237234

238235
// G3:
239-
gemm::op(tempMs, V, dA, DType(1.0), DType(1.0), false, false, s);
240-
gemm::op(UT, dA, tempMr, DType(1.0), DType(0.0), false, false, s);
241-
Copy(dA, tempMr, s);
236+
gemm::op(tempM, V, dA, DType(1.0), DType(1.0), false, false, s);
237+
for (int i = 0; i < n; i += m) {
238+
int ncols = n - i < m ? n - i : m;
239+
Tensor<xpu, 3, DType> t = Tensor<xpu, 3, DType>(dA.dptr_ + i,
240+
Shape3(k, m, ncols), dA.stride_, dA.stream_);
241+
Tensor<xpu, 3, DType> out = Tensor<xpu, 3, DType>(tempM.dptr_,
242+
Shape3(k, m, ncols), tempM.stride_, tempM.stream_);
243+
gemm::op(UT, t, out, DType(1.0), DType(0.0), false, false, s);
244+
Copy(t, out, s);
245+
}
242246
}
243247
};
244248

@@ -258,23 +262,21 @@ void NumpyLaGesvdBackward(const nnvm::NodeAttrs& attrs,
258262
}
259263
MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
260264
TBlob tspace(outputs[0]);
261-
TBlob tempMs, tempMr;
265+
TBlob tempM, tempMd;
266+
int kmn = outputs[0].shape_.Size();
267+
int kmm = inputs[0].shape_.Size();
268+
int km = inputs[1].shape_.Size();
262269
if (req[0] == kAddTo) {
263270
Tensor<xpu, 1, OType> tempspace = ctx.requested[0]
264-
.get_space_typed<xpu, 1, OType>(Shape1(2 * outputs[0].shape_.Size()), s);
265-
tspace = TBlob(tempspace.Slice(0, outputs[0].shape_.Size()))
266-
.reshape(outputs[0].shape_);
267-
tempMs = TBlob(tempspace.Slice(outputs[0].shape_.Size(),
268-
outputs[0].shape_.Size() + inputs[0].shape_.Size()))
269-
.reshape(inputs[0].shape_);
270-
tempMr = TBlob(tempspace.Slice(outputs[0].shape_.Size(),
271-
2 * outputs[0].shape_.Size()))
272-
.reshape(outputs[0].shape_);
271+
.get_space_typed<xpu, 1, OType>(Shape1(kmn + kmm + km), s);
272+
tspace = TBlob(tempspace.Slice(0, kmn)).reshape(outputs[0].shape_);
273+
tempM = TBlob(tempspace.Slice(kmn, kmn + kmm)).reshape(inputs[0].shape_);
274+
tempMd = TBlob(tempspace.Slice(kmn + kmm, kmn + kmm + km)).reshape(inputs[1].shape_);
273275
} else {
274276
Tensor<xpu, 1, OType> tempspace = ctx.requested[0]
275-
.get_space_typed<xpu, 1, OType>(Shape1(outputs[0].shape_.Size()), s);
276-
tempMs = TBlob(tempspace.Slice(0, inputs[0].shape_.Size())).reshape(inputs[0].shape_);
277-
tempMr = TBlob(tempspace.Slice(0, outputs[0].shape_.Size())).reshape(outputs[0].shape_);
277+
.get_space_typed<xpu, 1, OType>(Shape1(kmm + km), s);
278+
tempM = TBlob(tempspace.Slice(0, kmm)).reshape(inputs[0].shape_);
279+
tempMd = TBlob(tempspace.Slice(kmm, kmm + km)).reshape(inputs[1].shape_);
278280
}
279281
laop::op(inputs[0].FlatToKD<xpu, 3, OType>(s), // dUT
280282
inputs[1].FlatToKD<xpu, 2, OType>(s), // dL
@@ -283,8 +285,8 @@ void NumpyLaGesvdBackward(const nnvm::NodeAttrs& attrs,
283285
inputs[4].FlatToKD<xpu, 2, OType>(s), // L
284286
inputs[5].FlatToKD<xpu, 3, OType>(s), // V
285287
tspace.FlatToKD<xpu, 3, OType>(s), // dA
286-
tempMs.FlatToKD<xpu, 3, OType>(s), // tempMs
287-
tempMr.FlatToKD<xpu, 3, OType>(s), // tempMr
288+
tempM.FlatToKD<xpu, 3, OType>(s), // tempM
289+
tempMd.FlatToKD<xpu, 2, OType>(s), // tempMd
288290
s, attrs);
289291
if (req[0] == kAddTo) {
290292
Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s);

tests/python/unittest/test_numpy_op.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,7 @@ def get_grad(UT, L, V):
878878
for hybridize in [True, False]:
879879
for dtype in dtypes:
880880
for shape in shapes:
881-
rtol = 1e-3
882-
atol = 1e-3
881+
rtol = atol = 0.01
883882
test_svd = TestSVD()
884883
if hybridize:
885884
test_svd.hybridize()

0 commit comments

Comments
 (0)