Keras: Is it possible to load tensorflow .pb file into Keras as weight for model?

Created on 1 May 2017  路  19Comments  路  Source: keras-team/keras

I have followed the retraining on Tensorflow example for my specific classification task and have a grad-CAM visualization written code in Keras. For instance

Usually, I do load pre-train weights such as vgg16 or inception-v3 in .h5 format and works very well on my grad-CAM work. The problem is the retrained_graph.pb from retraining process by Tensorflow and I have no idea if there are any workaround like

  • mapping .pb file to .h5?

  • or do Keras have any interface to load .pb file in the same manner with loading .h5 file?

Note: I use Tensorflow as the backend

Please advise

stale

All 19 comments

+1
I am also having a similar problem. Any suggestions are appreciated.
Thanks.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

I'm also interested in this +1

+1

+1

+1

+1

+1

+1

+1

+1

+1

@sakares have you been able to solve this issue? I'm trying to do exactly the same (grad-CAM in Keras from a tensorflow-trained model saved as a .pb file)

@cmosquer No luck yet. Probably need to write own TF ".pb" to Keras ".h5" definition.

You should know the graph defination of your pb and copy all weights to each keras layer.
Seems the code format has some problems...

import tensorflow as tf
from tensorflow.python.platform import gfile
from keras.applications.resnet50 import ResNet50
from keras.layers import Dense, GlobalAveragePooling2D, Convolution2D, BatchNormalization
from keras.models import Model
from tensorflow.python.framework import tensor_util

GRAPH_PB_PATH = xxx.pb' #path to your .pb file
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]

wts = [n for n in graph_nodes if n.op=='Const']

weight_dict = {}
for i, n in enumerate(wts):
weight_dict[n.name] = i

model = ResNet50(input_shape=(224, 224, 3), include_top=True)
model.summary()

for layer in model.layers:
layer_weight = layer.get_weights()
name = layer.name
if len(layer_weight) == 0:
continue
if isinstance(layer, Convolution2D):
kname = name + '/kernel'
bname = name + '/bias'
if kname not in weight_dict or bname not in weight_dict:
print kname, bname
else:
weights = []
idx = weight_dict[kname]
wtensor = wts[idx].attr['value'].tensor
weight = tensor_util.MakeNdarray(wtensor)
weights.append(weight)

        idx = weight_dict[bname]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)
        layer.set_weights(weights)
        continue
if isinstance(layer, BatchNormalization):
    beta_name = name + '/beta'
    gamma_name = name + '/gamma'
    mmean_name = name + '/moving_mean'
    mvar_name = name + '/moving_variance'

    if beta_name not in weight_dict or gamma_name not in weight_dict or\
            mmean_name not in weight_dict or mvar_name not in weight_dict:
        print beta_name, gamma_name, mmean_name, mvar_name
    else:
        weights = []
        idx = weight_dict[gamma_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[beta_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[mmean_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[mvar_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)
        layer.set_weights(weights)
        continue
if isinstance(layer, Dense):
    kname = name + '/kernel'
    bname = name + '/bias'
    if kname not in weight_dict or bname not in weight_dict:
        print kname, bname
    else:
        weights = []
        idx = weight_dict[kname]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[bname]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)
        layer.set_weights(weights)
        continue
print name

model.save('tmp.h5')

Any progress made here, @sakares?

@sakares did you solve it? Need the same here

+1 Facing the same issue

I'm facing the same issue, I saw a post recommending to load the weights of the .pb model into de .h5 model. But in my case I can't even load the model, I don't know if anyone is in the same situation.

Was this page helpful?
0 / 5 - 0 ratings