Development / Machine Learning / Contributed

How Transfer Learning Can Make Machine Learning More Efficient

4 Dec 2020 12:30pm, by

Ever wonder how to scale the process of training machine learning models, without having to use a new dataset each time? Transfer learning is a machine learning technique used to solve a task quickly by leveraging knowledge gained from solving a related task. Pre-trained models can be re-purposed in a variety of ways, depending on the relatedness of the task, so only a small number of labeled examples from the new task are needed. Transfer learning can be a powerful tool for data scientists and engineers alike, enabling those without the means to train a model from scratch to benefit from the powerful features learned by deep models.

What Is Transfer Learning?

Mark Kurtz
Mark Kurtz is the Machine Learning Lead at Neural Magic. He's an experienced software and machine learning leader with a demonstrated success in making machine learning models successful and performant. Mark manages teams and efforts that ensure organizations realize high returns from their machine learning investments. He is currently building a software AI engine at Neural Magic, with a goal to bring GPU-class performance for deep learning to commodity CPUs.

Supervised learning is the problem of learning a function that maps inputs (observations) to outputs (labels) based on example pairs. Transfer learning is a variant of supervised learning that we can use when faced with a task with a limited number of these labeled examples. Or, if data scarcity is not an issue, we can leverage transfer learning when we would like to avoid expending the large number of resources required to train a data-hungry model.

This lack of training data could arise if labeled examples are difficult or expensive to collect or annotate, but, at the same time, the task may still require a large (and therefore data-hungry) machine learning model to solve it. For these reasons, there is often not enough data to train a model to an acceptable level of accuracy (or another performance criterion) from scratch.

To overcome data scarcity or to avoid training a model from scratch, we can leverage knowledge gained from training a model on a related task (the source task), for which there are many labeled examples, to solve the original task at hand (the target task). This is the main conceit of transfer learning, and it is often successful when the source and target tasks require similar information to solve.

Diagram of the general transfer learning approach (image source). We are interested in leveraging knowledge contained in a model trained on one task to inform a model used to solve another.

For example, machine learning models trained on images learn similar features (edges, corners, gradients, simple shapes, etc.) from different image datasets, suggesting that these features can be reused to solve other image recognition tasks.

Example ImageNet features learned by AlexNet (image source).

Transfer learning can be further broken down depending on the similarities and differences between the source and target tasks:

  • Source and target tasks may or may not share a common input (feature) space (example: the source task’s input is 32×32 RGB images, while the target task’s input is 128×128 grayscale images)
  • The distribution of source and target task features may or may not differ (example: the source task is to classify cartoon drawings of cats and dogs, while the target task is to classify real images of cats and dogs)
  • Source and target tasks may or may not share a common output (label) space (example: the source task involves classifying images of cats and dogs, while the target task is to draw bounding boxes around (detect) any cats or dogs in an image)
  • The conditional distribution of source and target task labels given features may or may not differ (example: both source and target tasks involve classification with the same label space, but in the target task, some labels are much rarer than others)

How Is Transfer Learning Implemented?

A basic approach to transfer learning with neural networks is laid out below, assuming that the source and target tasks share common feature and label spaces:

  1. Train a neural network to a high level of accuracy on a source task that is sufficiently related to the target task at hand; this is known as the pre-training phase. Another option is to use a pre-trained model from a model repo.
  2. Retrain the later layers of the neural network using the labeled examples available for the target task; this is known as the fine-tuning phase. One option is to fix, or “freeze,” the early layers of the neural network before retraining (for example, fix all layers in the network except for the last layer). This step is optional, and depending on the setup can help or hurt the process.

Using this approach, the neural network at the end of step #1 contains a great deal of information related to solving the source task. Step #2 “saves” the representation learned by the network in its early layers, and during step #3, it is used as a starting point for learning to solve the target task. In this way, we require only a small number of examples to fine-tune the parameters of the later layers of the network, rather than the large number of examples needed to pre-train the entire model. This basic approach can be easily adapted to suit the different kinds of transfer learning outlined in the previous section.

If the source and target tasks have similar feature spaces, we can expect that the retrained network from step 3 will be able to leverage the representation learned from the source task to solve the target tasks.

A compelling example of this approach involves using learned features of a convolutional neural network (CNN) trained to classify the ImageNet or Open Images image datasets. Here, the convolutional features learned by the CNN are considered as a general image representation, and re-purposed to solve image classification, scene recognition, fine-grained recognition, attribute detection, and image retrieval tasks on a diverse collection of datasets — often matching or outperforming the state-of-the-art, from-scratch approaches to solving them. Many of these datasets had far fewer labeled examples than were available for ILSVRC13 training, empirically demonstrating the statistical efficiency that transfer learning offers.

How Will Transfer Learning Democratize Machine Learning Applications?

As we’ve discussed, transfer learning is useful in cases where (1) we don’t have the means to curate a large enough dataset to train a model from scratch, or (2) we wish to avoid expending the computational resources or time required to train a model from scratch. In the right situations, transfer learning can open doors for more engineers to experiment with novel deep learning applications. For example, one application (see case study #1 in this article) fine-tunes a pre-trained ImageNet model to achieve high performance with few training examples in a cats vs. dogs image classification task.

Even without deep knowledge of dataset curation, model building, and optimization, existing models can be fine-tuned on a small number of labeled examples in the target task domain until a satisfactory level of performance is reached. Transfer learning enables individuals and small organizations to benefit from the powerful representational capacity of deep models, without the time or budgetary capacity that it typically requires.

Feature image via Pixabay.

A newsletter digest of the week’s most important stories & analyses.