What is JAX? JAX is NumPy but more with various functionalities designed to make machine learning research faster. It introduces a functional programming paradigm and has other valuable features for highperformance machine learning training.
But before we get into all of those details, I want to give an unjaxy introduction to set the stage for those new to ML. We will write a tiny neural network and see where we run into (design) issues, and learn a thing or two about JAX by addressing those issues.
You may want to skip to the last portion of the following section if you are familiar with the basics.
An unjaxy introduction
For now, we can treat JAX just like NumPy.
Let’s take this opportunity to construct a tiny neural network that outputs a probability value.
Matmul
The most basic building block is matrix multiplication. Remember, JAX is like NumPy. So you can do a lot of the stuff you do in NumPy similarly in JAX.
import jax.numpy as jnp
input = jnp.array([[2., 3.], [4., 5.], [6., 7]])
kernel = jnp.array([[9., 1., 0., 0.], [1., 0., 1., 0]])
In fact, you can use
import jax.numpy as np
and use JAX like NumPy. It should work for the vast majority of the cases, but let’s stick with importing as jnp
for clarity.
Now, you can multiply these two matrices as follows:
jnp.matmul(input, kernel)
or
jnp.dot(input, kernel)
In fact, you can also do the following as well:
input @ kernel
Here is the output:
DeviceArray([[20., 2., 2., 0.],
[40., 4., 4., 0.],
[60., 6., 6., 0.]], dtype=float32)
Curiously, the output is now an instance of DeviceArray
instead of a list or NumPy/JAX array. We will learn more about DeviceArray
later.
The forward pass
Let’s specify a list of layers and then run the inputs through the list of layers. Each layer will have a ReLU nonlinearity, and the final output will be a sigmoid.
def predict(input, layers):
for layer_index, kernel in enumerate(layers):
input = input @ kernel
if layer_index != len(layers)  1:
input = jax.nn.relu(input)
else:
input = jax.nn.sigmoid(input)
return input
kernel1 = jnp.array([[9., 1., 0., 0.], [1., 0., 1., 0.]])
kernel2 = jnp.array([[1.], [10.], [0.1], [2.]])
layers = [kernel1, kernel2]
input = jnp.array([[2., 3.], [4., 5.], [6., 7.]])
predict(input, layers)
Output:
DeviceArray([[0.785835 ],
[0.81757444],
[0.8455348 ]], dtype=float32)
Loss
We will use the binary crossentropy loss:
eps = 1e8
def loss_fn(input, target, layers):
output = predict(input, layers)
return jnp.mean(
target * jnp.log(output + eps) + (1  target) * jnp.log(1  output + eps))
In practice, we may use a numerically stable version of the above loss, which will take as input presigmoid values and then calculate the loss.
Backward pass
Ok, so now we have to calculate the gradient of the loss with respect to all the variables, and then update the layer weights and biases based on the gradient.
So how do we calculate the gradient?
In TensorFlow, the way this is done is with a tf.GradientTape
.
with tf.GradientTape() as g:
output = predict(input, ..)
loss = loss_fn(output, label)
grads = g.gradient(loss, weights)
and so on.
In JAX, things are a bit different. All you have to do is to use jax.grad()
to calculate the gradient.
And, while in TF we calculated the gradient on the actual loss computed, in JAX, we will call jax.grad
on the
loss function itself (jax.grad(loss_fn)
).
But if you call jax.grad(loss_fn)(input, target, layers)
, you will see the following output:
DeviceArray([[0.0397228 , 0.04369506],
[0.05030257, 0.05533284],
[0.0596227 , 0.06558497]], dtype=float32)
But what we want is to take the derivative with respect to the weights (i.e. layers).
By default, jax.grad
will take derivative with respect to the 0th argument, which in this case is the input. That’s why what you see above has the same shape as the input array. To fix this, we need to specify a argnums
to the argument.
jax.grad(loss, argnums=2)(input, target, layers)
Output:
[DeviceArray([[ 0.6383921 , 6.3839216 , 0.06383922, 0. ],
[ 0.78804016, 7.880402 , 0.07880402, 0. ]], dtype=float32),
DeviceArray([[6.5335693 ],
[0.6383921 ],
[0.78804016],
[0. ]], dtype=float32)]
We can now see that the shape of the gradients matches the shape of the layers.
Update step
Putting all of this together, our update step can look like this:
def update_step(input, target, layers):
grads = jax.grad(loss, argnums=2)(input, target, layers)
for index in range(len(layers)):
layers[index] = 0.1 * grads[index]
return layers
This is basically using SGD with learning rate of 0.1.
If you run
update_step(input, target, layers)
you will get
[DeviceArray([[ 8.9361610e+00, 1.6383922e+00, 6.3839220e03,
0.0000000e+00],
[ 9.2119598e01, 7.8804022e01, 9.9211961e01,
0.0000000e+00]], dtype=float32), DeviceArray([[ 0.34664303],
[10.063839 ],
[ 0.02119599],
[ 2. ]], dtype=float32)]
We have successfully taken one training step of our tiny neural network!
Reflecting back
Let’s stop here with this toy example and reflect on a few things:

