Sum currently supports 1 or 2 inputs. In some use cases, we have more than that. We need to support those too. Example:
op {
input: "__m21_shared"
input: "__m3_shared"
input: "__m11_shared"
input: "__m9_shared"
input: "__m14_shared"
input: "__m10_shared"
input: "__m22_shared"
input: "__m20_shared"
input: "__m24_shared"
input: "__m25_shared"
input: "__m26_shared"
input: "__m27_shared"
input: "__m28_shared"
output: "__m29_shared"
name: ""
type: "Sum"
device_option {
device_type: 0
}
}
What do you think about loading this as a tree of binary Sums?
Yeah, normally how do we name those internal nodes?
Another possibility is to represent this as concat + batched add.
Do we have batched add? And concat will create a fairly large intermediate tensor, right?
@yinghai
Yes, we do:
BB.newNode("BatchedAdd")
.addInput("Batch")
.addInput("Slice")
.addResultFromCtorArg()
.setDocstring(
"Adds the 'Slice' operand to each one of the slices in the batch.");
Actually wait, maybe @nadavrot meant to say BatchedReduceAdd?
Concat often doesn't take any extra memory, because we just direct its inputs to be written consecutively in memory. If that's the case, concat+batched reduce add is probably the way to go. (Boo, I thought I had it right for once :-p ).
@SplitInfinity Oh, yes!
@bertmaher I wonder if we'll be able to optimize the memory usage, because we won't have temporary storage for the results. We can start with something simple and iterate as necessary.
Thanks, folks. I just added a PR for that. Getting into it. :)
Most helpful comment
Another possibility is to represent this as concat + batched add.