Julia: better internal interface for extending `broadcast`

Created on 25 May 2017  Â·  16Comments  Â·  Source: JuliaLang/julia

@MikeInnes suggested in MikeInnes/Flux.jl#31 that we change the dot-call lowering in such a way as to allow it to be disabled for custom types (which may reside on a GPU or something and support only a small set of functions).

For example, currently x .+ y .* x.^z is lowered to broadcast((x,y,z) -> x+y*x^z, x,y,z). This could instead be lowered to:

if Base.isfusing(x,y,z)
    broadcast((x,y) -> x+y*x^z, x,y,z)
else
    broadcast(+, x, broadcast(*, y, broadcast(^, x, z)))
end

with Base.isfusing(...) = true being the default. This would also address #22053 (cc @malmaud).

Pro:

  • Makes life easier in the short run for containers based on external libraries that only support a few efficient operations. They could overload isfusing and broadcast(::typeof(foo), ...) for a small number of supported foo.

  • Makes life easier in specialized applications where it may not be possible to define broadcast for general functions, e.g. in Convex.jl where it needs to guarantee convexity.

Con:

  • You can no longer look at an expression like x .+ y .* x and know that it fuses into a single loop with no temporaries.

  • There is no middle ground. A single non-fusable operand will "spoil" fusion for an entire set of nested dot calls. (To fuse a subset of an expression, you'd have to explicitly assign it to a temporary array.) (We could lower to a nested set of isfusing calls, of course, but then you'd get an exponential explosion in lowered code size.)

In the long run, I really think that packages like TensorFlow.jl should exploit the underlying library's ability to define custom operations as callback functions to implement broadcast(f::Function, ...) for arbitrary f, at which point they can defining isfusing(...) = true and get all the benefits of fusion.

broadcast gpu speculative

All 16 comments

This seems like a good solution.

It reminds me of the solution we had to package precompilation - one non-precompile-enabled package disables all its dependents from precompiling, just like one non-fusing operand spoils a whole call. That put pressure on package authors to support precompiling, but gave them some time and flexibility to do so and preserved backwards compatibility. This solution gives us a similar path forward for supporting fusing across the Julia ecosystem.

That would also fix https://github.com/JuliaLang/julia/issues/19313 (though I'm less interested in it since Nullable is going to be replaced with Union{T, Null} to represent missingness).

I am right in saying that the branch in the lowered-code would be optimized away during specialization on input types? If used within a function where x,y,z are type-stable.

@oxinabox, yes.

Couldn't we do the same thing without this feature? For example, if we want to overload Tensor for their element-wise operations

function broadcast(f, x::Tensor, y::Tensor)
  f(Wrapper(x), Wrapper(y)).data
end

+(x::Wrapper{Tensor}, y::Wrapper{Tensor}) = element_wise_add(x.data, y.data) |> Wrapper
*(x::Wrapper{Tensor}, y::Wrapper{Tensor}) = element_wise_mult(x.data, y.data) |> Wrapper

so for x .+ y .* z, f become (x,y,z)->x+y*z, and it is called with the Wrappers. Then + and * also get called with Wrappers, so the actual Tensor are called by element_wise_mult and element_wise_add finally.

@ylxdzsw, good trick! That might well work.

Could that be seen as "lifting" the Arrays so they are considered scalars from the p.o.v of broadcast, defining operations in these scalars and then extracting out the Array's again?

It's a neat trick but I'm not sure how it will go in practice; at the least it's going to have an impact on error messages and traces.

How would you get multiple types (which could each be fusing or not) to play nicely together? It seems like you'd need some standard Wrapper type to do the dispatch on, and some way of marking what should be wrapped before broadcast – which devolves pretty quickly into a more complicated version of this approach :)

@MikeInnes, it is more complicated, but think it could be done entirely with dispatch by defining a broadcast(f, ::Union{Tensor,X,Y}...) method for all X and Y that your library supports broadcasting tensors with (presumably only a small number), rather than by changes to lowering.

The problem with that is that we are talking about a diverse set of array-like containers that come from different libraries. I am definitely in favour of the trait-like implementation.

The problem with that is that we are talking about a diverse set of array-like containers that come from different libraries. I am definitely in favour of the trait-like implementation.

Indeed, even calling some of them containers is misleading.
Eg TensorFlow's Tensors do not hold values as such.
They describe a computational graph in terms of tensors.
When that graph is executed, then some containers are filled in the C++ backend.
I guess objects-with-array-like-semantics is a bit of a mouthful though.

I'm interested in @ylxdzsw "trick", I think we would need to try it out, to get clarity

cc @madeleineudell

also @denizyuret

This came up quite a few times at JuliaCon so I put something together to hack around it. Please try it out and let me know how it goes.

I have to say, I'm pretty disappointed that Base has left a bunch of packages out in the cold like this. I hope it's worth it to save some characters on broadcast! in a few places.

Can we add this to the 1.0 milestone?

While I've added this to the 1.0 milestone, we may want to consider internals of how dot-fusion syntax works to be unstable.

We've decided that broadcast internals can change over 1.x.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

felixrehren picture felixrehren  Â·  3Comments

TotalVerb picture TotalVerb  Â·  3Comments

sbromberger picture sbromberger  Â·  3Comments

StefanKarpinski picture StefanKarpinski  Â·  3Comments

omus picture omus  Â·  3Comments