I'd like to be able to call an arbitrary list of jitted functions inside a function:
from numba import jit
@jit
def func(a):
return a
functions = (func, func)
@jit(nopython=True)
def aggregate_func(a):
for f in functions:
f(a)
aggregate_func(0)
But I get the following error:
Truncated Traceback (Use C-c C-x to view full TB):
/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py in resolve_value_type(self, inst, val)
959 except ValueError as e:
960 msg = str(e)
--> 961 raise TypingError(msg, loc=inst.loc)
962
963 def typeof_arg(self, inst, target, arg):
TypingError: Failed at nopython (nopython frontend)
Untyped global name 'functions': cannot determine Numba type of <class 'tuple'>
File "<ipython-input-30-bd470e6412ee>", line 11
OTOH, if I create the functions tuple inside the aggregate_func it works as expected:
@jit(nopython=True)
def aggregate_func(a):
functions = (func, func)
for f in functions:
f(a)
Also, if the tuple consists of non identical functions, it fails again:
from numba import jit
@jit
def func(a):
return a
@jit
def gfunc(a):
return a
@jit(nopython=True)
def aggregate_func(a):
functions = (func, func, func, gfunc)
for f in functions:
f(a)
aggregate_func(0)
But with a different error:
Truncated Traceback (Use C-c C-x to view full TB):
/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py in propagate(self, raise_errors)
765 if errors:
766 if raise_errors:
--> 767 raise errors[0]
768 else:
769 return errors
TypingError: Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.CallConstraint object at 0x7feb72c56320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/errors.py", line 243, in new_error_context
yield
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py", line 377, in __call__
fnty = typevars[self.func].getone()
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py", line 97, in getone
assert self.type is not None
AssertionError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py", line 128, in propagate
constraint(typeinfer)
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py", line 379, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/usr/lib/python3.5/contextlib.py", line 77, in __exit__
self.gen.throw(type, value, traceback)
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/errors.py", line 249, in new_error_context
six.reraise(type(newerr), newerr, sys.exc_info()[2])
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/errors.py", line 243, in new_error_context
yield
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py", line 377, in __call__
fnty = typevars[self.func].getone()
File "/home/dcole/py3env/lib/python3.5/site-packages/numba/typeinfer.py", line 97, in getone
assert self.type is not None
numba.errors.InternalError:
[1] During: typing of call at <ipython-input-37-5cfd7eaa5b15> (18)
--%<-----------------------------------------------------------------
File "<ipython-input-37-5cfd7eaa5b15>", line 18
I'm guessing this simply isn't supported, so:
Function type is not a first class, yet. Both the type system and runtime system do not handle first class function objects. Functions must be statically known in the compiler.
There's a workaround though. In the code below, I made a chain combinator that recursively builds a new function from the given list of functions. The chain combinator must be used outside of numba.jit. If your usecase accepts functions being defined separately first, this should help:
from numba import njit
@njit
def ident(x):
return x
def chain(fs, inner=ident):
head, tail = fs[-1], fs[:-1]
@njit
def wrap(x):
return head(inner(x))
if tail:
return chain(tail, wrap)
else:
return wrap
@njit
def foo(x):
return x + 1.2
@njit
def bar(x):
return x * 2
# must be used outside of the jit
foobar = chain((foo, bar))
foobarfoo = chain((foo, bar, foo))
@njit
def test():
return foobar(3), foobarfoo(3)
print(test())
A interesting thing happened at:
@jit(nopython=True)
def aggregate_func(a):
functions = (func, func)
for f in functions:
f(a)
functions is typed to tuple<typeof(func)>f is type to be getitem(tuple<typeof(func)>)->typeof(func)f is known by the compiler what exactly it is.but mixing functions won't work because what f is cannot be deduced. It can be a mixed of functions.
Thanks a lot @sklam !
I'm guessing it's not simple to make function types first class. Hopefully that would happen sometime soon.
@seibert is this something that could be solved on its own or can it only be done after having function types? in either case, how long do you think it would take for a knowledgeable dev to do it?
Most helpful comment
Function type is not a first class, yet. Both the type system and runtime system do not handle first class function objects. Functions must be statically known in the compiler.
There's a workaround though. In the code below, I made a
chaincombinator that recursively builds a new function from the given list of functions. Thechaincombinator must be used outside of numba.jit. If your usecase accepts functions being defined separately first, this should help: