In codes there are a 'prune_heads' method for the 'BertAttention' class, which refers to the 'prune_linear_layer' operation. Not understanding the meaning of such operation. The codes of 'prune_linear_layer' is listed below. Thanks for any help!
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
Yes, I'll add a detailed example for this method in the coming weeks (update of the bertology script).
This can be used to remove heads in the model following the work of Michel et al. (Are Sixteen Heads Really Better than One?) among others.
Thanks a lot!
Hi @thomwolf, would it be possible to provide an example on how to prune or select some heads for a layer? when i just change the config file by setting
config.pruned_heads = {11:[1,2,3]} and use it in initializing the model, it throws an error.
size mismatch for bert.encoder.layer.11.attention.self.query.weight: copying a param with shape torch.Size([768
urrent model is torch.Size([576, 768]). and more.
so, the default query,key and vaule are set with 768 dim.
I assume we can not just prune heads and still load the pre-trained model because the word embedding and layer norm was setup up with 768 dim.
meanwhile i came across bertology.py script and realize that we can save a model after pruning. that works fine for me. now, i'm trying to load the saved model, and I get the opposite error.
size mismatch for bert.encoder.layer.11.attention.self.query.weight: copying a param with shape torch.Size([576, 768]) from checkpoint, the sh
ape in current model is torch.Size([768, 768]).
the error wouldn't go away after even changing the config file.
Most helpful comment
Yes, I'll add a detailed example for this method in the coming weeks (update of the bertology script).
This can be used to remove heads in the model following the work of Michel et al. (Are Sixteen Heads Really Better than One?) among others.