tech,

An introduction to JAX

Mahesh Mahesh Follow Apr 04, 2021 · 13 mins read
Share this

What is JAX? JAX is NumPy but more with various functionalities designed to make machine learning research faster. It introduces a functional programming paradigm and has other valuable features for high-performance machine learning training.

But before we get into all of those details, I want to give an unjaxy introduction to set the stage for those new to ML. We will write a tiny neural network and see where we run into (design) issues, and learn a thing or two about JAX by addressing those issues.

You may want to skip to the last portion of the following section if you are familiar with the basics.

An unjaxy introduction

For now, we can treat JAX just like NumPy.

Let’s take this opportunity to construct a tiny neural network that outputs a probability value.

Matmul

The most basic building block is matrix multiplication. Remember, JAX is like NumPy. So you can do a lot of the stuff you do in NumPy similarly in JAX.

import jax.numpy as jnp

input = jnp.array([[2., 3.], [4., 5.], [6., 7]])
kernel = jnp.array([[9., 1., 0., 0.], [1., 0., 1., 0]])

In fact, you can use

import jax.numpy as np

and use JAX like NumPy. It should work for the vast majority of the cases, but let’s stick with importing as jnp for clarity.

Now, you can multiply these two matrices as follows:

jnp.matmul(input, kernel)

or

jnp.dot(input, kernel)

In fact, you can also do the following as well:

input @ kernel

Here is the output:

DeviceArray([[20.,  2.,  2.,  0.],
             [40.,  4.,  4.,  0.],
             [60.,  6.,  6.,  0.]], dtype=float32)

Curiously, the output is now an instance of DeviceArray instead of a list or NumPy/JAX array. We will learn more about DeviceArray later.

The forward pass

Let’s specify a list of layers and then run the inputs through the list of layers. Each layer will have a ReLU non-linearity, and the final output will be a sigmoid.

def predict(input, layers):
  for layer_index, kernel in enumerate(layers):
    input = input @ kernel
    if layer_index != len(layers) - 1:
      input = jax.nn.relu(input)
    else:
      input = jax.nn.sigmoid(input)
  return input

kernel1 = jnp.array([[9., 1., 0., 0.], [1., 0., 1., 0.]])
kernel2 = jnp.array([[1.], [-10.], [0.1], [2.]])

layers = [kernel1, kernel2]
input = jnp.array([[2., 3.], [4., 5.], [6., 7.]])

predict(input, layers)

Output:

DeviceArray([[0.785835  ],
             [0.81757444],
             [0.8455348 ]], dtype=float32)

Loss

We will use the binary cross-entropy loss:

eps = 1e-8
def loss_fn(input, target, layers):
  output = predict(input, layers)
  return -jnp.mean(
      target * jnp.log(output + eps) + (1 - target) * jnp.log(1 - output + eps))

In practice, we may use a numerically stable version of the above loss, which will take as input pre-sigmoid values and then calculate the loss.

Backward pass

Ok, so now we have to calculate the gradient of the loss with respect to all the variables, and then update the layer weights and biases based on the gradient.

So how do we calculate the gradient?

In TensorFlow, the way this is done is with a tf.GradientTape.

with tf.GradientTape() as g:
  output = predict(input, ..)
  loss = loss_fn(output, label)
  grads = g.gradient(loss, weights)

and so on.

In JAX, things are a bit different. All you have to do is to use jax.grad() to calculate the gradient. And, while in TF we calculated the gradient on the actual loss computed, in JAX, we will call jax.grad on the loss function itself (jax.grad(loss_fn)).

But if you call jax.grad(loss_fn)(input, target, layers), you will see the following output:

DeviceArray([[-0.0397228 ,  0.04369506],
             [-0.05030257,  0.05533284],
             [-0.0596227 ,  0.06558497]], dtype=float32)

But what we want is to take the derivative with respect to the weights (i.e. layers).

By default, jax.grad will take derivative with respect to the 0th argument, which in this case is the input. That’s why what you see above has the same shape as the input array. To fix this, we need to specify a argnums to the argument.

jax.grad(loss, argnums=2)(input, target, layers)

Output:

[DeviceArray([[ 0.6383921 , -6.3839216 ,  0.06383922,  0.        ],
              [ 0.78804016, -7.880402  ,  0.07880402,  0.        ]],            dtype=float32),
 DeviceArray([[6.5335693 ],
              [0.6383921 ],
              [0.78804016],
              [0.        ]], dtype=float32)]

