Julia: inference failure for broadcast implementations that use `flatten` (eg SparseArrays)

Created on 8 Jul 2018  ยท  16Comments  ยท  Source: JuliaLang/julia

I've been using the new broadcasting API to implement broadcasting for some custom types (which btw, the new API is great and massively simplified my code, thanks!). I did however run into the following surprising inference failure which I've reduced down to the following code:

import Base.Broadcast: BroadcastStyle, materialize, broadcastable
using Base.Broadcast: Broadcasted, Style, flatten
using Test

struct NumWrapper{T}
    data::T
end

broadcastable(n::NumWrapper) = n
BroadcastStyle(::Type{N}) where {N<:NumWrapper} = Style{N}()
broadcast_data(n::NumWrapper) = (n.data,)
function materialize(bc::Broadcasted{Style{N}}) where {N<:NumWrapper}
    flat_bc = flatten(bc)
    N(broadcast.(flat_bc.f, broadcast_data.(flat_bc.args)...)...)
end

foo(a,b) = @. a * a + (b * a) * b
bar(a,b) = @. a * a + b * a * b

@inferred foo(NumWrapper(1), NumWrapper(2)) #inferred as Any
@inferred bar(NumWrapper(1), NumWrapper(2)) #inferred correctly

As you can see, each NumWrapper just holds a number, and broadcasting over e.g. a::NumWrapper .+ b::NumWrapper becomes NumWrapper(broadcast.(+,(a.data,),(b.data,))...) = NumWrapper((a.data .+ b.data,)...). Note I need the .data wrapped in a tuple in my real code because in the real case NumWrapper has multiple fields; this is also crucial to triggering the bug, although in the simple example above it probably seems unnecessary. In any case, you see that the seemingly unimportant addition of the parenthesis around (b*a) spoils type stability. This is on commit 656d58782c.

Beyond a possible solution, I'm also curious if there's a workaround or a better way to code this up, I can't say I'm 100% sure I've used the new API as intended (in particular, I've used flatten and the docs make it seem like this shouldn't usually be necessary). In any case, hope this helps!

EDIT: Fixed a small mistake in the text describing what a::NumWrapper .+ b::NumWrapper became.

broadcast performance

Most helpful comment

Hi @mbauman and others, so #28111 definitely fixed the code in my original post, but the following still does not work :

using Test, SparseArrays
foo(a) = @. a*a*a*a*a*a*a
bar(a) = @. a*(a*a)*a*a*a*a
@inferred foo(spzeros(1)) # inferred
@inferred bar(spzeros(1)) # not inferred

(since it was pointed out above that SparseArrays use flatten, I was able to reduce this further to not bother with the NumWrapper thing).

Oddly enough (and hopefully a hint), the parenthesis around literally any other pair of 2 as is fine.

I think this is different than #28126 / @gasagna's example above because there there is no inference failure, while here there is.

All 16 comments

Excellent bug report! @keno, would be great if you can look at the inference failure. @mbauman, can you give advice on the broadcast API usage?