From a quick glance at the
predict
method, it is hard to interpret what the network architecture is. If you want to mix other kind of layers, like convolutional layers, the code will be completely unreadable. If you have used Keras, you will know that there are better ways of expressing models, and the above is not one of them. So we need a better way of expressing the models. Like Keras, there are many other highlevel libraries to make the code readable. We will look at them later. 
The
jax.grad
example above demonstrates a key feature of JAX — the functional programming paradigm. You can also use composition to make nested functions, e.g.jax.grad(jax.grad(loss_fn))
will create a function that takes the second derivative of loss_fn. 
This brings to the other philosophy that JAX embodies, that the code should read like mathematical equation. In TensorFlow, we saw that gradient was called on the output of the function. d(loss_fn)/dw makes a lot of sense.
An actual introduction to JAX
We can now start to describe JAX in more detail, and you should be able to see how JAX addresses the problems we have seen.
As we said before, JAX is basically like NumPy, but more. In a nutshell,
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for highperformance machine learning research.
Let’s dive in:

JAX stands for JAX is Autograd and XLA.

As the name indicates, a central feature of JAX is Autograd, which is used for automatic differentiation. What is cool is that, with JAX, you can differentiate native Python and NumPy code. You can compute higherorder derivatives by calling the
jax.grad
function repeatedly. 
JAX uses XLA to compile and run your code. This means that, it can be run on GPU and TPU as well. In fact, since XLA was originally built for TPUs, you will get very good performance on TPUs when using JAX. Also, even on CPUs, JAX can be much faster than NumPy due to its reliance on XLA. This is because XLA can do all sorts of optimizations. For example, if you express expensive computation over a large array but then slice out only a portion of the array to return, XLA will notice this and run the expensive computation only for the required portion. Another byproduct of using XLA underneath is that the same code can run on CPU/GPU/TPU. As someone who has had a fair amount of experience getting stuff to work on TPUs with TensorFlow, there is quite a bit of work involved in transforming code that was written for CPUs to TPUs.

Just in time compilation (JIT). If you profile the matmul of JAX and compare against the profile of matmul of numpy, you will see a speedup, thanks to XLA. But the above code is still underutilizing the benefit of XLA. Every line is compiled by XLA above, but you can compile blocks of code using XLA as well, for further speedup. This is done using
jax.jit
. You can call it as a function or use it as a decorator.
> %timeit n 100 update_step(input, target, layers)
100 loops, best of 5: 10.3 ms per loop
> update_step_jit = jax.jit(update_step)
> %timeit n 100 update_step_jit(input, target, layers)
The slowest run took 256.52 times longer than the fastest. This could mean that an intermediate result is being cached
100 loops, best of 5: 4.39 µs per loop
 JAX embraces functional programming paradigm. In particular, when we have functions without side effects, i.e. functions that don’t modify some global state, XLA can do a good job of optimizing these functions.
