Hey there, I have a question regarding the unsupervised fine-tuning of the wav2vec2.0 models. As expected, the results that the English-pretrained model achieves in different languages are not that groundbreaking out-of-the-box, at least for the small model pretrained on Libri.
In the readme, you provide an example of how to either train a new model completely from scratch or finetune a pretrained model with CTC on labeled data. While the latter works really well and achieves satisfying results on English datasets such as Common Voice or Voxforge with a fraction of labeled data that would be normally required, the result I got on a different language (Polish) with completely different phonetics are not that good. So naturally, the first thing I want to try is to adapt the domain of the unsupervised model to "get used to" the sound of Polish speech. While I could try to train such a model from scratch, in the paper you mention that it took 1.6 days on 64 V100 GPUS, so I imagine that in order to get the satisfying quality I would need to train for at least a week on 4 RTX 2080Ti I have available and that's something I cannot really afford at the moment. That's why I wanted to try to finetune the existing model on the target domain, hoping that this way I could improve results on polish data with a fraction of training time.
Soooo my questions are:
I have already launched the procedure by renaming the wav2vec_small to checkpoint_last.pt and starting from that directory as the --save-dir. However, I had to pass the --reset-optimizer flag, because apparently, the Criterions did not match (the code u have in readme uses --criterion wav2vec, however, the loaded checkpoint had BinaryCrossEntropyCriterion for some reason.
i think its worth a try but this is an open research question. you can do this with --restore-file pointing to pretrained model. dont forget various --reset-* flags (not just optimizer, you want to reset lr scheduler and so on as well).
you probably want to define some training budget, then set --max-update to the number of updates you will use. the learning rate will decay from initial value to 0 over the course of training (unless you change the lr scheduler). you probably want to use a lower starting lr, but i cant say what is a good value - you have to experiment
make sure your data is 16khz single channel
some of those models were trained on some old branch of the code, which is why criterions names are different
Thanks @alexeib, I will apply your suggestions and let you know how the training went :D BTW, obviously for any kind of fine-tuning with the model you provide I need to keep data in the 16kHz single-channel format. However, if I wanted to train a new model for 8k data, better suited for telephone speech, I guess I should adjust the parameters of the backend convolutional featurizer network so that parameters such as effective window length and resolution roughly match those of the 16kHz model?
Also BTW_2, I would be happy if you could tell me what is the valid_raw_wer metric and how does it differ from valid_wer?
valid_wer is wer decoded with an lm, if you provide one via --wer-args. otherwise it is the same as valid_raw_wer which is so called "viterbi" wer (but really its just argmax since there is no decoding)
valid_wer is wer decoded with an lm, if you provide one via --wer-args. otherwise it is the same as valid_raw_wer which is so called "viterbi" wer (but really its just argmax since there is no decoding)
Sorry if my question is out of place, but shouldn't wer be less than or equal to 100? How are values above 100 possible?
wer is edit distance normalized by target length. if your target length is 1 but you predicted 1000 incorrect words, then your wer will be 1000
- i think its worth a try but this is an open research question. you can do this with --restore-file pointing to pretrained model. dont forget various --reset-* flags (not just optimizer, you want to reset lr scheduler and so on as well).
- you probably want to define some training budget, then set --max-update to the number of updates you will use. the learning rate will decay from initial value to 0 over the course of training (unless you change the lr scheduler). you probably want to use a lower starting lr, but i cant say what is a good value - you have to experiment
make sure your data is 16khz single channel
some of those models were trained on some old branch of the code, which is why criterions names are different
You guys actually have a --finetune-from-model /path/to/model/pt flag that resets all the relevant stuff when specified!
Sorry for spamming you here, but you may be a better source of information than going through the code and configs and looking for a bug. I ran the fine-tuning with roughly the parameters specified in the readme, I just changed the :
#!/bin/bash
data_dir=$1
base_model=$2
python train.py --distributed-world-size 2 "${data_dir}" \
--save-dir "${base_model}" --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \
--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \
--conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \
--latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \
--adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 100000 \
--lr 0.0005 --warmup-updates 8000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \
--encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \
--loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 \
--max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
--max-tokens 640000 --max-update 100000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq 32 --reset-optimizer \
Let's call it the iteration_1 model. I let it run for around 50k updates and terminated to run the same training with --finetune-from-model /path/to/model/pt.
The iteration_1 model reached the following metrics (not sure how good/bad it is), the accuracy improved from an initial 42% to around 52% on the validation set, so I suppose it learned somehing:
```
020-11-17 16:58:58 | INFO | valid | epoch 641 | valid on 'valid' subset | loss 2.711 | ntokens 1484.41 | nsentences 12.2085 | prob_perplexity 351.996 | code_perplexity 302.738 | temp 1.475 | loss_0 1.31797 | loss_1 0.0323925 | loss_2 0.00514853 | accuracy 0.52731 |
wps 18284.5 | wpb 1484.4 | bsz 12.2 | num_updates 60835 | best_loss 2.711
2020-11-17 16:58:58 | INFO | fairseq_cli.train | begin save checkpoint
2020-11-17 16:59:07 | INFO | fairseq.checkpoint_utils | saved checkpoint models/201112_base_960_finetune_unsupervised_polish_16k/checkpoint641.pt (epoch 641 @ 60835 updates, score 2.711) (writing took 9.34416126087308 seconds)
I then tried to run the ctc fine-tuning with the `iteration_1` model. There was, however, a problem, where the checkpoint from training did not contain the `args` and broke (the second time I ran the same train script with the same model it worked for some reason, perhaps some defaults were set and saved?). I ran the ctc training and it looks like the `iteration_1` model works better in terms of raw, argmax/viterbi loss than the raw english model without any additional unsupervised finetuning, which is promising. However, for some reason, the model stopped working with the arpa language model. Looking at the transcription it seems that when lm is used it inserts a lot of random characters, as seen below.
target:['dwieście', 'czterdzieści', 'pięć']
pred_raw:['dwieście', 'czterdzieści', 'pięć']
pred_decode:['dwieście', 'czterdzieści', 'i', 'pięć', 'w', 'w', 'w', 'w', 'w', 'w']
```
Have you ever encountered something like that? I didn't have any problems working with lm decoding when I ctc finetuned the wav2vec_small model you provide.
are you using your own language model on your target language?
you might have to retune the language model weights now that you've updated the model (i.e. --lm-weight and --word-score). the 2nd one in particular is word insertion penalty that may help you get better results
Yup, I am using my own ARPA model for Polish. I will try to debug the issue, for starters tweaking the decoding hyperparams does not seem to improve the situation. Maybe I will try checking out to master and run the unsupervised fine-tuning from there, maybe something got fixed in the meantime.
I can say, however, that the unsupervised fine-tuning on domain data seems to work quite well, as using just 20% of available data (around 100 hours) gives better CTC output than when using 100% of data and wav2vec_small model that has been unsupervised trained on just libri. I just need to fix the issue with LMs.
@alexeib, sorry for relying on you so much :< If you may, I would like to know semantically what is the difference between <s> tag and <|> tag in emissions? When dumping raw emissions from the model, it seems like the old model, ie. ctc finetuned wav2vec_small from readme is more likely to output sequences with <s> between relevant symbols and at the end of utterance there would be like two | symbols, while the iteration 1 ctc finetuned model outputs | much more often.
so the emissions for the old model are like
dict apply ['<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', 's', 'z', 'z', 'e', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', 'ś', 'ś', '<s>', '<s>', '<s>', '<s>', 'ć', 'ć', '<s>', '|', '|', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>']
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 10 8 8 5 0 0 0 0 0 0 0 0 0 0 0 27
27 0 0 0 0 25 25 0 4 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
and the decoded transcription and reference word is sześć (six)
The new model, however, for the same wave would produce the following output:
dict apply ['<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', 's', 'z', '<s>', 'e', '<s>', 'ś', 'ś', 'ś', 'ć', 'ć', '|', '|', '|', '|', '|', '|', '|', '|', '|', '|', '|', '|', '|', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>', '<s>']
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 10 8 0 5 0 27 27 27 25 25 4 4 4 4 4 4
4 4 4 4 4 4 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
and the decoded transcription is sześć t t t t z, which is retarded, because single t is not even a valid word in Polish, we have it in lm lexicon for spelling purposes.
<s> is interpreted by fairseq as a "beginning of sentence" token, which i have hijacked to use a ctc blank token. so if you want to decode a ctc output then you collapse consecutive duplicates, then remove blanks.
| is a word boundary token (as defined by lexicon/training data - nothing special about it code wise except during eval)