We can now see that the shape of the gradients matches the shape of the layers.

Update step

Putting all of this together, our update step can look like this:

def update_step(input, target, layers):
  grads = jax.grad(loss, argnums=2)(input, target, layers)
  for index in range(len(layers)):
    layers[index] -= 0.1 * grads[index]
  return layers

This is basically using SGD with learning rate of 0.1.

If you run

update_step(input, target, layers)

you will get

[DeviceArray([[ 8.9361610e+00,  1.6383922e+00, -6.3839220e-03,
                0.0000000e+00],
              [ 9.2119598e-01,  7.8804022e-01,  9.9211961e-01,
                0.0000000e+00]], dtype=float32), DeviceArray([[  0.34664303],
              [-10.063839  ],
              [  0.02119599],
              [  2.        ]], dtype=float32)]

We have successfully taken one training step of our tiny neural network!

Reflecting back

Let’s stop here with this toy example and reflect on a few things:

  1. From a quick glance at the predict method, it is hard to interpret what the network architecture is. If you want to mix other kind of layers, like convolutional layers, the code will be completely unreadable. If you have used Keras, you will know that there are better ways of expressing models, and the above is not one of them. So we need a better way of expressing the models. Like Keras, there are many other high-level libraries to make the code readable. We will look at them later.

  2. The jax.grad example above demonstrates a key feature of JAX — the functional programming paradigm. You can also use composition to make nested functions, e.g. jax.grad(jax.grad(loss_fn)) will create a function that takes the second derivative of loss_fn.

  3. This brings to the other philosophy that JAX embodies, that the code should read like mathematical equation. In TensorFlow, we saw that gradient was called on the output of the function. d(loss_fn)/dw makes a lot of sense.

An actual introduction to JAX

We can now start to describe JAX in more detail, and you should be able to see how JAX addresses the problems we have seen.

As we said before, JAX is basically like NumPy, but more. In a nutshell,

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Let’s dive in:

  1. JAX stands for JAX is Autograd and XLA.

  2. As the name indicates, a central feature of JAX is Autograd, which is used for automatic differentiation. What is cool is that, with JAX, you can differentiate native Python and NumPy code. You can compute higher-order derivatives by calling the jax.grad function repeatedly.

  3. JAX uses XLA to compile and run your code. This means that, it can be run on GPU and TPU as well. In fact, since XLA was originally built for TPUs, you will get very good performance on TPUs when using JAX. Also, even on CPUs, JAX can be much faster than NumPy due to its reliance on XLA. This is because XLA can do all sorts of optimizations. For example, if you express expensive computation over a large array but then slice out only a portion of the array to return, XLA will notice this and run the expensive computation only for the required portion. Another byproduct of using XLA underneath is that the same code can run on CPU/GPU/TPU. As someone who has had a fair amount of experience getting stuff to work on TPUs with TensorFlow, there is quite a bit of work involved in transforming code that was written for CPUs to TPUs.

  4. Just in time compilation (JIT). If you profile the matmul of JAX and compare against the profile of matmul of numpy, you will see a speedup, thanks to XLA. But the above code is still underutilizing the benefit of XLA. Every line is compiled by XLA above, but you can compile blocks of code using XLA as well, for further speedup. This is done using jax.jit. You can call it as a function or use it as a decorator.

> %timeit -n 100 update_step(input, target, layers)
100 loops, best of 5: 10.3 ms per loop
> update_step_jit = jax.jit(update_step)
> %timeit -n 100 update_step_jit(input, target, layers)
The slowest run took 256.52 times longer than the fastest. This could mean that an intermediate result is being cached 
100 loops, best of 5: 4.39 µs per loop
  1. JAX embraces functional programming paradigm. In particular, when we have functions without side effects, i.e. functions that don’t modify some global state, XLA can do a good job of optimizing these functions.

Intermediate JAX

DeviceArray

When you encounter an array stored as DeviceArray, it means that the array is stored on the device (e.g. inside TPU). By default, this value is not returned unless you have requested to print. So this can avoid expensive back and forth.

PRNGKey

For best performance, it is preferred that functions are without side-effects and don’t store states or access global variables. This means that the usual way of specifying the random number generator’s seed wouldn’t work (we typically just call a function and set the seed to be used by many other functions). So we have to pass in the seed explicitly in JAX.

jax.random.PRNGKey(0)

jax.vmap

