Transformers: Confused about the prune heads operation.

Created on 21 Jul 2019  ·  4Comments  ·  Source: huggingface/transformers

In codes there are a 'prune_heads' method for the 'BertAttention' class, which refers to the 'prune_linear_layer' operation. Not understanding the meaning of such operation. The codes of 'prune_linear_layer' is listed below. Thanks for any help!

def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer

Most helpful comment

Yes, I'll add a detailed example for this method in the coming weeks (update of the bertology script).

This can be used to remove heads in the model following the work of Michel et al. (Are Sixteen Heads Really Better than One?) among others.

All 4 comments

Yes, I'll add a detailed example for this method in the coming weeks (update of the bertology script).

This can be used to remove heads in the model following the work of Michel et al. (Are Sixteen Heads Really Better than One?) among others.

Thanks a lot!

Hi @thomwolf, would it be possible to provide an example on how to prune or select some heads for a layer? when i just change the config file by setting
config.pruned_heads = {11:[1,2,3]} and use it in initializing the model, it throws an error.

size mismatch for bert.encoder.layer.11.attention.self.query.weight: copying a param with shape torch.Size([768
urrent model is torch.Size([576, 768]). and more. 

so, the default query,key and vaule are set with 768 dim.
I assume we can not just prune heads and still load the pre-trained model because the word embedding and layer norm was setup up with 768 dim.

meanwhile i came across bertology.py script and realize that we can save a model after pruning. that works fine for me. now, i'm trying to load the saved model, and I get the opposite error.

size mismatch for bert.encoder.layer.11.attention.self.query.weight: copying a param with shape torch.Size([576, 768]) from checkpoint, the sh
ape in current model is torch.Size([768, 768]).

the error wouldn't go away after even changing the config file.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

rsanjaykamath picture rsanjaykamath  ·  3Comments

siddsach picture siddsach  ·  3Comments

adigoryl picture adigoryl  ·  3Comments

0x01h picture 0x01h  ·  3Comments

fyubang picture fyubang  ·  3Comments