Tensor2tensor: Looking for pre-trained transformer checkpoints for translation

Created on 8 Mar 2018  路  7Comments  路  Source: tensorflow/tensor2tensor

The machine I have right now is just too slow. Wondering maybe anyone would be willing to share any pre-trained transformer model checkpoints for the translation task on any language?

Most helpful comment

The filename is not correct, just check:

! ls -l /content/t2t/checkpoints/transformer_ende_test

The necessary files are:

  • /content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0.data-00000-of-00001
  • /content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0.index
  • /content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0.meta
  • /content/t2t/checkpoints/transformer_ende_test/checkpoint

When you copy these files into your Google drive, you should be able to start decoding on your own machine :)

All 7 comments

en-de checkpoints are available at Google Colab

@martinpopel How exactly do I get these checkpoints from Google Colab?
Running files.download('/content/t2t/checkpoints/transformer_ende_test/model.ckpt-350855') throws MessageError: Error: Failed to download:

You could try to "upload" the model to your Google drive, see manual here :)

@stefan-it

I've tried, it just can't find the files on the notebook server for some reason. This is the code:

import os
import tensorflow as tf
from tensor2tensor import problems
!pip install -U -q PyDrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

ckpt_name = "transformer_ende_test"
gs_ckpt_dir = "gs://tensor2tensor-checkpoints/"
gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)

# Copy the pretrained checkpoint locally
checkpoint_dir = os.path.expanduser("~/t2t/checkpoints")
tf.gfile.MakeDirs(checkpoint_dir)
!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}

ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))
ckpt_path

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# PyDrive reference:
# https://googledrive.github.io/PyDrive/docs/build/html/index.html

# 2. Create & upload a file text file.
uploaded = drive.CreateFile({'title': 'checkpoint.txt'})
uploaded.SetContentFile(ckpt_path)
uploaded.Upload()
print('Uploaded file with ID {}'.format(uploaded.get('id')))

# 3. Load a file by ID and print its contents.
downloaded = drive.CreateFile({'id': uploaded.get('id')})
print('Downloaded content "{}"'.format(downloaded.GetContentString()))

uploaded.SetContentFile(ckpt_path) Throws:

IOError: [Errno 2] No such file or directory: u'/content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0'

The filename is not correct, just check:

! ls -l /content/t2t/checkpoints/transformer_ende_test

The necessary files are:

  • /content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0.data-00000-of-00001
  • /content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0.index
  • /content/t2t/checkpoints/transformer_ende_test/averaged.ckpt-0.meta
  • /content/t2t/checkpoints/transformer_ende_test/checkpoint

When you copy these files into your Google drive, you should be able to start decoding on your own machine :)

@stefan-it It's working. Thanks!

Has anyone tried to convert this checkpoint to .pb file? I tried:
import tensorflow as tf
meta_path = './model.ckpt-1421000.meta' # Your .meta file

with tf.Session() as sess:

# Restore the graph
saver = tf.train.import_meta_graph(meta_path, clear_devices=True)

# Load weights
print tf.train.latest_checkpoint('.')
saver.restore(sess, tf.train.latest_checkpoint('.'))


# Output nodes
output_node_names =[n.name for n in tf.get_default_graph().as_graph_def().node]

# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
    sess,
    sess.graph_def,
    output_node_names)

# Save the frozen graph
with open('output_graph.pb', 'wb') as f:
  f.write(frozen_graph_def.SerializeToString())

I managed to reconstruct a .pb model from this script but I cannot load that model in python.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

thompsonb picture thompsonb  路  3Comments

apeterswu picture apeterswu  路  3Comments

ndvbd picture ndvbd  路  3Comments

sebastian-nehrdich picture sebastian-nehrdich  路  4Comments

SapphireEmbers picture SapphireEmbers  路  3Comments