I've managed to get model parallelism working on gpt2 for forward inference by modifying the GPT2Model class and adding a few lines to the generate method to ensure that tensors that need to be on the same device always are. It automatically distributes the blocks evenly across any number of GPUs that are detected. I had to add an additional argument to Trainer (model_parallel) to avoid conflicting distribute behavior. Unfortunately, I'm stuck on backprop, specifically in Trainer.training_step on the line loss.backward().
loss is tensor(71.5152, device='cuda:3', grad_fn=<NllLossBackward>)
The error is:
RuntimeError: expected device cuda:3 but got device cuda:0 (compute_types at ..\aten\src\ATen\native\TensorIterator.cpp:246)
(no backtrace available)
So something somewhere is on the wrong device. It would be a miracle if someone knows how to fix this, but more realistically I'm hoping for a list of things that might be wrong which I can check. Can do a code review with someone from the transformers team. This could be the pattern to enable model parallelism on all PyTorch transformers.
Hey @alexorona, it's great that you are working on model parallelism! Could you open a PR with the proposed changes to GPT2 and maybe post a code snippet to reproduce your error with the code in your PR?
I'm happy to take a look :-)
@patrickvonplaten An update: I managed to get around the problem by carefully following every tensor in the GPT2 model and had to place the lm_head on the first layer because the wte layer is used by it. Model parallelism is now working and c confirmed with nvidia-smi: tensors are moving appropriately and well-balanced across the GPUs and models are training. It's not useful at all to create a PR right now: I'm using a version of transformers that's probably a month old and the code is barely holding together.
I'd like to get the latest (and hopefully last) functional challenge solved before putting together a PR. This latest problem is extremely challenging. Only someone with a very deep knowledge of the transformers implement of the Attention class, Trainer and possibly modeling_utils.py can provide an intuition as to what's happening.
Here's the problem: The same model on the same GPU with the same token size consumes more memory while training if there are more GPUs. For example, the first attention block will consume 2.2 GB of GPU memory on a Tesla v100 if there are 4 Tesla v100s on the instance. Meanwhile, the same block will consume 4.2 GB of GPU memory on a Tesla v100 if there are 8 Tesla v100s on the instance. It makes no sense. I believe the behavior is coming from Attention._attn. Does anyone know whether there's something in the implementation that would cause tensors to use up more GPU memory if more GPUs are added? Note: I've disabled all of the data parallelism in Trainer, which would be the obvious source.
Some additional details:
# Running gpt-xl on 4 GPUs. Model uses 2.2 GB of memory per attention block.
Block: 0
Total GPU Memory Usage: 1.40915456
Block: 1
Total GPU Memory Usage:3.604413952
Block: 2
Total GPU Memory Usage:5.803867648
Block: 3
Total GPU Memory Usage:8.003321344
Block: 4
Total GPU Memory Usage: 10.20277504
Block: 5
Total GPU Memory Usage: 12.402228736
Block: 6
Total GPU Memory Usage: 14.601682432
# Running gpt-xl on 8 GPUs. Model uses 4.2 GB of memory per attention block.
Block: 0
Total GPU Memory Usage: 1.468251648
Block: 1
Total GPU Memory Usage: 5.847236096
Block: 2
Total GPU Memory Usage: 10.226220544
Block: 3
Total GPU Memory Usage: 14.605204992
class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config, layers_map):
super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
# Layers map for 4 GPUs
self.layers_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
1: [11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23],
2: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36],
3: [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
self.wte = self.wte.to('cuda:' + str(min(self.layers_map.keys())))
self.wpe = self.wpe.to('cuda:' + str(min(self.layers_map.keys())))
self.drop = self.drop.cuda('cuda:' + str(min(self.layers_map.keys())))
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
# Skipping over some details in the forward method
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
print('Block:', i)
gpu_memory = torch.cuda.memory_allocated(device = hidden_states.device)/(1e+9)
print("GPU Memory:", gpu_memory)
if output_hidden_states:
print('output hidden shapes us true')
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
if layer_past is not None:
layer_past = layer_past.cuda(hidden_states.device)
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
del outputs
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present = outputs[:2]
if use_cache is True:
presents = presents + (present,)
if output_attentions:
all_attentions = all_attentions + (outputs[2],)
for k,v in self.layers_map.items():
if i == v[-1] and k != max(self.layers_map.keys()):
hidden_states = hidden_states.to('cuda:' + str(k + 1))
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=True):
super().__init__()
hidden_size = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(hidden_size, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config)
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
attn_outputs = self.attn(
self.ln_1(hidden_states),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + hidden_states
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attn_outputs = self.crossattention(
self.ln_cross_attn(hidden_states),
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states
outputs = [hidden_states] + outputs
return outputs # hidden_states, present, (cross_attentions, attentions)
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer(
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.is_cross_attention = is_cross_attention
if self.is_cross_attention:
self.c_attn = Conv1D(2 * n_state, nx)
self.q_attn = Conv1D(n_state, nx)
else:
self.c_attn = Conv1D(3 * n_state, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
self.softmax = nn.Softmax(dim=-1)
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
w = torch.matmul(q, k)
if self.scale:
w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
mask = self.bias[:, :, ns - nd : ns, :ns]
mask = mask.to(w.device)
self.masked_bias = self.masked_bias.to(w.device)
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if attention_mask is not None:
# Apply the attention mask
w = w + attention_mask
w = self.softmax(w)
w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if output_attentions:
outputs.append(w)
del mask, nd, ns, v, q, k, attention_mask, head_mask, output_attentions, w
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
return outputs
# Layers map for 4 x Tesla v100 GPUs
layers_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
1: [11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23],
2: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36],
3: [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
# Layers map for 8 x Tesla v100 GPUs
layers_map = {0: [0, 1, 2, 3, 4],
1: [5, 6, 7, 8, 9, 10],
2: [11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 21, 22, 23],
4: [24, 25, 26, 27, 28, 29],
5: [30, 31, 32, 33, 34, 35],
6: [36, 37, 38, 39, 40, 41],
7: [, 42, 43, 44, 45, 46, 47]}
model = TransformersModel(layers_map = layers_map)
Got it working. The TrainingArguments object has data parallelism baked into it (along with a lot of other things), so my manual override of the batch size was failing. The tensor size was exploding because TrainingArguments was automatically adjusting the minimum batch size to be the number of tensors. Fine-tuned a gpt2-xl model with 1024 tokens with good results in just 15 minutes.
@alexorona Can you please share code of some example(s) of parallelisms you get to work (maybe through PR to repo examples)?
@patrickvonplaten @LSinev Greatly simplified the working code and refined so that the same basic approach can be used for other models as well. I took at look at T5 and 99% confident I can use the same approach to make it parallelizable. Will get a PR up this week, probably by Sunday.
Model parallel PR merged to transformers.
Most helpful comment
Got it working. The
TrainingArgumentsobject has data parallelism baked into it (along with a lot of other things), so my manual override of the batch size was failing. The tensor size was exploding becauseTrainingArgumentswas automatically adjusting the minimum batch size to be the number of tensors. Fine-tuned a gpt2-xl model with 1024 tokens with good results in just 15 minutes.