After training my feed forward model I want to save it to disk so I can load it
back up at a later point. The way I do it seems rather cumbersome, but I could
not figure out a better way. Any Advice would be appreciated.
# train_iter and test_iter are NDArrayIter and taken care
# of before this code snippet
# model definition
data = mx.sym.var('data')
fc1 = mx.sym.FullyConnected(data=data, num_hidden=10)
relu1 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=relu1, num_hidden=12)
relu2 = mx.sym.Activation(data=fc2, act_type="relu")
yhat = mx.sym.FullyConnected(data=relu1, num_hidden=n_classes)
prediction = mx.sym.SoftmaxOutput(data=yhat, name="softmax")
model = mx.mod.Module(symbol=prediction, context=mx.cpu())
model.fit(train_iter,
eval_data=test_iter,
optimizer="sgd",
optimizer_params={'learning_rate': 0.1},
eval_metric="acc",
batch_end_callback = mx.callback.Speedometer(BATCH_SIZE,300),
num_epoch=50)
model.save_checkpoint("myfile", 50)
#now assume this is a new session where I want to load the model from before
sym, arg_params, aux_params = mx.model.load_checkpoint("myfile", 50)
new_model = mx.mod.Module(symbol=sym)
new_model.bind(train_iter.provide_data, train_iter.provide_data)
new_model.set_params(arg_params, aux_params)
This all works, but there is the parameter epoch which has to be specified for loading the
files. Is there a way independent on epoch to save/load models?
I can think of two obvious but "hacky" solutions:
a) Save with a fixed number for epoch that actually does not reflect the actual epoch of training.
b) Add an additional model.save_params() after the save_checkpoint()
Package used: Python
MXNet version: 0.10.0
try write your own save and load function:
def save(mod, prefix):
mod._symbol.save('%s-symbol.json' % prefix)
mod.save_params('%s.params' % prefix)
def load(prefix):
symbol = mx.sym.load('%s-symbol.json' % prefix)
save_dict = mx.nd.load('%s.params' % prefix)
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
print name
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return (symbol, arg_params, aux_params)
This issue is closed due to lack of activity in the last 90 days. Feel free to ping me to reopen if this is still an active issue. Thanks!
Most helpful comment
try write your own
saveandloadfunction: