Many of our latency issues (e.g. time to first plot) primarily consist of inference time. This issue is for documenting and tracking ideas for improving this. I'll try to put ~one idea per comment.
Idea 1: customize max_methods per module.
The first 4
on this line in base/compiler/params.jl is really important:
#=inline_tupleret_bonus, max_methods, union_splitting, apply_union_enum=#
400, 4, 4, 8,
When an inferred call site has multiple possible targets (due to type imprecision) that limits the number of methods we examine. So the amount of work done by inference can theoretically be max_methods^n
where n
is the length of a call chain. Usually the blowup is not that bad, but on rare occasion it is, and we end up inferring tons of methods on very general types, ultimately getting no really useful information. Reducing max_methods reliably speeds up inference, with by far the best results from max_methods=1 (unsurprisingly!).
Setting max_methods=1 globally often results in worse type info. But a lot of code is probably fine with it. The idea here is to allow setting max_methods=1 for all code in a given module, e.g. Plots.jl, similar (but orthogonal) to the per-module @nospecialize
we already have.
Idea 2: speed up matching_methods
The inner loop of inference is looking up the set of method matches for a call site. Our procedure for doing that is fairly bulky. The majority of lookups (~60%) are actually for concrete types with a single method match, which could be handled by a much more efficient table lookup like what we use for method dispatch.
Idea 3: cache matching_methods between inference and inlining
Inference calls _methods_by_ftype
to determine matching methods for each call site. Then we do inlining, which calls it again for each call site to see if it can inline anything. We could add a per-statement array to cache those results.
Idea 4: bail out of long abstract call chains
As discussed in idea 1 above, inference sometimes goes on a wild goose chase where f calls g calls h and so on, for a very long chain, and potentially all on Any
and resulting in nothing useful. We could potentially detect that situation and bail out. I imagine we would throw back to the last "useful" function (e.g. one whose signature passes jl_isa_compileable_sig
), recording Any
as the type of the problematic call site.
This is a little tricky. Most importantly, note that the logic
if call_depth > max
return Any
end
does not work, since it would make inference context-sensitive: we'd infer a different type for a function based on where in the call chain it occurs. So some care is needed.
Idea 5: elide recursion checks in abstract_call_method
For each call site, we check whether it is recursive and therefore might need to be limited. That involves stepping back through the inference call stack. However, we might just be trying to look up some straightforward thing for which we already have a good cached inference result. It seems we should be able to elide the recursion check in that case. I haven't thought too much about this one though.
Idea 6: avoid inferring work that only matter for compilation
Inference always completely analyzes the callee function, under the assumption that we may later need to optimize or inline it. However, in many cases, we don't end up optimizing or inlining it. Just computing the return type is often easy and can be substantially cheaper.
Idea 7: be able to save inference results in .ji files
E.g. #31466#32705 (EDIT @timholy) works towards this. Could save a lot of inference time for precompiled packages.
Backedges part 1: reducing graph density
When f
calls g
, and inference on f
depends on information about g
, a "backedge" is stored from g
to f
, so that if g
changes we can invalidate the inference result.
Backedges need to be stored in .ji files and processed on load, which takes time, and invalidating too many functions leads to costly re-inferring and compiling.
The idea here is to reduce the number of backedges somehow. For example, find call sites that are not "important" (e.g. they only happen on an error path) and infer them as Any
to avoid storing a backedge.
Backedges part 2: precision
Currently any new method causes invalidation via any backedges with an overlapping signature. This could be much more precise if it could take inference results into account. For example, if all we infer about f(x)
is that it returns an AbstractArray
, and f(x)
is dynamically dispatched, then new methods of f
should have no effect as long as they also return AbstractArray
.
Backedges part 3: limit the propagation of invalidation
If a given MethodInstance gets invalidated, all of its callers (direct and indirect) also get invalidated. In many cases this seems wasteful: in particular, if the return type inference is the same and the MethodInstance was not inline_worthy
, then it seems that it should be possible (in principle) to just update the immediate callers to call the new version, and thus break the (sometimes very long) chain of invalidations.
(Cross post from slack, don't want to lose it.) Here's a call chain of 3 methods. We change the lowest one, and if we could rewrite the call from the middle method to call a different instance of the lowest method, we could stop there. This comes up frequently when implementing "fallbacks" and "specializations," where here I've implemented an O(N)
fallback and an O(1)
specialization:
julia> @noinline function countelements(iter)
# In general, we just have to count them, O(N)
n = 0
for item in iter
n += 1
end
return n
end
countelements (generic function with 1 method)
julia> @noinline domath(x) = countelements(x)*5.0 - 1.5
domath (generic function with 1 method)
julia> doubleit(x) = 2*domath(x)
doubleit (generic function with 1 method)
julia> x = [1, 2, 3]
3-element Array{Int64,1}:
1
2
3
julia> doubleit(x)
27.0
julia> @code_llvm doubleit(x); @ REPL[3]:1 within `doubleit'
define double @julia_doubleit_214(%jl_value_t* nonnull align 16 dereferenceable(40)) {
top:
%1 = call double @j_domath_215(%jl_value_t* nonnull %0)
; ┌ @ promotion.jl:312 within `*' @ float.jl:405
%2 = fmul double %1, 2.000000e+00
; â””
ret double %2
}
Now watch the invalidations when we define the O(1)
algorithm:
julia> unsafe_store!(cglobal(:jl_debug_method_invalidation, Cint), 1)
Ptr{Int32} @0x00007f07d4824180
julia> @noinline countelements(iter::AbstractArray) = length(iter)
domath(Array{Int64, 1}) (in Main)
doubleit(Array{Int64, 1}) (in Main)
>> Main.countelements(...) Tuple{typeof(Main.countelements), AbstractArray{T, N} where N where T}
countelements (generic function with 2 methods)
julia> doubleit(x)
27.0
julia> @code_llvm doubleit(x); @ REPL[3]:1 within `doubleit'
define double @julia_doubleit_249(%jl_value_t* nonnull align 16 dereferenceable(40)) {
top:
%1 = call double @j_domath_250(%jl_value_t* nonnull %0)
; ┌ @ promotion.jl:312 within `*' @ float.jl:405
%2 = fmul double %1, 2.000000e+00
; â””
ret double %2
}
You can see doubleit
got recompiled when all we really had to do was change a single call in domath
.
My preliminary sense is that this comes up quite a lot in practice. Currently we get a bunch of invalidations from loading FixedPointNumbers that stem from convert(Type{<:Bool}, ::Bool)
. Now, that might be fixed by other means, but in any event it would seem that once you fix the direct callers of this method you should be able to stop there.
Most helpful comment
Idea 7: be able to save inference results in .ji files
E.g.
#31466#32705 (EDIT @timholy) works towards this. Could save a lot of inference time for precompiled packages.