Getting Started with Training Large Models in TensorFlow

Reading time ~7 minutes

Getting Started with Training Large Models in TensorFlow

One of the most common questions from people new to deep learning is “how do I get start training large models?”. Given the growing number of frameworks and tutorials available, this question is not always easy to answer. One consequence of this is a lack of collaboration, where many researchers waste time implementing their own solution to the same problem. Ideally, research groups who can share common framework examples optimized on their local hardware and software systems offer a highly effecient workflow for new research. This blog post is an attempt in moving the MLRG in this direction, through demonstrating how to train a large model in TensorFlow on the SHARCNET copper cluster.

Here we train the Trained Ternary Quantization model by Zhou et al. as an example. This model is based on the AlexNet convolutional neural network CNN and is trained from scratch on the ImageNet ILSVRC12 dataset. The open source TensorFlow implementation is written in Tensorpack, a neural network toolbox which has convenient functions for training large TensorFlow models. Tensorpack is recommended as a toolbox with common deep learning functions important for training large models, such as data prefetching and multi-GPU training.

Where Can I Find the Example Code?

I am hosting the ternarynet example code on my GitHub. Here you can find instructions for getting started training this network from scratch on copper. The basic steps are:

More Tensorpack examples are also available, such as ResNet, Generative Adversarial Networks GAN, and Long Short Term models LSTM networks.

How Do I Build My Model?

The ternarynet model is constructed using the model description class, with methods for creating input placeholders and creating the symbolic graph, as seen here. Similar to TF-Slim, Tensorpack uses more compact network definitions compared to built-in Tensorflow libraries, and improves readability.

class Model(ModelDesc):
    """ Model description class """

    def _get_input_vars(self):
        """ Create or return input placeholders """
        return [InputVar(tf.float32, [None, 224, 224, 3], 'input'),
                InputVar(tf.int32, [None], 'label') ]

    def _build_graph(self, input_vars):
        """ Build TensorFlow symbolic graph """

How Do I Build My Data Pipeline?

One of the most tedious and time-consuming aspects of deep learning is building an efficient data pipeline. While TensorFlow provides some useful methods for reading data, it still takes a considerable effort to optimize a data pipeline for large datasets. Tensorpack provides a standard interface for typical data pipelines which consist of reading from disk, applying augmentations, grouping into batches, and prefetching. In the ternarynet data pipeline, the dataset is implemented as a python generator method with a number of different augmentors such as resizing, cropping, brightening, etc…

# define a Dataflow which produces image-label pairs from the ILSVRC12 database
ds = dataset.ILSVRC12('/path/to/ilsvrc12', shuffle=isTrain)
# apply augmentors
ds = AugmentImageComponent(ds, augmentors)
# group data into batches of size 128
ds = BatchData(ds, 128)
# start minimum of 12 threads to prefetch data in parallel (transfer data with ZeroMQ)
ds = PrefetchDataZMQ(ds, min(12, multiprocessing.cpu_count()))

How Do I Train My Model?

Training with Tensorpack is done with trainers which abstract away much of the redundant code involved with saving the model, computing validation statistics, etc… Most importantly, Tensorpack implements a multi-GPU trainer which performs synchronous/asynchronous parallelization in a single node. A condensed version of the ternarynet trainer configuration is seen below.

config = TrainConfig(
                [ClassificationError('wrong-top1', 'val-error-top1'))


What Else Should I Be Doing?

There are several other convenient methods available in Tensorpack that are beneficial to training large models, especially with job schedulers such as those used with copper. Organizing logging, saving models, and restoring training from checkpoints can save valuable time in the case of unforeseen job cancellation:

config.session_init = SaverRestore(checkpoint_file)

Running inference with a pre-trained model is often important:

run_image(Model(), ParamRestore(np.load(, /path/to/data)

Tensorpack even supports sending training statistics to your mobile phone:

SendStat('curl -u your_id: \\
                        -d type=note -d title="validation error" \\
                        -d body={validation_error} > /dev/null 2>&1',

While it is difficult to predict the tools and workflows on which a project ends, hopefully we can help with where to start.