source: codes/icosagcm/devel/Python/dev/numba.py @ 985

Last change on this file since 985 was 932, checked in by dubos, 5 years ago

devel/Python : automatic inference of Numba spec - signature is now obsolete

File size: 2.3 KB
Line 
1from __future__ import absolute_import     
2from __future__ import print_function
3
4import numpy as np
5import numba
6from numba import int32, int64, float64
7
8class NumbaData(object):
9    """A base class to extract data from derived class instances and use it as argument to @jit functions. Derived classes must set the 'signature' attribute."""
10    def data(self):
11        """Returns a jitclass instance containing attributes copied from self, using self.signature which is of the form type,names,type,names ... where names is a string 'attr1 attr2 attr3' containing space-separated names of attributes of self. Those attributes are declared to numba with the type preceding them. The result of data() can be used as argument to a @jit function."""
12        cls = self.__class__.__name__
13        spec = []
14        for name in dir(self):
15            attr = getattr(self,name)
16            if not name.startswith('__') and not callable(attr):
17                tp = None
18                if isinstance(attr, int): tp=int64
19                if isinstance(attr, float): tp=float64
20                if isinstance(attr, np.ndarray):
21                    dtype=attr.dtype
22                    if dtype == np.int32 : dtype=int32
23                    elif dtype == np.float64 : dtype=float64
24                    else: dtype=None
25                    if dtype is None: print('Unknown dtype ', attr.dtype)
26                       
27                    if   len(attr.shape)==1 : tp = dtype[:]
28                    elif len(attr.shape)==2 : tp = dtype[:,:]
29                    elif len(attr.shape)==3 : tp = dtype[:,:,:]
30                    else: print('%s.%s is a numpy array with unsupported rank >3'%(cls,name))
31                       
32                if tp is None:
33                    print('Type of attribute %s.%s is not recognized'%(cls,name), type(attr))
34                else:
35                    spec.append( (name, tp) )
36       
37        @numba.jitclass(spec)
38        class JitClass(object):
39            def __init__(self): pass
40
41        data=JitClass()
42        for name,thetype in spec: 
43            print( 'Making %s.%s available @jit functions : '%(self.__class__.__name__, name), type(getattr(self,name)))
44            setattr(data, name, getattr(self,name))
45        return data
46
47jit=numba.jit(nopython=True, nogil=True, error_model='numpy', fastmath=True)
Note: See TracBrowser for help on using the repository browser.