Incubator-mxnet: Save / Load model (Python)

Created on 7 Jul 2017  路  2Comments  路  Source: apache/incubator-mxnet

After training my feed forward model I want to save it to disk so I can load it
back up at a later point. The way I do it seems rather cumbersome, but I could
not figure out a better way. Any Advice would be appreciated.

# train_iter and test_iter are NDArrayIter and taken care
# of before this code snippet

# model definition
data = mx.sym.var('data')

fc1 = mx.sym.FullyConnected(data=data, num_hidden=10)
relu1 = mx.sym.Activation(data=fc1, act_type="relu")

fc2 = mx.sym.FullyConnected(data=relu1, num_hidden=12)
relu2 = mx.sym.Activation(data=fc2, act_type="relu")

yhat = mx.sym.FullyConnected(data=relu1, num_hidden=n_classes)
prediction = mx.sym.SoftmaxOutput(data=yhat, name="softmax")

model = mx.mod.Module(symbol=prediction, context=mx.cpu())
model.fit(train_iter,
          eval_data=test_iter,
          optimizer="sgd",
          optimizer_params={'learning_rate': 0.1},
          eval_metric="acc",
          batch_end_callback = mx.callback.Speedometer(BATCH_SIZE,300),
          num_epoch=50)

model.save_checkpoint("myfile", 50)

#now assume this is a new session where I want to load the model from before
sym, arg_params, aux_params = mx.model.load_checkpoint("myfile", 50)
new_model = mx.mod.Module(symbol=sym)
new_model.bind(train_iter.provide_data, train_iter.provide_data)
new_model.set_params(arg_params, aux_params)

This all works, but there is the parameter epoch which has to be specified for loading the
files. Is there a way independent on epoch to save/load models?
I can think of two obvious but "hacky" solutions:
a) Save with a fixed number for epoch that actually does not reflect the actual epoch of training.
b) Add an additional model.save_params() after the save_checkpoint()

Environment info

Package used: Python

MXNet version: 0.10.0

Most helpful comment

try write your own save and load function:

def save(mod, prefix):
    mod._symbol.save('%s-symbol.json' % prefix)
    mod.save_params('%s.params' % prefix)

def load(prefix):
    symbol = mx.sym.load('%s-symbol.json' % prefix)
    save_dict = mx.nd.load('%s.params' % prefix)
    arg_params = {}
    aux_params = {}
    for k, v in save_dict.items():
        tp, name = k.split(':', 1)
        print name
        if tp == 'arg':
            arg_params[name] = v
        if tp == 'aux':
            aux_params[name] = v
    return (symbol, arg_params, aux_params)

All 2 comments

try write your own save and load function:

def save(mod, prefix):
    mod._symbol.save('%s-symbol.json' % prefix)
    mod.save_params('%s.params' % prefix)

def load(prefix):
    symbol = mx.sym.load('%s-symbol.json' % prefix)
    save_dict = mx.nd.load('%s.params' % prefix)
    arg_params = {}
    aux_params = {}
    for k, v in save_dict.items():
        tp, name = k.split(':', 1)
        print name
        if tp == 'arg':
            arg_params[name] = v
        if tp == 'aux':
            aux_params[name] = v
    return (symbol, arg_params, aux_params)

This issue is closed due to lack of activity in the last 90 days. Feel free to ping me to reopen if this is still an active issue. Thanks!

Was this page helpful?
0 / 5 - 0 ratings