jax-js-nonconsuming (Fork)
    Preparing search index...
    customVjp: <F extends (...args: any[]) => any>(
        fwd: (...args: Parameters<F>) => [ReturnType<F>, any],
        bwd: (residuals: any, cotangents: ReturnType<F>) => any,
    ) => (...args: Parameters<F>) => ReturnType<F> = linearizeModule.customVjp

    Create a function with a custom vector-Jacobian product (VJP) rule.

    When the returned function is differentiated via vjp, grad, or valueAndGrad, the supplied bwd function is invoked instead of the standard autodiff backward pass. This enables implicit differentiation, numerically stable gradients, and memory-efficient backpropagation.

    Type Declaration

      • <F extends (...args: any[]) => any>(
            fwd: (...args: Parameters<F>) => [ReturnType<F>, any],
            bwd: (residuals: any, cotangents: ReturnType<F>) => any,
        ): (...args: Parameters<F>) => ReturnType<F>
      • Create a function with a custom vector-Jacobian product (VJP) rule.

        This is the jax-js equivalent of JAX's jax.custom_vjp. It lets you define exactly how reverse-mode differentiation should behave for a function, which is essential for:

        • Implicit differentiation (optimizers, fixed-point solvers)
        • Numerically stable gradients (log-sum-exp, softmax)
        • Memory-efficient backprop (checkpointing, not saving intermediates)

        Type Parameters

        • F extends (...args: any[]) => any

        Parameters

        • fwd: (...args: Parameters<F>) => [ReturnType<F>, any]

          Forward function. Called with the same arguments as the original function. Must return [outputs, residuals] where residuals is any pytree of values needed by the backward pass. The outputs are returned to the caller.

        • bwd: (residuals: any, cotangents: ReturnType<F>) => any

          Backward function. Called with (residuals, cotangents) where residuals is the second element returned by fwd, and cotangents has the same pytree structure as outputs. Must return an array (or pytree) of cotangents with the same structure as the positional arguments to fwd.

        Returns (...args: Parameters<F>) => ReturnType<F>

        A function with the same signature as (...args) => outputs. When differentiated via vjp(), grad(), or valueAndGrad(), the custom bwd function is invoked instead of the standard autodiff backward pass.

        // Numerically stable log1pexp with custom gradient
        const log1pexp = customVjp(
        // fwd: compute output + save residuals for backward
        (x: np.Array) => {
        const out = np.log(np.add(np.array(1), np.exp(x)));
        return [out, x]; // save x as residual
        },
        // bwd: compute gradient using residuals
        (x: np.Array, g: np.Array) => {
        return np.multiply(g, np.subtract(np.array(1), np.exp(np.negative(x))));
        },
        );

        const dx = grad(log1pexp)(np.array(100.0));
        // → 1.0 (correct), not NaN (what autodiff of log(1+exp(100)) gives)
        // Implicit differentiation for an optimizer
        const solve = customVjp(
        (params: np.Array) => {
        const solution = runOptimizer(params); // opaque solve
        return [solution, { solution, params }];
        },
        (res, g) => {
        // Use the implicit function theorem: dx_star/dtheta = -(dF/dx)^-1 (dF/dtheta)
        const { solution, params } = res;
        const implicitGrad = computeImplicitGrad(solution, params, g);
        return implicitGrad;
        },
        );

    Forward function returning [outputs, residuals].

    Backward function (residuals, cotangents) => inputCotangents.

    const stableLog1pExp = customVjp(
    (x) => {
    const out = np.log(np.add(np.array(1), np.exp(x)));
    return [out, x];
    },
    (x, g) => np.multiply(g, np.subtract(np.array(1), np.exp(np.negative(x)))),
    );

    const dx = grad(stableLog1pExp)(np.array(100.0)); // 1.0, not NaN