Numba: scipy.linalg expm and solve_banded functions

Created on 28 Sep 2018  路  4Comments  路  Source: numba/numba

Dear developer and users,
If possible, I'll like to know an example to use scipy.linalg expm and solve_banded function inside numba nopython jitted function. (https://numba.pydata.org/numba-doc/latest/extending/high-level.html).

question scipy

Most helpful comment

Thanks for the query.
In general SciPy functions are not supported by Numba. Supported features are noted in the documentation:

__To end up with something with the highest performance__:
To obtain the functionality you require scipy.linalg.expm would need entirely reimplementing in a form Numba can compile (use the Pad茅 approximation?). The banded solve is less effort as it just requires a call to LAPACK {s,d,c,z}gbsv which is available from the SciPy LAPACK cython exports:

In [12]: import scipy.linalg.cython_lapack as cla

In [13]: cla.__pyx_capi__['dgbsv']
Out[13]: <capsule object "void (int *, int *, int *, int *, __pyx_t_5scipy_6linalg_13cython_lapack_d *, int *, int *, __pyx_t_5scipy_6linalg_13cython_lapack_d *, int *, int *)" at 0x7fc8c54e9540>

from there the Numba extension module provides a method of getting the address:

In [17]: import numba.extending as nbe

In [18]: nbe.get_cython_function_address('scipy.linalg.cython_lapack', 'dgbsv')
Out[18]: 140499892185344

which can then just be used in an appropriately declared ctypes.CFUNCTYPE. An example is documented here: http://numba.pydata.org/numba-doc/latest/extending/high-level.html#importing-cython-functions

At this point, because Numba supports ctypes, this bound function can be used in a nopython jitted function (pass the pointers via the array .ctypes attribute).

__To end up with something that works well, with far less effort__:
Whilst reimplementing expm in Numba would probably result in a faster expm, the expm in SciPy is probably quite fast already. Further, the solve_banded function performance is likely to be dominated by the LAPACK routine, so there may be little benefit in running this from nopython mode. If you need these two functions as part of a more complicated algorithm, Numba 0.40.0 has a new feature, numba.objmode, a context manager that will allow nopython mode to jump back into object mode (documentation is here). This would allow nopython mode to run before and after those unsupported SciPy calls. Conceptually:


@njit
def func(matrix):
  # do manipulations on matrix, or whatever is needed here
  # then jump into object mode to run the SciPy functions
  with objmode(answer='float64[:]'):
    newmat1 = scipy.linalg.expm(matrix)
    answer = scipy.linalg.solve_banded(<e.g. stage newmat1 for the call, supply RHS etc>)
  # jump back into nopython mode
  # do something with answer etc

All 4 comments

Thanks for the query.
In general SciPy functions are not supported by Numba. Supported features are noted in the documentation:

__To end up with something with the highest performance__:
To obtain the functionality you require scipy.linalg.expm would need entirely reimplementing in a form Numba can compile (use the Pad茅 approximation?). The banded solve is less effort as it just requires a call to LAPACK {s,d,c,z}gbsv which is available from the SciPy LAPACK cython exports:

In [12]: import scipy.linalg.cython_lapack as cla

In [13]: cla.__pyx_capi__['dgbsv']
Out[13]: <capsule object "void (int *, int *, int *, int *, __pyx_t_5scipy_6linalg_13cython_lapack_d *, int *, int *, __pyx_t_5scipy_6linalg_13cython_lapack_d *, int *, int *)" at 0x7fc8c54e9540>

from there the Numba extension module provides a method of getting the address:

In [17]: import numba.extending as nbe

In [18]: nbe.get_cython_function_address('scipy.linalg.cython_lapack', 'dgbsv')
Out[18]: 140499892185344

which can then just be used in an appropriately declared ctypes.CFUNCTYPE. An example is documented here: http://numba.pydata.org/numba-doc/latest/extending/high-level.html#importing-cython-functions

At this point, because Numba supports ctypes, this bound function can be used in a nopython jitted function (pass the pointers via the array .ctypes attribute).

__To end up with something that works well, with far less effort__:
Whilst reimplementing expm in Numba would probably result in a faster expm, the expm in SciPy is probably quite fast already. Further, the solve_banded function performance is likely to be dominated by the LAPACK routine, so there may be little benefit in running this from nopython mode. If you need these two functions as part of a more complicated algorithm, Numba 0.40.0 has a new feature, numba.objmode, a context manager that will allow nopython mode to jump back into object mode (documentation is here). This would allow nopython mode to run before and after those unsupported SciPy calls. Conceptually:


@njit
def func(matrix):
  # do manipulations on matrix, or whatever is needed here
  # then jump into object mode to run the SciPy functions
  with objmode(answer='float64[:]'):
    newmat1 = scipy.linalg.expm(matrix)
    answer = scipy.linalg.solve_banded(<e.g. stage newmat1 for the call, supply RHS etc>)
  # jump back into nopython mode
  # do something with answer etc

Dear @stuartarchibald,
Thanks for explaining this to me. I just need to call scipy functions inside numba jitted function and the snippet of code you posted is exactly what I want. I will try this approach..

@sahaskn no problem, thanks for getting back to us, I'm glad this works for you.

@stuartarchibald, It works. The snippet of code is :

from numba import njit, objmode
import numpy as np
import scipy.linalg as slg
@njit
def func(matrix):
with objmode(y='float64[:,:]'):
y = slg.expm(matrix)
return y

a = np.asarray([[0.0, 1.0], [1.0, 0.0]])
func(a) == slg.expm(a)

Was this page helpful?
0 / 5 - 0 ratings