Other than to say that I suspect you're working harder than you need to, this isn't quite enough to go on. What are the rules for indexing a NumWrapper? What rules of arithmetic do they follow? The main design goal of the new API for extending broadcasting (https://github.com/JuliaLang/julia/pull/23939) was to be "orthogonal" to both indexing and arithmetic, so that you only have to specify things that are specific to broadcasting (like how the indexing should work when combined with other containers).

So, you probably need some indexing rules, maybe like those in number.jl. And you should make sure that arithmetic works on its own:

julia> a, b = NumWrapper(1), NumWrapper(2)
(NumWrapper{Int64}(1), NumWrapper{Int64}(2))

julia> a*a + (b*a)*b
ERROR: MethodError: no method matching *(::NumWrapper{Int64}, ::NumWrapper{Int64})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:502
Stacktrace:
 [1] top-level scope at none:0

The absence of these rules may be why you have non-inferrability.

If they basically act like scalars, then it's just possible that you might need only one definition to support broadcasting:

Base.BroadcastStyle(::Type{<:NumWrapper}) = Base.Broadcast.DefaultArrayStyle{0}()

That declares that they act like 0-dimensional arrays, and I think the normal rules will work to preserve the "eltype." If you do have to give them a specific style, then you're going to have to define binary rules to specify how they combine with other containers.

Thanks, yea I think you're right, sorry, the above example has too much stripped out for you to be able to really give meaningful advice.

To say a bit more, the real purpose is to have a structure (what I called NumWrapper above and I'll call Field below) for which broadcast operations are "forwarded" down to the individual fields of the structure. So if you had, e.g.

struct Foo <: Field
    a::Matrix
    b::Matrix
end

struct Bar <: Field
    c::Matrix
end

then you would have basically the following equivalency,

x::Foo .+ y::Foo .+ z::Foo = Foo(x.a .+ y.a .+ z.a, x.b .+ y.b .+ z.b)

Note that this isn't quite what e.g. tuple broadcasting does, because that doesn't "foward" the broadcast to the individual elements, so it can be inefficient if the elements are things like large matrices as is my case.

Similarly, you would have,

x::Foo .+ w::Bar = Foo(x.a .+ w.c, x.b .+ w.c)

In the real case I do indeed have a set of binary broadcast rules, for example to decide that the above result is a Foo (and some for reasons beyond even what I showed here).

I also have a materialize function which is in fact identical to the one in my first post, i.e. it does a call to flatten then does a broadcasted broadcast call over the fields in the data structure. That's really the piece I'm unsure of is the right way to do this. Any advice on this much appreciated (but by no means expected! really the main purpose here was the bug report)

This helps. There's still some question in my mind about whether you should be expressing those operations as broadcasting .+ or just ordinary +. Does Field support indexing, i.e., can I ask for (x::Foo)[3,5]? Can I set an entry with x[3,5] = (1.7, 3.5) or similar? If the answer is "no" then this might not be broadcasting; you might be just defining your own arithmetic operators which you happen to be writing in broadcasted notation. But if they don't support indexing, then in my opinion this is an abuse of broadcasting. (It's fine to use broadcasting in the internal implementation of your arithmetic rules, but if Foo doesn't support indexing then summing two Foos isn't a broadcasting operation.)

If they do support indexing, then you might only need something along the lines of

# These act like two-dimensional arrays, and should take precedence over e.g., Array
struct FooStyle <: Broadcast.AbstractArrayStyle{2} end
struct BarStyle <: Broadcast.AbstractArrayStyle{2} end
Base.BroadcastStyle(::Type{<:Foo}) = FooStyle()
Base.BroadcastStyle(::Type{<:Bar}) = BarStyle()

# FooStyle "beats" BarStyle: Foo.+Foo->Foo, Foo.+Bar->Foo, Bar.+Bar->Bar
Base.BroadcastStyle(::FooStyle, ::BarStyle) = FooStyle()

# Teach Julia how to allocate the output container
Base.similar(bc::Broadcasted{FooStyle}, ::Type{ElType}) = Foo{ElType}(undef, size(bc))
Base.similar(bc::Broadcasted{BarStyle}, ::Type{ElType}) = Bar{ElType}(undef, size(bc))

Julia's internals should handle the rest automatically for you :smile:.

I would agree with Tim on the broadcasting design, although I must confess that this is a rather clever abuse that allows forwarding of arbitrary functions. You may run into trouble if your type hits another that overrides broadcast styles, though.

As far as the inference failure, it's stemming from flatten โ€” we recursively construct a set of closures to do the flattening, and that's where inference is losing the trail. It's a rather crazy computation, but the part that's failing is the relatively simpler flattening of argument lists in cat_nested. I did a bit of cursory poking at it, but throwing a few extra @inlines and ::Vararg{Any,N} specializations on it didn't seem to solve the underlying issue. Here's a simpler MWE:

julia> bc =  Broadcast.Broadcasted(+, (Broadcast.Broadcasted(*, (1, 2)), Broadcast.Broadcasted(*, (Broadcast.Broadcasted(*, (3, 4)), 5))));

julia> Broadcast.cat_nested(x->x.args, bc)
(1, 2, 3, 4, 5)

julia> @code_warntype Broadcast.cat_nested(x->x.args, bc)
Body::Tuple{Int64,Int64,Vararg{Any,N} where N}
โ€ฆ

Thanks for the replies. So my Fields do represent abstract vectors, the vector components just being all the entries in the matrices...e.g. the vector representation of f::Foo would be [f.a[:]; f.b[:]] (the reason to store them as matrices instead of just actual vectors is that there's other operations I do on them that make more sense in terms of the separate matrices). This means they could have get/setindex defined and I could implement broadcasting the way @timholy suggested, but in my get/setindex function I would have to do some indexing arithmetic to map into the vector representation, and my guess was that that would make broadcasting slower. Instead I went about it in the "top-down" fashion above and forwarded the broadcast directly to the matrices, which I figured wouldn't incur the performance hit (this I can confirm on 0.6 where it currently works) and also felt more intuitive. But maybe there's a way to code the indexing arithmetic in a way without incurring a performance hit? I admit I have not tried.

I guess the question I have then is: is this a performance bug or not? It sounds like this may be a case where broadcasting is inappropriate but this is an optimization we want to work anyway?

This is a real performance issue, and it'll have an effect on any broadcasting implementation that uses flatten โ€” this notably includes SparseArrays. Here's the minimal example:

using Base: tail
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(cat_nested(t[1], tail(t)), rest)
cat_nested(t::Tuple{}, tail) = tail
t = ((1, 2), ((3, 4), 5))
@code_warntype cat_nested(t, ())

This worked in Julia 0.6, but I imagine that it broke as part of @vtjnash's work in limiting non-terminating recursive inference. Is there a workaround here? Or could we support this? The goal is simply to flatten arbitrarily nested tuples โ€” in this case, just return (1,2,3,4,5).

Ah, here's the workaround: don't be doubly-recursive within a single method:

using Base: tail
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
cat_nested(t::Tuple{}, tail::Tuple{}) = ()
t = ((1, 2), ((3, 4), 5))
@code_warntype cat_nested(t, ())

Ok, this is bizarre. I'm getting a state-dependent order-of-compilation difference in inference:

$ ./julia
               _
   _       _ _(_)_     |  A fresh approach to technical computing
  (_)     | (_) (_)    |  Documentation: https://docs.julialang.org
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 0.7.0-beta.283 (2018-07-12 22:44 UTC)
 _/ |\__'_|_|_|\__'_|  |  Commit 98061abb5a* (0 days old master)
|__/                   |  x86_64-linux-gnu

julia> using Base: tail
       cat_nested(t::Tuple) = cat_nested(t, ())
       cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
       cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
       cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
       cat_nested(t::Tuple{}, tail::Tuple{}) = ()
       t = ((1, 2), ((3, 4), 5))
((1, 2), ((3, 4), 5))

julia> @code_warntype cat_nested(t, ())
Body::NTuple{5,Int64}
4 1 โ”€ %1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}            โ”‚โ•ป       getindex
  โ”‚         getfield(%%t, 1)                                           โ”‚โ•ป       tail
  โ”‚   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}          โ”‚โ”‚
  โ”‚   %4  = Base.getfield(%1, 1, true)::Int64                          โ”‚โ”‚โ•ป       getindex
  โ”‚         getfield(%1, 1)                                            โ”‚โ”‚โ•ป       tail
  โ”‚   %6  = getfield(%1, 2)::Int64                                     โ”‚โ”‚โ”‚
  โ”‚   %7  = Base.getfield(%3, 1, true)::Tuple{Int64,Int64}             โ”‚โ”‚โ”‚โ•ปโ•ทโ•ทโ•ท    cat_nested
  โ”‚         getfield(%3, 1)                                            โ”‚โ”‚โ”‚โ”‚โ•ป       cat_nested
  โ”‚   %9  = getfield(%3, 2)::Int64                                     โ”‚โ”‚โ”‚โ”‚โ”‚โ”ƒโ”‚      cat_nested
  โ”‚   %10 = Base.getfield(%7, 1, true)::Int64                          โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ•ป       cat_nested
  โ”‚         getfield(%7, 1)                                            โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ•ป       tail
  โ”‚   %12 = getfield(%7, 2)::Int64                                     โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚
  โ”‚   %13 = Core.tuple(%4, %6, %10, %12, %9)::NTuple{5,Int64}          โ”‚โ”‚
  โ””โ”€โ”€       return %13                                                 โ”‚

julia> @code_warntype cat_nested(t)
Body::NTuple{5,Int64}
2 1 โ”€ %1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}          โ”‚โ•ปโ•ท       cat_nested
  โ”‚         getfield(%%t, 1)                                         โ”‚โ”‚โ•ป        tail
  โ”‚   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}        โ”‚โ”‚โ”‚
  โ”‚   %4  = Base.getfield(%1, 1, true)::Int64                        โ”‚โ”‚โ”‚โ•ป        getindex
  โ”‚         getfield(%1, 1)                                          โ”‚โ”‚โ”‚โ•ป        tail
  โ”‚   %6  = getfield(%1, 2)::Int64                                   โ”‚โ”‚โ”‚โ”‚
  โ”‚   %7  = Base.getfield(%3, 1, true)::Tuple{Int64,Int64}           โ”‚โ”‚โ”‚โ”‚โ•ปโ•ทโ•ทโ•ท     cat_nested
  โ”‚         getfield(%3, 1)                                          โ”‚โ”‚โ”‚โ”‚โ”‚โ•ป        cat_nested
  โ”‚   %9  = getfield(%3, 2)::Int64                                   โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”ƒโ”‚       cat_nested
  โ”‚   %10 = Base.getfield(%7, 1, true)::Int64                        โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ•ป        cat_nested
  โ”‚         getfield(%7, 1)                                          โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ•ป        tail
  โ”‚   %12 = getfield(%7, 2)::Int64                                   โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚โ”‚
  โ”‚   %13 = Core.tuple(%4, %6, %10, %12, %9)::NTuple{5,Int64}        โ”‚โ”‚โ”‚
  โ””โ”€โ”€       return %13                                               โ”‚

