Hi, I just want to know if you plan to release fine-tuning and evaluation code for SWAG dataset.
If not, I wonder if the training procedure is same as MRPC. (more specificly, label 0 for distractors and 1 for gold-ending)
For maintainability reasons we don't plan on releasing more code than what we've released (except for the gradient accumulation code that we've promised). You could train it as a binary classification, but we actually did something different where you softmax over the logits from different examples. This only requires a few lines of code but does require changing the input processing.
Let's assume your batch size is 8 and your sequence length is 128. Each SWAG example has 4 entries, the correct one and 3 incorrect ones.
Instead of your input_fn returning an input_ids of size [128], it should return one of size [4, 128]. Same for mask and sequence ids. So for each example, you will generate the sequences predicate ending0, predicate ending1, predicate ending2, predicate ending3. Also return a label scalar which is in an integer in the range [0, 3] to indicate what the gold ending is.
After batching, your model_fn will get an input of shape [8, 4, 128]. Reshape these to [32, 128] before passing them into BertModel. I.e., BERT will consider all of these independently.
Compute the logits as in run_classifier.py, but your "classifier layer" will just be a vector of size [768] (or whatever your hidden size is).
Now you have a set of logits of size [32]. Re-shape these back into [8, 4] and then compute tf.nn.log_softmax() over the 4 endings for each example. Now you have log probabilities of shape [8, 4] over the 4 endings and a label tensor of shape [8], so compute the loss exactly as you would for a classification problem.
Most helpful comment
For maintainability reasons we don't plan on releasing more code than what we've released (except for the gradient accumulation code that we've promised). You could train it as a binary classification, but we actually did something different where you softmax over the logits from different examples. This only requires a few lines of code but does require changing the input processing.
Let's assume your batch size is 8 and your sequence length is 128. Each SWAG example has 4 entries, the correct one and 3 incorrect ones.
Instead of your
input_fnreturning aninput_idsof size[128], it should return one of size[4, 128]. Same for mask and sequence ids. So for each example, you will generate the sequencespredicate ending0,predicate ending1,predicate ending2,predicate ending3. Also return a label scalar which is in an integer in the range[0, 3]to indicate what the gold ending is.After batching, your
model_fnwill get an input of shape[8, 4, 128]. Reshape these to[32, 128]before passing them intoBertModel. I.e., BERT will consider all of these independently.Compute the logits as in
run_classifier.py, but your "classifier layer" will just be a vector of size[768](or whatever your hidden size is).Now you have a set of logits of size
[32]. Re-shape these back into[8, 4]and then computetf.nn.log_softmax()over the 4 endings for each example. Now you have log probabilities of shape[8, 4]over the 4 endings and a label tensor of shape[8], so compute the loss exactly as you would for a classification problem.