Onnxruntime: ORT Training: Training mnist using provided sample?

Created on 26 Apr 2020  Â·  4Comments  Â·  Source: microsoft/onnxruntime

Is your feature request related to a problem? Please describe.
I see as part of the recently added ort training there are some sample examples that use it to train e.g. an MNIST network:

https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/orttraining/orttraining/models/mnist/main.cc#L96

These take as input an onnx model "to be trained." Is there additional documentation on the format of the model to input? I'm assuming it needs to somehow be augmented with the gradient ops – is there a canonical way to export a training model from PyTorch in onnx format, or is there a provided tool to convert a standard inference onnx model to a training one suitable for use with this runner?

Describe the solution you'd like
Documentation on the type of input model that the ort trainer examples expect.

training-core

Most helpful comment

I attached the mnist test data and models. Please run with provided command line. As you can see, to work with cpp API, you need to follow the mnist example to build the graph with loss output so that backprop graph can be constructed.

mnist.zip
./onnxruntime_training_mnist --model_name ~/mnist/mnist_gemm_simple --train_data_dir ~/mnist/mnist_data/

Thanks

All 4 comments

The recommended way to use ort training feature is with ORT python front end. You may use ORTTrainer to train a PyTorch or an ONNX model.

There is a MNIST example to train a PyTorch model.
https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/onnxruntime/test/python/onnxruntime_test_ort_trainer.py#L250
Model conversion and augmentation with gradient ops are handed in ORTTrainer and the ORT backend.

If you already have an ONNX model. there is an example too (which also uses ORTTrainer). However in this case the ONNX model needs to output loss as well). Please be noted that this approach is only used for testing purpose for now. We will see if there are strong need to really support this use case:
https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/onnxruntime/test/python/onnxruntime_test_ort_trainer.py#L342

In both cases, we like to make the train script as nature to the PyTorch training as possible. Please give it a try and let us know.
Thanks

Thank you for the response!

recommended way to use ort training feature is with ORT python front end

Our use case (training on embedded/edge devices) prevents us from using the Python API; instead we're limited to native binaries and therefore need to call into the C/C++ APIs directly. It seems like this should be possible given the existence of the training example in

https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/orttraining/orttraining/models/mnist/main.cc#L40

Does this example require "the ONNX model.. to output loss as well" as you mentioned? If so, how would one go about adding the necessary gradient ops to construct such a model? I noticed there's also:

https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/orttraining/orttraining/models/mnist/test_grad_graph_builder.cc#L72

Is this meant to be used with orttraining/orttraining/models/mnist/main.cc?

Adding on, I tried running the onnxruntime_training_mnist example binary (the code below) https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/orttraining/orttraining/models/mnist/main.cc#L40

on the Conv/Relu/Maxpool MNIST model builder provided at orttraining/tools/mnist_model_builder/mnist_conv_builder.ipynb

This seems to throw an exception:

terminate called after throwing an instance of 'onnxruntime::OnnxRuntimeException'
  what():  /home/onnxruntime/onnxruntime/orttraining/orttraining/core/graph/gradient_builder_base.h:63 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t) const i < node_->OutputDefs().size() was false. 

at the line
https://github.com/microsoft/onnxruntime/blob/f1a948fd62da4a66b37a5e6d1717519042c15d32/orttraining/orttraining/core/framework/gradient_graph_builder.cc#L190

which calls into

https://github.com/microsoft/onnxruntime/blob/f1a948fd62da4a66b37a5e6d1717519042c15d32/orttraining/orttraining/core/graph/gradient_builder_base.h#L39

For reference, the node type and input/output args that need grad are:

Node type MaxPool
Output args need grad:
         T3
Input args need grad:
         T2

This seems to be because the gradient builder for MaxPool is defined as:

https://github.com/microsoft/onnxruntime/blob/9a4d1c772094ba2518b39d77207a2f852a9fedcb/orttraining/orttraining/core/graph/gradient_builder.cc#L492-L498

even though we only have 1 input/output def. Given that AveragePoolGradient doesn't have O(1) I'm not sure if this is an issue with my model or a possible bug. (ONNX operator spec mentions that the second output for MaxPool: Indices (optional) : I is optional.

I attached the mnist test data and models. Please run with provided command line. As you can see, to work with cpp API, you need to follow the mnist example to build the graph with loss output so that backprop graph can be constructed.

mnist.zip
./onnxruntime_training_mnist --model_name ~/mnist/mnist_gemm_simple --train_data_dir ~/mnist/mnist_data/

Thanks

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Pavel-Konarik picture Pavel-Konarik  Â·  4Comments

JammyZhou picture JammyZhou  Â·  3Comments

fdwr picture fdwr  Â·  4Comments

joshuacwnewton picture joshuacwnewton  Â·  3Comments

walbermr picture walbermr  Â·  3Comments