But if I restart my session and call those two methods in a different order:

$ ./julia -q
julia> using Base: tail
       cat_nested(t::Tuple) = cat_nested(t, ())
       cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
       cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
       cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
       cat_nested(t::Tuple{}, tail::Tuple{}) = ()
       t = ((1, 2), ((3, 4), 5))
((1, 2), ((3, 4), 5))

julia> @code_warntype cat_nested(t)
Body::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
2 1 โ”€ %1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}                     โ”‚โ•ปโ•ท  cat_nested
  โ”‚         getfield(%%t, 1)                                                    โ”‚โ”‚โ•ป   tail
  โ”‚   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}                   โ”‚โ”‚โ”‚
  โ”‚   %4  = Core.tuple(%3)::Tuple{Tuple{Tuple{Int64,Int64},Int64}}              โ”‚โ”‚
  โ”‚   %5  = Base.getfield(%1, 1, true)::Int64                                   โ”‚โ”‚โ”‚โ•ป   getindex
  โ”‚   %6  = Core.tuple(%5)::Tuple{Int64}                                        โ”‚โ”‚โ”‚
  โ”‚         getfield(%1, 1)                                                     โ”‚โ”‚โ”‚โ•ป   tail
  โ”‚   %8  = getfield(%1, 2)::Int64                                              โ”‚โ”‚โ”‚โ”‚
  โ”‚   %9  = Core.tuple(%8)::Tuple{Int64}                                        โ”‚โ”‚โ”‚โ”‚
  โ”‚   %10 = invoke Main.cat_nested(%9::Tuple{Int64}, %4::Tuple{Tuple{Tuple{Int64,Int64},Int64}})::Tuple{Int64,Int64,Int64,Vararg{Any,N} where N}
  โ”‚   %11 = Core._apply(Core.tuple, %6, %10)::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
  โ””โ”€โ”€       return %11                                                          โ”‚

