Non-consuming ownership fork
Website | API Reference | Compatibility Table | Copilot Instructions | jax-js Discord
Fork notice: This is a fork of ekzhang/jax-js with a non-consuming ownership model. Operations leave inputs alive (no
.refneeded), andusingdeclarations provide deterministic GPU/WASM memory cleanup.Why this fork? The original jax-js uses move semantics, where operations consume their inputs. This fork was created for teams familiar with MATLAB or Python (NumPy) where move semantics are unexpected. We also fast-tracked
lax.scanandlax.associativeScanimplementations.usingdeclarations handle the common case β block-scoped arrays are disposed automatically β but patterns like method chains, loop-carried state, and nested results still need manual care. A built-incheckLeaksdiagnostic and ESLint plugin help catch whatusingmisses. See Tradeoffs for an honest comparison.See Differences from upstream for a full comparison between the original and this fork.
π€ The fork code & documentation commits are AI-generated with gentle human supervision.
jax-js is a machine learning framework for the browser. It aims to bring JAX-style, high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
npm install github:hamk-uas/jax-js-nonconsuming
To pin a specific release tag (once available):
npm install github:hamk-uas/jax-js-nonconsuming#v0.3.0
pnpm users: pnpm requires explicit permission to run build scripts from Git dependencies. Add
this to your package.json:
{
"pnpm": {
"onlyBuiltDependencies": ["@hamk-uas/jax-js-nonconsuming"]
}
}
Without this, pnpm install fails with ERR_PNPM_GIT_DEP_PREPARE_NOT_ALLOWED. The prepare script
runs tsdown to build the package and its sub-packages (optax, loaders, onnx, eslint-plugin).
Under the hood, it translates array operations into a compiler representation, then synthesizes kernels in WebAssembly and WebGPU.
The original jax-js was written from scratch with zero zero external dependencies. jax-js and this fork maintain close API compatibility with NumPy/JAX. Since everything runs client-side, jax-js is one of the most portable GPU ML frameworks, since it runs anywhere a browser can run.
import { numpy as np } from "@hamk-uas/jax-js-nonconsuming";
// Array operations, compatible with JAX/NumPy.
{
using x = np.array([1, 2, 3]);
using y = x.mul(4); // [4, 8, 12]
}
In vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest way to get started on a blank HTML page.
<script type="module">
import { numpy as np } from "https://esm.sh/@jax-js/jax";
</script>
This table refers to latest versions of each browser. WebGPU has gained wide support in browsers as of late 2025.
| Platform | CPU (Wasm) | GPU (WebGPU) | GPU (WebGL) |
|---|---|---|---|
| Chrome / Edge | β | β | β |
| Firefox | β | β - macOS 26+ | β |
| Safari | β | β - macOS 26+ | β |
| iOS | β | β - iOS 26+ | β |
| Chrome for Android | β | β | β |
| Firefox for Android | β | β | β |
| Node.js | β | β | β |
| Deno | β | β - async | β |
Community usage:
Demos:
Here's a quick, high-level comparison with other popular web ML runtimes. Performance labels are workload- and hardware-dependent.
| Feature | jax-js-nonconsuming | jax-js v0.1.9 | TensorFlow.js | onnxruntime-web |
|---|---|---|---|---|
| Overview | ||||
| API style | JAX/NumPy | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
| Speed | Very fast | Very fast | Fast | Very fast |
| Bundle size (gzip) | 184 KB | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
| Autodiff & JIT | ||||
| Gradients | β | β | β | β |
| Jacobian and Hessian | β | β | β | β |
jvp() forward differentiation |
β | β | β | β |
jit() kernel fusion |
β | β | β | β |
vmap() auto-vectorization |
β | β | β | β |
scan() scan over leading axis |
β | β | β | β |
associativeScan() parallel prefix scan |
β | β | β | β |
| Graph capture | β | β | β | β |
| Backends & Data | ||||
| WebGPU backend | β | β | π‘ Preview | β |
| WebGL backend | β | β | β | β |
| Wasm (CPU) backend | β | β | β | β |
| Eager array API | β | β | β | β |
| Run ONNX models | π‘ Partial | π‘ Partial | β | β |
| Read safetensors | β | β | β | β |
| Float64 | β | β | β | β |
| Float32 | β | β | β | β |
| Float16 | β | β | β | β |
| BFloat16 | β | β | β | β |
| Packed Uint8 | β | β | β | π‘ Partial |
| Mixed precision | β | β | β | β |
| Mixed devices | β | β | β | β |
| Ops & Numerics | ||||
| Arithmetic functions | β | β | β | β |
| Matrix multiplication | β | β | β | β |
| General einsum | β | β | π‘ Partial | π‘ Partial |
| Sorting | β | β | β | β |
| Activation functions | β | β | β | β |
| NaN/Inf numerics | β | β | β | β |
| Basic convolutions | β | β | β | β |
| n-d convolutions | β | β | β | β |
| Strided/dilated convolution | β | β | β | β |
| Cholesky, Lstsq | β | β | β | β |
| LU, Solve, Determinant | β | β | β | β |
| SVD | β | β | β | β |
| FFT | β | β | β | β |
| Basic RNG (Uniform, Normal) | β | β | β | β |
| Advanced RNG | β | β | β | β |
Programming in jax-js looks very similar to JAX,
just in JavaScript.
Create an array with np.array():
import { numpy as np } from "@hamk-uas/jax-js-nonconsuming";
using ar = np.array([1, 2, 3]);
By default, this is a float32 array, but you can specify a different dtype:
using ar = np.array([1, 2, 3], { dtype: np.int32 });
For more efficient construction, create an array from a JS TypedArray buffer:
const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
using ar = np.array(buf).reshape([2, 3]);
Once you're done with it, you can unwrap a jax.Array back into JavaScript. This will also apply
any pending operations or lazy updates:
// 1) Returns a possibly nested JavaScript array.
ar.js();
await ar.jsAsync(); // Faster, non-blocking
// 2) Returns a flat TypedArray data buffer.
ar.dataSync();
await ar.data(); // Fastest, non-blocking
// If you don't need the array after reading data:
const floats = await ar.consumeData();
Arrays can have mathematical operations applied to them. For example:
import { numpy as np, scipySpecial as special } from "@hamk-uas/jax-js-nonconsuming";
using x = np.arange(100).astype(np.float32); // array of integers [0..99]
using y1 = x.add(x); // x + x
using y2 = np.sin(x); // sin(x)
using tanhX = np.tanh(x);
using y3 = tanhX.mul(5); // 5 * tanh(x)
using y4 = special.erfc(x); // erfc(x)
Big Arrays take up a lot of memory. Python ML libraries override the __del__() method to free
memory, but JavaScript has no such API for running object destructors
(cf.).
In jax-js-nonconsuming, operations do not consume their inputs β you can freely reuse an array
in multiple expressions without any special syntax. When you're done with an array, call
.dispose() to free its memory, or use JavaScript's using keyword for automatic disposal:
{
using x = np.array([1, 2, 3]);
using doubled = x.add(x);
const y = doubled.mul(x); // x used three times β no problem
y.dispose();
// x and doubled are automatically disposed at end of block
}
For best performance, wrap compute-heavy code in jit(). The JIT compiler automatically manages
intermediate buffers β allocating, reusing, and freeing them at the optimal points:
const f = jit((x: np.Array) => {
using sq = x.mul(x);
using s = sq.sum();
return np.sqrt(s);
});
const result = f(x); // intermediates freed automatically inside jit
result.dispose(); // caller disposes the output when done
f.dispose(); // free captured constants when the function is no longer needed
JIT functions support using for automatic cleanup β ideal for short-lived compiled functions:
{
using f = jit((x: np.Array) => x.mul(x).sum());
const result = f(input);
console.log(await result.data());
result.dispose();
// f's captured constants freed automatically at block end
}
JIT cache hierarchy. Disposing a jit function frees its captured GPU/WASM buffer constants
(the expensive part), but lightweight metadata caches survive for reuse:
| Cache | Freed by f.dispose()? |
Freed by clearCaches()? |
|---|---|---|
| Captured constants (buffers) | Yes | Yes |
| JIT compilation cache | No (lightweight) | Yes |
| GPU shader pipelines | No (cheap handles) | No (device lifetime) |
For long-running applications (optimization loops, servers), call clearCaches() periodically to
reclaim all JIT metadata. This is rarely needed β the metadata caches are small β but it prevents
unbounded growth when creating many distinct JIT closures:
import { clearCaches } from "@hamk-uas/jax-js-nonconsuming";
for (const batch of batches) {
using f = jit((x) => model(params, x));
const loss = f(batch);
// ... update params ...
loss.dispose();
}
clearCaches(); // flush all JIT metadata after the loop
The @hamk-uas/eslint-plugin-jax-js catches the most common memory leaks (missing using,
use-after-dispose, unnecessary .ref) at edit time β see the
plugin README for setup.
Ownership invariance principle: write code that is ownership-correct in both eager mode and
jit() mode. jit() is a performance optimization (fusion, recycling), not a semantics change. If
code leaks or relies on different ownership behavior in eager mode, treat it as a real bug. For CI
enforcement in user code, use jaxJs.configs.invariance from @hamk-uas/eslint-plugin-jax-js.
JAX's signature composable transformations are also supported in jax-js. Here is a simple example of
using grad and vmap to compute the derivative of a function:
import { numpy as np, grad, vmap } from "@hamk-uas/jax-js-nonconsuming";
using x = np.linspace(-10, 10, 1000);
using y1 = vmap(grad(np.sin))(x); // d/dx sin(x) = cos(x)
using y2 = np.cos(x);
np.allclose(y1, y2); // => true
The jit function is especially useful when doing long sequences of primitives on GPU, since it
fuses operations together into a single kernel dispatch. This
improves memory bandwidth usage on hardware
accelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:
export const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {
using x1sq = np.square(x1);
using x2sq = np.square(x2);
using sum = x1sq.add(x2sq);
return np.sqrt(sum);
});
Without JIT, the hypot() function would require four kernel dispatches: two multiplies, one add,
and one sqrt. JIT fuses these together into a single kernel that does it all at once.
Nested jit is transparent. If a jitted function calls another jitted function, the inner jit
is structurally inlined β the compiler flattens nested jit boundaries and fuses everything into a
single compiled program. This means you can safely jit a model's predict method for fast
standalone inference, then also use it inside a larger jit(grad(...)) training step without any
overhead or conflict:
const model = {
predict: jit((params, x) => {
// ... model forward pass ...
}),
};
// Inference: calls the jitted predict directly β compiled kernel
const logits = model.predict(params, x);
// Training: outer jit flattens the inner jit β single fused program
const trainStep = jit((params, x, y) => {
const [loss, grads] = valueAndGrad(lossFn)(params, x, y);
return [loss, updateParams(params, grads)];
});
This also works through scan, grad, vmap, and all other transformations β inner jit
boundaries are always dissolved during compilation.
All functional transformations can take typed JsTree of inputs and outputs. These are similar to
JAX's pytrees, and it's basically just a structure of
nested JavaScript objects and arrays. For instance:
import { grad, numpy as np, tree } from "@hamk-uas/jax-js-nonconsuming";
type Params = {
foo: np.Array;
bar: np.Array[];
};
function getSums(p: Params) {
using fooSum = p.foo.sum();
using bar0 = p.bar[0].sum();
using bar1 = p.bar[1].sum();
using barSum = bar0.add(bar1);
return fooSum.add(barSum);
}
using g = tree.makeDisposable(
grad(getSums)({
foo: np.array([1, 2, 3]),
bar: [np.array([10]), np.array([11, 12])],
}),
);
// => { foo: [1, 1, 1], bar: [[1], [1, 1]] }
Note that you need to use type alias syntax rather than interface to define fine-grained
JsTree types.
Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory and determines how to execute compiled operations on them.
There are currently 4 devices in jax-js:
cpu: Slow, interpreted JS, only meant for debugging.wasm: WebAssembly, with optional multi-threaded parallel dispatch
via WasmWorkerPool (requires SharedArrayBuffer).webgpu: WebGPU, available on
supported browsers (Chrome, Firefox, Safari, iOS).webgl: WebGL2, via
fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much
slower than WebGPU. It's offered on a best-effort basis and not as well-supported. The webgl
device has not been tested during development of jax-js-nonconsuming.We recommend webgpu for best performance running neural networks and wasm for narrow
sequential computations. The default device is wasm, but you can change this at startup time:
import { defaultDevice, init } from "@hamk-uas/jax-js-nonconsuming";
const devices = await init(); // Starts all available backends.
if (devices.includes("webgpu")) {
defaultDevice("webgpu");
} else {
console.warn("WebGPU is not supported, falling back to Wasm.");
}
You can also place individual arrays on specific devices:
import { devicePut, numpy as np } from "@hamk-uas/jax-js-nonconsuming";
using ar = np.array([1, 2, 3]); // Starts with device="wasm"
await devicePut(ar, "webgpu"); // Now device="webgpu"
jax-js includes three helper libraries, available as sub-path exports from the main package:
import { adam } from "@hamk-uas/jax-js-nonconsuming/optax";
import { cachedFetch, safetensors } from "@hamk-uas/jax-js-nonconsuming/loaders";
import { ONNXModel } from "@hamk-uas/jax-js-nonconsuming/onnx";
No extra install needed β they're included when you install @hamk-uas/jax-js-nonconsuming. The
loaders and onnx sub-packages have optional native dependencies (@bufbuild/protobuf,
sentencepiece-buf, onnx-buf) that are listed in the main package's optionalDependencies and
installed automatically if available for your platform.
To see per-kernel traces in browser development tools, call jax.profiler.startTrace().
The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for individual
browsers. Also, this library uniquely has the jit() feature that fuses operations together and
records an execution graph. jax-js achieves over 7000 GFLOP/s for matrix multiplication on an
Apple M4 Max chip (try it).
In that specific benchmark run, it was faster than both TensorFlow.js and ONNX Runtime Web, which both use handwritten libraries of custom kernels. Results vary by model, operator mix, and hardware.
It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as well as unique optimizations such as FlashAttention variants.
That's all for this short tutorial. Please see the generated API reference for detailed documentation.
The following technical details are for contributing to jax-js-nonconsuming and modifying its internals.
This repository is managed by pnpm. You can compile and build all packages in
watch mode with:
pnpm install
pnpm run build:watch
The pnpm install command automatically sets up Git hooks via
Husky. This repository intentionally shifts verification left:
we run the necessary quality and release-safety checks in local pre-commit rather than relying on
heavy CI gates for GitHub Pages deployment confidence.
Run checks iteratively while developing:
You can also run linting and formatting manually:
pnpm lint # Run ESLint (includes @hamk-uas/eslint-plugin-jax-js ownership rules)
pnpm format # Format all files with Prettier
pnpm format:check # Check formatting without writing
pnpm check # Run TypeScript type checking
pnpm run test:eslint-plugin # Rule-level tests for in-repo ESLint plugin
pnpm run lint:ownership:internal # Maintainer transform ownership checks
Then you can run tests in a headless browser using Vitest.
pnpm exec playwright install
pnpm test
pnpm run test:policy:strict # Strict mode: no expected-failure debt
pnpm run test:arch # Architectural mode: failures gated by manifest
pnpm run test:website:smoke # Website build + smoke checks
Tests run in headless Chromium via Playwright with full WebGPU support. This requires a working Vulkan GPU driver on the host. Verify with:
vulkaninfo --summary # Should show your GPU device
Key requirements:
vitest.config.ts use --use-angle=vulkan to
route WebGPU through the host GPU. Without a working Vulkan driver, WebGPU tests will be skipped.pnpm exec playwright install downloads a bundled Chromium that
includes the Dawn WebGPU implementation.isSecureContext === true. Vitest's dev server serves on
localhost, which qualifies automatically. Direct about:blank navigation would not.DISPLAY is set (X11) or a Wayland session is active. The
config passes DISPLAY and XAUTHORITY from the environment.--no-sandbox flag avoids permission issues in containers or
non-root environments.The full set of Chrome flags used (configured in vitest.config.ts):
--no-sandbox --headless=new --use-angle=vulkan --enable-features=Vulkan
--disable-vulkan-surface --enable-unsafe-webgpu
Troubleshooting:
vulkaninfo --summary and ensure your GPU driver is installed.Xvfb or a virtual framebuffer.SharedArrayBuffer) are configured in vitest.config.ts.Libraries that depend on @hamk-uas/jax-js-nonconsuming (e.g. dlm-js) can reuse the same headless
WebGPU setup. Install the required dev dependencies:
pnpm add -D vitest @vitest/browser-playwright playwright
pnpm exec playwright install
Then create a vitest.config.ts with the same Chromium flags and COOP/COEP headers:
import { defineConfig } from "vitest/config";
import { playwright } from "vitest/browser";
export default defineConfig({
server: {
headers: {
"Cross-Origin-Embedder-Policy": "require-corp",
"Cross-Origin-Opener-Policy": "same-origin",
},
},
test: {
browser: {
enabled: true,
headless: true,
provider: playwright({
launchOptions: {
args: [
"--no-sandbox",
"--headless=new",
"--use-angle=vulkan",
"--enable-features=Vulkan",
"--disable-vulkan-surface",
"--enable-unsafe-webgpu",
],
env: {
DISPLAY: process.env.DISPLAY ?? ":0",
XAUTHORITY:
process.env.XAUTHORITY ?? `/run/user/${process.getuid?.() ?? 1000}/gdm/Xauthority`,
},
},
}),
instances: [{ browser: "chromium" }],
},
},
});
The Chrome flags route WebGPU through your host Vulkan driver in headless mode.
Browsers gate SharedArrayBuffer behind
cross-origin isolation.
Without it, the WASM worker pool cannot dispatch in parallel and several WebGPU features are
unavailable. This applies to production deployments, not just tests.
Your server must send two HTTP response headers on every page that uses jax-js:
| Header | Value |
|---|---|
Cross-Origin-Embedder-Policy |
require-corp |
Cross-Origin-Opener-Policy |
same-origin |
The Vitest config above already sets these for the dev server. For production, configure your hosting platform. Examples:
Vercel (vercel.json β this repo's own config):
{
"headers": [
{
"source": "/(.*)",
"headers": [
{ "key": "Cross-Origin-Embedder-Policy", "value": "require-corp" },
{ "key": "Cross-Origin-Opener-Policy", "value": "same-origin" }
]
}
]
}
Nginx:
add_header Cross-Origin-Embedder-Policy "require-corp" always;
add_header Cross-Origin-Opener-Policy "same-origin" always;
Netlify (_headers):
/*
Cross-Origin-Embedder-Policy: require-corp
Cross-Origin-Opener-Policy: same-origin
Side effect:
require-corpmeans every sub-resource (image, script, iframe) must either be same-origin or served with aCross-Origin-Resource-Policy: cross-originheader. If your page loads third-party assets without that header, they will be blocked. See the MDN guide for details.
Architectural mode is intended for large refactors and uses .ci/expected-failures.json as an
explicit, expiring debt ledger. See docs/testing-policy.md for workflow details.
For maintainer-only transform ownership checks in framework internals, run:
pnpm run lint:ownership:internal
For website ownership checks used by demos/repl:
pnpm run lint:ownership:website
Pre-commit is branch-aware and runs via scripts/precommit.sh.
build, check, lint, format:checktest:eslint-pluginlint:ownership:internal, lint:ownership:websitevitest run test/refcount.test.ts test/transform-compositions.test.tsmain, master, release/*, hotfix/*):
test:policy:stricttest:website:smokeThis keeps day-to-day feature iteration fast while enforcing release-grade checks when committing on main/release branches.
For large refactors with explicit, expiring expected-failure debt, use architectural mode:
JAX_ARCH_MODE=1 git commit -m "your message"
Architectural mode still enforces strict core invariant suites before applying manifest-based
failure accounting. See docs/testing-policy.md for workflow details.
You can override profile selection explicitly:
JAX_PRECOMMIT_PROFILE=feature git commit -m "..."
JAX_PRECOMMIT_PROFILE=full git commit -m "..."
Before merging to main, run one commit (or dry run) with full profile:
JAX_PRECOMMIT_PROFILE=full git commit -m "chore: pre-merge full local checks"
Inspiration from hamk-uas/eslint-plugin-jax-js: keep hooks explicit and developer-visible, and
keep maintainer release checks documented and reproducible.
We are currently on an older version of Playwright that supports using WebGPU in headless mode; newer versions skip the WebGPU tests.
To start a Vite dev server running the website, demos and REPL:
pnpm -C website dev
This section is for maintainers who create releases.
# 1. Make sure all checks pass
pnpm build
pnpm check
pnpm run test:policy:strict
pnpm run test:website:smoke
pnpm run test:eslint-plugin
pnpm run lint:ownership:internal
pnpm run lint:ownership:website
# Or equivalently (same full profile as main-branch pre-commit):
JAX_PRECOMMIT_PROFILE=full scripts/precommit.sh
# 2. Bump version, commit, and tag in one step.
# Choose patch / minor / major β see Version numbering below.
# The -m flag sets the commit message; %s is replaced with the new version.
npm version patch -m "chore(release): v%s"
# 3. Push the commit and tag
git push && git push --tags
# 4. Create a GitHub release via the gh CLI.
# --title: short phrase after the version number, e.g. "fix grad(solve)" or "add lax.foo"
# --notes: write a human-readable description (see format below) β required, don't skip this
gh release create "v$(node -p "require('./package.json').version")" \
--title "v$(node -p "require('./package.json').version") β <short description>" \
--notes "## <Category>: <what changed>
### What was wrong / motivation
<one paragraph>
### What changed
- **\`src/...\`** β description
### Upgrade
\`\`\`bash
npm install github:hamk-uas/jax-js-nonconsuming#v$(node -p "require('./package.json').version")
\`\`\`
**Full Changelog**: https://github.com/hamk-uas/jax-js-nonconsuming/compare/vPREV...v$(node -p "require('./package.json').version")"
Users install specific tags, so after releasing they can upgrade with:
npm install github:hamk-uas/jax-js-nonconsuming#v0.2.1
| Change type | Bump | Example |
|---|---|---|
| Documentation only (README, comments, copilot-instructions) | none | Users on main get it automatically |
| Bug fix, precision improvement, internal refactor | patch | Kahan summation, ownership fix, test improvements |
| New jax-js/NumPy ops added to API surface | patch | New numpy.foo() function |
| New public API or feature (transform, backend capability) | minor | lax.scan, buffer recycling, new ESLint rule |
| Breaking API or ownership-model behavior change | major | Removing a public function, changing dispose rules |
This fork uses independent semver β it does not mirror upstream ekzhang/jax-js tags. Track
upstream compatibility in release notes, and choose bump level by user-visible impact in this fork.
When rebasing/syncing from upstream, the bump level depends on what user-facing changes come along.
For simple bug-fix PRs (the common case):
main.Contributions are welcome! Please open issues and PRs on this repository for topics specific to the non-consuming fork, such as:
using-based patterns@hamk-uas/eslint-plugin-jax-js linterlax.scan, buffer recycling, checkLeaks, or other fork-only featuresFor feature requests or bugs that apply to both branches (e.g., new NumPy/JAX ops, backend improvements, core tracing), please file them upstream at ekzhang/jax-js instead. This avoids duplicate work and ensures fixes land in both codebases.
Upstream sync policy: We may periodically rebase onto upstream to pick up new features and fixes, but there is no guarantee of continuous updates. Maintenance debt can accumulate across projects, and this fork may be reduced in scope or paused if priorities shift. Upstream jax-js continues independently.
Before submitting a PR, run the full CI checks locally:
pnpm build && pnpm check && pnpm run test:policy:strict && pnpm run test:website:smoke && pnpm run test:eslint-plugin && pnpm run lint:ownership:internal && pnpm run lint:ownership:website
This fork replaces the upstream move-semantics ownership model with a non-consuming model.
Outside ownership semantics and fork-specific features, the APIs are broadly aligned for common
NumPy/JAX usage (jit, grad, vmap, backends, demos), with some intentional divergence (for
example scan, checkLeaks, and ownership tooling).
| Aspect | Upstream (ekzhang/jax-js) | This fork (non-consuming) |
|---|---|---|
| Ownership model | Move semantics | Non-consuming |
| Operations consume inputs? | Yes β every op decrements refcount | No β inputs stay alive |
.ref needed to reuse arrays? |
Yes β x.ref before passing to a second op |
Not in user code |
UseAfterFreeError risk |
Common if .ref is forgotten |
Gone for reuse; still possible after explicit .dispose() |
using declarations |
Not used | First-class β auto-dispose at block end |
| ESLint plugin | @hamk-uas/eslint-plugin-jax-js (move) |
@hamk-uas/eslint-plugin-jax-js (non-consuming) |
lax.scan |
Not implemented | Full support (JIT, autodiff, vmap, native compilation) |
lax.associativeScan |
Not implemented | Kogge-Stone parallel prefix scan; pytrees, any axis, reverse, autodiff. Eager mode routes through a cached whole-call jit wrapper (outside abstract tracing); explicit jit(...) still gives maximum throughput and manual cache lifetime control. |
| Buffer recycling | Not implemented | JIT-level recycle step + WebGPU buffer pool |
tree.makeDisposable |
Not available | Wraps any object for using-based cleanup |
Array.consumeData() |
Not available | Reads data and disposes in one call |
checkLeaks diagnostic |
Not available | Runtime leak detection with stack traces |
The non-consuming model makes some things easier and other things harder. Here are the real costs:
Important context: JIT neutralizes most intermediate-lifetime differences. Under jit(), both
ownership models compile to the same Jaxpr graph, and the compiler derives buffer lifetimes from
that graph β not from the ownership model. Intermediates and buffers are freed at exact last use; in
many common workloads this makes peak-memory behavior very similar across models. The tradeoffs
below apply primarily to eager mode β the mode you use for debugging, prototyping, and the REPL.
Production hot paths should be JIT-wrapped anyway for performance (kernel fusion), which narrows the
practical ownership-model gap to JIT boundaries (who disposes inputs, outputs, and the jit
function itself) and any code that runs outside jit().
Silent leaks replace noisy crashes. Move semantics crash immediately (UseAfterFreeError) when
you forget .ref β painful, but the error points straight at the bug. The non-consuming model never
crashes from reuse, but a missing .dispose() leaks GPU memory silently. For the common case β
block-scoped arrays β using declarations prevent this: arrays are disposed automatically at block
exit. But using can't help everywhere: method chains create anonymous intermediates that nobody
names (and thus nobody disposes), nested results from scan/grad need tree.dispose(), and
loop-carried state or arrays stored in caches require manual discipline. The checkLeaks diagnostic
(built into the test suite so every test is leak-checked) and the ESLint plugin catch many of these
cases, but they are developer tools, not a runtime safety net. (Move semantics can also leak β e.g.
an over-.ref'd array or a retained reference β but the fail-fast default for reuse bugs makes
those easier to spot.)
Higher peak memory in eager mode. Expression chains like x.mul(y).add(z).sub(w) create
intermediate arrays that linger until GC or explicit disposal. With move semantics, each
intermediate is freed as soon as the next operation consumes it. In the non-consuming model, all
intermediates stay alive simultaneously β for large tensors this can significantly increase peak
memory (the exact factor depends on chain length and tensor size). Breaking chains into using
temporaries solves this (intermediates are disposed at block exit), but the code is more verbose
than the NumPy equivalent. Under jit(), both models free intermediates at the optimal point β this
is purely an eager-mode difference. But eager mode is where you debug, and debugging with higher
memory footprint is a real obstacle.
JavaScript GC doesn't know about GPU memory. The JS garbage collector tracks JS heap pressure,
not the 4 GB of VRAM on your GPU. A leaked 512Γ512 f32 buffer is 1 MB of GPU memory but only ~64
bytes of JS heap. GC may never run. FinalizationRegistry is too slow and unpredictable to rely on.
This problem affects both ownership models β any jax-js program must eventually free GPU buffers
explicitly. The non-consuming model simply makes it easier to forget, because nothing crashes when
you do.
Method chains become a pain point in eager mode. a.mul(b).add(c).div(d) is natural in NumPy.
In the non-consuming model, each .method() allocates a new GPU buffer. The fix is using
declarations, but they require separate statements β one per intermediate. Under jit(), these
chains produce tracers (not real GPU buffers) and the compiler manages everything β so the memory
cost only appears in eager code. Still, eager code is where you prototype and learn the API:
// β Leaks two intermediate GPU buffers in eager mode:
const result = a.mul(b).add(c).div(d);
// β
Correct, but more verbose than the NumPy equivalent:
using t1 = a.mul(b);
using t2 = t1.add(c);
const result = t2.div(d);
// Under jit(), intermediates are tracers (not real buffers), so the chain
// doesn't leak GPU memory in practice. But write the second form anyway β
// code should be ownership-correct in both eager and jit mode.
using has ecosystem gaps. The TC39 Explicit Resource Management proposal is not yet supported
everywhere β Svelte's parser can't handle using in .svelte files, and older bundlers may need
transpilation. A polyfill is included, but it adds friction.
More tooling required for edge cases. using handles the most common pattern (block-scoped
arrays) at the language level. But for patterns it doesn't cover β method chains, pytree results,
loop-carried state, long-lived closures β the non-consuming model leans on voluntary tooling: the
ESLint plugin for static analysis and checkLeaks for runtime detection. Move semantics fail fast
for reuse mistakes, but have their own blind spots (over-.ref, retained references, forgotten
vjpFn.dispose()) that also need tooling and discipline.
Neither model is free. Move semantics pay with UseAfterFreeError bugs, .ref boilerplate, and
their own leak surfaces (over-ref, retained refs). The non-consuming model eliminates those costs
but introduces its own: silent leaks for patterns that using can't cover, higher eager-mode memory
for unchained intermediates, and reliance on checkLeaks/ESLint for the gaps. Under jit(), the
two models converge β the compiled programs are identical. Both models need discipline; they just
fail in different ways. This fork bets that using-by-default plus opt-in tooling is easier to
manage for teams coming from Python/MATLAB β but it is a genuine tradeoff, not a free lunch.
using
declarations handle cleanup, and lax.scan is available.@hamk-uas ESLint plugin, or if you need to
stay on the upstream release cadence.The two versions are not drop-in interchangeable β ownership patterns from one model can behave
incorrectly or awkwardly in the other, especially around .ref and disposal discipline. The
@hamk-uas/eslint-plugin-jax-js included here enforces the non-consuming patterns and will flag
.ref usage as unnecessary.
.ref calls β operations no longer consume inputs.using β using x = np.array(...) auto-disposes at
block end..dispose() explicitly for long-lived arrays β or wrap in tree.makeDisposable().@hamk-uas/eslint-plugin-jax-js β it catches leaks, use-after-dispose, and
unnecessary .ref at edit time. See the plugin README for setup.This fork is developed primarily using AI coding agents (GitHub Copilot, Claude, GPT, Gemini) with
gentle human supervision. All changes go through the full CI pipeline (pnpm test, pnpm check)
and the pre-commit hook runs the complete test suite before every commit.