Is your feature request related to a problem? Please describe.
I see as part of the recently added ort training there are some sample examples that use it to train e.g. an MNIST network:
These take as input an onnx model "to be trained." Is there additional documentation on the format of the model to input? I'm assuming it needs to somehow be augmented with the gradient ops – is there a canonical way to export a training model from PyTorch in onnx format, or is there a provided tool to convert a standard inference onnx model to a training one suitable for use with this runner?
Describe the solution you'd like
Documentation on the type of input model that the ort trainer examples expect.
The recommended way to use ort training feature is with ORT python front end. You may use ORTTrainer to train a PyTorch or an ONNX model.
There is a MNIST example to train a PyTorch model.
https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/onnxruntime/test/python/onnxruntime_test_ort_trainer.py#L250
Model conversion and augmentation with gradient ops are handed in ORTTrainer and the ORT backend.
If you already have an ONNX model. there is an example too (which also uses ORTTrainer). However in this case the ONNX model needs to output loss as well). Please be noted that this approach is only used for testing purpose for now. We will see if there are strong need to really support this use case:
https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/onnxruntime/test/python/onnxruntime_test_ort_trainer.py#L342
In both cases, we like to make the train script as nature to the PyTorch training as possible. Please give it a try and let us know.
Thanks
Thank you for the response!
recommended way to use ort training feature is with ORT python front end
Our use case (training on embedded/edge devices) prevents us from using the Python API; instead we're limited to native binaries and therefore need to call into the C/C++ APIs directly. It seems like this should be possible given the existence of the training example in
Does this example require "the ONNX model.. to output loss as well" as you mentioned? If so, how would one go about adding the necessary gradient ops to construct such a model? I noticed there's also:
Is this meant to be used with orttraining/orttraining/models/mnist/main.cc?
Adding on, I tried running the onnxruntime_training_mnist example binary (the code below) https://github.com/microsoft/onnxruntime/blob/78fde2c4cbeaaaf46fa97e1ffaea0fa01cbbcb33/orttraining/orttraining/models/mnist/main.cc#L40
on the Conv/Relu/Maxpool MNIST model builder provided at orttraining/tools/mnist_model_builder/mnist_conv_builder.ipynb
This seems to throw an exception:
terminate called after throwing an instance of 'onnxruntime::OnnxRuntimeException'
what(): /home/onnxruntime/onnxruntime/orttraining/orttraining/core/graph/gradient_builder_base.h:63 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t) const i < node_->OutputDefs().size() was false.
which calls into
For reference, the node type and input/output args that need grad are:
Node type MaxPool
Output args need grad:
T3
Input args need grad:
T2
This seems to be because the gradient builder for MaxPool is defined as:
even though we only have 1 input/output def. Given that AveragePoolGradient doesn't have O(1) I'm not sure if this is an issue with my model or a possible bug. (ONNX operator spec mentions that the second output for MaxPool: Indices (optional) : I is optional.
I attached the mnist test data and models. Please run with provided command line. As you can see, to work with cpp API, you need to follow the mnist example to build the graph with loss output so that backprop graph can be constructed.
mnist.zip
./onnxruntime_training_mnist --model_name ~/mnist/mnist_gemm_simple --train_data_dir ~/mnist/mnist_data/
Thanks
Most helpful comment
I attached the mnist test data and models. Please run with provided command line. As you can see, to work with cpp API, you need to follow the mnist example to build the graph with loss output so that backprop graph can be constructed.
mnist.zip
./onnxruntime_training_mnist --model_name ~/mnist/mnist_gemm_simple --train_data_dir ~/mnist/mnist_data/
Thanks