1 | from __future__ import absolute_import |
---|
2 | from __future__ import print_function |
---|
3 | |
---|
4 | import numpy as np |
---|
5 | import numba |
---|
6 | from numba import int32, int64, float64 |
---|
7 | |
---|
8 | class NumbaData(object): |
---|
9 | """A base class to extract data from derived class instances and use it as argument to @jit functions. Derived classes must set the 'signature' attribute.""" |
---|
10 | def data(self): |
---|
11 | """Returns a jitclass instance containing attributes copied from self, using self.signature which is of the form type,names,type,names ... where names is a string 'attr1 attr2 attr3' containing space-separated names of attributes of self. Those attributes are declared to numba with the type preceding them. The result of data() can be used as argument to a @jit function.""" |
---|
12 | cls = self.__class__.__name__ |
---|
13 | spec = [] |
---|
14 | for name in dir(self): |
---|
15 | attr = getattr(self,name) |
---|
16 | if not name.startswith('__') and not callable(attr): |
---|
17 | tp = None |
---|
18 | if isinstance(attr, int): tp=int64 |
---|
19 | if isinstance(attr, float): tp=float64 |
---|
20 | if isinstance(attr, np.ndarray): |
---|
21 | dtype=attr.dtype |
---|
22 | if dtype == np.int32 : dtype=int32 |
---|
23 | elif dtype == np.float64 : dtype=float64 |
---|
24 | else: dtype=None |
---|
25 | if dtype is None: print('Unknown dtype ', attr.dtype) |
---|
26 | |
---|
27 | if len(attr.shape)==1 : tp = dtype[:] |
---|
28 | elif len(attr.shape)==2 : tp = dtype[:,:] |
---|
29 | elif len(attr.shape)==3 : tp = dtype[:,:,:] |
---|
30 | else: print('%s.%s is a numpy array with unsupported rank >3'%(cls,name)) |
---|
31 | |
---|
32 | if tp is None: |
---|
33 | print('Type of attribute %s.%s is not recognized'%(cls,name), type(attr)) |
---|
34 | else: |
---|
35 | spec.append( (name, tp) ) |
---|
36 | |
---|
37 | @numba.jitclass(spec) |
---|
38 | class JitClass(object): |
---|
39 | def __init__(self): pass |
---|
40 | |
---|
41 | data=JitClass() |
---|
42 | for name,thetype in spec: |
---|
43 | print( 'Making %s.%s available @jit functions : '%(self.__class__.__name__, name), type(getattr(self,name))) |
---|
44 | setattr(data, name, getattr(self,name)) |
---|
45 | return data |
---|
46 | |
---|
47 | jit=numba.jit(nopython=True, nogil=True, error_model='numpy', fastmath=True) |
---|