Custom Gradients in TensorFlow
TensorFlow defines deep learning models as computational graphs, where nodes are called ops, short for operations, and the data that flows between these ops are called tensors. Given a graph of ops, TensorFlow uses automatic differentiation to compute gradients. The theory behind automatic differentiation is that all numeric computations are composed of a finite set of elementary operations for which the gradient is well defined. In TensorFlow, each op must then have a well defined gradient for automatic differentiation to work properly.
When adding new ops in
TensorFlow, you must use tf.RegisterGradient
to register a gradient function which computes gradients with respect to the
ops’ input tensors given gradients with respect to the ops’ output tensors. For
example, let’s say we have an operation Square
which computes the square of
the input. Its forward activity and backward activity are defined as follows:
\begin{equation} \textbf{Forward: } y = x^2 \end{equation}
\begin{equation} \textbf{Backward: } y = 2x \end{equation}
The gradient of this function is registered like this:
@tf.RegisterGradient("Square")
def _square_grad(op, grad):
return grad * 2
However, sometimes we aren’t interested in creating new ops, but in changing the gradient behaviour of an existing op. For example, let’s say we wanted to quantize a 32-bit floating point weight tensor in our graph to 1-bit each time it was computed, using the XNOR-Net quantization scheme:
\begin{equation} \textbf{Forward: } T_O = \text{sign}(T_I) * \textbf{E}(|T_I|) \end{equation}
\begin{equation} \textbf{Backward: } \frac{\delta C}{\delta T_I} = \frac{\delta C}{\delta T_O} \end{equation}
where is the input tensor, is the output tensor, and the
scaling factor is the average of absolute weight values. In the forward pass
this quantization mechanic uses the tf.sign
function and a scaling factor to
perform the rounding. However in the backward pass we want to use a
straight through estimator to circumvent
the issue of the gradient of the sign function being zero almost everywhere. To
achieve this, we can use the identity function to simply pass through the
gradient untouched (as seen in the DoReFa-Net code).
G = tf.get_default_graph()
def quantize(x):
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.reduce_mean(tf.abs(x))
return tf.sign(x) * E
Here, we use gradient_override_map
context manager to map the identity
function to the gradient of the sign
function within the scope of the context. However,
this code snippet has two problems. First, in the backward pass the
output will still be scaled by E
. To deal with this, we can use a clever
trick and normalize the input by E
to cancel this effect while not changing
the sign. Second, the tf.reduce_mean
will contribute gradient activity which
we do not want. To prevent this, we can use tf.stop_gradient
which acts like
the identity functon in the forward direction, but stops the accumulated
gradient from flowing through that operator in the backward direction.
G = tf.get_default_graph()
def quantize(x):
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
This code can be used to implement the XNOR-Net quantization scheme for 1-bit
weights. However, what if we wish to use a custom gradient for this
quantization instead of the identity function? Once again we can use the
tf.RegisterGradient
decorator:
@tf.RegisterGradient("QuantizeGrad")
def quantize_grad(op, x):
# compute custom gradient
# ...
G = tf.get_default_graph()
def quantize(x)
with G.gradient_override_map({"Sign": "QuantizeGrad"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
One last note, tf.identity
is incredibly useful. In our examples here,
we have seen it used to implement straight through estimators, but it is also
useful
for managing the flow of tensors between devices, and explicity adding nodes to
the graph. For example, it is used in the CIFAR-10
example code to make sure total_loss
is computed by adding a node to the graph.
TensorFlow, as with deep learning in general, contains importance in even the
most subtle areas.