jax.vmap is like map by Python but adds vectorization. You can express your computation for a single example, and then use vmap to run the computation for multiple examples at a time (aka adding batch dimension). This can help with performance as well.

FLAX: A high-level library for neural network

As we saw above, writing matmuls by hand to define neural networks gets old very quickly. There are in fact several high-level libraries to improve the ergonomics, and the two most popular ones are FLAX and Haiku. We will look at FLAX here.

Here is an example of how we can define the same neutral network we had before.

class Net(nn.Module):
  
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(4)(x)
    x = nn.relu(x)
    x = nn.Dense(1)(x)
    return nn.sigmoid(x)

As you can see, this is a lot more readable than what we had. While this looks like an object, which has state, what actually happens is that there are other functions in nn.Module that will convert this class into pure functions.

And since we don’t want to store any state in the function, we need to explicitly get the variables of our neural network by calling init.

vars = Net().init(jax.random.PRNGKey(0), input)

You can see here that we have to pass the random key as we noted before. If you inspect the vars, they will look like this:

FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[ 1.1381536 , -1.0838526 ,  0.37998098,  0.15393464],
                         [ 0.17555283, -0.3848625 ,  0.52419275, -1.4104135 ]],            dtype=float32),
            bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
        },
        Dense_1: {
            kernel: DeviceArray([[ 0.22024584],
                         [ 0.5676514 ],
                         [ 0.4185372 ],
                         [-0.3969197 ]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

It also has the bias terms, which we didn’t include before. One quick thing to note is that the shape of kernel is (2, 4), and the shape of the bias is (4,). If the input is of shape (B, 2) [B was set to 3 in our example], then input times kernel should be of shape (B, 4). So how can we add a vector of shape (4,) to that? The answer is broadcasting. This array will get added to each of the B rows, which is exactly what we want.

The other thing to note is that the input’s batch size doesn’t matter.

So, the following, where the input shape is (1, 2), will produce the same result as well:

vars = Net().init(jax.random.PRNGKey(0), jnp.array([[1., 2.]]))
vars
FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[ 1.1381536 , -1.0838526 ,  0.37998098,  0.15393464],
                         [ 0.17555283, -0.3848625 ,  0.52419275, -1.4104135 ]],            dtype=float32),
            bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
        },
        Dense_1: {
            kernel: DeviceArray([[ 0.22024584],
                         [ 0.5676514 ],
                         [ 0.4185372 ],
                         [-0.3969197 ]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

And in fact, if we use an input of shape (2,), we will still see the same output

vars = Net().init(jax.random.PRNGKey(0), jnp.array([1., 2.]))
vars
FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[ 1.1381536 , -1.0838526 ,  0.37998098,  0.15393464],
                         [ 0.17555283, -0.3848625 ,  0.52419275, -1.4104135 ]],            dtype=float32),
            bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
        },
        Dense_1: {
            kernel: DeviceArray([[ 0.22024584],
                         [ 0.5676514 ],
                         [ 0.4185372 ],
                         [-0.3969197 ]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

To get the prediction, we can just call:

> Net().apply(vars, input)

DeviceArray([[0.58010364],
             [0.4232422 ],
             [0.24422583],
             [0.60353917],
             [0.505488  ],
             [0.572509  ],
             [0.67478746],
             [0.72071725],
             [0.53566575],
             [0.44790095]], dtype=float32)

You may be wondering why we are initializing the object of Net again here. This is merely for convenience: since this object is stateless, it is fine to create a new object here.

In fact, this is a great opportunity to show that jax.vmap can be used as a way to add batch dimension.

> vmap(lambda x: Net().apply(vars, x))(input)

DeviceArray([[0.58010364],
             [0.4232422 ],
             [0.24422583],
             [0.60353917],
             [0.505488  ],
             [0.572509  ],
             [0.67478746],
             [0.72071725],
             [0.53566575],
             [0.44790095]], dtype=float32)

Conclusion

So, there you go. You have learnt the basics of JAX. There is a lot more that we didn’t cover, like using optax for defining optimizers and loss functions, or how to create realistic training loops when using FLAX. Let’s reserve that for another day, but hope this can help you get started!

References:

  1. Jax Quickstart
  2. https://github.com/google/jax#compilation-with-jit
  3. Day 1 Talks: JAX, Flax & Transformers
Join Newsletter
Subscribe to receive updates.
Mahesh
Written by Mahesh Follow
I think and read about Technology and the Human Mind. I am currently an ML engineer at Google, but the opinions here are my own and do not reflect that of my employer. Read More »