Compare Deep Learning Toolkits: Theano, TensorFlow, TensorFlow 2.0, PyTorch, and JAX
My recent work on PyTorch Distributed and TorchRec requires me to learn PyTorch 2.0. At the same time, I am learning JAX and XLA from Alpa authors in my spare time. Looking back from these technologies in 2022 at older generations of technologies, it seems that various deep learning toolkits are trying to address the two critical challenges:
- functional transformations, including autograd and parallelizations such as vmap, pmap, and pjit, and,
- heterogeneous computing, the CPU takes care of the control flow and the GPU/TPU takes care of tensor computation and collective communication.
All examples in this document are executable in Colab.
Functional Transformation
I use the term “functional transformation” here to mean changing one procedure into another. The most common example is autograd, which takes the forward procedure written by users and creates the backward procedure, which is usually too complex for users to write. Functional transformation raises the question of how to represent the input and output procedures so that it is easy to write the functional transformation algorithm.
Theano: Explicitly Build the IR
Theano, now known as the Apsara project, was one of the first deep learning toolkits. It has an API that lets users build the IR as a data structure in memory. Then, we can tell Theano to do the autograd and turn the result into a Python function.
import aesara
from aesara import tensor as ata = at.dscalar("a") # Define placeholders, which have no values.
b = at.dscalar("b")c = a * b # c now contains the IR of an expression.TT
dc = aesara.grad(c, a) # Convert the IR in c into another one, dcf_dc = aesara.function([a, b], dc) # Convert the IR into a Python function,
assert f_dc(1.5, 2.5) == 2.5 # so we can call it.
TensorFlow 1.x: A VM to Run the IR
TensorFlow 1.x keeps the idea of building the IR explicitly. In TensorFlow, the above example looks almost the same. With TensorFlow 1.x, the main difference is that we don’t turn the backward IR into a Python function and then use the Python interpreter to run it. Instead, we send the IR to the TensorFlow runtime service to run it.
import tensorflow.compat.v1 as tf # TensorFlow 1.x API
import numpy as np
tf.disable_eager_execution()a = tf.placeholder(tf.float32, shape=())
b = tf.placeholder(tf.float32, shape=())c = a * b
dc = tf.gradients(c, [a], stop_gradients=[a, b])with tf.compat.v1.Session() as sess: # TensorFlow has a runtime to execute the IR,
x = np.single(2) # so, no converting it into Python code.
y = np.single(3)
print(sess.run(dc, feed_dict={a:x, b:y}))
PyTorch 1.x: No IR for Forward
PyTorch does not turn the forward pass into an IR like Theano or TensorFlow does. Instead, it uses the Python interpreter to run the forward pass. During this run, an IR representing the backward pass is built as a side effect. This is known as the “eager mode”.
import torcha = torch.tensor(1.0, requires_grad=True) # These are not placeholders, but values.
b = torch.tensor(2.0)c = a * b # Evaluates c and derives the IR of the backward in c.grad_fn_.
c.backward() # Executes c.grad_fn_.
print(c.grad)
TensorFlow 2.x: Gradient Tape
TensorFlow 2.x adds an eager mode API like PyTorch’s. This API traces how the forward pass was run into an IR called the GradientTape. TensorFlow 2.x can figure out the backward pass from this trace.
import tensorflow as tfa = tf.Variable(1.0) # Like PyTorch, these are values, not placehodlers.
b = tf.Variable(2.0)with tf.GradientTape() as tape:
c = a * b
dcda = tape.gradient(c, a)
print(dcda)
JAX
JAX does not expose low-level details like GradientTape to users. The JAX way of thinking, on the other hand, is that both the input and output functions are just Python functions.
import jax a = 2.0
b = 3.0
jax.grad(jax.lax.mul)(a, b) # Compute c = a * b w.r.t. a. The result is b=3.jax.jit(jax.grad(jax.lax.mul))(a,b)jax.experimental.pjit(jax.grad(jax.lax.mul),
device_mesh(ntpus))(a,b)
For advanced users who want to write their own functional transformations, they can call low-level APIs like make_jaxpr
to get access to the IR, which is known as JAXPR.
jax.make_jaxpr(jax.lax.mul)(2.0, 3.0) # Returns the IR representing jax.lax.mul(2,3)
jax.make_jaxpr(jax.grad(jax.lax.mul))(2.0, 3.0) # Returns the IR of grad(mul)(2,3)
FuncTorch
functorch is a JAX-like function transformation based on PyTorch.
import torch, functorcha = torch.tensor([2.0])
b = torch.tensor([3.0])
functorch.grad(torch.dot)(a, b)
JAX’s make_jaxpr
is analogous to make_fx
from functorch.
def f(a, b):
return torch.dot(a, b) # Have to wrap the builtin function dot into f.
print(functorch.make_fx(f)(a, b).code)
print(functorch.make_fx(functorch.grad(f))(a, b).code)
TensorFlow 2.x, JAX, and functorch all build an IR for the forward pass, but PyTorch eager mode does not. Not only is the IR useful for autograd, but it is also useful for other kinds of functional transformations. In the following example, functorch.compile.aot_function
will invoke the callback print_compile_fn
twice, once for the forward pass and once for the backward pass.
from functorch.compile import aot_function
import torch.fx as fxdef print_compile_fn(fx_module, args):
print(fx_module)
return fx_module
aot_fn = aot_function(torch.dot, print_compile_fn)
aot_fn(a, b)
High-Order Derivatives
PyTorch
import torch
from torch import autogradx = torch.tensor(1., requires_grad = True)
y = 2*x**3 + 8first_derivative = autograd.grad(y, x, create_graph=True)
print(first_derivative)second_derivative = autograd.grad(first_derivative, x)
print(second_derivative)
TensorFlow 2.x
import tensorflow as tfx = tf.Variable(1.0)with tf.GradientTape() as outer_tape:
with tf.GradientTape() as tape:
y = 2*x**3 + 8
dy_dx = tape.gradient(y, x)
print(dy_dx)
d2y_dx2 = outer_tape.gradient(dy_dx, x)
print(d2y_dx2)
JAX
def f(a):
return 2*a**3 + 8print(jax.grad(f)(1.0))
print(jax.grad(jax.grad(f))(1.0))
Dynamic Control Flows
There are two levels of dynamic control flows, a coarse level running on CPU and a fine-grained level on GPU or TPU. This section we cover the CPU one. Let us focus on the conditional (if/else) as an example for us to review various toolkits.
TensorFlow 1.x
Let With TensorFlow 1.x, we need to build the conditional into the IR explicitly. The conditional is a special operator tf.cond
.
def f1(): return tf.multiply(a, 17)
def f2(): return tf.add(b, 23)
r = tf.cond(tf.less(a, b), f1, f2)with tf.compat.v1.Session() as sess: # TensorFlow has a runtime to execute the IR,
print(sess.run(r, feed_dict={a:x, b:y}))
TensorFlow 2.x
TensorFlow 2.x supports explicit building of control flows using tf.cond
and tf.while_loop
. In addition, derived from an experimental project google/tangent, a feature called AutoGraph can convert Python control flows into tf.cond
or tf.while_loop
. This feature makes use of that Python interpreter knows and source code of functions, for example, the function g
in the following example. It calls Python’s standard library to parse the source code into an AST and then the SSA form to understand the control flows.
def g(x, y):
if tf.reduce_any(x < y):
return tf.multiply(x, 17)
return tf.add(y, 23)
converted_g = tf.autograph.to_graph(g)import inspect
print(inspect.getsource(converted_g))
JAX
In my opinion, the very complex Python syntax makes it challenging to understand control flows by parsing the source code. AutoGraph often fails in its work. Indeed, if this approach is less challenging, the Python developer community should have never failed so many times building a Python compiler. Given this challenge, it makes sense for us to coming back to building the control flows into the IR explicitly. For this purpose, JAX provides jax.lax.cond
and jax.lax.for_loop
.
jax.lax.cond(a < b, lambda : a*17, lambda: b+23)
Given the conditional, you might think that we can build recursion. Unfortunately, the following recursion for computing the factorial is not tracable by JAX.
def factorial(r, x):
return jax.lax.cond(x <= 1.0, lambda: r, lambda: factorial(r*x, x-1))
factorial(1.0, 3.0)
You might expect that the above call to factorial
computes 3!=6
. However, it will throw an exception complaining the recursion depth exceeds a maximum value. This is because recursion relies on not only the conditional, but also function definition and invocation.
PyTorch
PyTorch used to be Python-native. As we explained earlier in this document, the functional transformation of grad
and vamp
are all on-the-fly -- thanks to the versatile Dispatcher mechanism. Here it is important to note that
- functional transformation on-the-fly is more efficient than transform-after-building-an-IR, the way that Theano and TensorFlow use.
- JAX, when does
grad
andvmap
, also transform on-the-fly. However, more complex transformations likepmap
andpjit
requires an overview of the whole computation process and the IR is inevitable.
Due to the necessity of the IR for pmap
and pjit
, PyTorch community recently added torch.cond
https://github.com/pytorch/pytorch/pull/83154