julia> @code_warntype cat_nested(t, ())
Body::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
4 1 โ”€ %1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}                       โ”‚โ•ป  getindex
  โ”‚         getfield(%%t, 1)                                                      โ”‚โ•ป  tail
  โ”‚   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}                     โ”‚โ”‚
  โ”‚   %4  = Core.tuple(%3)::Tuple{Tuple{Tuple{Int64,Int64},Int64}}                โ”‚
  โ”‚   %5  = Base.getfield(%1, 1, true)::Int64                                     โ”‚โ”‚โ•ป  getindex
  โ”‚   %6  = Core.tuple(%5)::Tuple{Int64}                                          โ”‚โ”‚
  โ”‚         getfield(%1, 1)                                                       โ”‚โ”‚โ•ป  tail
  โ”‚   %8  = getfield(%1, 2)::Int64                                                โ”‚โ”‚โ”‚
  โ”‚   %9  = Core.tuple(%8)::Tuple{Int64}                                          โ”‚โ”‚โ”‚
  โ”‚   %10 = invoke Main.cat_nested(%9::Tuple{Int64}, %4::Tuple{Tuple{Tuple{Int64,Int64},Int64}})::Tuple{Int64,Int64,Int64,Vararg{Any,N} where N}
  โ”‚   %11 = Core._apply(Core.tuple, %6, %10)::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
  โ””โ”€โ”€       return %11                                                            โ”‚

Consider this MWE, a similar case to the OP:

using BenchmarkTools

struct AugmentedState{X, Q}
    x::X
    q::Q
end