Intermediate JAX
DeviceArray
When you encounter an array stored as DeviceArray, it means that the array is stored on the device (e.g. inside TPU). By default, this value is not returned unless you have requested to print. So this can avoid expensive back and forth.
PRNGKey
For best performance, it is preferred that functions are without sideeffects and don’t store states or access global variables. This means that the usual way of specifying the random number generator’s seed wouldn’t work (we typically just call a function and set the seed to be used by many other functions). So we have to pass in the seed explicitly in JAX.
jax.random.PRNGKey(0)
jax.vmap
jax.vmap
is like map
by Python but adds vectorization.
You can express your computation for a single example, and then use vmap to run the computation for multiple examples at a time (aka adding batch dimension).
This can help with performance as well.
FLAX: A highlevel library for neural network
As we saw above, writing matmuls by hand to define neural networks gets old very quickly. There are in fact several highlevel libraries to improve the ergonomics, and the two most popular ones are FLAX and Haiku. We will look at FLAX here.
Here is an example of how we can define the same neutral network we had before.
class Net(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(4)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return nn.sigmoid(x)
As you can see, this is a lot more readable than what we had. While this looks like an object, which has state, what actually happens is that there are other functions in nn.Module that will convert this class into pure functions.
And since we don’t want to store any state in the function, we need to explicitly get the variables of our neural network by calling init
.
vars = Net().init(jax.random.PRNGKey(0), input)
You can see here that we have to pass the random key as we noted before. If you inspect the vars, they will look like this:
FrozenDict({
params: {
Dense_0: {
kernel: DeviceArray([[ 1.1381536 , 1.0838526 , 0.37998098, 0.15393464],
[ 0.17555283, 0.3848625 , 0.52419275, 1.4104135 ]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: DeviceArray([[ 0.22024584],
[ 0.5676514 ],
[ 0.4185372 ],
[0.3969197 ]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
It also has the bias terms, which we didn’t include before. One quick thing to note is that the shape of kernel is (2, 4), and the shape of the bias is (4,). If the input is of shape (B, 2) [B was set to 3 in our example], then input times kernel should be of shape (B, 4). So how can we add a vector of shape (4,) to that? The answer is broadcasting. This array will get added to each of the B rows, which is exactly what we want.
The other thing to note is that the input’s batch size doesn’t matter.
So, the following, where the input shape is (1, 2), will produce the same result as well:
vars = Net().init(jax.random.PRNGKey(0), jnp.array([[1., 2.]]))
vars
FrozenDict({
params: {
Dense_0: {
kernel: DeviceArray([[ 1.1381536 , 1.0838526 , 0.37998098, 0.15393464],
[ 0.17555283, 0.3848625 , 0.52419275, 1.4104135 ]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: DeviceArray([[ 0.22024584],
[ 0.5676514 ],
[ 0.4185372 ],
[0.3969197 ]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
And in fact, if we use an input of shape (2,), we will still see the same output
vars = Net().init(jax.random.PRNGKey(0), jnp.array([1., 2.]))
vars
FrozenDict({
params: {
Dense_0: {
kernel: DeviceArray([[ 1.1381536 , 1.0838526 , 0.37998098, 0.15393464],
[ 0.17555283, 0.3848625 , 0.52419275, 1.4104135 ]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: DeviceArray([[ 0.22024584],
[ 0.5676514 ],
[ 0.4185372 ],
[0.3969197 ]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
To get the prediction, we can just call:
> Net().apply(vars, input)
DeviceArray([[0.58010364],
[0.4232422 ],
[0.24422583],
[0.60353917],
[0.505488 ],
[0.572509 ],
[0.67478746],
[0.72071725],
[0.53566575],
[0.44790095]], dtype=float32)
You may be wondering why we are initializing the object of Net
again here. This is merely for convenience: since this object is stateless, it is fine to create a new object here.
In fact, this is a great opportunity to show that jax.vmap
can be used as a way to add batch dimension.
> vmap(lambda x: Net().apply(vars, x))(input)
DeviceArray([[0.58010364],
[0.4232422 ],
[0.24422583],
[0.60353917],
[0.505488 ],
[0.572509 ],
[0.67478746],
[0.72071725],
[0.53566575],
[0.44790095]], dtype=float32)
Conclusion
So, there you go. You have learnt the basics of JAX. There is a lot more that we didn’t cover, like using optax
for defining optimizers and loss functions, or how to create realistic training loops when using FLAX.
Let’s reserve that for another day, but hope this can help you get started!