Custom Gradients in TensorFlow

Reading time ~6 minutes

Custom Gradients in TensorFlow

TensorFlow defines deep learning models as computational graphs, where nodes are called ops, short for operations, and the data that flows between these ops are called tensors. Given a graph of ops, TensorFlow uses automatic differentiation to compute gradients. The theory behind automatic differentiation is that all numeric computations are composed of a finite set of elementary operations for which the gradient is well defined. In TensorFlow, each op must then have a well defined gradient for automatic differentiation to work properly.

When adding new ops in TensorFlow, you must use tf.RegisterGradient to register a gradient function which computes gradients with respect to the ops’ input tensors given gradients with respect to the ops’ output tensors. For example, let’s say we have an operation Square which computes the square of the input. Its forward activity and backward activity are defined as follows:

\begin{equation} \textbf{Forward: } y = x^2 \end{equation}

\begin{equation} \textbf{Backward: } y = 2x \end{equation}

The gradient of this function is registered like this:

@tf.RegisterGradient("Square")
def _square_grad(op, grad):
    return grad * 2

However, sometimes we aren’t interested in creating new ops, but in changing the gradient behaviour of an existing op. For example, let’s say we wanted to quantize a 32-bit floating point weight tensor in our graph to 1-bit each time it was computed, using the XNOR-Net quantization scheme:

\begin{equation} \textbf{Forward: } T_O = \text{sign}(T_I) * \textbf{E}(|T_I|) \end{equation}

\begin{equation} \textbf{Backward: } \frac{\delta C}{\delta T_I} = \frac{\delta C}{\delta T_O} \end{equation}

where is the input tensor, is the output tensor, and the scaling factor is the average of absolute weight values. In the forward pass this quantization mechanic uses the tf.sign function and a scaling factor to perform the rounding. However in the backward pass we want to use a straight through estimator to circumvent the issue of the gradient of the sign function being zero almost everywhere. To achieve this, we can use the identity function to simply pass through the gradient untouched (as seen in the DoReFa-Net code).

G = tf.get_default_graph()
def quantize(x): 
    with G.gradient_override_map({"Sign": "Identity"}):
        E = tf.reduce_mean(tf.abs(x))
        return tf.sign(x) * E

Here, we use gradient_override_map context manager to map the identity function to the gradient of the sign function within the scope of the context. However, this code snippet has two problems. First, in the backward pass the output will still be scaled by E. To deal with this, we can use a clever trick and normalize the input by E to cancel this effect while not changing the sign. Second, the tf.reduce_mean will contribute gradient activity which we do not want. To prevent this, we can use tf.stop_gradient which acts like the identity functon in the forward direction, but stops the accumulated gradient from flowing through that operator in the backward direction.

G = tf.get_default_graph()
def quantize(x): 
    with G.gradient_override_map({"Sign": "Identity"}):
        E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
        return tf.sign(x / E) * E

This code can be used to implement the XNOR-Net quantization scheme for 1-bit weights. However, what if we wish to use a custom gradient for this quantization instead of the identity function? Once again we can use the tf.RegisterGradient decorator:

@tf.RegisterGradient("QuantizeGrad")
def quantize_grad(op, x):
    # compute custom gradient
    # ...

G = tf.get_default_graph()
def quantize(x)
    with G.gradient_override_map({"Sign": "QuantizeGrad"}):
        E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
        return tf.sign(x / E) * E

One last note, tf.identity is incredibly useful. In our examples here, we have seen it used to implement straight through estimators, but it is also useful for managing the flow of tensors between devices, and explicity adding nodes to the graph. For example, it is used in the CIFAR-10 example code to make sure total_loss is computed by adding a node to the graph. TensorFlow, as with deep learning in general, contains importance in even the most subtle areas.