Machine Learning

Transfer learning

Naranjito 2024. 1. 9. 15:27
  • Transfer learning

 

A technique in machine learning (ML) in which knowledge learned from a task is re-used in order to boost performance on a related task. In other words, the re-use of a pre-trained model on a new problem. 

For example, for image classification, knowledge gained while learning to recognize cars could be applied when trying to recognize trucks.

 

The general idea is to use the knowledge a model has learned from a task with a lot of available labeled training data in a new task that doesn't have much data. Instead of starting the learning process from scratch.


 

  • low-level feature filter : extract the color change or the direction of the boundary, detect edges.
  • high-level feature filter : detect specific features-such as patterns in which circles are repeated or birds' beaks.

 

https://dacon.io/en/forum/405988

https://builtin.com/data-science/transfer-learning

 


  • How to use Transfer learning

 

  • Case 1: Small Data Set, Similar Data

 

- slice off the end of the neural network

- add a new fully connected layer that matches the number of classes in the new data set

- randomize the weights of the new fully connected layer; freeze all the weights from the pre-trained network

- train the network to update the weights of the new fully connected layer

 

To avoid overfitting on the small data set, the weights of the original network will be held constant rather than re-training the weights.

Since the data sets are similar, images from each data set will have similar higher level features. Therefore most or all of the pre-trained neural network layers already contain relevant information about the new data set and should be kept.


  • Case 2: Small Data Set, Different Data

 

- slice off most of the pre-trained layers near the beginning of the network

- add to the remaining pre-trained layers a new fully connected layer that matches the number of classes in the new data set

- randomize the weights of the new fully connected layer; freeze all the weights from the pre-trained network

- train the network to update the weights of the new fully connected layer

 

Because the data set is small, overfitting is still a concern. To combat overfitting, the weights of the original neural network will be held constant, like in the first case.

But the original training set and the new data set do not share higher level features. In this case, the new network will only use the layers containing lower level features.


  • Case 3: Large Data Set, Similar Data

 

- remove the last fully connected layer and replace with a layer matching the number of classes in the new data set

- randomly initialize the weights in the new fully connected layer

- initialize the rest of the weights using the pre-trained weights

- re-train the entire neural network

 

Overfitting is not as much of a concern when training on a large data set; therefore, you can re-train all of the weights.

Because the original training set and the new data set share higher level features, the entire neural network is used as well.


  • Case 4: Large Data Set, Different Data

 

 

- remove the last fully connected layer and replace with a layer matching the number of classes in the new data set

- re-train the network from scratch with randomly initialized weights

- alternatively, you could just use the same strategy as the "large and similar" data case

 

Even though the data set is different from the training data, initializing the weights from the pre-trained network might make training faster. So this case is exactly the same as the case with a large, similar data set.

If using the pre-trained network as a starting point does not produce a successful model, another option is to randomly initialize the convolutional neural network weights and train the network from scratch.