# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""JAX gradient endpoint helpers, used in the JAX recipe.
These functions remove the boilerplate from a JAX-backed Tesseract
``tesseract_api.py`` by providing one-line implementations of the
``apply``, ``jacobian``, ``jacobian_vector_product``,
``vector_jacobian_product`` and ``abstract_eval`` endpoints.
"""
from collections.abc import Callable
from typing import Any
import equinox as eqx
import jax
import jax.numpy as jnp
from pydantic import BaseModel
from tesseract_core.runtime.array_encoding import _fast_tobytes
from tesseract_core.runtime.tree_transforms import (
LRUCache,
filter_func,
flatten_with_paths,
set_at_path,
)
_jax_vjp_cache: LRUCache | None = None
def _set_jax_vjp_cache_size(size: int) -> None:
"""Enable or resize the VJP residual cache, discarding existing entries.
Public entry point: :func:`tesseract_core.runtime.experimental.set_jax_vjp_cache_size`.
"""
global _jax_vjp_cache
_jax_vjp_cache = LRUCache(maxsize=size) if size > 0 else None
def _hash_tree(tree: Any) -> int:
"""Compute a hash of a pytree's structure and leaves.
Suitable for use as an :class:`LRUCache` key. Array leaves contribute
their dtype + shape + raw bytes so leaves with identical bytes but
different interpretations (e.g. ``int64[4]`` vs ``int64[2,2]``) don't
collide. Non-array leaves contribute themselves directly; they must be
hashable.
"""
leaves, treedef = jax.tree.flatten(tree)
# jax.PyTreeDef's __hash__ collides on dicts with different keys, so we
# use its string form as the discriminator instead.
items: list = [str(treedef)]
for leaf in leaves:
if hasattr(leaf, "tobytes"):
items.append((leaf.dtype.str, leaf.shape, bytes(_fast_tobytes(leaf))))
else:
items.append(leaf)
return hash(tuple(items))
[docs]
def jax_apply(apply_jit: Callable, inputs: BaseModel) -> dict:
"""Run ``apply_jit`` and, if caching is enabled, populate the VJP cache.
``apply_jit`` is assumed to already be JIT-compiled (e.g. wrapped with
``@eqx.filter_jit``); this helper does not jit it. The user-facing
``apply`` endpoint may want to do pre/post-processing around the call,
so we cannot wrap it in a jit internally.
When :data:`_jax_vjp_cache` is set (see
:func:`tesseract_core.runtime.experimental.set_jax_vjp_cache_size`),
the forward pass is run via ``jax.vjp`` so the resulting backward
function can be stashed and reused by a later :func:`jax_vjp` call.
Otherwise this is just ``apply_jit(inputs.model_dump())``.
"""
inputs_dict = inputs.model_dump()
if _jax_vjp_cache is None:
return apply_jit(inputs_dict)
# Compute forward pass via jax.vjp to cache residuals for a potential
# subsequent vector_jacobian_product call. eqx.partition separates
# array (differentiable) from non-array outputs; has_aux tells jax.vjp
# to only differentiate through the array outputs.
def _apply_for_vjp(inputs_dict: dict) -> tuple:
out = apply_jit(inputs_dict)
diff_out, static_out = eqx.partition(out, eqx.is_array)
return diff_out, static_out
diff_primals, vjp_func, static_primals = jax.vjp(
_apply_for_vjp, inputs_dict, has_aux=True
)
out = eqx.combine(diff_primals, static_primals)
cotangent_template = jax.tree.map(jnp.zeros_like, diff_primals)
_jax_vjp_cache.put(_hash_tree(inputs_dict), (vjp_func, cotangent_template))
return out
[docs]
def jax_vjp(
apply_jit: Callable,
inputs: BaseModel,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
) -> dict[str, Any]:
"""Compute the vector-Jacobian product.
Reuses the cached backward from a prior :func:`jax_apply` call when one is available (see
:func:`tesseract_core.runtime.experimental.set_jax_vjp_cache_size`); otherwise falls
through to a JIT-compiled ``jax.vjp`` evaluation. The JIT
compilation happens internally on the first miss for a given
(input shape/dtype, path subset) combination and is cached for reuse.
"""
inputs_dict = inputs.model_dump()
# Use get (not pop) so the cached residuals can serve multiple sequential
# vjp calls on the same inputs -- for example, when tesseract-jax's
# value_and_grad is followed by jax.jacrev, which decomposes into many
# vjp calls per output basis vector.
if (
_jax_vjp_cache is not None
and (cached := _jax_vjp_cache.get(_hash_tree(inputs_dict))) is not None
):
vjp_func, cotangent_template = cached
full_cotangent = jax.tree.map(jnp.zeros_like, cotangent_template)
full_cotangent = set_at_path(full_cotangent, cotangent_vector)
(all_input_cotangents,) = vjp_func(full_cotangent)
return flatten_with_paths(all_input_cotangents, include_paths=vjp_inputs)
# Cache disabled or cache miss: fall back to original JIT-compiled path.
return _vjp_jit(
apply_jit,
inputs_dict,
tuple(vjp_inputs),
tuple(vjp_outputs),
cotangent_vector,
)
[docs]
def jax_jvp(
apply_jit: Callable,
inputs: BaseModel,
jvp_inputs: set[str],
jvp_outputs: set[str],
tangent_vector: dict[str, Any],
) -> dict[str, Any]:
"""Compute the Jacobian-vector product via :func:`jax.jvp`.
JIT compilation is applied internally and cached per
``(input shape/dtype, jvp_inputs, jvp_outputs)`` combination.
"""
return _jvp_jit(
apply_jit,
inputs.model_dump(),
tuple(jvp_inputs),
tuple(jvp_outputs),
tangent_vector,
)
[docs]
def jax_jacobian(
apply_jit: Callable,
inputs: BaseModel,
jac_inputs: set[str],
jac_outputs: set[str],
) -> dict[str, dict[str, Any]]:
"""Compute the Jacobian via :func:`jax.jacrev`.
JIT compilation is applied internally and cached per
``(input shape/dtype, jac_inputs, jac_outputs)`` combination.
"""
return _jac_jit(
apply_jit, inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs)
)
[docs]
def jax_abstract_eval(apply_jit: Callable, abstract_inputs: BaseModel) -> dict:
"""Calculate the output shape of ``apply_jit`` from the shape of its inputs."""
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
jaxified_inputs = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
abstract_inputs.model_dump(),
is_leaf=is_shapedtype_dict,
)
dynamic_inputs, static_inputs = eqx.partition(
jaxified_inputs, filter_spec=is_shapedtype_struct
)
def wrapped_apply(dynamic_inputs: dict) -> dict:
inputs = eqx.combine(static_inputs, dynamic_inputs)
return apply_jit(inputs)
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
return jax.tree.map(
lambda x: (
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x
),
jax_shapes,
is_leaf=is_shapedtype_struct,
)
#
# Internal jit-compiled fallbacks (used on cache miss or when caching is disabled).
#
@eqx.filter_jit
def _jac_jit(
apply_jit: Callable,
inputs: dict,
jac_inputs: tuple[str, ...],
jac_outputs: tuple[str, ...],
) -> dict:
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
return jax.jacrev(filtered_apply)(
flatten_with_paths(inputs, include_paths=jac_inputs)
)
@eqx.filter_jit
def _jvp_jit(
apply_jit: Callable,
inputs: dict,
jvp_inputs: tuple[str, ...],
jvp_outputs: tuple[str, ...],
tangent_vector: dict,
) -> dict:
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
return jax.jvp(
filtered_apply,
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
[tangent_vector],
)[1]
@eqx.filter_jit
def _vjp_jit(
apply_jit: Callable,
inputs: dict,
vjp_inputs: tuple[str, ...],
vjp_outputs: tuple[str, ...],
cotangent_vector: dict,
) -> dict:
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
_, vjp_func = jax.vjp(
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
)
return vjp_func(cotangent_vector)[0]