from __future__ import absolute_import from __future__ import print_function import numpy as np import numba from numba import int32, int64, float64 class NumbaData(object): """A base class to extract data from derived class instances and use it as argument to @jit functions. Derived classes must set the 'signature' attribute.""" def data(self): """Returns a jitclass instance containing attributes copied from self, using self.signature which is of the form type,names,type,names ... where names is a string 'attr1 attr2 attr3' containing space-separated names of attributes of self. Those attributes are declared to numba with the type preceding them. The result of data() can be used as argument to a @jit function.""" cls = self.__class__.__name__ spec = [] for name in dir(self): attr = getattr(self,name) if not name.startswith('__') and not callable(attr): tp = None if isinstance(attr, int): tp=int64 if isinstance(attr, float): tp=float64 if isinstance(attr, np.ndarray): dtype=attr.dtype if dtype == np.int32 : dtype=int32 elif dtype == np.float64 : dtype=float64 else: dtype=None if dtype is None: print('Unknown dtype ', attr.dtype) if len(attr.shape)==1 : tp = dtype[:] elif len(attr.shape)==2 : tp = dtype[:,:] elif len(attr.shape)==3 : tp = dtype[:,:,:] else: print('%s.%s is a numpy array with unsupported rank >3'%(cls,name)) if tp is None: print('Type of attribute %s.%s is not recognized'%(cls,name), type(attr)) else: spec.append( (name, tp) ) @numba.jitclass(spec) class JitClass(object): def __init__(self): pass data=JitClass() for name,thetype in spec: print( 'Making %s.%s available @jit functions : '%(self.__class__.__name__, name), type(getattr(self,name))) setattr(data, name, getattr(self,name)) return data jit=numba.jit(nopython=True, nogil=True, error_model='numpy', fastmath=True)