Skip to content

Commit 9ef99d0

Browse files
egalliashrit-ms
authored andcommitted
[WebNN EP] Automatically move input CPU tensors to ml-tensor (#23073)
### Description If it would improve performance, this patch moves the CPU to ml-tensor before sending the to the ONNXRuntime WebNN EP. ### Motivation and Context We are currently performing 2 extra copies on input tensors located in the CPU when using the WebNN EP (JS -(copy)-> wasm heap -(copy)-> JS -> WebNN API). This patch removes these extra copies.
1 parent f3b1543 commit 9ef99d0

File tree

8 files changed

+196
-43
lines changed

8 files changed

+196
-43
lines changed

js/web/lib/wasm/jsep/backend-webnn.ts

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ export class WebNNBackend {
7575
* Current session id.
7676
*/
7777
private activeSessionId?: number;
78+
/**
79+
* Maps from session id to list of graph inputs.
80+
*/
81+
private sessionGraphInputs: Map<number, string[]> = new Map();
82+
/**
83+
* Temporary graph inputs for the current session.
84+
* These inputs will be registered when the session is created.
85+
*/
86+
private temporaryGraphInputs: string[] = [];
87+
/**
88+
* Temporary tensors for the current session.
89+
*/
90+
private temporarySessionTensorIds: Map<number, TensorId[]> = new Map();
7891

7992
constructor(env: Env) {
8093
configureLogger(env.logLevel!, !!env.debug);
@@ -88,9 +101,24 @@ export class WebNNBackend {
88101
}
89102

90103
public onRunStart(sessionId: number): void {
104+
LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`);
91105
this.activeSessionId = sessionId;
92106
}
93107

108+
public onRunEnd(sessionId: number): void {
109+
LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`);
110+
const tensorIds = this.temporarySessionTensorIds.get(sessionId);
111+
if (!tensorIds) {
112+
return;
113+
}
114+
for (const tensorId of tensorIds) {
115+
LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensorId}}`);
116+
this.tensorManager.releaseTensorId(tensorId);
117+
}
118+
this.temporarySessionTensorIds.delete(sessionId);
119+
this.activeSessionId = undefined;
120+
}
121+
94122
public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
95123
if (optionsOrDevice instanceof GPUDevice) {
96124
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
@@ -126,14 +154,6 @@ export class WebNNBackend {
126154
}
127155
}
128156

129-
public get currentContext(): MLContext {
130-
const mlContext = this.getMLContext(this.currentSessionId);
131-
if (!mlContext) {
132-
throw new Error(`No MLContext found for session ${this.currentSessionId}`);
133-
}
134-
return mlContext;
135-
}
136-
137157
public registerMLContext(sessionId: number, mlContext: MLContext): void {
138158
this.mlContextBySessionId.set(sessionId, mlContext);
139159
let sessionIds = this.sessionIdsByMLContext.get(mlContext);
@@ -142,9 +162,15 @@ export class WebNNBackend {
142162
this.sessionIdsByMLContext.set(mlContext, sessionIds);
143163
}
144164
sessionIds.add(sessionId);
165+
166+
if (this.temporaryGraphInputs.length > 0) {
167+
this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
168+
this.temporaryGraphInputs = [];
169+
}
145170
}
146171

147172
public onReleaseSession(sessionId: number): void {
173+
this.sessionGraphInputs.delete(sessionId);
148174
const mlContext = this.mlContextBySessionId.get(sessionId)!;
149175
if (!mlContext) {
150176
// Current session is not a WebNN session.
@@ -177,6 +203,7 @@ export class WebNNBackend {
177203
}
178204

179205
public async ensureTensor(
206+
sessionId: number | undefined,
180207
tensorId: TensorId,
181208
onnxDataType: DataType,
182209
dimensions: number[],
@@ -186,7 +213,34 @@ export class WebNNBackend {
186213
if (!webnnDataType) {
187214
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
188215
}
189-
return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld);
216+
return this.tensorManager.ensureTensor(
217+
sessionId ?? this.currentSessionId,
218+
tensorId,
219+
webnnDataType,
220+
dimensions,
221+
copyOld,
222+
);
223+
}
224+
225+
public async createTemporaryTensor(
226+
sessionId: number,
227+
onnxDataType: DataType,
228+
shape: readonly number[],
229+
): Promise<TensorId> {
230+
LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`);
231+
const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
232+
if (!dataType) {
233+
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
234+
}
235+
const tensorId = this.tensorManager.reserveTensorId();
236+
await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false);
237+
const tensorIds = this.temporarySessionTensorIds.get(sessionId);
238+
if (!tensorIds) {
239+
this.temporarySessionTensorIds.set(sessionId, [tensorId]);
240+
} else {
241+
tensorIds.push(tensorId);
242+
}
243+
return tensorId;
190244
}
191245

192246
public uploadTensor(tensorId: TensorId, data: Uint8Array): void {
@@ -209,13 +263,13 @@ export class WebNNBackend {
209263
};
210264
}
211265

212-
public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
266+
public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
213267
const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
214268
if (!webnnDataType) {
215269
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
216270
}
217271

