Transformers: Fine-tune specific layers

Created on 6 Oct 2019  路  3Comments  路  Source: huggingface/transformers

Is there any easy way to fine-tune specific layers of the model instead of fine-tuning the complete model?

Most helpful comment

Thanks. Your code works fine. I did the following:

```
if freeze_embeddings:
for param in list(model.bert.embeddings.parameters()):
param.requires_grad = False
print ("Froze Embedding Layer")

freeze_layers is a string "1,2,3" representing layer number

if freeze_layers is not "":
     layer_indexes = [int(x) for x in freeze_layers.split(",")]
     for layer_idx in layer_indexes:
          for param in list(model.bert.encoder.layer[layer_idx].parameters()):
              param.requires_grad = False
          print ("Froze Layer: ", layer_idx)

```

All 3 comments

In Pytorch or Tensorflow? If Pytorch, this issue might be of help.

In my scripts, I use the following code. Passing down a parameter 'freeze' (list) to the config that I use. All layers that start with any of the given strings will be frozen.

# Freeze parts of pretrained model
# config['freeze'] can be "all" to freeze all layers,
# or any number of prefixes, e.g. ['embeddings', 'encoder']
if 'freeze' in config and config['freeze']:
    for name, param in self.base_model.named_parameters():
        if config['freeze'] == 'all' or 'all' in config['freeze'] or name.startswith(tuple(config['freeze'])):
            param.requires_grad = False
            logging.info(f"Froze layer {name}...")

Thanks. Your code works fine. I did the following:

```
if freeze_embeddings:
for param in list(model.bert.embeddings.parameters()):
param.requires_grad = False
print ("Froze Embedding Layer")

freeze_layers is a string "1,2,3" representing layer number

if freeze_layers is not "":
     layer_indexes = [int(x) for x in freeze_layers.split(",")]
     for layer_idx in layer_indexes:
          for param in list(model.bert.encoder.layer[layer_idx].parameters()):
              param.requires_grad = False
          print ("Froze Layer: ", layer_idx)

```

Was this page helpful?
0 / 5 - 0 ratings