jax-js-nonconsuming (Fork)
    Preparing search index...

    jax-js-nonconsuming logo: JAX in pure JavaScript

    Non-consuming ownership fork

    Website | API Reference | Compatibility Table | Copilot Instructions | jax-js Discord

    Fork notice: This is a fork of ekzhang/jax-js with a non-consuming ownership model. Operations leave inputs alive (no .ref needed), and using declarations provide deterministic GPU/WASM memory cleanup.

    Why this fork? The original jax-js uses move semantics, where operations consume their inputs. This fork was created for teams familiar with MATLAB or Python (NumPy) where move semantics are unexpected. We also fast-tracked lax.scan and lax.associativeScan implementations. using declarations handle the common case β€” block-scoped arrays are disposed automatically β€” but patterns like method chains, loop-carried state, and nested results still need manual care. A built-in checkLeaks diagnostic and ESLint plugin help catch what using misses. See Tradeoffs for an honest comparison.

    See Differences from upstream for a full comparison between the original and this fork.

    πŸ€– The fork code & documentation commits are AI-generated with gentle human supervision.

    jax-js is a machine learning framework for the browser. It aims to bring JAX-style, high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.

    npm install github:hamk-uas/jax-js-nonconsuming
    

    To pin a specific release tag (once available):

    npm install github:hamk-uas/jax-js-nonconsuming#v0.3.0
    

    pnpm users: pnpm requires explicit permission to run build scripts from Git dependencies. Add this to your package.json:

    {
    "pnpm": {
    "onlyBuiltDependencies": ["@hamk-uas/jax-js-nonconsuming"]
    }
    }

    Without this, pnpm install fails with ERR_PNPM_GIT_DEP_PREPARE_NOT_ALLOWED. The prepare script runs tsdown to build the package and its sub-packages (optax, loaders, onnx, eslint-plugin).

    Under the hood, it translates array operations into a compiler representation, then synthesizes kernels in WebAssembly and WebGPU.

    The original jax-js was written from scratch with zero zero external dependencies. jax-js and this fork maintain close API compatibility with NumPy/JAX. Since everything runs client-side, jax-js is one of the most portable GPU ML frameworks, since it runs anywhere a browser can run.

    import { numpy as np } from "@hamk-uas/jax-js-nonconsuming";

    // Array operations, compatible with JAX/NumPy.
    {
    using x = np.array([1, 2, 3]);
    using y = x.mul(4); // [4, 8, 12]
    }

    In vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest way to get started on a blank HTML page.

    <script type="module">
    import { numpy as np } from "https://esm.sh/@jax-js/jax";
    </script>

    This table refers to latest versions of each browser. WebGPU has gained wide support in browsers as of late 2025.

    Platform CPU (Wasm) GPU (WebGPU) GPU (WebGL)
    Chrome / Edge βœ… βœ… βœ…
    Firefox βœ… βœ… - macOS 26+ βœ…
    Safari βœ… βœ… - macOS 26+ βœ…
    iOS βœ… βœ… - iOS 26+ βœ…
    Chrome for Android βœ… βœ… βœ…
    Firefox for Android βœ… ❌ βœ…
    Node.js βœ… ❌ ❌
    Deno βœ… βœ… - async ❌

    Community usage:

    Demos:

    Here's a quick, high-level comparison with other popular web ML runtimes. Performance labels are workload- and hardware-dependent.

    Feature jax-js-nonconsuming jax-js v0.1.9 TensorFlow.js onnxruntime-web
    Overview
    API style JAX/NumPy JAX/NumPy TensorFlow-like Static ONNX graphs
    Speed Very fast Very fast Fast Very fast
    Bundle size (gzip) 184 KB 80 KB 269 KB 90 KB + 24 MB Wasm
    Autodiff & JIT
    Gradients βœ… βœ… βœ… ❌
    Jacobian and Hessian βœ… βœ… ❌ ❌
    jvp() forward differentiation βœ… βœ… ❌ ❌
    jit() kernel fusion βœ… βœ… ❌ ❌
    vmap() auto-vectorization βœ… βœ… ❌ ❌
    scan() scan over leading axis βœ… ❌ ❌ ❌
    associativeScan() parallel prefix scan βœ… ❌ ❌ ❌
    Graph capture βœ… βœ… ❌ βœ…
    Backends & Data
    WebGPU backend βœ… βœ… 🟑 Preview βœ…
    WebGL backend βœ… βœ… βœ… βœ…
    Wasm (CPU) backend βœ… βœ… βœ… βœ…
    Eager array API βœ… βœ… βœ… ❌
    Run ONNX models 🟑 Partial 🟑 Partial ❌ βœ…
    Read safetensors βœ… βœ… ❌ ❌
    Float64 βœ… βœ… ❌ ❌
    Float32 βœ… βœ… βœ… βœ…
    Float16 βœ… βœ… ❌ βœ…
    BFloat16 ❌ ❌ ❌ ❌
    Packed Uint8 ❌ ❌ ❌ 🟑 Partial
    Mixed precision βœ… βœ… ❌ βœ…
    Mixed devices βœ… βœ… ❌ ❌
    Ops & Numerics
    Arithmetic functions βœ… βœ… βœ… βœ…
    Matrix multiplication βœ… βœ… βœ… βœ…
    General einsum βœ… βœ… 🟑 Partial 🟑 Partial
    Sorting βœ… βœ… ❌ ❌
    Activation functions βœ… βœ… βœ… βœ…
    NaN/Inf numerics βœ… βœ… βœ… βœ…
    Basic convolutions βœ… βœ… βœ… βœ…
    n-d convolutions βœ… βœ… ❌ βœ…
    Strided/dilated convolution βœ… βœ… βœ… βœ…
    Cholesky, Lstsq βœ… βœ… ❌ ❌
    LU, Solve, Determinant βœ… βœ… ❌ ❌
    SVD ❌ ❌ ❌ ❌
    FFT βœ… βœ… βœ… βœ…
    Basic RNG (Uniform, Normal) βœ… βœ… βœ… βœ…
    Advanced RNG βœ… βœ… ❌ ❌

    Programming in jax-js looks very similar to JAX, just in JavaScript.

    Create an array with np.array():

    import { numpy as np } from "@hamk-uas/jax-js-nonconsuming";

    using ar = np.array([1, 2, 3]);

    By default, this is a float32 array, but you can specify a different dtype:

    using ar = np.array([1, 2, 3], { dtype: np.int32 });
    

    For more efficient construction, create an array from a JS TypedArray buffer:

    const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
    using ar = np.array(buf).reshape([2, 3]);

    Once you're done with it, you can unwrap a jax.Array back into JavaScript. This will also apply any pending operations or lazy updates:

    // 1) Returns a possibly nested JavaScript array.
    ar.js();
    await ar.jsAsync(); // Faster, non-blocking

    // 2) Returns a flat TypedArray data buffer.
    ar.dataSync();
    await ar.data(); // Fastest, non-blocking

    // If you don't need the array after reading data:
    const floats = await ar.consumeData();

    Arrays can have mathematical operations applied to them. For example:

    import { numpy as np, scipySpecial as special } from "@hamk-uas/jax-js-nonconsuming";

    using x = np.arange(100).astype(np.float32); // array of integers [0..99]

    using y1 = x.add(x); // x + x
    using y2 = np.sin(x); // sin(x)
    using tanhX = np.tanh(x);
    using y3 = tanhX.mul(5); // 5 * tanh(x)
    using y4 = special.erfc(x); // erfc(x)

    Big Arrays take up a lot of memory. Python ML libraries override the __del__() method to free memory, but JavaScript has no such API for running object destructors (cf.).

    In jax-js-nonconsuming, operations do not consume their inputs β€” you can freely reuse an array in multiple expressions without any special syntax. When you're done with an array, call .dispose() to free its memory, or use JavaScript's using keyword for automatic disposal:

    {
    using x = np.array([1, 2, 3]);
    using doubled = x.add(x);
    const y = doubled.mul(x); // x used three times β€” no problem
    y.dispose();
    // x and doubled are automatically disposed at end of block
    }

    For best performance, wrap compute-heavy code in jit(). The JIT compiler automatically manages intermediate buffers β€” allocating, reusing, and freeing them at the optimal points:

    const f = jit((x: np.Array) => {
    using sq = x.mul(x);
    using s = sq.sum();
    return np.sqrt(s);
    });
    const result = f(x); // intermediates freed automatically inside jit
    result.dispose(); // caller disposes the output when done
    f.dispose(); // free captured constants when the function is no longer needed

    JIT functions support using for automatic cleanup β€” ideal for short-lived compiled functions:

    {
    using f = jit((x: np.Array) => x.mul(x).sum());
    const result = f(input);
    console.log(await result.data());
    result.dispose();
    // f's captured constants freed automatically at block end
    }

    JIT cache hierarchy. Disposing a jit function frees its captured GPU/WASM buffer constants (the expensive part), but lightweight metadata caches survive for reuse:

    Cache Freed by f.dispose()? Freed by clearCaches()?
    Captured constants (buffers) Yes Yes
    JIT compilation cache No (lightweight) Yes
    GPU shader pipelines No (cheap handles) No (device lifetime)

    For long-running applications (optimization loops, servers), call clearCaches() periodically to reclaim all JIT metadata. This is rarely needed β€” the metadata caches are small β€” but it prevents unbounded growth when creating many distinct JIT closures:

    import { clearCaches } from "@hamk-uas/jax-js-nonconsuming";

    for (const batch of batches) {
    using f = jit((x) => model(params, x));
    const loss = f(batch);
    // ... update params ...
    loss.dispose();
    }
    clearCaches(); // flush all JIT metadata after the loop

    The @hamk-uas/eslint-plugin-jax-js catches the most common memory leaks (missing using, use-after-dispose, unnecessary .ref) at edit time β€” see the plugin README for setup.

    Ownership invariance principle: write code that is ownership-correct in both eager mode and jit() mode. jit() is a performance optimization (fusion, recycling), not a semantics change. If code leaks or relies on different ownership behavior in eager mode, treat it as a real bug. For CI enforcement in user code, use jaxJs.configs.invariance from @hamk-uas/eslint-plugin-jax-js.

    JAX's signature composable transformations are also supported in jax-js. Here is a simple example of using grad and vmap to compute the derivative of a function:

    import { numpy as np, grad, vmap } from "@hamk-uas/jax-js-nonconsuming";

    using x = np.linspace(-10, 10, 1000);

    using y1 = vmap(grad(np.sin))(x); // d/dx sin(x) = cos(x)
    using y2 = np.cos(x);

    np.allclose(y1, y2); // => true

    The jit function is especially useful when doing long sequences of primitives on GPU, since it fuses operations together into a single kernel dispatch. This improves memory bandwidth usage on hardware accelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:

    export const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {
    using x1sq = np.square(x1);
    using x2sq = np.square(x2);
    using sum = x1sq.add(x2sq);
    return np.sqrt(sum);
    });

    Without JIT, the hypot() function would require four kernel dispatches: two multiplies, one add, and one sqrt. JIT fuses these together into a single kernel that does it all at once.

    Nested jit is transparent. If a jitted function calls another jitted function, the inner jit is structurally inlined β€” the compiler flattens nested jit boundaries and fuses everything into a single compiled program. This means you can safely jit a model's predict method for fast standalone inference, then also use it inside a larger jit(grad(...)) training step without any overhead or conflict:

    const model = {
    predict: jit((params, x) => {
    // ... model forward pass ...
    }),
    };

    // Inference: calls the jitted predict directly β€” compiled kernel
    const logits = model.predict(params, x);

    // Training: outer jit flattens the inner jit β€” single fused program
    const trainStep = jit((params, x, y) => {
    const [loss, grads] = valueAndGrad(lossFn)(params, x, y);
    return [loss, updateParams(params, grads)];
    });

    This also works through scan, grad, vmap, and all other transformations β€” inner jit boundaries are always dissolved during compilation.

    All functional transformations can take typed JsTree of inputs and outputs. These are similar to JAX's pytrees, and it's basically just a structure of nested JavaScript objects and arrays. For instance:

    import { grad, numpy as np, tree } from "@hamk-uas/jax-js-nonconsuming";

    type Params = {
    foo: np.Array;
    bar: np.Array[];
    };

    function getSums(p: Params) {
    using fooSum = p.foo.sum();
    using bar0 = p.bar[0].sum();
    using bar1 = p.bar[1].sum();
    using barSum = bar0.add(bar1);
    return fooSum.add(barSum);
    }

    using g = tree.makeDisposable(
    grad(getSums)({
    foo: np.array([1, 2, 3]),
    bar: [np.array([10]), np.array([11, 12])],
    }),
    );
    // => { foo: [1, 1, 1], bar: [[1], [1, 1]] }

    Note that you need to use type alias syntax rather than interface to define fine-grained JsTree types.

    Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory and determines how to execute compiled operations on them.

    There are currently 4 devices in jax-js:

    • cpu: Slow, interpreted JS, only meant for debugging.
    • wasm: WebAssembly, with optional multi-threaded parallel dispatch via WasmWorkerPool (requires SharedArrayBuffer).
    • webgpu: WebGPU, available on supported browsers (Chrome, Firefox, Safari, iOS).
    • webgl: WebGL2, via fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much slower than WebGPU. It's offered on a best-effort basis and not as well-supported. The webgl device has not been tested during development of jax-js-nonconsuming.

    We recommend webgpu for best performance running neural networks and wasm for narrow sequential computations. The default device is wasm, but you can change this at startup time:

    import { defaultDevice, init } from "@hamk-uas/jax-js-nonconsuming";

    const devices = await init(); // Starts all available backends.

    if (devices.includes("webgpu")) {
    defaultDevice("webgpu");
    } else {
    console.warn("WebGPU is not supported, falling back to Wasm.");
    }

    You can also place individual arrays on specific devices:

    import { devicePut, numpy as np } from "@hamk-uas/jax-js-nonconsuming";

    using ar = np.array([1, 2, 3]); // Starts with device="wasm"
    await devicePut(ar, "webgpu"); // Now device="webgpu"

    jax-js includes three helper libraries, available as sub-path exports from the main package:

    • optax β€” optimizers (Adam, SGD) and gradient processing
    • loaders β€” Safetensors, BPE tokenizers, OPFS-cached downloads
    • onnx β€” load ONNX models into native jax-js functions
    import { adam } from "@hamk-uas/jax-js-nonconsuming/optax";
    import { cachedFetch, safetensors } from "@hamk-uas/jax-js-nonconsuming/loaders";
    import { ONNXModel } from "@hamk-uas/jax-js-nonconsuming/onnx";

    No extra install needed β€” they're included when you install @hamk-uas/jax-js-nonconsuming. The loaders and onnx sub-packages have optional native dependencies (@bufbuild/protobuf, sentencepiece-buf, onnx-buf) that are listed in the main package's optionalDependencies and installed automatically if available for your platform.

    To see per-kernel traces in browser development tools, call jax.profiler.startTrace().

    The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for individual browsers. Also, this library uniquely has the jit() feature that fuses operations together and records an execution graph. jax-js achieves over 7000 GFLOP/s for matrix multiplication on an Apple M4 Max chip (try it).

    In that specific benchmark run, it was faster than both TensorFlow.js and ONNX Runtime Web, which both use handwritten libraries of custom kernels. Results vary by model, operator mix, and hardware.

    It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as well as unique optimizations such as FlashAttention variants.

    That's all for this short tutorial. Please see the generated API reference for detailed documentation.

    The following technical details are for contributing to jax-js-nonconsuming and modifying its internals.

    This repository is managed by pnpm. You can compile and build all packages in watch mode with:

    pnpm install
    pnpm run build:watch

    The pnpm install command automatically sets up Git hooks via Husky. This repository intentionally shifts verification left: we run the necessary quality and release-safety checks in local pre-commit rather than relying on heavy CI gates for GitHub Pages deployment confidence.

    Run checks iteratively while developing:

    You can also run linting and formatting manually:

    pnpm lint          # Run ESLint (includes @hamk-uas/eslint-plugin-jax-js ownership rules)
    pnpm format # Format all files with Prettier
    pnpm format:check # Check formatting without writing
    pnpm check # Run TypeScript type checking
    pnpm run test:eslint-plugin # Rule-level tests for in-repo ESLint plugin
    pnpm run lint:ownership:internal # Maintainer transform ownership checks

    Then you can run tests in a headless browser using Vitest.

    pnpm exec playwright install
    pnpm test
    pnpm run test:policy:strict # Strict mode: no expected-failure debt
    pnpm run test:arch # Architectural mode: failures gated by manifest
    pnpm run test:website:smoke # Website build + smoke checks

    Tests run in headless Chromium via Playwright with full WebGPU support. This requires a working Vulkan GPU driver on the host. Verify with:

    vulkaninfo --summary   # Should show your GPU device
    

    Key requirements:

    1. Vulkan driver β€” the headless Chrome flags in vitest.config.ts use --use-angle=vulkan to route WebGPU through the host GPU. Without a working Vulkan driver, WebGPU tests will be skipped.
    2. Playwright browsers β€” pnpm exec playwright install downloads a bundled Chromium that includes the Dawn WebGPU implementation.
    3. Secure context β€” WebGPU requires isSecureContext === true. Vitest's dev server serves on localhost, which qualifies automatically. Direct about:blank navigation would not.
    4. Display server β€” even in headless mode, Chrome needs access to a display for GPU initialization. On Linux, ensure DISPLAY is set (X11) or a Wayland session is active. The config passes DISPLAY and XAUTHORITY from the environment.
    5. No root / no sandbox β€” the --no-sandbox flag avoids permission issues in containers or non-root environments.

    The full set of Chrome flags used (configured in vitest.config.ts):

    --no-sandbox --headless=new --use-angle=vulkan --enable-features=Vulkan
    --disable-vulkan-surface --enable-unsafe-webgpu

    Troubleshooting:

    • If WebGPU tests are skipped, check vulkaninfo --summary and ensure your GPU driver is installed.
    • On headless servers (no display), use Xvfb or a virtual framebuffer.
    • COOP/COEP headers (required for SharedArrayBuffer) are configured in vitest.config.ts.

    Libraries that depend on @hamk-uas/jax-js-nonconsuming (e.g. dlm-js) can reuse the same headless WebGPU setup. Install the required dev dependencies:

    pnpm add -D vitest @vitest/browser-playwright playwright
    pnpm exec playwright install

    Then create a vitest.config.ts with the same Chromium flags and COOP/COEP headers:

    import { defineConfig } from "vitest/config";
    import { playwright } from "vitest/browser";

    export default defineConfig({
    server: {
    headers: {
    "Cross-Origin-Embedder-Policy": "require-corp",
    "Cross-Origin-Opener-Policy": "same-origin",
    },
    },
    test: {
    browser: {
    enabled: true,
    headless: true,
    provider: playwright({
    launchOptions: {
    args: [
    "--no-sandbox",
    "--headless=new",
    "--use-angle=vulkan",
    "--enable-features=Vulkan",
    "--disable-vulkan-surface",
    "--enable-unsafe-webgpu",
    ],
    env: {
    DISPLAY: process.env.DISPLAY ?? ":0",
    XAUTHORITY:
    process.env.XAUTHORITY ?? `/run/user/${process.getuid?.() ?? 1000}/gdm/Xauthority`,
    },
    },
    }),
    instances: [{ browser: "chromium" }],
    },
    },
    });

    The Chrome flags route WebGPU through your host Vulkan driver in headless mode.

    Browsers gate SharedArrayBuffer behind cross-origin isolation. Without it, the WASM worker pool cannot dispatch in parallel and several WebGPU features are unavailable. This applies to production deployments, not just tests.

    Your server must send two HTTP response headers on every page that uses jax-js:

    Header Value
    Cross-Origin-Embedder-Policy require-corp
    Cross-Origin-Opener-Policy same-origin

    The Vitest config above already sets these for the dev server. For production, configure your hosting platform. Examples:

    Vercel (vercel.json β€” this repo's own config):

    {
    "headers": [
    {
    "source": "/(.*)",
    "headers": [
    { "key": "Cross-Origin-Embedder-Policy", "value": "require-corp" },
    { "key": "Cross-Origin-Opener-Policy", "value": "same-origin" }
    ]
    }
    ]
    }

    Nginx:

    add_header Cross-Origin-Embedder-Policy "require-corp" always;
    add_header Cross-Origin-Opener-Policy   "same-origin"  always;
    

    Netlify (_headers):

    /*
    Cross-Origin-Embedder-Policy: require-corp
    Cross-Origin-Opener-Policy: same-origin

    Side effect: require-corp means every sub-resource (image, script, iframe) must either be same-origin or served with a Cross-Origin-Resource-Policy: cross-origin header. If your page loads third-party assets without that header, they will be blocked. See the MDN guide for details.

    Architectural mode is intended for large refactors and uses .ci/expected-failures.json as an explicit, expiring debt ledger. See docs/testing-policy.md for workflow details.

    For maintainer-only transform ownership checks in framework internals, run:

    pnpm run lint:ownership:internal
    

    For website ownership checks used by demos/repl:

    pnpm run lint:ownership:website
    

    Pre-commit is branch-aware and runs via scripts/precommit.sh.

    • Feature profile (default on non-main branches):
      • build, check, lint, format:check
      • test:eslint-plugin
      • lint:ownership:internal, lint:ownership:website
      • core invariants: vitest run test/refcount.test.ts test/transform-compositions.test.ts
    • Full profile (default on main, master, release/*, hotfix/*):
      • everything in feature profile
      • test:policy:strict
      • test:website:smoke

    This keeps day-to-day feature iteration fast while enforcing release-grade checks when committing on main/release branches.

    For large refactors with explicit, expiring expected-failure debt, use architectural mode:

    JAX_ARCH_MODE=1 git commit -m "your message"
    

    Architectural mode still enforces strict core invariant suites before applying manifest-based failure accounting. See docs/testing-policy.md for workflow details.

    You can override profile selection explicitly:

    JAX_PRECOMMIT_PROFILE=feature git commit -m "..."
    JAX_PRECOMMIT_PROFILE=full git commit -m "..."

    Before merging to main, run one commit (or dry run) with full profile:

    JAX_PRECOMMIT_PROFILE=full git commit -m "chore: pre-merge full local checks"
    

    Inspiration from hamk-uas/eslint-plugin-jax-js: keep hooks explicit and developer-visible, and keep maintainer release checks documented and reproducible.

    We are currently on an older version of Playwright that supports using WebGPU in headless mode; newer versions skip the WebGPU tests.

    To start a Vite dev server running the website, demos and REPL:

    pnpm -C website dev
    

    This section is for maintainers who create releases.

    # 1. Make sure all checks pass
    pnpm build
    pnpm check
    pnpm run test:policy:strict
    pnpm run test:website:smoke
    pnpm run test:eslint-plugin
    pnpm run lint:ownership:internal
    pnpm run lint:ownership:website

    # Or equivalently (same full profile as main-branch pre-commit):
    JAX_PRECOMMIT_PROFILE=full scripts/precommit.sh

    # 2. Bump version, commit, and tag in one step.
    # Choose patch / minor / major β€” see Version numbering below.
    # The -m flag sets the commit message; %s is replaced with the new version.
    npm version patch -m "chore(release): v%s"

    # 3. Push the commit and tag
    git push && git push --tags

    # 4. Create a GitHub release via the gh CLI.
    # --title: short phrase after the version number, e.g. "fix grad(solve)" or "add lax.foo"
    # --notes: write a human-readable description (see format below) β€” required, don't skip this
    gh release create "v$(node -p "require('./package.json').version")" \
    --title "v$(node -p "require('./package.json').version") β€” <short description>" \
    --notes "## <Category>: <what changed>

    ### What was wrong / motivation
    <one paragraph>

    ### What changed
    - **\`src/...\`** β€” description

    ### Upgrade
    \`\`\`bash
    npm install github:hamk-uas/jax-js-nonconsuming#v$(node -p "require('./package.json').version")
    \`\`\`

    **Full Changelog**: https://github.com/hamk-uas/jax-js-nonconsuming/compare/vPREV...v$(node -p "require('./package.json').version")"

    Users install specific tags, so after releasing they can upgrade with:

    npm install github:hamk-uas/jax-js-nonconsuming#v0.2.1
    
    Change type Bump Example
    Documentation only (README, comments, copilot-instructions) none Users on main get it automatically
    Bug fix, precision improvement, internal refactor patch Kahan summation, ownership fix, test improvements
    New jax-js/NumPy ops added to API surface patch New numpy.foo() function
    New public API or feature (transform, backend capability) minor lax.scan, buffer recycling, new ESLint rule
    Breaking API or ownership-model behavior change major Removing a public function, changing dispose rules

    This fork uses independent semver β€” it does not mirror upstream ekzhang/jax-js tags. Track upstream compatibility in release notes, and choose bump level by user-visible impact in this fork. When rebasing/syncing from upstream, the bump level depends on what user-facing changes come along.

    For simple bug-fix PRs (the common case):

    1. Merge the PR to main.
    2. Version, tag, push, and create the GitHub release β€” bug fixes are always a patch bump. Follow the releasing steps above.

    Contributions are welcome! Please open issues and PRs on this repository for topics specific to the non-consuming fork, such as:

    • The non-consuming ownership model and using-based patterns
    • The @hamk-uas/eslint-plugin-jax-js linter
    • lax.scan, buffer recycling, checkLeaks, or other fork-only features
    • Documentation, demos, or examples specific to this fork

    For feature requests or bugs that apply to both branches (e.g., new NumPy/JAX ops, backend improvements, core tracing), please file them upstream at ekzhang/jax-js instead. This avoids duplicate work and ensures fixes land in both codebases.

    Upstream sync policy: We may periodically rebase onto upstream to pick up new features and fixes, but there is no guarantee of continuous updates. Maintenance debt can accumulate across projects, and this fork may be reduced in scope or paused if priorities shift. Upstream jax-js continues independently.

    Before submitting a PR, run the full CI checks locally:

    pnpm build && pnpm check && pnpm run test:policy:strict && pnpm run test:website:smoke && pnpm run test:eslint-plugin && pnpm run lint:ownership:internal && pnpm run lint:ownership:website
    

    This fork replaces the upstream move-semantics ownership model with a non-consuming model. Outside ownership semantics and fork-specific features, the APIs are broadly aligned for common NumPy/JAX usage (jit, grad, vmap, backends, demos), with some intentional divergence (for example scan, checkLeaks, and ownership tooling).

    Aspect Upstream (ekzhang/jax-js) This fork (non-consuming)
    Ownership model Move semantics Non-consuming
    Operations consume inputs? Yes β€” every op decrements refcount No β€” inputs stay alive
    .ref needed to reuse arrays? Yes β€” x.ref before passing to a second op Not in user code
    UseAfterFreeError risk Common if .ref is forgotten Gone for reuse; still possible after explicit .dispose()
    using declarations Not used First-class β€” auto-dispose at block end
    ESLint plugin @hamk-uas/eslint-plugin-jax-js (move) @hamk-uas/eslint-plugin-jax-js (non-consuming)
    lax.scan Not implemented Full support (JIT, autodiff, vmap, native compilation)
    lax.associativeScan Not implemented Kogge-Stone parallel prefix scan; pytrees, any axis, reverse, autodiff. Eager mode routes through a cached whole-call jit wrapper (outside abstract tracing); explicit jit(...) still gives maximum throughput and manual cache lifetime control.
    Buffer recycling Not implemented JIT-level recycle step + WebGPU buffer pool
    tree.makeDisposable Not available Wraps any object for using-based cleanup
    Array.consumeData() Not available Reads data and disposes in one call
    checkLeaks diagnostic Not available Runtime leak detection with stack traces

    The non-consuming model makes some things easier and other things harder. Here are the real costs:

    Important context: JIT neutralizes most intermediate-lifetime differences. Under jit(), both ownership models compile to the same Jaxpr graph, and the compiler derives buffer lifetimes from that graph β€” not from the ownership model. Intermediates and buffers are freed at exact last use; in many common workloads this makes peak-memory behavior very similar across models. The tradeoffs below apply primarily to eager mode β€” the mode you use for debugging, prototyping, and the REPL. Production hot paths should be JIT-wrapped anyway for performance (kernel fusion), which narrows the practical ownership-model gap to JIT boundaries (who disposes inputs, outputs, and the jit function itself) and any code that runs outside jit().

    Silent leaks replace noisy crashes. Move semantics crash immediately (UseAfterFreeError) when you forget .ref β€” painful, but the error points straight at the bug. The non-consuming model never crashes from reuse, but a missing .dispose() leaks GPU memory silently. For the common case β€” block-scoped arrays β€” using declarations prevent this: arrays are disposed automatically at block exit. But using can't help everywhere: method chains create anonymous intermediates that nobody names (and thus nobody disposes), nested results from scan/grad need tree.dispose(), and loop-carried state or arrays stored in caches require manual discipline. The checkLeaks diagnostic (built into the test suite so every test is leak-checked) and the ESLint plugin catch many of these cases, but they are developer tools, not a runtime safety net. (Move semantics can also leak β€” e.g. an over-.ref'd array or a retained reference β€” but the fail-fast default for reuse bugs makes those easier to spot.)

    Higher peak memory in eager mode. Expression chains like x.mul(y).add(z).sub(w) create intermediate arrays that linger until GC or explicit disposal. With move semantics, each intermediate is freed as soon as the next operation consumes it. In the non-consuming model, all intermediates stay alive simultaneously β€” for large tensors this can significantly increase peak memory (the exact factor depends on chain length and tensor size). Breaking chains into using temporaries solves this (intermediates are disposed at block exit), but the code is more verbose than the NumPy equivalent. Under jit(), both models free intermediates at the optimal point β€” this is purely an eager-mode difference. But eager mode is where you debug, and debugging with higher memory footprint is a real obstacle.

    JavaScript GC doesn't know about GPU memory. The JS garbage collector tracks JS heap pressure, not the 4 GB of VRAM on your GPU. A leaked 512Γ—512 f32 buffer is 1 MB of GPU memory but only ~64 bytes of JS heap. GC may never run. FinalizationRegistry is too slow and unpredictable to rely on. This problem affects both ownership models β€” any jax-js program must eventually free GPU buffers explicitly. The non-consuming model simply makes it easier to forget, because nothing crashes when you do.

    Method chains become a pain point in eager mode. a.mul(b).add(c).div(d) is natural in NumPy. In the non-consuming model, each .method() allocates a new GPU buffer. The fix is using declarations, but they require separate statements β€” one per intermediate. Under jit(), these chains produce tracers (not real GPU buffers) and the compiler manages everything β€” so the memory cost only appears in eager code. Still, eager code is where you prototype and learn the API:

    // ❌ Leaks two intermediate GPU buffers in eager mode:
    const result = a.mul(b).add(c).div(d);

    // βœ… Correct, but more verbose than the NumPy equivalent:
    using t1 = a.mul(b);
    using t2 = t1.add(c);
    const result = t2.div(d);

    // Under jit(), intermediates are tracers (not real buffers), so the chain
    // doesn't leak GPU memory in practice. But write the second form anyway β€”
    // code should be ownership-correct in both eager and jit mode.

    using has ecosystem gaps. The TC39 Explicit Resource Management proposal is not yet supported everywhere β€” Svelte's parser can't handle using in .svelte files, and older bundlers may need transpilation. A polyfill is included, but it adds friction.

    More tooling required for edge cases. using handles the most common pattern (block-scoped arrays) at the language level. But for patterns it doesn't cover β€” method chains, pytree results, loop-carried state, long-lived closures β€” the non-consuming model leans on voluntary tooling: the ESLint plugin for static analysis and checkLeaks for runtime detection. Move semantics fail fast for reuse mistakes, but have their own blind spots (over-.ref, retained references, forgotten vjpFn.dispose()) that also need tooling and discipline.

    Neither model is free. Move semantics pay with UseAfterFreeError bugs, .ref boilerplate, and their own leak surfaces (over-ref, retained refs). The non-consuming model eliminates those costs but introduces its own: silent leaks for patterns that using can't cover, higher eager-mode memory for unchained intermediates, and reliance on checkLeaks/ESLint for the gaps. Under jit(), the two models converge β€” the compiled programs are identical. Both models need discipline; they just fail in different ways. This fork bets that using-by-default plus opt-in tooling is easier to manage for teams coming from Python/MATLAB β€” but it is a genuine tradeoff, not a free lunch.

    • Use this fork if you want a simpler ownership model where arrays can be freely reused, using declarations handle cleanup, and lax.scan is available.
    • Use upstream if you prefer fail-fast ownership enforcement (crashes over silent leaks), are already invested in the move-semantics model and the @hamk-uas ESLint plugin, or if you need to stay on the upstream release cadence.

    The two versions are not drop-in interchangeable β€” ownership patterns from one model can behave incorrectly or awkwardly in the other, especially around .ref and disposal discipline. The @hamk-uas/eslint-plugin-jax-js included here enforces the non-consuming patterns and will flag .ref usage as unnecessary.

    1. Remove all .ref calls β€” operations no longer consume inputs.
    2. Replace manual refcount juggling with using β€” using x = np.array(...) auto-disposes at block end.
    3. Call .dispose() explicitly for long-lived arrays β€” or wrap in tree.makeDisposable().
    4. Install @hamk-uas/eslint-plugin-jax-js β€” it catches leaks, use-after-dispose, and unnecessary .ref at edit time. See the plugin README for setup.

    This fork is developed primarily using AI coding agents (GitHub Copilot, Claude, GPT, Gemini) with gentle human supervision. All changes go through the full CI pipeline (pnpm test, pnpm check) and the pre-commit hook runs the complete test suite before every commit.