Onnxruntime: Java API OnnxTensor.createTensor() uses buffer incorrectly

Created on 28 Aug 2020  路  14Comments  路  Source: microsoft/onnxruntime

Describe the bug
When constructing a tensor from a buffer (any type), OnnxTensor.createTensor() (e.g., here) uses data.capacity() rather than data.remaining() to determine the buffer size. This results in an exception if the received buffer is created over only part of an array (e.g. via IntBuffer.wrap(data, offset, len)).

Urgency
I can work around it, but it'll have some perf impact.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Java API
  • ONNX Runtime installed from (source or binary): binary 1.4.0
  • ONNX Runtime version: 1.4.0

To Reproduce

Here's some sample code that fails:

final IntBuffer src = IntBuffer.wrap(new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8}, 3, 4).asReadOnlyBuffer();
OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), src, new long[]{4});

produces this exception:

ai.onnxruntime.OrtException: Shape [4], requires 4 elements but the buffer has 9 elements.
    at ai.onnxruntime.TensorInfo.constructFromBuffer(TensorInfo.java:259)
    at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:731)
    at ai.onnxruntime.OnnxTensor.createTensor(OnnxTensor.java:701)

Expected behavior

The above example should create a 1-d tensor with value [3, 4, 5, 6]

Java bug

All 14 comments

@yuslepukhin Is it correct?

Hmm, I wrote it as capacity because the offset into a Buffer wasn't a concept that I thought it should track as it's essentially treating a java.nio.Buffer as a complete ndarray that you can't slice. I didn't realise the Buffer.wrap call set the offset and that people would expect it to be semantically meaningful. Note that the IntBuffer.wrap call uses the whole backing array as the buffer, not the slice that's selected, it just sets the offset into the buffer to be the specified point. I guess at least it should state in the docs that it copies the whole buffer.

What's the use case for passing in chunks of a Buffer?

I have a large array in memory and would like to create tensors from parts of it without another array allocation (I understand that OnnxTensor will have to copy that bit of data at some point, in particular the non-data.isDirect() path).

The contract of reading from a buffer is you can read from the current position up to the limit (designated via remaining or hasRemaining). It'll throw BufferUnderflowException if you try to read past its limit.

I think you'll find replacing capacity() with remaining() will do the right thing. For example the tmp.put(data) here will only read from the current position to the limit:

ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
        tmp = buffer.asIntBuffer();
        tmp.put(data);

the behavior of the isDirect() path I guess depends on how the buffer is used in the native code, but I don't have a direct buffer so I guess it isn't affecting me.

The method should probably use buffer.limit() rather than buffer.capacity(), which is a straightforward fix, and definitely should be documented that it operates on the whole of the filled buffer. Potentially it could also have an optional parameter which controlled if it looks at the buffer position or not.

When I use Buffers for moving around tensor data I don't tend to consider the position within that buffer to be a useful piece of state, it's not got a good analogy to an ndarray (as it's too primitive to be a functional slice). Given you can pass in whole batches of data to be scored at once, rather than slicing it into separate buffers inducing a copy, it seems better to do that.

Edit: I'm happy to switch it over to use limit (as that's a bug on my part, though not the issue you're hitting), and if we add a boolean flag that lets it read the whole buffer vs from the position (i.e. switch limit vs remaining with the limit one doc'd that it will ignore the position) would that meet your need?

The flag would work for me assuming the example above produces the expected behavior. Sorry to be a pest, but why not just follow the contract of Buffer so there are no surprises?

Because the buffer semantics aren't ideal for an interchange format for numerical data as the position is usually irrelevant (and causes issues like needing to be reset each time the buffer is converted into an ONNXTensor which is very confusing if you're coming at it from an ML or Python perspective). Unfortunately until MemorySegment lands it's the only option inside the JDK.

I'll work up a PR today or Monday, and then we can discuss the exact semantics in reference to some code?

I get that the abstraction isn't perfect for what you want to do, but I don't understand what you gain by ignoring the contract. If you follow the contract it'll still work transparently for everyone that's using it naively (position = 0, limit = capacity), but it'll also work transparently for those who are using more "advanced" functionality. No flags required. FWIW, the equivalent TF Java API just works.

Either way, I appreciate you looking into it and providing a workaround. Thanks!

@daveray would you like to create a PR?

On the JNI side it's harder to access the position of a direct ByteBuffer (I'll have to pass in position and remaining into the JNI call to avoid calling back into the JVM) and it'll have to do the pointer arithmetic on the C side to get the right position out, so it's going to be a bigger patch than just a Java side fix.

There's also another issue, which is should the create method leave the state of the Buffer unchanged (i.e. rewind it to the incoming position). It doesn't do that at the moment, because it's not sensitive to the position. I would prefer it to not change the state of the buffer in an observable way (to ONNX Runtime at least), but I'm not sure if that would cause issues for your usecase. @daveray what do you think?

Hey. Sorry about the delayed response. I think leaving the state of the buffer unchanged would be fine. Thanks!

On the JNI side it's harder to access the position of a direct ByteBuffer (I'll have to pass in position and remaining into the JNI call to avoid calling back into the JVM) and it'll have to do the pointer arithmetic on the C side to get the right position out, so it's going to be a bigger patch than just a Java side fix.

Ah, the joys of refusing to use a tool like JavaCPP... That's one more thing you're going to need to reinvent! Like I keep telling you, why not put all that stuff in an external library so that it could be useful to other libraries than ORT? I see you've started to rely on multiple libraries in Tribuo. (BTW, is this one an official Oracle project? Or is that also a personal hobby that just happens to be in a repository under Oracle but that isn't endorsed by Oracle in any way?) Do you realize that you'll need to keep patching all the libraries that you support there one by one to fix these kinds of issues when they pop up?

@daveray FYI, the C/C++ API mapped with JavaCPP has no such issues:
https://github.com/bytedeco/javacpp-presets/tree/master/onnxruntime

@saudet Tribuo will be officially launched at the Java developer day in a few weeks. It's been used in production inside Oracle for several years and we're excited to share it with everyone. I look forward to your contributions as we help to build the ML ecosystem on the Java platform.

I've made a PR which should fix this issue. @daveray could you check it over? I think the test case I added accurately captures your usecase, but let me know if it doesn't. Sorry it took a while, got caught up with Tribuo launch stuff.

I think we may close it if everyone agrees?

Was this page helpful?
0 / 5 - 0 ratings