218-
const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions);
272+
const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions);
219273
LOG_DEBUG(
220274
'verbose',
221275
() =>
@@ -291,6 +345,18 @@ export class WebNNBackend {
291345
return builder.constant(desc, bufferView);
292346
}
293347

348+
public registerGraphInput(inputName: string): void {
349+
this.temporaryGraphInputs.push(inputName);
350+
}
351+
352+
public isGraphInput(sessionId: number, inputName: string): boolean {
353+
const inputNames = this.sessionGraphInputs.get(sessionId);
354+
if (!inputNames) {
355+
return false;
356+
}
357+
return inputNames.includes(inputName);
358+
}
359+
294360
public flush(): void {
295361
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
296362
}

js/web/lib/wasm/jsep/init.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ export const init = async (
287287
// jsepReleaseTensorId,
288288
(tensorId: number) => backend.releaseTensorId(tensorId),
289289
// jsepEnsureTensor
290-
async (tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
291-
backend.ensureTensor(tensorId, onnxDataType, shape, copyOld),
290+
async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
291+
backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld),
292292
// jsepUploadTensor
293293
(tensorId: number, data: Uint8Array) => {
294294
backend.uploadTensor(tensorId, data);

js/web/lib/wasm/jsep/webnn/tensor-manager.ts

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export interface TensorManager {
2727
* Ensure a MLTensor is created for the TensorId.
2828
*/
2929
ensureTensor(
30+
sessionId: number,
3031
tensorId: TensorId,
3132
dataType: MLOperandDataType,
3233
shape: readonly number[],
@@ -46,9 +47,9 @@ export interface TensorManager {
4647
*/
4748
releaseTensorsForSession(session: number): void;
4849
/**
49-
* Register an externally created MLTensor with a given MLContext and return a TensorId.
50+
* Register an externally created MLTensor with a given session id and return a TensorId.
5051
*/
51-
registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId;
52+
registerTensor(sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId;
5253
}
5354

5455
let tensorGuid = 1;
@@ -177,11 +178,12 @@ class TensorIdTracker {
177178
}
178179

179180
public async ensureTensor(
180-
context: MLContext,
181+
sessionId: number,
181182
dataType: MLOperandDataType,
182183
shape: readonly number[],
183184
copyOld: boolean,
184185
): Promise<MLTensor> {
186+
const context = this.tensorManager.getMLContext(sessionId);
185187
if (this.wrapper) {
186188
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
187189
return this.wrapper.tensor;
@@ -198,7 +200,7 @@ class TensorIdTracker {
198200

199201
// eslint-disable-next-line no-bitwise
200202
const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE;
201-
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true);
203+
this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true);
202204

203205
if (copyOld && this.activeUpload) {
204206
this.wrapper.write(this.activeUpload);
@@ -256,6 +258,14 @@ class TensorManagerImpl implements TensorManager {
256258

257259
constructor(private backend: WebNNBackend) {}
258260

261+
public getMLContext(sessionId: number): MLContext {
262+
const context = this.backend.getMLContext(sessionId);
263+
if (!context) {
264+
throw new Error('MLContext not found for session.');
265+
}
266+
return context;
267+
}
268+
259269
public reserveTensorId(): TensorId {
260270
const tensorId = createNewTensorId();
261271
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this));
@@ -274,6 +284,7 @@ class TensorManagerImpl implements TensorManager {
274284
}
275285

276286
public async ensureTensor(
287+
sessionId: number,
277288
tensorId: TensorId,
278289
dataType: MLOperandDataType,
279290
shape: number[],
@@ -290,7 +301,7 @@ class TensorManagerImpl implements TensorManager {
290301
if (!tensor) {
291302
throw new Error('Tensor not found.');
292303
}
293-
return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld);
304+
return tensor.ensureTensor(sessionId, dataType, shape, copyOld);
294305
}
295306

296307
public upload(tensorId: TensorId, data: Uint8Array): void {
@@ -325,17 +336,18 @@ class TensorManagerImpl implements TensorManager {
325336
}
326337

327338
public registerTensor(
328-
mlContext: MLContext,
339+
sessionId: number,
329340
mlTensor: MLTensor,
330341
dataType: MLOperandDataType,
331342
shape: readonly number[],
332343
): TensorId {
344+
const context = this.getMLContext(sessionId);
333345
const tensorId = createNewTensorId();
334346
// Defaulting to READ | WRITE if usage is not provided.
335347
// eslint-disable-next-line no-bitwise
336348
const wrapper = new TensorWrapper({
337-
sessionId: this.backend.currentSessionId,
338-
context: mlContext,
349+
sessionId,
350+
context,
339351
tensor: mlTensor,
340352
dataType,
341353
shape,
@@ -349,14 +361,14 @@ class TensorManagerImpl implements TensorManager {
349361
* Get or create an MLTensor with the given data type and shape.
350362
*/
351363
public async getCachedTensor(
364+
sessionId: number,
352365
dataType: MLOperandDataType,
353366
shape: readonly number[],
354367
usage: MLTensorUsageFlags | undefined,
355368
writable: boolean,
356369
readable: boolean,
357370
): Promise<TensorWrapper> {
358-
const sessionId = this.backend.currentSessionId;
359-
const context = this.backend.currentContext;
371+
const context = this.getMLContext(sessionId);
360372
for (const [index, tensor] of this.freeTensors.entries()) {
361373
if (tensor.canReuseTensor(context, dataType, shape)) {
362374
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);

0 commit comments

Comments
 (0)