@MikeInnes suggested in MikeInnes/Flux.jl#31 that we change the dot-call lowering in such a way as to allow it to be disabled for custom types (which may reside on a GPU or something and support only a small set of functions).
For example, currently x .+ y .* x.^z
is lowered to broadcast((x,y,z) -> x+y*x^z, x,y,z)
. This could instead be lowered to:
if Base.isfusing(x,y,z)
broadcast((x,y) -> x+y*x^z, x,y,z)
else
broadcast(+, x, broadcast(*, y, broadcast(^, x, z)))
end
with Base.isfusing(...) = true
being the default. This would also address #22053 (cc @malmaud).
Pro:
Makes life easier in the short run for containers based on external libraries that only support a few efficient operations. They could overload isfusing
and broadcast(::typeof(foo), ...)
for a small number of supported foo
.
Makes life easier in specialized applications where it may not be possible to define broadcast
for general functions, e.g. in Convex.jl where it needs to guarantee convexity.
Con:
You can no longer look at an expression like x .+ y .* x
and know that it fuses into a single loop with no temporaries.
There is no middle ground. A single non-fusable operand will "spoil" fusion for an entire set of nested dot calls. (To fuse a subset of an expression, you'd have to explicitly assign it to a temporary array.) (We could lower to a nested set of isfusing
calls, of course, but then you'd get an exponential explosion in lowered code size.)
In the long run, I really think that packages like TensorFlow.jl should exploit the underlying library's ability to define custom operations as callback functions to implement broadcast(f::Function, ...)
for arbitrary f
, at which point they can defining isfusing(...) = true
and get all the benefits of fusion.
This seems like a good solution.
It reminds me of the solution we had to package precompilation - one non-precompile-enabled package disables all its dependents from precompiling, just like one non-fusing operand spoils a whole call. That put pressure on package authors to support precompiling, but gave them some time and flexibility to do so and preserved backwards compatibility. This solution gives us a similar path forward for supporting fusing across the Julia ecosystem.
That would also fix https://github.com/JuliaLang/julia/issues/19313 (though I'm less interested in it since Nullable
is going to be replaced with Union{T, Null}
to represent missingness).
I am right in saying that the branch in the lowered-code would be optimized away during specialization on input types? If used within a function where x
,y
,z
are type-stable.
@oxinabox, yes.
Couldn't we do the same thing without this feature? For example, if we want to overload Tensor
for their element-wise operations
function broadcast(f, x::Tensor, y::Tensor)
f(Wrapper(x), Wrapper(y)).data
end
+(x::Wrapper{Tensor}, y::Wrapper{Tensor}) = element_wise_add(x.data, y.data) |> Wrapper
*(x::Wrapper{Tensor}, y::Wrapper{Tensor}) = element_wise_mult(x.data, y.data) |> Wrapper
so for x .+ y .* z
, f
become (x,y,z)->x+y*z
, and it is called with the Wrapper
s. Then +
and *
also get called with Wrapper
s, so the actual Tensor
are called by element_wise_mult
and element_wise_add
finally.
@ylxdzsw, good trick! That might well work.
Could that be seen as "lifting" the Array
s so they are considered scalars from the p.o.v of broadcast
, defining operations in these scalars and then extracting out the Array
's again?
It's a neat trick but I'm not sure how it will go in practice; at the least it's going to have an impact on error messages and traces.
How would you get multiple types (which could each be fusing or not) to play nicely together? It seems like you'd need some standard Wrapper
type to do the dispatch on, and some way of marking what should be wrapped before broadcast – which devolves pretty quickly into a more complicated version of this approach :)
@MikeInnes, it is more complicated, but think it could be done entirely with dispatch by defining a broadcast(f, ::Union{Tensor,X,Y}...)
method for all X
and Y
that your library supports broadcasting tensors with (presumably only a small number), rather than by changes to lowering.
The problem with that is that we are talking about a diverse set of array-like containers that come from different libraries. I am definitely in favour of the trait-like implementation.
The problem with that is that we are talking about a diverse set of array-like containers that come from different libraries. I am definitely in favour of the trait-like implementation.
Indeed, even calling some of them containers is misleading.
Eg TensorFlow's Tensors
do not hold values as such.
They describe a computational graph in terms of tensors.
When that graph is executed, then some containers are filled in the C++ backend.
I guess objects-with-array-like-semantics is a bit of a mouthful though.
I'm interested in @ylxdzsw "trick", I think we would need to try it out, to get clarity
cc @madeleineudell
also @denizyuret
This came up quite a few times at JuliaCon so I put something together to hack around it. Please try it out and let me know how it goes.
I have to say, I'm pretty disappointed that Base has left a bunch of packages out in the cold like this. I hope it's worth it to save some characters on broadcast!
in a few places.
Can we add this to the 1.0 milestone?
While I've added this to the 1.0 milestone, we may want to consider internals of how dot-fusion syntax works to be unstable.
We've decided that broadcast internals can change over 1.x.