Unxt: The Sharp Bits#

This guide covers common pitfalls and surprising behaviors when working with unxt quantities in JAX. Like JAX itself, unxt has some “sharp bits” — behaviors that might surprise you if you’re coming from NumPy or non-JAX Python scientific computing.

Tip

If you’re new to unxt, start with the Quantity guide first. This guide assumes you’re familiar with basic unxt usage.

Pure Functions and Immutability#

❌ Problem: Trying to Mutate Quantities#

Coming from NumPy or Astropy, you might expect to modify quantities in place:

import jax.numpy as jnp
import unxt as u

# This doesn't work as expected!
q = u.Q([1.0, 2.0, 3.0], "m")
try:
    q[0] = u.Q(5.0, "m")  # ❌ Error or doesn't modify in place
except Exception as e:
    print(f"Error: {e}")

✅ Solution: Use Functional Updates#

Quantities are immutable. Use JAX’s functional update methods:

q = u.Q([1.0, 2.0, 3.0], "m")
new_q = q.at[0].set(u.Q(5.0, "m"))

Or use dataclasses.replace() (or dataclassish.replace()) for more complex updates:

from dataclasses import replace

new_q = replace(q, value=q.value.at[0].set(5.0))
from dataclassish import replace

new_q = replace(q, value=q.value.at[0].set(5.0))

Why? JAX requires pure functions for transformations like jit and grad. Immutability ensures your functions have no side effects.

JAX Control Flow#

❌ Problem: Control Flow on Quantity Values#

JAX control flow requires special handling, independent of unit considerations:

import jax


@jax.jit
def bad_clamp(x: u.Q):
    # ❌ Python if statement with traced values doesn't work
    if x.value > 10.0:
        return u.Q(10.0, x.unit)
    else:
        return x

✅ Solution: Use JAX Control Flow Primitives#

Use jax.lax.cond for traced values, or use jax.numpy.where:

import jax.lax


@jax.jit
def good_clamp(x: u.Q):
    # ✅ Use jax.lax.cond for control flow
    return jax.lax.cond(x.value > 10.0, lambda x: u.Q(10.0, x.unit), lambda x: x, x)


# Or use jax.numpy.where for simple cases
@jax.jit
def clamp_with_where(x: u.Q):
    # ✅ jnp.where works with quantities
    import quaxed.numpy as jnp

    return jnp.where(x.value > 10.0, u.Q(10.0, x.unit), x)

Note: Checking dimensions in control flow is fine because dimensions are static:

@jax.jit
def process(x: u.Q):
    # ✅ This works! Dimension check happens at trace time
    if u.dimension_of(x) == u.dimension("length"):
        return x * 2  # This branch traces
    else:
        return x  # This branch is never traced for length inputs

Operations on Quantities#

❌ Problem: Operating on Quantities with JAX Functions#

Most direct JAX operations don’t work:

import jax.numpy as jnp

q = u.Q([1.0, 2.0, 3.0], "m")

# ❌ These might not preserve units as expected
try:
    jnp.concatenate([q, q])
except Exception as e:
    print(f"Error: {e}")

✅ Solution: Use Quaxified Functions#

Use quaxed for pre-quaxified JAX functions that handle units:

import quaxed.numpy as jnp  # Note: quaxed, not jax

q = u.Q([1.0, 2.0, 3.0], "m")

# ✅ These preserve quantities correctly
result = jnp.concat([q, q])  # Still Quantity
result = jnp.stack([q, q])  # Still Quantity

General rule: Import from quaxed when working with unxt quantities:

# ✅ Do this
import quaxed.numpy as jnp
from quaxed import lax
from quaxed.scipy import special

# ❌ Not this (unless you manually quaxify)
import jax.numpy as jnp

Alternative: You can also quaxify individual functions instead of using quaxed:

import jax.numpy as jnp
import quax

# Quaxify a specific function
quaxified_sum = quax.quaxify(jnp.sum)

positions = u.Q([1.0, 2.0, 3.0], "m")
total = quaxified_sum(positions)  # Preserves units


# Or use as a decorator
@quax.quaxify
def my_function(x):
    return jnp.sum(x**2)


result = my_function(positions)  # Works with quantities

✅ Dimension Checking Works in JIT#

Good news! Dimensions are checked inside JIT:

import jax


@jax.jit
def add_quantities(x, y):
    return x + y


length = u.Q(5.0, "m")
time = u.Q(2.0, "s")

# ✅ This will raise an error at trace time
try:
    add_quantities(length, time)
except Exception as e:
    print(e)

Why it works: The units are static on the Quantity PyTree. unxt can catch dimension mismatches during tracing.

❌ Problem: Units Triggering Recompilation#

The catch is that functions compile separately for each unit, not just dimension:

@jax.jit
def add_lengths(x: u.Q, y: u.Q):
    return x + y


