Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions web_ui/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion web_ui/packages/smart-tools/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"type": "module",
"dependencies": {
"@doodle3d/clipper-js": "~1.0.11",
"onnxruntime-web": "~1.16.3",
"onnxruntime-web": "~1.24.3",
"polylabel": "~1.1.0"
},
"devDependencies": {
Expand Down
12 changes: 6 additions & 6 deletions web_ui/packages/smart-tools/src/ritm/interfaces.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
// Copyright (C) 2022-2025 Intel Corporation
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import * as ort from 'onnxruntime-web';
import { InferenceSession, Tensor } from 'onnxruntime-web';

import { Point, RegionOfInterest, ShapeType } from '../shared/interfaces';

export interface MainModelResponse {
instances: ort.Tensor;
instances_aux: ort.Tensor;
feature: ort.Tensor;
instances: Tensor;
instances_aux: Tensor;
feature: Tensor;
}

export interface Models {
preprocess: ort.InferenceSession;
main: ort.InferenceSession;
preprocess: InferenceSession;
main: InferenceSession;
}

export interface RITMPoint {
Expand Down
22 changes: 11 additions & 11 deletions web_ui/packages/smart-tools/src/ritm/ritm.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2022-2025 Intel Corporation
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import * as ort from 'onnxruntime-web';
import { env, InferenceSession, Tensor } from 'onnxruntime-web';

import type { OpenCVTypes } from '../opencv/interfaces';
import { Point, Polygon, RegionOfInterest, Shape, ShapeType } from '../shared/interfaces';
Expand All @@ -21,22 +21,22 @@ class RITM {
constructor(private CV: OpenCVTypes.cv) {}

async load() {
ort.env.wasm.wasmPaths = sessionParams.wasmRoot;
env.wasm.wasmPaths = sessionParams.wasmRoot;

this.models = {
main: await this.loadModel(RITMModels.main),
preprocess: await this.loadModel(RITMModels.preprocess),
};
}

async loadModel(source: string): Promise<ort.InferenceSession> {
async loadModel(source: string): Promise<InferenceSession> {
const data = await (await loadSource(source))?.arrayBuffer();

if (!data) {
throw 'Could not load model';
}

return ort.InferenceSession.create(data);
return InferenceSession.create(data);
}

loadImage(imageData: ImageData) {
Expand Down Expand Up @@ -204,7 +204,7 @@ class RITM {
return resultContour;
}

buildResultMask(mask: ort.Tensor, box: OpenCVTypes.Rect): OpenCVTypes.Mat {
buildResultMask(mask: Tensor, box: OpenCVTypes.Rect): OpenCVTypes.Mat {
let normalMat: OpenCVTypes.Mat | null = null;
try {
this.sigmoid(mask);
Expand Down Expand Up @@ -263,13 +263,13 @@ class RITM {

const shape = [1, 3, templateSize.height, templateSize.width];

return new ort.Tensor('float32', data, shape);
return new Tensor('float32', data, shape);
} finally {
normal?.forEach((m) => m.delete());
}
}

buildImageTensor(box: OpenCVTypes.Rect, templateSize: OpenCVTypes.Size): ort.Tensor {
buildImageTensor(box: OpenCVTypes.Rect, templateSize: OpenCVTypes.Size): Tensor {
if (!this.image) {
throw 'buildImageTensor requires imageData to be loaded';
}
Expand All @@ -285,7 +285,7 @@ class RITM {
const shape = [1, 3, templateSize.height, templateSize.width];
const data = stackPlanes(this.CV, dst);

return new ort.Tensor('float32', data, shape);
return new Tensor('float32', data, shape);
} finally {
dst?.delete();
}
Expand Down Expand Up @@ -323,15 +323,15 @@ class RITM {
}
}

async runPreProcess(pointTensor: ort.Tensor): Promise<ort.Tensor> {
async runPreProcess(pointTensor: Tensor): Promise<Tensor> {
if (!this.models) {
throw 'RITM Model needs to be loaded before running preprocess';
}

return (await this.models.preprocess.run({ points: pointTensor })).coord_features;
}

async runMainModel(points: ort.Tensor, image: ort.Tensor): Promise<MainModelResponse> {
async runMainModel(points: Tensor, image: Tensor): Promise<MainModelResponse> {
if (!this.models) {
throw 'RITM Model needs to be loaded before running HRNet';
}
Expand All @@ -341,7 +341,7 @@ class RITM {
return this.models.main.run(tensors) as unknown as Promise<MainModelResponse>;
}

sigmoid(mask: ort.Tensor) {
sigmoid(mask: Tensor) {
const data = mask.data as Float32Array;

// Apply sigmoid function directly: 1 / (1 + e^(-x))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (C) 2022-2025 Intel Corporation
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import * as ort from 'onnxruntime-web';
import { Tensor } from 'onnxruntime-web';

import type { OpenCVTypes } from '../opencv/interfaces';

interface PreprocessorResult {
tensor: ort.Tensor;
tensor: Tensor;
width: number;
height: number;
newWidth: number;
Expand Down Expand Up @@ -53,7 +53,13 @@ export class OpenCVPreprocessor {
throw new Error('Something went wrong with preprocessing the image.');
}

const tensor = new ort.Tensor('float32', input.data32F, [1, 3, this.config.size, this.config.size]);
// `input.data32F` is a view into WASM memory owned by OpenCV's `input` Mat, which is
// freed in the `finally` block below. `session.run()` uploads the tensor data
// asynchronously (especially on the WebGPU EP), so we must copy the data into a
// JS-owned Float32Array to avoid reading freed memory and hanging/garbage output.
const data = new Float32Array(input.data32F);
const tensor = new Tensor('float32', data, [1, 3, this.config.size, this.config.size]);

return { tensor, width, height, newWidth, newHeight };
} finally {
imageCv.delete();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2022-2025 Intel Corporation
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import * as ort from 'onnxruntime-common';
import { Tensor } from 'onnxruntime-web';

import type { OpenCVTypes } from '../opencv/interfaces';
import { Point, ShapeType } from '../shared/interfaces';
Expand Down Expand Up @@ -67,7 +67,7 @@ export class SegmentAnythingDecoder {
return results;
}

private getIndexOfMaskWithHighestConfidence(iou_predictions: ort.Tensor) {
private getIndexOfMaskWithHighestConfidence(iou_predictions: Tensor) {
let predictionIdx = 0;

for (let p = 0; p < iou_predictions.dims[1]; p++) {
Expand All @@ -86,9 +86,9 @@ export class SegmentAnythingDecoder {
},
{ encoderResult, originalWidth, originalHeight, newWidth, newHeight }: EncodingOutput
): Promise<{
masks: ort.Tensor;
iouPredictions: ort.Tensor;
lowResMasks: ort.Tensor;
masks: Tensor;
iouPredictions: Tensor;
lowResMasks: Tensor;
}> {
const pointCoords: number[] = [];
const pointLabels: number[] = [];
Expand Down Expand Up @@ -118,17 +118,17 @@ export class SegmentAnythingDecoder {
}

const ratio = 1024 / Math.max(originalHeight, originalWidth);
const feeds: Record<string, ort.Tensor> = {
image_embeddings: encoderResult,
const feeds: Record<string, Tensor> = {
image_embeddings: new Tensor(encoderResult.type, encoderResult.data, encoderResult.dims),
// TODO: reuse the low_res_masks output, also use existing polygons?
mask_input: new ort.Tensor(new Float32Array(256 * 256).fill(1), [1, 1, 256, 256]),
has_mask_input: new ort.Tensor(new Float32Array(1).fill(0), [1]),
orig_im_size: new ort.Tensor(
mask_input: new Tensor(new Float32Array(256 * 256).fill(1), [1, 1, 256, 256]),
has_mask_input: new Tensor(new Float32Array(1).fill(0), [1]),
orig_im_size: new Tensor(
new Float32Array([Math.round(originalHeight * ratio), Math.round(originalWidth * ratio)]),
[2]
),
point_coords: new ort.Tensor(new Float32Array(pointCoords), [1, pointCoords.length / 2, 2]),
point_labels: new ort.Tensor(new Float32Array(pointLabels), [1, pointLabels.length]),
point_coords: new Tensor(new Float32Array(pointCoords), [1, pointCoords.length / 2, 2]),
point_labels: new Tensor(new Float32Array(pointLabels), [1, pointLabels.length]),
};

const outputData = await this.session.run(feeds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import type { OpenCVTypes } from '@geti/smart-tools/opencv';
import type * as Comlink from 'comlink';
import * as ort from 'onnxruntime-common';
import { Tensor } from 'onnxruntime-web';

import { OpenCVPreprocessor, OpenCVPreprocessorConfig } from './pre-processing';
import { type Session } from './session';
Expand All @@ -12,8 +12,18 @@ type cv = typeof OpenCVTypes;

type ModelSession = Session | Comlink.Remote<Session>;

// A plain-object representation of ort.Tensor that survives structured-clone
// (Comlink transfers between workers). ort.Tensor instances lose their class
// identity and `location` property when cloned, causing onnxruntime >=1.20 to
// throw "invalid data location: undefined".
export type SerializableTensor = {
data: Float32Array;
dims: number[];
type: Tensor.Type;
};

export type EncodingOutput = {
encoderResult: ort.Tensor;
encoderResult: SerializableTensor;
originalWidth: number;
originalHeight: number;
newWidth: number;
Expand All @@ -38,7 +48,17 @@ export class SegmentAnythingEncoder {
console.timeEnd('[SAM] Encoding');

const outputNames = await this.session.outputNames();
const encoderResult = outputData[outputNames[0]];
const gpuTensor = outputData[outputNames[0]];

// ort.Tensor instances lose their class identity (and `location` getter)
// when structured-cloned by Comlink across workers, causing onnxruntime
// >=1.20 to throw "invalid data location: undefined". Store raw typed
// array data so the decoder can reconstruct a valid tensor.
const encoderResult: SerializableTensor = {
data: (await gpuTensor.getData()) as Float32Array,
dims: [...gpuTensor.dims],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be gpuTensor.dims?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to clone things, but we could maybe make them read-only. In any case i will update on my next PR

type: gpuTensor.type as Tensor.Type,
};

const originalWidth = initialImageData.width;
const originalHeight = initialImageData.height;
Expand Down
28 changes: 21 additions & 7 deletions web_ui/packages/smart-tools/src/segment-anything/session.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Copyright (C) 2022-2025 Intel Corporation
// LIMITED EDGE SOFTWARE DISTRIBUTION LICENSE

import type { InferenceSession } from 'onnxruntime-common';
import * as ort from 'onnxruntime-web';
import { env, InferenceSession } from 'onnxruntime-web';

import { loadSource } from '../utils/tool-utils';
import { SessionParameters, sessionParams } from '../utils/wasm-utils';
Expand All @@ -14,26 +13,34 @@ const loadModel = async (modelPath: string) => {
export class Session {
ortSession: InferenceSession | undefined;
params: SessionParameters;
private runQueue: Promise<void> = Promise.resolve();

constructor() {
this.params = sessionParams;
}

public async init(modelPath: string) {
ort.env.wasm.numThreads = this.params.numThreads;
ort.env.wasm.wasmPaths = this.params.wasmRoot;
ort.env.wasm.simd = true;
env.wasm.numThreads = this.params.numThreads;
env.wasm.wasmPaths = this.params.wasmRoot;
env.wasm.simd = true;
// Suppress expected "some nodes not assigned to WebGPU EP" warnings β€”
// ORT intentionally keeps shape-related ops on CPU for performance.
env.logLevel = 'error';
Comment thread
jpggvilaca marked this conversation as resolved.

const modelData = await loadModel(modelPath);

if (!modelData) {
throw new Error(`Unable to load model from "${modelPath}"`);
}

const session = await ort.InferenceSession.create(modelData, {
const session = await InferenceSession.create(modelData, {
executionProviders: this.params.executionProviders,
graphOptimizationLevel: 'all',
executionMode: 'parallel',
// 0=verbose, 1=info, 2=warning, 3=error, 4=fatal. Silences the
// native "VerifyEachNodeIsAssignedToAnEp" warnings emitted when
// ORT intentionally keeps shape-related ops on the CPU EP.
logSeverityLevel: 3,
});

this.ortSession = session;
Expand All @@ -43,7 +50,14 @@ export class Session {
if (!this.ortSession) {
throw Error('the session is not initialized. Call `init()` method first.');
}
return await this.ortSession.run(input);
// onnxruntime-web does not support concurrent run() calls on the same session.
// Serialize calls through a void queue so the result type stays clean.
const runNext = this.runQueue.then(() => this.ortSession!.run(input));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make sure we have ortSession, i.e. not use !?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100%, agree. Let me update on the next PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why we need this, could u please explain more?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Artifacts from Claude, but the gist is that each run() waits for the previous one before starting, so results always resolve in call order.

In any case i agree this is confusing so i will refactor this as well on my next PR. It will be something like:

const next = this.runQueue.catch(() => {}).then(() => session.run(input));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should cancel ongoing processing if there is a new one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to wait for it. The user clicked so it's intentional. Worry not. The next version will be much simpler :)

this.runQueue = runNext.then(
() => undefined,
() => undefined
);
return runNext;
}

public inputNames(): readonly string[] {
Expand Down
Loading
Loading