Ml-agents: Using a Keras/Tensorflow Model as Brain

Created on 29 Mar 2018  路  7Comments  路  Source: Unity-Technologies/ml-agents

Does anyone know if you can export a previously trained model from Keras/Tensorflow into the bytes format required?

Ie: Train model in Keras using TF back end -> call TF to export -> place exported file into models/ in ml-agents folder.

help-wanted

Most helpful comment

I found that it does work. Here is my code:

import tensorflow as tf
from keras.models import model_from_json
from keras.optimizers import SGD

studypath = 'some\\path'

with tf.Session() as sess:
    h5 =  studypath + '\\somefile.h5' 
    json = h5.replace(".h5", "_.json")

    sgd = SGD(lr=0.1, decay=.1, momentum=.9, nesterov=False)
    model = model_from_json(open(json).read())
    model.compile(loss='categorical_crossentropy',
                  optimizer=sgd,
                  metrics=['accuracy'])
    model.load_weights(h5)

    init_op = tf.initialize_all_variables()
    sess.run(init_op)

    graph = sess.graph

    tf.train.Saver().save(sess, studypath+'simple.ckpt')
    tf.train.write_graph(sess.graph.as_graph_def(), logdir=studypath,     name='simple_as_binary.pb', as_text=False)

All 7 comments

I believe I'll need to load the saved model and weights in python, name the nodes in the graph, then call tensorflow's freeze_graph function. This seems to be outlined in: https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Using-TensorFlow-Sharp-in-Unity.md

Edit: You can get the TF graph from keras by calling keras.backend.get_session().graph

Hi @ohernpaul,

Were you able to get it working? You are exactly right that you can use the freeze_graph function to concert a graph to a .bytes. By default the tensorflow function produces a protobuf file, which we rename to .bytes so that Unity asset importer can understand it.

@awjuliani I'm currently working on it, will post soon.

Had to modify my model to have 'observation' and 'action' layer names and a reshaping layer because I was reshaping data outside graph.

Great to hear that it worked! I will close the issue for now. If you have additional questions, feel free to either re-open this issue or make another one.

@awjuliani Will this line work for serializing the model?

tf.train.Saver().save(sess, 'simple.ckpt') tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.', name='simple_as_binary.pb', as_text=False)

This produces a protobuf file, which you said you can rename to .bytes

Rather than using
`from tensorflow.python.tools import freeze_graph

freeze_graph.freeze_graph(input_graph = model_path +'/raw_graph_def.pb',
input_binary = True,
input_checkpoint = last_checkpoint,
output_node_names = "action",
output_graph = model_path +'/your_name_graph.bytes' ,
clear_devices = True, initializer_nodes = "",input_saver = "",
restore_op_name = "save/restore_all", filename_tensor_name = "save/Const:0")`

From: https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Using-TensorFlow-Sharp-in-Unity.md

I found that it does work. Here is my code:

import tensorflow as tf
from keras.models import model_from_json
from keras.optimizers import SGD

studypath = 'some\\path'

with tf.Session() as sess:
    h5 =  studypath + '\\somefile.h5' 
    json = h5.replace(".h5", "_.json")

    sgd = SGD(lr=0.1, decay=.1, momentum=.9, nesterov=False)
    model = model_from_json(open(json).read())
    model.compile(loss='categorical_crossentropy',
                  optimizer=sgd,
                  metrics=['accuracy'])
    model.load_weights(h5)

    init_op = tf.initialize_all_variables()
    sess.run(init_op)

    graph = sess.graph

    tf.train.Saver().save(sess, studypath+'simple.ckpt')
    tf.train.write_graph(sess.graph.as_graph_def(), logdir=studypath,     name='simple_as_binary.pb', as_text=False)

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

gerardsimons picture gerardsimons  路  3Comments

Porigon45 picture Porigon45  路  3Comments

scotthovestadt picture scotthovestadt  路  4Comments

Procuste34 picture Procuste34  路  3Comments

MarcPilgaard picture MarcPilgaard  路  3Comments