unxt.experimental đź§Ş

unxt.experimental đź§Ş#

Experimental features.

Warning

These features may be removed or changed in the future without notice.

On some occasions JAX’s automatic differentiation functions do not work well with quantities. This is checked by enabling runtime type-checking (see the docs), which will raise an error if a quantity’s units do not match the expected input / output units of a function. In these cases, you can use the functions in this module to provide the units to the automatic differentiation functions. Instead of directly propagating the units through the automatic differentiation functions, the units are stripped and re-applied, while also being provided within the function being AD’d.

To import this experimental module

>>> from unxt import experimental
unxt._src.experimental.grad(fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | UnitBase | CompositeUnit | str, ...])#

Gradient of a function with units.

In general, if you can use quax.quaxify(jax.grad(func)) (or the syntactic sugar quaxed.grad(func)), that’s the better option! The difference from those functions is how this units are supported. quaxify will directly propagate the units through the automatic differentiation functions. But sometimes that doesn’t work and we need to strip the units and re-apply them. This function does that, using the “units” kwarg.

See also

jax.grad

The original JAX gradient function.

Examples

>>> import jax.numpy as jnp
>>> import unxt as u
>>> def cube_volume(x: u.Quantity["length"]) -> u.Quantity["volume"]:
...     return x**3
>>> grad_cube_volume = u.experimental.grad(cube_volume, units=("m",))
>>> grad_cube_volume(u.Q(2.0, "m"))
Quantity(Array(12., dtype=float32, weak_type=True), unit='m2')
Parameters:
Return type:

Callable[[Unpack[TypeVarTuple]], TypeVar(R, bound= AbstractQuantity)]

unxt._src.experimental.hessian(fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | UnitBase | CompositeUnit | str, ...])#

Hessian.

In general, if you can use quax.quaxify(jax.hessian(func)) (or the syntactic sugar quax.hessian(func)), that’s the better option! The difference from those functions is how this units are supported. quaxify will directly propagate the units through the automatic differentiation functions. But sometimes that doesn’t work and we need to strip the units and re-apply them. This function does that, using the units kwarg.

See also

jax.hessian

The original JAX hessian function.

Examples

>>> import jax.numpy as jnp
>>> import unxt as u
>>> def cubbe_volume(x: u.Quantity["length"]) -> u.Quantity["volume"]:
...     return x**3
>>> hessian_cubbe_volume = u.experimental.hessian(cubbe_volume, units=("m",))
>>> hessian_cubbe_volume(u.Q(2.0, "m"))
BareQuantity(Array(12., dtype=float32, weak_type=True), unit='m')
Parameters:
Return type:

Callable[[Unpack[TypeVarTuple]], TypeVar(R, bound= AbstractQuantity)]

unxt._src.experimental.jacfwd(fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | UnitBase | CompositeUnit | str, ...])#

Jacobian of fun evaluated column-by-column using forward-mode AD.

In general, if you can use quax.quaxify(jax.jacfwd(func)) (or the syntactic sugar quax.jacfwd(func)), that’s the better option! The difference from those functions is how this units are supported. quaxify will directly propagate the units through the automatic differentiation functions. But sometimes that doesn’t work and we need to strip the units and re-apply them. This function does that, using the units kwarg.

See also

jax.jacfwd

The original JAX jacfwd function.

Examples

>>> import jax.numpy as jnp
>>> import unxt as u
>>> def cubbe_volume(x: u.Quantity["length"]) -> u.Quantity["volume"]:
...     return x**3
>>> jacfwd_cubbe_volume = u.experimental.jacfwd(cubbe_volume, units=("m",))
>>> jacfwd_cubbe_volume(u.Q(2.0, "m"))
BareQuantity(Array(12., dtype=float32, weak_type=True), unit='m2')
Parameters:
Return type:

Callable[[Unpack[TypeVarTuple]], TypeVar(R, bound= AbstractQuantity)]