# First call: compiles for meters
result1 = add_lengths(u.Q(5.0, "m"), u.Q(3.0, "m"))

# Second call: RECOMPILES for kilometers (different unit!)
result2 = add_lengths(u.Q(1.0, "km"), u.Q(2.0, "km"))

# Third call: RECOMPILES for mixed units (m and km)
result3 = add_lengths(u.Q(5.0, "m"), u.Q(3.0, "km"))

✅ Solution: Use Consistent Units#

To avoid recompilation, standardize units before calling JIT functions:

@jax.jit
def add_lengths_m(x: u.Q, y: u.Q):
    """Expects both inputs in meters."""
    return x + y


# Convert to standard units before JIT
length_km = u.Q(3.0, "km")
length_m_input = length_km.uconvert("m")

result = add_lengths_m(u.Q(5.0, "m"), length_m_input)

Key insight: Dimensions are checked statically, but each unique combination of units creates a new compiled version.

Mixing Quantity Types#

❌ Problem: Confused by BareQuantity vs Quantity#

Different quantity types have different guarantees:

# What's the difference?
q1 = u.Q(5.0, "m")
q2 = u.quantity.BareQuantity(5.0, "m")
q3 = u.quantity.StaticQuantity(5.0, "m")

❌ Problem: Quantities are Dynamic#

import functools as ft


@ft.partial(jax.jit, static_argnames=("constant",))
def function(x, *, constant=u.Quantity(3.26, "lyr")):
    ...

✅ Solution: Choose the Right Type#

Quantity — Standard choice with full dimension checking:

length = u.Q(5.0, "m")
time = u.Q(2.0, "s")
speed = length / time  # ✅ Creates Quantity with correct dimension

BareQuantity — No dimension checking, just unit tracking:

# Use when you need raw speed, trust your dimensions
length = u.quantity.BareQuantity(5.0, "m")
time = u.quantity.BareQuantity(2.0, "s")
speed = length / time  # Faster, but no dimension validation

StaticQuantity — For compile-time constants:

# Use for constants that won't change
G = u.quantity.StaticQuantity(6.674e-11, "m^3 kg^-1 s^-2")
@ft.partial(jax.jit, static_argnames=("constant",))
def function(x, *, constant=u.StaticQuantity(3.26, "lyr")):
    ...

When to use each:

Type

Use Case

Dimension Checking

Performance

Quantity

Default choice

✅ Full

Good

BareQuantity

Trust your math

❌ None

Better

StaticQuantity

Constants

✅ Full

Best (no tracer)

Dimension Checking Overhead#

❌ Problem: Slow Tests or Development#

Dimension checking uses beartype for runtime validation, which can add overhead:

✅ Solution: Control Runtime Type Checking#

Set the environment variable to control checking:

# Disable for production (faster)
export UNXT_ENABLE_RUNTIME_TYPECHECKING=False

# Enable for testing (safer)
export UNXT_ENABLE_RUNTIME_TYPECHECKING=beartype.beartype

Or in code:

import os

# Fast mode for production
os.environ["UNXT_ENABLE_RUNTIME_TYPECHECKING"] = "False"

# Safe mode for testing
os.environ["UNXT_ENABLE_RUNTIME_TYPECHECKING"] = "beartype.beartype"

Default: Runtime checking is False unless you’re running tests.

Quantity as a PyTree: JAX flattening overhead#

See the Performance Guide

❌ Problem: Quantity is slower than Array#

For most functions, Quantity input is slower than an Array. This is because Quantities are PyTrees that combine a value and a unit. When a PyTree passes through a jax.jit() boundary it is de-structured then re-structured. This process has an associated overhead.

@jax.jit
@quax.quaxify
def func(x, y):
    return jnp.sum((x**3 - y**3) / (x**2 + y**2))


x, y = jnp.asarray([1, 2, 3]), jnp.asarray([4, 5, 6])
func(x, y)

# vs
qx, qy = u.Q(x, "m"), u.Q(y, "m")
func(qx, qy)

✅ Solution: Don’t pass through the outermost jax.jit boundary#

If the PyTree is formed within the jit context then all the nodes of the PyTree (the static parts) are constant-folded by JAX and will not contribute to the run-time, only the time for first compilation.

@ft.partial(jax.jit, static_argnames=("usys",))
def func(x, y, *, usys):
    x = u.Q.from_(x, usys["length"])
    y = u.Q.from_(y, usys["length"])
    return quax.quaxify(jnp.sum)((x**3 - y**3) / (x**2 + y**2))


x, y = jnp.asarray([1, 2, 3]), jnp.asarray([4, 5, 6])
func(x, y, usys=u.unitsystems.si)

This only applies to the outer-most function. Nesting jitted and quaxified functions are fine. The outermost jit boundary handles the constant-folding.

See Also#