Quantity#
Creating Quantity Instances#
Quantity objects are created by passing a value and a unit to the Quantity constructor (with Q as a shorthand).
>>> import unxt as u
>>> u.Quantity(5, "m")
Quantity(Array(5, dtype=int32, weak_type=True), unit='m')
The constructor will automatically convert the value to a jax.Array (if it is not already one) and convert the unit to a Unit object.
The value and unit of a Quantity object can be accessed using the value and unit attributes, respectively:
>>> q = u.Q([1, 2, 3, 5], "m")
>>> q.value
Array([1, 2, 3, 5], dtype=int32)
>>> q.unit
Unit("m")
If you want more flexible options to create a Quantity, you can use the Quantity.from_ class method. This uses multiple dispatch to determine the appropriate constructor based on the input arguments.
>>> u.Q.from_(5, "m") # same as Quantity(5, "m")
Quantity(Array(5, dtype=int32, ...), unit='m')
>>> u.Q.from_({"value": [1, 2, 3], "unit": "m"})
Quantity(Array([1, 2, 3], dtype=int32), unit='m')
>>> u.Q.from_(q) # from another Quantity object
Quantity(Array([1, 2, 3, 5], dtype=int32), unit='m')
>>> u.Q.from_(5, "m", dtype=float) # specify the dtype
Quantity(Array(5., dtype=float32), unit='m')
There are many more options available with Quantity.from_. For a complete list of options run Quantity.from_.methods in an IDE.
>>> u.Quantity.from_.methods
List of 9 method(s):
[0] from_(cls: type, value: typing.Union[ArrayLike, ...], unit: typing.Any, *,
dtype) -> unxt...quantity...AbstractQuantity
<function AbstractQuantity.from_ at ...>
...
Quantity.from_ assists with interfacing with other libraries, e.g. see Interop with Astropy.
Converting to Different Units#
Quantity objects can be converted to different units and values in those units. If you prefer an object-oriented approach, use the uconvert method.
>>> q = u.Q(5, "m")
>>> q.uconvert("cm")
Quantity(Array(500., dtype=float32, ...), unit='cm')
Note
The Astropy API .to is also available for Quantity objects.
>>> q.to("cm")
Quantity(Array(500., dtype=float32, ...), unit='cm')
If you prefer a more functional approach, use the uconvert function.
>>> u.uconvert("cm", q)
Quantity(Array(500., dtype=float32, ...), unit='cm')
Low-Level Value Conversion with uconvert_value#
For operations on raw numerical values without Quantity wrapping, use the lower-level uconvert_value function. This is useful when you need to perform batch conversions or work with JAX transformations.
Basic usage with unit strings:
>>> # Convert raw values between units
>>> u.uconvert_value("km", "m", 1000)
1.0
>>> import jax.numpy as jnp
>>> u.uconvert_value("km", "m", jnp.array([1000, 2000, 5000]))
Array([1., 2., 5.], dtype=float32, ...)
With unit objects:
>>> u.uconvert_value(u.unit("km"), u.unit("m"), 5000)
5.0
Convenience dispatch for Quantity objects:
The uconvert_value function also provides a convenience dispatch that works directly with Quantity objects, allowing you to use the lower-level function without breaking compatibility:
>>> u.uconvert_value("km", "m", q)
Quantity(Array(0.005, dtype=float32, ...), unit='km')
This dispatch just calls uconvert so you donโt need to extract the value manually.
Relationship to other functions:
uconvert_valueoperates on raw numerical values and returns raw valuesuconvertoperates onQuantityobjects and returnsQuantityobjectsInternally,
uconvertoften delegates touconvert_valuefor the numerical conversion step
Performance considerations:
Use uconvert_value directly when:
Performing batch conversions on arrays
Working inside JAX transformations (jit, vmap, grad)
Avoiding the overhead of
QuantityobjectsMaximum performance is critical
>>> import jax
>>> @jax.jit
... def batch_convert_to_km(values_in_m):
... return u.uconvert_value("km", "m", values_in_m)
>>> batch_convert_to_km(jnp.array([1000., 5000., 10000.]))
Array([ 1., 5., 10.], dtype=float32)
Converting to Values in Different Units#
To convert to the value in the new units, use the ustrip function.
>>> u.ustrip("cm", q)
Array(500., dtype=float32, ...)
Alternatively the ustrip method can be used.
>>> q.ustrip("cm")
Array(500., dtype=float32, ...)
When working with either an array or a Quantity object, you can use the ustrip function with the unxt.quantity.AllowValue flag to allow arrays without units to be passed in, assuming them to be in the correct output units.
>>> import jax.numpy as jnp
>>> u.ustrip(u.quantity.AllowValue, "cm", 500)
500
Note
The Astropy API .to_value is also available for Quantity objects.
>>> q.to_value("cm")
Array(500., dtype=float32, ...)
With reference to jax.Array#
Quantity objects are designed to mirror jax.Array and the Array API.
Note
If you find that a method or property is missing, please open an issue on the GitHub repository.
This means you can perform operations on Quantity objects just like you would with jax.Array.
Arithmetic Operations#
You can perform standard mathematical operations on Quantity objects:
>>> q1 = u.Q(5, "m")
>>> q2 = u.Q(10, "m")
>>> q1 + q2
Quantity(Array(15, dtype=int32, ...), unit='m')
>>> q1 * 1.5
Quantity(Array(7.5, dtype=float32, ...), unit='m')
>>> q1 / q2
Quantity(Array(0.5, dtype=float32, ...), unit='')
>>> q1 ** 2
Quantity(Array(25, dtype=int32, ...), unit='m2')
Comparison Operations#
>>> q1 = u.Q([1., 2, 3], "m")
>>> q2 = u.Q([100., 201, 300], "cm")
>>> q1 < q2
Array([False, True, False], dtype=bool)
>>> q1 == q2
Array([ True, False, True], dtype=bool)
Indexing and Slicing#
>>> q = u.Q([1, 2, 3, 4], "m")
>>> q[1]
Quantity(Array(2, dtype=int32), unit='m')
>>> q[1:]
Quantity(Array([2, 3, 4], dtype=int32), unit='m')
Array Updates#
unxt supports JAX-style array updates. See ๐ช JAX - The Sharp Bits ๐ช for more details.
>>> q = u.Q([1., 2, 3, 4], "m")
>>> q.at[2].set(u.Q(30.1, "cm"))
Quantity(Array([1. , 2. , 0.301, 4. ], dtype=float32), unit='m')
JAX Functions#
JAX function normally only support pure JAX arrays.
>>> import jax.numpy as jnp # regular JAX
>>> x = u.Q([1, 2, 3], "m")
>>> try: jnp.square(x)
... except TypeError: print("not a pure JAX array")
not a pure JAX array
We use quax to enable Quantity support across most of the JAX ecosystem! See the quax docs for implementation details. The short version is that you can use Quantity in JAX functions so long they pass through a quax.quaxify call. Here are a few examples:
This is the way to โquaxifyโ a JAX function. A powerful feature of quaxify is that it enables Quantity support through all the JAX functions inside the top function. With unxt you can use normal JAX!
>>> import jax.numpy as jnp # regular JAX
>>> from quax import quaxify
>>> @quaxify # Now it works with Quantity... that's it!
... def func(x, y):
... return jnp.square(x) + jnp.multiply(x, y) # normal JAX
>>> y = u.Q([4, 5, 6], "m")
>>> func(x, y)
Quantity(Array([ 5, 14, 27], dtype=int32), unit='m2')
quaxed is a convenience library that pre-โquaxifyโs JAX functions. Itโs a drop-in replacement for much of JAX.
>>> import quaxed.numpy as jnp # pre-quaxified JAX
>>> jnp.square(x) + jnp.multiply(x, y)
Quantity(Array([ 5, 14, 27], dtype=int32), unit='m2')
quaxed is totally optional. You can quax.quaxify manually, to only decorate your top-level functions or to call 3rd party functions.
Attention
Quantity should support all JAX functions. If you find a function that doesnโt work, please open an issue on the GitHub repository.
Pretty Printing#
Quantity objects support the wadler_lindig library for pretty printing.
>>> import wadler_lindig as wl
>>> q = u.Q([1, 2, 3], "m")
>>> wl.pprint(q) # The default pretty printing
Quantity(i32[3], unit='m')
The type parameter can be included in the representation:
>>> wl.pprint(q, include_params=True)
Quantity['length'](i32[3], unit='m')
The str method uses this as well:
>>> print(q)
Quantity['length']([1, 2, 3], unit='m')
Arrays can be printed in full:
>>> wl.pprint(q, short_arrays=False)
Quantity(Array([1, 2, 3], dtype=int32), unit='m')
The repr method uses this setting:
>>> print(repr(q))
Quantity(Array([1, 2, 3], dtype=int32), unit='m')
The units can be turned from a named argument to a positional argument by setting named_unit=False:
>>> wl.pprint(q, named_unit=False)
Quantity(i32[3], 'm')
Instead of printing the value as either a full Array or a short array, you can compactify the value to its compact Array form:
>>> wl.pprint(q, short_arrays="compact")
Quantity([1, 2, 3], unit='m')
For more compact output, the Quantity class has a short name Q that can be used by setting use_short_name=True:
>>> wl.pprint(q, use_short_name=True)
Q(i32[3], unit='m')
The short name can be combined with other printing options:
>>> wl.pprint(q, use_short_name=True, include_params=True)
Q['length'](i32[3], unit='m')
>>> wl.pprint(q, use_short_name=True, short_arrays="compact")
Q([1, 2, 3], unit='m')
See the wadler_lindig documentation for more details on the pretty printing options.
Specialized Quantity Objects#
Working with Angle Objects#
The Angle class is a specialized quantity for representing angular measurements, similar to Quantity but with additional features and constraints tailored for angles.
Creating Angles#
You can create an Angle just like a Quantity, by specifying a value and a unit with angular dimensions:
>>> a = u.Angle(45, "deg")
>>> a
Angle(Array(45, dtype=int32, weak_type=True), unit='deg')
Just like Quantity, you can flexibly create Angle objects using the from_() constructor:
>>> u.Angle.from_(45, "deg")
Angle(Array(45, dtype=int32, weak_type=True), unit='deg')
>>> u.Angle.from_([45, 90], "deg")
Angle(Array([45, 90], dtype=int32), unit='deg')
>>> u.Angle.from_(jnp.array([10, 15, 20]), "deg")
Angle(Array([10, 15, 20], dtype=int32), unit='deg')
Mathematical Operations#
Angle objects support arithmetic operations, broadcasting, and most mathematical functions, just like Quantity:
>>> b = u.Angle(30, "deg")
>>> a + b
Angle(Array(75, dtype=int32, weak_type=True), unit='deg')
>>> 2 * a
Angle(Array(90, dtype=int32, weak_type=True), unit='deg')
>>> a.to("rad")
Angle(Array(0.7853982, dtype=float32, weak_type=True), unit='rad')
For more information on mathematical operations, see the unxt documentation.
Enforced Dimensionality#
Unlike a generic Quantity, the Angle class enforces that the unit must be angular (e.g., degrees, radians). Attempting to use a non-angular unit will raise an error:
>>> try: u.Angle(1, "m")
... except ValueError as e: print(e)
Angle must have units with angular dimensions.
Wrapping Angles#
A key feature of Angle is the ability to wrap values to a specified range, which is useful for keeping angles within a branch cut:
>>> a = u.Angle(370, "deg")
>>> a.wrap_to(u.Q(0, "deg"), u.Q(360, "deg"))
Angle(Array(10, dtype=int32, weak_type=True), unit='deg')
The wrap_to() method has a function counterpart
>>> u.quantity.wrap_to(a, u.Q(0, "deg"), u.Q(360, "deg"))
Angle(Array(10, dtype=int32, weak_type=True), unit='deg')
Working with StaticQuantity Objects#
The StaticQuantity class is a parametric quantity with a static value stored as a NumPy array. It accepts Python scalars and NumPy arrays only, rejecting JAX arrays. This makes it convenient for static arguments in jax.jit or jax.vmap.
>>> import numpy as np
>>> import jax
>>> import jax.numpy as jnp
>>> from functools import partial
>>> import unxt as u
>>> sq = u.StaticQuantity(np.array([1.0, 2.0]), "m")
>>> jq = u.Q(jnp.array([1.0, 1.0]), "m")
>>> @partial(jax.jit, static_argnames=("sq",))
... def add(jq, sq):
... return jq + u.Q(jnp.asarray(sq.value), sq.unit)
>>> add(jq, sq)
Quantity(Array([2., 3.], dtype=float32), unit='m')
Working with StaticValue in Quantity#
If you want a regular Quantity but need its value to be static (for hashing or static JAX arguments), wrap the value with StaticValue. Arithmetic behaves like the wrapped array, and StaticValue + StaticValue returns a StaticValue:
>>> import numpy as np
>>> import jax.numpy as jnp
>>> import unxt as u
>>> sv = u.quantity.StaticValue(np.array([1.0, 2.0]))
>>> q_static = u.Q(sv, "m")
>>> q = u.Q(jnp.array([3.0, 4.0]), "m")
>>> q_static + q
Quantity(Array([4., 6.], dtype=float32), unit='m')
See also