Transformers: Almost Have Model Parallelism Working on GPT2 Fine-Tuning

Created on 1 Oct 2020  ยท  6Comments  ยท  Source: huggingface/transformers

โ“ Questions & Help

I've managed to get model parallelism working on gpt2 for forward inference by modifying the GPT2Model class and adding a few lines to the generate method to ensure that tensors that need to be on the same device always are. It automatically distributes the blocks evenly across any number of GPUs that are detected. I had to add an additional argument to Trainer (model_parallel) to avoid conflicting distribute behavior. Unfortunately, I'm stuck on backprop, specifically in Trainer.training_step on the line loss.backward().

loss is tensor(71.5152, device='cuda:3', grad_fn=<NllLossBackward>)

The error is:

RuntimeError: expected device cuda:3 but got device cuda:0 (compute_types at ..\aten\src\ATen\native\TensorIterator.cpp:246)
(no backtrace available)

So something somewhere is on the wrong device. It would be a miracle if someone knows how to fix this, but more realistically I'm hoping for a list of things that might be wrong which I can check. Can do a code review with someone from the transformers team. This could be the pattern to enable model parallelism on all PyTorch transformers.

Most helpful comment

Got it working. The TrainingArguments object has data parallelism baked into it (along with a lot of other things), so my manual override of the batch size was failing. The tensor size was exploding because TrainingArguments was automatically adjusting the minimum batch size to be the number of tensors. Fine-tuned a gpt2-xl model with 1024 tokens with good results in just 15 minutes.

All 6 comments

Hey @alexorona, it's great that you are working on model parallelism! Could you open a PR with the proposed changes to GPT2 and maybe post a code snippet to reproduce your error with the code in your PR?

I'm happy to take a look :-)

@patrickvonplaten An update: I managed to get around the problem by carefully following every tensor in the GPT2 model and had to place the lm_head on the first layer because the wte layer is used by it. Model parallelism is now working and c confirmed with nvidia-smi: tensors are moving appropriately and well-balanced across the GPUs and models are training. It's not useful at all to create a PR right now: I'm using a version of transformers that's probably a month old and the code is barely holding together.

I'd like to get the latest (and hopefully last) functional challenge solved before putting together a PR. This latest problem is extremely challenging. Only someone with a very deep knowledge of the transformers implement of the Attention class, Trainer and possibly modeling_utils.py can provide an intuition as to what's happening.

Here's the problem: The same model on the same GPU with the same token size consumes more memory while training if there are more GPUs. For example, the first attention block will consume 2.2 GB of GPU memory on a Tesla v100 if there are 4 Tesla v100s on the instance. Meanwhile, the same block will consume 4.2 GB of GPU memory on a Tesla v100 if there are 8 Tesla v100s on the instance. It makes no sense. I believe the behavior is coming from Attention._attn. Does anyone know whether there's something in the implementation that would cause tensors to use up more GPU memory if more GPUs are added? Note: I've disabled all of the data parallelism in Trainer, which would be the obvious source.

Some additional details:

# Running gpt-xl on 4 GPUs. Model uses 2.2 GB of memory per attention block.
Block: 0
Total GPU Memory Usage: 1.40915456
Block: 1
Total GPU Memory Usage:3.604413952
Block: 2
Total GPU Memory Usage:5.803867648
Block: 3
Total GPU Memory Usage:8.003321344
Block: 4
Total GPU Memory Usage: 10.20277504
Block: 5
Total GPU Memory Usage: 12.402228736
Block: 6
Total GPU Memory Usage: 14.601682432
# Running gpt-xl on 8 GPUs. Model uses 4.2 GB of memory per attention block.
Block: 0
Total GPU Memory Usage: 1.468251648
Block: 1
Total GPU Memory Usage: 5.847236096
Block: 2
Total GPU Memory Usage: 10.226220544
Block: 3
Total GPU Memory Usage: 14.605204992
class GPT2Model(GPT2PreTrainedModel):
    def __init__(self, config, layers_map):
        super().__init__(config)

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])

        # Layers map for 4 GPUs
        self.layers_map =  {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                          1: [11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23],
                          2: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36],
                          3: [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}

        self.wte = self.wte.to('cuda:' + str(min(self.layers_map.keys())))
        self.wpe = self.wpe.to('cuda:' + str(min(self.layers_map.keys())))
        self.drop = self.drop.cuda('cuda:' + str(min(self.layers_map.keys())))

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        # Skipping over some details in the forward method

    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
                print('Block:', i)
                gpu_memory = torch.cuda.memory_allocated(device = hidden_states.device)/(1e+9)
                print("GPU Memory:", gpu_memory)
                if output_hidden_states:
                    print('output hidden shapes us true')
                    all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)

                if layer_past is not None:
                    layer_past = layer_past.cuda(hidden_states.device)

                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                del outputs
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

                hidden_states, present = outputs[:2]

                if use_cache is True:
                    presents = presents + (present,)

                if output_attentions:
                    all_attentions = all_attentions + (outputs[2],)

                for k,v in self.layers_map.items():
                    if i == v[-1] and k != max(self.layers_map.keys()):
                        hidden_states = hidden_states.to('cuda:' + str(k + 1))