_state(x::AugmentedState) = x.x
_quad(x::AugmentedState) = x.q
_state(x) = x
_quad(x) = x


if VERSION > v"0.6.5"
    using Printf
    @inline Broadcast.broadcastable(x::AugmentedState) = x
    Base.ndims(::Type{<:AugmentedState}) = 0

    @inline function Broadcast.materialize!(dest::AugmentedState, bc::Broadcast.Broadcasted)
        bcf = Broadcast.flatten(bc)
        Broadcast.broadcast!(bcf.f, _state(dest), map(_state, bcf.args)...)
        Broadcast.broadcast!(bcf.f, _quad(dest),  map(_quad,  bcf.args)...)
        return dest
    end
else
    @inline function Base.Broadcast.broadcast!(f, dest::AugmentedState, args...)
        broadcast!(f, _state(dest), map(_state, args)...)
        broadcast!(f,  _quad(dest), map(_quad,  args)...)
        return dest
    end
end

a = AugmentedState(rand(100), rand(100))
b = AugmentedState(rand(100), rand(100))
c = AugmentedState(rand(100), rand(100))
d = AugmentedState(rand(100), rand(100))

bar(a, b, c, d) = (a .= a .+ 2.0.*b .+ 5.0.*c .- 7.0.*d; a)

t1 = @belapsed $bar($a, $b, $c, $d)
t2 = @belapsed (bar($a.x, $b.x, $c.x, $d.x); bar($a.q, $b.q, $c.q, $d.q))

@printf "AugumentedState = %7d ns\n" t1*10^9
@printf "Arrays          = %7d ns\n" t2*10^9
@printf "Ratio           = %7d x  \n"  t1/t2

Using v0.6.4 I get:

AugumenteState =     615 ns
Arrays         =     616 ns
Ratio          =       1 x  

while on latest master I get:

AugumentedState =   13962 ns
Arrays          =     256 ns
Ratio           =      55 x  

Checking out #28111 does not seem to make a large difference for this case, though.

Closed by accident? @gasagna showed #28111 likely isn't the solution (at least for his related problem).

Edit: I see, that's considered a different issue #28126

Yes, using latest master, the performance of the code in my previous comment does not improve and has high allocations.

Hi @mbauman and others, so #28111 definitely fixed the code in my original post, but the following still does not work :

using Test, SparseArrays
foo(a) = @. a*a*a*a*a*a*a
bar(a) = @. a*(a*a)*a*a*a*a
@inferred foo(spzeros(1)) # inferred
@inferred bar(spzeros(1)) # not inferred

(since it was pointed out above that SparseArrays use flatten, I was able to reduce this further to not bother with the NumWrapper thing).

Oddly enough (and hopefully a hint), the parenthesis around literally any other pair of 2 as is fine.

I think this is different than #28126 / @gasagna's example above because there there is no inference failure, while here there is.

Messing around some more, I think I've reduced it further to what might be the same "order-of-compilation difference in inference" mentioned in https://github.com/JuliaLang/julia/issues/27988#issuecomment-404882335 above, although that was before #28111, here I'm after (specifically 0.7rc2)

Here's a MWE. This will infer correctly:

using Test
using Base.Broadcast: cat_nested, Broadcasted
@inferred cat_nested(1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)
@inferred cat_nested(Broadcasted(*, (1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)))

and in a fresh session, this will fail (this is the same thing as above with the 3rd line commented):

using Test
using Base.Broadcast: cat_nested, Broadcasted
#@inferred cat_nested(1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)
@inferred cat_nested(Broadcasted(*, (1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)))

with

ERROR: return type NTuple{7,Int64} does not match inferred return type Tuple{Int64,Int64,Int64,Int64,Int64,Vararg{Int64,N} where N}

That failure in the second case is exactly what's causing the failure for a*(a*a)*a*a*a*a above.

Curious if anyone has made progress on this since the dust has settled after JuliaCon?

After more reading / digging, I venture to guess that what is happening here is exactly what is described in the "Independence of the cached result" section from https://juliacomputing.com/blog/2017/05/15/inference-converage2.html which "remains an unsolved problem". This is probably clear to the people that might be able to fix this anyway (if I'm right that that's what it is), but I figure worth mentioning.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

TotalVerb picture TotalVerb  ยท  3Comments

dpsanders picture dpsanders  ยท  3Comments

musm picture musm  ยท  3Comments

wilburtownsend picture wilburtownsend  ยท  3Comments

arshpreetsingh picture arshpreetsingh  ยท  3Comments