ConstCreate a function with a custom vector-Jacobian product (VJP) rule.
This is the jax-js equivalent of JAX's jax.custom_vjp. It lets you
define exactly how reverse-mode differentiation should behave for a
function, which is essential for:
Forward function. Called with the same arguments as the
original function. Must return [outputs, residuals] where
residuals is any pytree of values needed by the backward
pass. The outputs are returned to the caller.
Backward function. Called with (residuals, cotangents) where
residuals is the second element returned by fwd, and
cotangents has the same pytree structure as outputs.
Must return an array (or pytree) of cotangents with the same
structure as the positional arguments to fwd.
A function with the same signature as (...args) => outputs.
When differentiated via vjp(), grad(), or valueAndGrad(),
the custom bwd function is invoked instead of the standard
autodiff backward pass.
// Numerically stable log1pexp with custom gradient
const log1pexp = customVjp(
// fwd: compute output + save residuals for backward
(x: np.Array) => {
const out = np.log(np.add(np.array(1), np.exp(x)));
return [out, x]; // save x as residual
},
// bwd: compute gradient using residuals
(x: np.Array, g: np.Array) => {
return np.multiply(g, np.subtract(np.array(1), np.exp(np.negative(x))));
},
);
const dx = grad(log1pexp)(np.array(100.0));
// → 1.0 (correct), not NaN (what autodiff of log(1+exp(100)) gives)
// Implicit differentiation for an optimizer
const solve = customVjp(
(params: np.Array) => {
const solution = runOptimizer(params); // opaque solve
return [solution, { solution, params }];
},
(res, g) => {
// Use the implicit function theorem: dx_star/dtheta = -(dF/dx)^-1 (dF/dtheta)
const { solution, params } = res;
const implicitGrad = computeImplicitGrad(solution, params, g);
return implicitGrad;
},
);
Create a function with a custom vector-Jacobian product (VJP) rule.
When the returned function is differentiated via vjp, grad, or valueAndGrad, the supplied
bwdfunction is invoked instead of the standard autodiff backward pass. This enables implicit differentiation, numerically stable gradients, and memory-efficient backpropagation.