class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=True):
        super().__init__()
        hidden_size = config.n_embd
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = Attention(hidden_size, n_ctx, config, scale)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        if config.add_cross_attention:
            self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = MLP(inner_dim, config)

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
    ):
        attn_outputs = self.attn(
            self.ln_1(hidden_states),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + hidden_states

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            assert hasattr(
                self, "crossattention"
            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
            cross_attn_outputs = self.crossattention(
                self.ln_cross_attn(hidden_states),
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = hidden_states + attn_output
            outputs = outputs + cross_attn_outputs[1:]  # add cross attentions if we output attention weights

        feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
        # residual connection
        hidden_states = hidden_states + feed_forward_hidden_states

        outputs = [hidden_states] + outputs
        return outputs  # hidden_states, present, (cross_attentions, attentions)

class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
        super().__init__()

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
        self.is_cross_attention = is_cross_attention
        if self.is_cross_attention:
            self.c_attn = Conv1D(2 * n_state, nx)
            self.q_attn = Conv1D(n_state, nx)
        else:
            self.c_attn = Conv1D(3 * n_state, nx)
        self.c_proj = Conv1D(n_state, nx)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.pruned_heads = set()
        self.softmax = nn.Softmax(dim=-1)

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
        )
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
        self.pruned_heads = self.pruned_heads.union(heads)

    def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / (float(v.size(-1)) ** 0.5)
        nd, ns = w.size(-2), w.size(-1)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            mask = self.bias[:, :, ns - nd : ns, :ns]
            mask = mask.to(w.device)

            self.masked_bias = self.masked_bias.to(w.device)
            w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask
        w = self.softmax(w)
        w = self.attn_dropout(w)

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

        outputs = [torch.matmul(w, v)]
        if output_attentions:
            outputs.append(w)
        del mask, nd, ns, v, q, k, attention_mask, head_mask, output_attentions, w
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        return outputs

# Layers map for 4 x Tesla v100 GPUs
layers_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                          1: [11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23],
                          2: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36],
                          3: [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}

# Layers map for 8 x Tesla v100 GPUs
layers_map = {0: [0, 1, 2, 3, 4],
                          1: [5, 6, 7, 8, 9, 10],
                          2: [11, 12, 13, 14, 15, 16],
                          3: [17, 18, 19, 21, 22, 23],
                          4: [24, 25, 26, 27, 28, 29],
                          5: [30, 31, 32, 33, 34, 35],
                          6: [36, 37, 38, 39, 40, 41],
                          7: [, 42, 43, 44, 45, 46, 47]}

model = TransformersModel(layers_map = layers_map)

Got it working. The TrainingArguments object has data parallelism baked into it (along with a lot of other things), so my manual override of the batch size was failing. The tensor size was exploding because TrainingArguments was automatically adjusting the minimum batch size to be the number of tensors. Fine-tuned a gpt2-xl model with 1024 tokens with good results in just 15 minutes.

@alexorona Can you please share code of some example(s) of parallelisms you get to work (maybe through PR to repo examples)?

@patrickvonplaten @LSinev Greatly simplified the working code and refined so that the same basic approach can be used for other models as well. I took at look at T5 and 99% confident I can use the same approach to make it parallelizable. Will get a PR up this week, probably by Sunday.

Model parallel PR merged to transformers.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

chuanmingliu picture chuanmingliu  ยท  3Comments

guanlongtianzi picture guanlongtianzi  ยท  3Comments

iedmrc picture iedmrc  ยท  3Comments

fyubang picture fyubang  ยท  3Comments

zhezhaoa picture zhezhaoa  ยท  3Comments