Performance Optimization with Unitful Quantities#
In this guide, we’ll explore how to think about performance optimization when working with unxt Quantities in JAX. The key insight is understanding where the overhead lives and when it matters.
Key Concepts#
Wrapper overhead: Operations on Quantities have overhead compared to raw JAX arrays – they’re wrapped with unit information.
JIT removes overhead: JAX’s JIT compiler can eliminate much of this wrapper overhead by tracing through the code.
Pytree complexity: Quantities are JAX pytrees, which adds cost when crossing JIT boundaries (converting between traced and non-traced values).
Strategy: The secret to performance is to minimize pytree conversions at the boundary between traced and non-traced code.
Let’s start by importing the libraries we’ll need and setting up some test data.
import functools as ft
import jax
import jax.numpy as jnp
import quax
import unxt as u
We’ll create 1000-element arrays with physical units – these will be our test data throughout this guide.
x = jnp.linspace(0.1, 10.0, 1000)
y = jnp.linspace(11, 100.0, 1000)
qx = u.Q(x, "m")
qy = u.Q(y, "m")
Baseline: Raw JAX Performance#
First, let’s establish our baseline by measuring the performance of a plain JAX function with raw arrays. We’ll time both:
First call: Includes JIT compilation overhead
Repeated calls: Shows the performance after compilation
def func(x, y):
return jnp.sum((x ** 3 - y ** 3) / (x**2 + y**2))
%time jax.block_until_ready(func(x, y))
%timeit jax.block_until_ready(func(x, y))
CPU times: user 126 ms, sys: 6.34 ms, total: 132 ms
Wall time: 122 ms
119 μs ± 432 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
With JIT Compilation#
Now let’s compile the same function with jax.jit. Notice the dramatic speedup; JIT compilation converts Python loops and operations into optimized GPU/CPU kernels.
Key insight: JIT is almost always worth it. The first call takes longer due to compilation, but subsequent calls are much faster.
jitted_func = jax.jit(func)
%time jax.block_until_ready(jitted_func(x, y))
%timeit jax.block_until_ready(jitted_func(x, y))
CPU times: user 47.5 ms, sys: 3.98 ms, total: 51.4 ms
Wall time: 35.6 ms
7.72 μs ± 81.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
quaxify for unit support#
Now let’s apply the same function to Quantities using quax.quaxify. This wraps the function so it can handle Quantity inputs.
quax_func = quax.quaxify(func)
%time jax.block_until_ready(quax_func(qx, qy))
%timeit jax.block_until_ready(quax_func(qx, qy))
CPU times: user 125 ms, sys: 3.01 ms, total: 128 ms
Wall time: 119 ms
12.7 ms ± 46 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
The Problem: Wrapper Overhead#
Wow! That’s much slower than the JIT’d version. This is the wrapper overhead in action. The quaxify decorator has to:
Unwrap the Quantities into arrays
Track unit information
Re-wrap the result back into a Quantity
Do all of this EVERY time the function is called, without JIT optimization
This is why JIT is a necessary ingredient – let’s see what happens when we add JIT:
Solution 1: JIT the Quaxified Function#
By combining jax.jit with quax.quaxify, we eliminate much of the wrapper overhead. JIT compiles away the dynamic dispatch and wrapping logic.
jitted_quax_func = jax.jit(quax.quaxify(func))
%time jax.block_until_ready(jitted_quax_func(x, y))
%timeit jax.block_until_ready(jitted_quax_func(x, y))
CPU times: user 57.6 ms, sys: 0 ns, total: 57.6 ms
Wall time: 41.5 ms
7.65 μs ± 55.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Now let’s pass actual Quantities to the jitted function. Notice there’s some overhead compared to raw arrays, but it’s much better than the non-JIT version!
%time jax.block_until_ready(jitted_quax_func(qx, qy))
%timeit jax.block_until_ready(jitted_quax_func(qx, qy))
CPU times: user 66.4 ms, sys: 3.99 ms, total: 70.4 ms
Wall time: 51.3 ms
20.3 μs ± 215 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Solution 2: Minimize Pytree Conversions at the Boundary#
There’s still a small overhead when passing Quantities across the JIT boundary. This is because Quantities are JAX pytrees – they need to be decomposed before tracing and recomposed after.
The trick? Move the pytree conversion inside the JIT. Here’s the key insight: we create a thin outer JIT’d wrapper that converts arrays to Quantities at the start, calls the inner unitful function, and extracts the result. This way, all the pytree overhead is inside the JIT boundary where it gets compiled away.
@jax.jit
def outer_func(x, y):
qx = u.Q(x, "m")
qy = u.Q(y, "m")
# This calls the unitful function inside a jitted context
out = jitted_quax_func(qx, qy)
return out.ustrip("m")
%time jax.block_until_ready(outer_func(x, y))
%timeit jax.block_until_ready(outer_func(x, y))
CPU times: user 53 ms, sys: 3.01 ms, total: 56 ms
Wall time: 41.6 ms
7.68 μs ± 56.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
def func2(x, y):
out1 = jitted_quax_func(x, y)
out2= jitted_quax_func(x, y)
return out1 + 3 * out2
jitted_quax_func2 = jax.jit(quax.quaxify(func2))
@ft.partial(jax.jit, static_argnames=("usys",))
def outer_func3(x, y, *, usys):
qx = u.Q.from_(x, usys["length"])
qy = u.Q.from_(y, usys["length"])
out = jitted_quax_func2(qx, qy)
return out.ustrip(usys)
usys = u.unitsystems.si
%time jax.block_until_ready(outer_func3(x, y, usys=usys))
%timeit jax.block_until_ready(outer_func3(x, y, usys=usys))
CPU times: user 163 ms, sys: 8.99 ms, total: 172 ms
Wall time: 150 ms
8.73 μs ± 60.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
The Strategy: Outer Wrapper Pattern#
Here’s the optimal pattern:
Accept raw arrays at the outermost function boundary
Create Quantities inside the JIT by wrapping the arrays with units
Call the inner unitful function inside the JIT
Extract the result and return it as a raw array (or strip units if needed)
This way, all unit handling is compiled away, and you only pay the “cost of thinking in units” during actual computation – not at function call boundaries.
Results: Overhead Eliminated#
Wow! Notice the dramatic speedup—we’re nearly as fast as the raw JAX version!
Why this works:
The outer JIT compiles away all the unit wrapping/unwrapping
The
jitted_quax_funcruns inside the trace as a compiled operationThe only overhead is JIT’s normal pytree handling, which is minimal
Important caveat: This is a fixed cost that only appears once per outermost function call. If your function is called once with a million-element array, this optimization is huge. If your function is called a million times with scalar inputs, the overhead per element is negligible.
Summary: How to Think About Performance#
Here are the key takeaways for optimizing performance with unxt Quantities:
Always use JIT for hot code - The overhead of
quaxifyis negligible inside a JIT’d contextMinimize pytree boundary crossings - Use the outer wrapper pattern where you pass raw arrays to the outermost function
Create Quantities inside JIT - This lets the compiler optimize away unit handling
It’s a fixed cost per call - The optimization matters more for functions that process large arrays or are called infrequently
Don’t microoptimize prematurely - Write correct code first. If units make your code clearer, use them. Only optimize the outermost layer if profiling shows it’s necessary.
The bottom line: Use Quantities freely in your code—they’re designed to work well with JAX. When you need performance, apply the outer wrapper pattern to your hot functions. The rest of your codebase can stay clean and unit-aware.