Type Checking#
TL;DR#
You can tell functions about the dtype, shape, and dimensions of a Quantity. The dtype and shape information can be checked statically, and all three can be checked at runtime.
In the following example we will define a function that operates on two length-‘N’ (1-D and equally-shaped) float dtype arrays. The function takes a length and a time and returns a velocity.
from jaxtyping import Float
import unxt as u
def function(
x: Float[u.Quantity["length"], "N"],
t: Float[u.Quantity["time"], "N"],
) -> Float[u.Quantity["speed"], "N"]:
return x / t
For information on typing in Python see the built-in typing module. Refer to the jaxtyping library for information on how to annotate the dtype and shape of a Quantity, for example integer arrays or variable / context-dependent shapes. jaxtyping also powers unxt’s runtime typechecking, discussed next.
Runtime Type Checking#
Using jaxtyping,unxt supports runtime type checking, where type annotations are enforced during execution. This is very useful for finding and preventing type-related errors, like passing the wrong type of argument to a function or returning the wrong type of value. To enable runtime type checking on all of unxt, set the environment variable UNXT_ENABLE_RUNTIME_TYPECHECKING to beartype.beartype or any other runtime typecheck backend supported by jaxtyping.
# Enable runtime type checking
export UNXT_ENABLE_RUNTIME_TYPECHECKING="beartype.beartype"
Attention
We recommend enabling runtime type checking during development.
For normal use, try enabling and disabling runtime type checking to assess any performance impact.
The performance overhead associated with runtime type checking should be small but isn’t always – in particular it can affect the time for JAX to jit code. To turn off runtime type checking set the environment variable to None.
# Disable runtime type checking
export UNXT_ENABLE_RUNTIME_TYPECHECKING="None"
Absent the environment variable, this is the default.
Tip
You can set environment variables directly in Python. Execute the following before importing unxt (or any library that imports unxt).
import os
os.environ["UNXT_ENABLE_RUNTIME_TYPECHECKING"] = "beartype.beartype"
In the background unxt checks for the UNXT_ENABLE_RUNTIME_TYPECHECKING environment variable and passes it to jaxtyping’s import hook. jaxtyping also offers function-specific checking through the jaxtyped decorator.
Here’s an example:
>>> from jaxtyping import Shaped, jaxtyped
>>> from beartype import beartype as typechecker # or use any supported typechecker
>>> import unxt as u
>>> @jaxtyped(typechecker=typechecker)
... def velocity(
... x: Shaped[u.Quantity["length"], "N"],
... t: Shaped[u.Quantity["time"], "N"],
... ) -> Shaped[u.Quantity["speed"], "N"]:
... return x / t
>>> x = u.Q([2.], "m")
>>> t = u.Q([1.], "s")
>>> velocity(x, t)
Quantity(Array([2.], dtype=float32), unit='m / s')
Dimension Annotations to Quantity#
In the previous sections Quantity annotations had strings specifying the dimensions of that Quantity. Let’s explore this a little more deeply.
First the theory. Python classes can be ‘parametric’, where the class is parametrized by a set of metadata. The most common example of this is for generics in the builtin typing library where the metadata is type information about a function or object. This is useful for static type checking. However we are not limited to only type information. Classes can implement any form of parametric design (see here). We use the library plum, on which unxt depends, to enhance Python’s parametric functionality and enable Quantity classes to be parametrized by their unit’s dimensions in a way that can be checked by runtime type checkers.
Now for some examples.
>>> import unxt as u
When a Quantity is constructed it is parametrized by the unit’s dimension. This can be specified explicitly
>>> u.Quantity["length"](1, "m")
Quantity['length'](Array(1, dtype=int32, ...), unit='m')
or inferred.
>>> u.Q(1, "m")
Quantity['length'](Array(1, dtype=int32, ...), unit='m')
When given explicitly Quantity will check the input dimensions. Here a length-parametrized Quantity is (correctly) refusing dimensions of time.
>>> try:
... u.Quantity["length"](1, "s")
... except Exception as e:
... print(e)
Physical type mismatch.
That should catch some bugs!
The act of filling a Quantity’s parameters and its construction may be separated
>>> LengthQuantity = u.Quantity["length"]
>>> LengthQuantity
<class 'unxt...Quantity[PhysicalType('length')]'>
This parametric design is how unxt supports runtime type checking.
In unxt not all Quantity classes are parametric. The base class, unxt.quantity.AbstractQuantity is not parametric, nor is the concrete class unxt.quantity.BareQuantity. Parametric classes incur a small performance overhead (generally eliminated in jitted code), which ultra-performance-optimized code might want to avoid, at the cost of inference and checking of the dimensions.
Note
BareQuantity[<dimension>] does nothing and is for informational purposes only.
Check out plum to explore more powerful features of parametric classes.