@@ -75,6 +75,19 @@ export class WebNNBackend {
7575 * Current session id.
7676 */
7777 private activeSessionId ?: number ;
78+ /**
79+ * Maps from session id to list of graph inputs.
80+ */
81+ private sessionGraphInputs : Map < number , string [ ] > = new Map ( ) ;
82+ /**
83+ * Temporary graph inputs for the current session.
84+ * These inputs will be registered when the session is created.
85+ */
86+ private temporaryGraphInputs : string [ ] = [ ] ;
87+ /**
88+ * Temporary tensors for the current session.
89+ */
90+ private temporarySessionTensorIds : Map < number , TensorId [ ] > = new Map ( ) ;
7891
7992 constructor ( env : Env ) {
8093 configureLogger ( env . logLevel ! , ! ! env . debug ) ;
@@ -88,9 +101,24 @@ export class WebNNBackend {
88101 }
89102
90103 public onRunStart ( sessionId : number ) : void {
104+ LOG_DEBUG ( 'verbose' , ( ) => `[WebNN] onRunStart {sessionId: ${ sessionId } }` ) ;
91105 this . activeSessionId = sessionId ;
92106 }
93107
108+ public onRunEnd ( sessionId : number ) : void {
109+ LOG_DEBUG ( 'verbose' , ( ) => `[WebNN] onRunEnd {sessionId: ${ sessionId } }` ) ;
110+ const tensorIds = this . temporarySessionTensorIds . get ( sessionId ) ;
111+ if ( ! tensorIds ) {
112+ return ;
113+ }
114+ for ( const tensorId of tensorIds ) {
115+ LOG_DEBUG ( 'verbose' , ( ) => `[WebNN] releasing temporary tensor {tensorId: ${ tensorId } }` ) ;
116+ this . tensorManager . releaseTensorId ( tensorId ) ;
117+ }
118+ this . temporarySessionTensorIds . delete ( sessionId ) ;
119+ this . activeSessionId = undefined ;
120+ }
121+
94122 public async createMLContext ( optionsOrDevice ?: MLContextOptions | GPUDevice ) : Promise < MLContext > {
95123 if ( optionsOrDevice instanceof GPUDevice ) {
96124 const mlContextIndex = this . mlContextCache . findIndex ( ( entry ) => entry . gpuDevice === optionsOrDevice ) ;
@@ -126,14 +154,6 @@ export class WebNNBackend {
126154 }
127155 }
128156
129- public get currentContext ( ) : MLContext {
130- const mlContext = this . getMLContext ( this . currentSessionId ) ;
131- if ( ! mlContext ) {
132- throw new Error ( `No MLContext found for session ${ this . currentSessionId } ` ) ;
133- }
134- return mlContext ;
135- }
136-
137157 public registerMLContext ( sessionId : number , mlContext : MLContext ) : void {
138158 this . mlContextBySessionId . set ( sessionId , mlContext ) ;
139159 let sessionIds = this . sessionIdsByMLContext . get ( mlContext ) ;
@@ -142,9 +162,15 @@ export class WebNNBackend {
142162 this . sessionIdsByMLContext . set ( mlContext , sessionIds ) ;
143163 }
144164 sessionIds . add ( sessionId ) ;
165+
166+ if ( this . temporaryGraphInputs . length > 0 ) {
167+ this . sessionGraphInputs . set ( sessionId , this . temporaryGraphInputs ) ;
168+ this . temporaryGraphInputs = [ ] ;
169+ }
145170 }
146171
147172 public onReleaseSession ( sessionId : number ) : void {
173+ this . sessionGraphInputs . delete ( sessionId ) ;
148174 const mlContext = this . mlContextBySessionId . get ( sessionId ) ! ;
149175 if ( ! mlContext ) {
150176 // Current session is not a WebNN session.
@@ -177,6 +203,7 @@ export class WebNNBackend {
177203 }
178204
179205 public async ensureTensor (
206+ sessionId : number | undefined ,
180207 tensorId : TensorId ,
181208 onnxDataType : DataType ,
182209 dimensions : number [ ] ,
@@ -186,7 +213,34 @@ export class WebNNBackend {
186213 if ( ! webnnDataType ) {
187214 throw new Error ( `Unsupported ONNX data type: ${ onnxDataType } ` ) ;
188215 }
189- return this . tensorManager . ensureTensor ( tensorId , webnnDataType , dimensions , copyOld ) ;
216+ return this . tensorManager . ensureTensor (
217+ sessionId ?? this . currentSessionId ,
218+ tensorId ,
219+ webnnDataType ,
220+ dimensions ,
221+ copyOld ,
222+ ) ;
223+ }
224+
225+ public async createTemporaryTensor (
226+ sessionId : number ,
227+ onnxDataType : DataType ,
228+ shape : readonly number [ ] ,
229+ ) : Promise < TensorId > {
230+ LOG_DEBUG ( 'verbose' , ( ) => `[WebNN] createTemporaryTensor {onnxDataType: ${ onnxDataType } , shape: ${ shape } }` ) ;
231+ const dataType = onnxDataTypeToWebnnDataType . get ( onnxDataType ) ;
232+ if ( ! dataType ) {
233+ throw new Error ( `Unsupported ONNX data type: ${ onnxDataType } ` ) ;
234+ }
235+ const tensorId = this . tensorManager . reserveTensorId ( ) ;
236+ await this . tensorManager . ensureTensor ( sessionId , tensorId , dataType , shape , false ) ;
237+ const tensorIds = this . temporarySessionTensorIds . get ( sessionId ) ;
238+ if ( ! tensorIds ) {
239+ this . temporarySessionTensorIds . set ( sessionId , [ tensorId ] ) ;
240+ } else {
241+ tensorIds . push ( tensorId ) ;
242+ }
243+ return tensorId ;
190244 }
191245
192246 public uploadTensor ( tensorId : TensorId , data : Uint8Array ) : void {
@@ -209,13 +263,13 @@ export class WebNNBackend {
209263 } ;
210264 }
211265
212- public registerMLTensor ( tensor : MLTensor , onnxDataType : DataType , dimensions : number [ ] ) : TensorId {
266+ public registerMLTensor ( sessionId : number , tensor : MLTensor , onnxDataType : DataType , dimensions : number [ ] ) : TensorId {
213267 const webnnDataType = onnxDataTypeToWebnnDataType . get ( onnxDataType ) ;
214268 if ( ! webnnDataType ) {
215269 throw new Error ( `Unsupported ONNX data type: ${ onnxDataType } ` ) ;
216270 }
217271
218- const id = this . tensorManager . registerTensor ( this . currentContext , tensor , webnnDataType , dimensions ) ;
272+ const id = this . tensorManager . registerTensor ( sessionId , tensor , webnnDataType , dimensions ) ;
219273 LOG_DEBUG (
220274 'verbose' ,
221275 ( ) =>
@@ -291,6 +345,18 @@ export class WebNNBackend {
291345 return builder . constant ( desc , bufferView ) ;
292346 }
293347
348+ public registerGraphInput ( inputName : string ) : void {
349+ this . temporaryGraphInputs . push ( inputName ) ;
350+ }
351+
352+ public isGraphInput ( sessionId : number , inputName : string ) : boolean {
353+ const inputNames = this . sessionGraphInputs . get ( sessionId ) ;
354+ if ( ! inputNames ) {
355+ return false ;
356+ }
357+ return inputNames . includes ( inputName ) ;
358+ }
359+
294360 public flush ( ) : void {
295361 // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
296362 }
0 commit comments