diff --git a/package.json b/package.json index 224682fb9..1b4f9754d 100644 --- a/package.json +++ b/package.json @@ -43,7 +43,7 @@ "@huggingface/jinja": "^0.2.2" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "~1.19.0" }, "devDependencies": { "@types/jest": "^29.5.1", diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 0bee3dce7..3d30087e9 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -26,7 +26,6 @@ export let ONNX; export const executionProviders = [ // 'webgpu', - 'wasm' ]; if (typeof process !== 'undefined' && process?.release?.name === 'node') { @@ -34,12 +33,14 @@ if (typeof process !== 'undefined' && process?.release?.name === 'node') { ONNX = ONNX_NODE.default ?? ONNX_NODE; // Add `cpu` execution provider, with higher precedence that `wasm`. - executionProviders.unshift('cpu'); + executionProviders.push('cuda', 'cpu'); } else { // Running in a browser-environment ONNX = ONNX_WEB.default ?? ONNX_WEB; + executionProviders.push('wasm'); + // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x). // As a temporary fix, we disable it for now. // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644 diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 469054cac..a9051c24b 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -56,16 +56,25 @@ export class Tensor { */ constructor(...args) { if (args[0] instanceof ONNXTensor) { + const tensor = args[0]; + // Create shallow copy - Object.assign(this, args[0]); + Object.assign(this, tensor); + // Object.assign() doesn't catch the data prop for some reason + this.data = tensor.data; } else { - // Create new tensor - Object.assign(this, new ONNXTensor( + const tensor = new ONNXTensor( /** @type {DataType} */(args[0]), /** @type {Exclude} */(args[1]), args[2] - )); + ); + + // Create new tensor + Object.assign(this, tensor); + + // Object.assign() doesn't catch the data prop for some reason + this.data = tensor.data; } return new Proxy(this, {