source: XMLIO_V2/external/include/blitz/array/eval.cc @ 80

Last change on this file since 80 was 80, checked in by ymipsl, 14 years ago

ajout lib externe

File size: 38.0 KB
Line 
1#ifndef BZ_ARRAYEVAL_CC
2#define BZ_ARRAYEVAL_CC
3
4#ifndef BZ_ARRAY_H
5 #error <blitz/array/eval.cc> must be included via <blitz/array.h>
6#endif
7
8BZ_NAMESPACE(blitz)
9
10/*
11 * Assign an expression to an array.  For performance reasons, there are
12 * several traversal mechanisms:
13 *
14 * - Index traversal scans through the destination array in storage order.
15 *   The expression is evaluated using a TinyVector<int,N> operand.  This
16 *   version is used only when there are index placeholders in the expression
17 *   (see <blitz/indexexpr.h>)
18 * - Stack traversal also scans through the destination array in storage
19 *   order.  However, push/pop stack iterators are used.
20 * - Fast traversal follows a Hilbert (or other) space-filling curve to
21 *   improve cache reuse for stencilling operations.  Currently, the
22 *   space filling curves must be generated by calling
23 *   generateFastTraversalOrder(TinyVector<int,N_dimensions>).
24 * - 2D tiled traversal follows a tiled traversal, to improve cache reuse
25 *   for 2D stencils.  Space filling curves have too much overhead to use
26 *   in two-dimensions.
27 *
28 * _bz_tryFastTraversal is a helper class.  Fast traversals are only
29 * attempted if the expression looks like a stencil -- it's at least
30 * three-dimensional, has at least six array operands, and there are
31 * no index placeholders in the expression.  These are all things which
32 * can be checked at compile time, so the if()/else() syntax has been
33 * replaced with this class template.
34 */
35
36// Fast traversals require <set> from the ISO/ANSI C++ standard library
37#ifdef BZ_HAVE_STD
38#ifdef BZ_ARRAY_SPACE_FILLING_TRAVERSAL
39
40template<bool canTryFastTraversal>
41struct _bz_tryFastTraversal {
42    template<typename T_numtype, int N_rank, typename T_expr, typename T_update>
43    static bool tryFast(Array<T_numtype,N_rank>& array, 
44        BZ_ETPARM(T_expr) expr, T_update)
45    {
46        return false;
47    }
48};
49
50template<>
51struct _bz_tryFastTraversal<true> {
52    template<typename T_numtype, int N_rank, typename T_expr, typename T_update>
53    static bool tryFast(Array<T_numtype,N_rank>& array, 
54        BZ_ETPARM(T_expr) expr, T_update)
55    {
56        // See if there's an appropriate space filling curve available.
57        // Currently fast traversals use an N-1 dimensional curve.  The
58        // Nth dimension column corresponding to each point on the curve
59        // is traversed in the normal fashion.
60        TraversalOrderCollection<N_rank-1> traversals;
61        TinyVector<int, N_rank - 1> traversalGridSize;
62
63        for (int i=0; i < N_rank - 1; ++i)
64            traversalGridSize[i] = array.length(array.ordering(i+1));
65
66#ifdef BZ_DEBUG_TRAVERSE
67cout << "traversalGridSize = " << traversalGridSize << endl;
68cout.flush();
69#endif
70
71        const TraversalOrder<N_rank-1>* order =
72            traversals.find(traversalGridSize);
73
74        if (order)
75        {
76#ifdef BZ_DEBUG_TRAVERSE
77    cerr << "Array<" << BZ_DEBUG_TEMPLATE_AS_STRING_LITERAL(T_numtype)
78         << ", " << N_rank << ">: Using stack traversal" << endl;
79#endif
80            // A curve was available -- use fast traversal.
81            array.evaluateWithFastTraversal(*order, expr, T_update());
82            return true;
83        }
84
85        return false;
86    }
87};
88
89#endif // BZ_ARRAY_SPACE_FILLING_TRAVERSAL
90#endif // BZ_HAVE_STD
91
92template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
93inline Array<T_numtype, N_rank>& 
94Array<T_numtype, N_rank>::evaluate(T_expr expr, 
95    T_update)
96{
97    // Check that all arrays have the same shape
98#ifdef BZ_DEBUG
99    if (!expr.shapeCheck(shape()))
100    {
101      if (assertFailMode == false)
102      {
103        cerr << "[Blitz++] Shape check failed: Module " << __FILE__
104             << " line " << __LINE__ << endl
105             << "          Expression: ";
106        prettyPrintFormat format(true);   // Use terse formatting
107        BZ_STD_SCOPE(string) str;
108        expr.prettyPrint(str, format);
109        cerr << str << endl ;
110      }
111
112#if 0
113// Shape dumping is broken by change to using string for prettyPrint
114             << "          Shapes: " << shape() << " = ";
115        prettyPrintFormat format2;
116        format2.setDumpArrayShapesMode();
117        expr.prettyPrint(cerr, format2);
118        cerr << endl;
119#endif
120        BZ_PRE_FAIL;
121    }
122#endif
123
124    BZPRECHECK(expr.shapeCheck(shape()),
125        "Shape check failed." << endl << "Expression:");
126
127    BZPRECHECK((T_expr::rank == N_rank) || (T_expr::numArrayOperands == 0), 
128        "Assigned rank " << T_expr::rank << " expression to rank " 
129        << N_rank << " array.");
130
131    /*
132     * Check that the arrays are not empty (e.g. length 0 arrays)
133     * This fixes a bug found by Peter Bienstman, 6/16/99, where
134     * Array<double,2> A(0,0),B(0,0); B=A(tensor::j,tensor::i);
135     * went into an infinite loop.
136     */
137
138    if (numElements() == 0)
139        return *this;
140
141#ifdef BZ_DEBUG_TRAVERSE
142    cout << "T_expr::numIndexPlaceholders = " << T_expr::numIndexPlaceholders
143         << endl; 
144    cout.flush();
145#endif
146
147    // Tau profiling code.  Provide Tau with a pretty-printed version of
148    // the expression.
149    // NEEDS_WORK-- use a static initializer somehow.
150
151#ifdef BZ_TAU_PROFILING
152    static BZ_STD_SCOPE(string) exprDescription;
153    if (!exprDescription.length())   // faked static initializer
154    {
155        exprDescription = "A";
156        prettyPrintFormat format(true);   // Terse mode on
157        format.nextArrayOperandSymbol();
158        T_update::prettyPrint(exprDescription);
159        expr.prettyPrint(exprDescription, format);
160    }
161    TAU_PROFILE(" ", exprDescription, TAU_BLITZ);
162#endif
163
164    // Determine which evaluation mechanism to use
165    if (T_expr::numIndexPlaceholders > 0)
166    {
167        // The expression involves index placeholders, so have to
168        // use index traversal rather than stack traversal.
169
170        if (N_rank == 1)
171            return evaluateWithIndexTraversal1(expr, T_update());
172        else
173            return evaluateWithIndexTraversalN(expr, T_update());
174    }
175    else {
176
177        // If this expression looks like an array stencil, then attempt to
178        // use a fast traversal order.
179        // Fast traversals require <set> from the ISO/ANSI C++ standard
180        // library.
181
182#ifdef BZ_HAVE_STD
183#ifdef BZ_ARRAY_SPACE_FILLING_TRAVERSAL
184
185        enum { isStencil = (N_rank >= 3) && (T_expr::numArrayOperands > 6)
186            && (T_expr::numIndexPlaceholders == 0) };
187
188        if (_bz_tryFastTraversal<isStencil>::tryFast(*this, expr, T_update()))
189            return *this;
190
191#endif
192#endif
193
194#ifdef BZ_ARRAY_2D_STENCIL_TILING
195        // Does this look like a 2-dimensional stencil on a largeish
196        // array?
197
198        if ((N_rank == 2) && (T_expr::numArrayOperands >= 5))
199        {
200            // Use a heuristic to determine whether a tiled traversal
201            // is desirable.  First, estimate how much L1 cache is needed
202            // to achieve a high hit rate using the stack traversal.
203            // Try to err on the side of using tiled traversal even when
204            // it isn't strictly needed.
205
206            // Assumptions:
207            //    Stencil width 3
208            //    3 arrays involved in stencil
209            //    Uniform data type in arrays (all T_numtype)
210           
211            int cacheNeeded = 3 * 3 * sizeof(T_numtype) * length(ordering(0));
212            if (cacheNeeded > BZ_L1_CACHE_ESTIMATED_SIZE)
213                return evaluateWithTiled2DTraversal(expr, T_update());
214        }
215
216#endif
217
218        // If fast traversal isn't available or appropriate, then just
219        // do a stack traversal.
220        if (N_rank == 1)
221            return evaluateWithStackTraversal1(expr, T_update());
222        else
223            return evaluateWithStackTraversalN(expr, T_update());
224    }
225}
226
227template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
228inline Array<T_numtype, N_rank>&
229Array<T_numtype, N_rank>::evaluateWithStackTraversal1(
230    T_expr expr, T_update)
231{
232#ifdef BZ_DEBUG_TRAVERSE
233    BZ_DEBUG_MESSAGE("Array<" << BZ_DEBUG_TEMPLATE_AS_STRING_LITERAL(T_numtype)
234         << ", " << N_rank << ">: Using stack traversal");
235#endif
236    FastArrayIterator<T_numtype, N_rank> iter(*this);
237    iter.loadStride(firstRank);
238    expr.loadStride(firstRank);
239
240    bool useUnitStride = iter.isUnitStride(firstRank)
241          && expr.isUnitStride(firstRank);
242
243#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
244    int commonStride = expr.suggestStride(firstRank);
245    if (iter.suggestStride(firstRank) > commonStride)
246        commonStride = iter.suggestStride(firstRank);
247    bool useCommonStride = iter.isStride(firstRank,commonStride)
248        && expr.isStride(firstRank,commonStride);
249
250 #ifdef BZ_DEBUG_TRAVERSE
251    BZ_DEBUG_MESSAGE("BZ_ARRAY_EXPR_USE_COMMON_STRIDE:" << endl
252        << "    commonStride = " << commonStride << " useCommonStride = "
253        << useCommonStride);
254 #endif
255#else
256    int commonStride = 1;
257    bool useCommonStride = false;
258#endif
259
260    const T_numtype * last = iter.data() + length(firstRank) 
261        * stride(firstRank);
262
263    if (useUnitStride || useCommonStride)
264    {
265#ifdef BZ_USE_FAST_READ_ARRAY_EXPR
266
267#ifdef BZ_DEBUG_TRAVERSE
268    BZ_DEBUG_MESSAGE("BZ_USE_FAST_READ_ARRAY_EXPR with commonStride");
269#endif
270        int ubound = length(firstRank) * commonStride;
271        T_numtype* restrict data = const_cast<T_numtype*>(iter.data());
272
273        if (commonStride == 1)
274        {
275 #ifndef BZ_ARRAY_STACK_TRAVERSAL_UNROLL
276            for (int i=0; i < ubound; ++i)
277                T_update::update(*data++, expr.fastRead(i));
278 #else
279            int n1 = ubound & 3;
280            int i = 0;
281            for (; i < n1; ++i)
282                T_update::update(*data++, expr.fastRead(i));
283           
284            for (; i < ubound; i += 4)
285            {
286#ifndef BZ_ARRAY_STACK_TRAVERSAL_CSE_AND_ANTIALIAS
287                T_update::update(*data++, expr.fastRead(i));
288                T_update::update(*data++, expr.fastRead(i+1));
289                T_update::update(*data++, expr.fastRead(i+2));
290                T_update::update(*data++, expr.fastRead(i+3));
291#else
292                const int t1 = i+1;
293                const int t2 = i+2;
294                const int t3 = i+3;
295
296                _bz_typename T_expr::T_numtype tmp1, tmp2, tmp3, tmp4;
297
298                tmp1 = expr.fastRead(i);
299                tmp2 = expr.fastRead(BZ_NO_PROPAGATE(t1));
300                tmp3 = expr.fastRead(BZ_NO_PROPAGATE(t2));
301                tmp4 = expr.fastRead(BZ_NO_PROPAGATE(t3));
302
303                T_update::update(*data++, tmp1);
304                T_update::update(*data++, tmp2);
305                T_update::update(*data++, tmp3);
306                T_update::update(*data++, tmp4);
307#endif
308            }
309 #endif // BZ_ARRAY_STACK_TRAVERSAL_UNROLL
310
311        }
312 #ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
313        else {
314
315  #ifndef BZ_ARRAY_STACK_TRAVERSAL_UNROLL
316            for (int i=0; i != ubound; i += commonStride)
317                T_update::update(data[i], expr.fastRead(i));
318  #else
319            int n1 = (length(firstRank) & 3) * commonStride;
320
321            int i = 0;
322            for (; i != n1; i += commonStride)
323                T_update::update(data[i], expr.fastRead(i));
324
325            int strideInc = 4 * commonStride;
326            for (; i != ubound; i += strideInc)
327            {
328                T_update::update(data[i], expr.fastRead(i));
329                int i2 = i + commonStride;
330                T_update::update(data[i2], expr.fastRead(i2));
331                int i3 = i + 2 * commonStride;
332                T_update::update(data[i3], expr.fastRead(i3));
333                int i4 = i + 3 * commonStride;
334                T_update::update(data[i4], expr.fastRead(i4));
335            }
336  #endif  // BZ_ARRAY_STACK_TRAVERSAL_UNROLL
337        }
338 #endif  // BZ_ARRAY_EXPR_USE_COMMON_STRIDE
339
340#else   // ! BZ_USE_FAST_READ_ARRAY_EXPR
341
342#ifdef BZ_DEBUG_TRAVERSE
343    BZ_DEBUG_MESSAGE("Common stride, no fast read");
344#endif
345        while (iter.data() != last)
346        {
347            T_update::update(*const_cast<T_numtype*>(iter.data()), *expr);
348            iter.advance(commonStride);
349            expr.advance(commonStride);
350        }
351#endif
352    }
353    else {
354        while (iter.data() != last)
355        {
356            T_update::update(*const_cast<T_numtype*>(iter.data()), *expr);
357            iter.advance();
358            expr.advance();
359        }
360    }
361
362    return *this;
363}
364
365template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
366inline Array<T_numtype, N_rank>&
367Array<T_numtype, N_rank>::evaluateWithStackTraversalN(
368    T_expr expr, T_update)
369{
370    /*
371     * A stack traversal replaces the usual nested loops:
372     *
373     * for (int i=A.lbound(firstDim); i <= A.ubound(firstDim); ++i)
374     *   for (int j=A.lbound(secondDim); j <= A.ubound(secondDim); ++j)
375     *     for (int k=A.lbound(thirdDim); k <= A.ubound(thirdDim); ++k)
376     *       A(i,j,k) = 0;
377     *
378     * with a stack data structure.  The stack allows this single
379     * routine to replace any number of nested loops.
380     *
381     * For each dimension (loop), these quantities are needed:
382     * - a pointer to the first element encountered in the loop
383     * - the stride associated with the dimension/loop
384     * - a pointer to the last element encountered in the loop
385     *
386     * The basic idea is that entering each loop is a "push" onto the
387     * stack, and exiting each loop is a "pop".  In practice, this
388     * routine treats accesses the stack in a random-access way,
389     * which confuses the picture a bit.  But conceptually, that's
390     * what is going on.
391     */
392
393    /*
394     * ordering(0) gives the dimension associated with the smallest
395     * stride (usually; the exceptions have to do with subarrays and
396     * are uninteresting).  We call this dimension maxRank; it will
397     * become the innermost "loop".
398     *
399     * Ordering the loops from ordering(N_rank-1) down to
400     * ordering(0) ensures that the largest stride is associated
401     * with the outermost loop, and the smallest stride with the
402     * innermost.  This is critical for good performance on
403     * cached machines.
404     */
405
406    const int maxRank = ordering(0);
407    // const int secondLastRank = ordering(1);
408
409    // Create an iterator for the array receiving the result
410    FastArrayIterator<T_numtype, N_rank> iter(*this);
411
412    // Set the initial stack configuration by pushing the pointer
413    // to the first element of the array onto the stack N times.
414
415    int i;
416    for (i=1; i < N_rank; ++i)
417    {
418        iter.push(i);
419        expr.push(i);
420    }
421
422    // Load the strides associated with the innermost loop.
423    iter.loadStride(maxRank);
424    expr.loadStride(maxRank);
425
426    /*
427     * Is the stride in the innermost loop equal to 1?  If so,
428     * we might take advantage of this and generate more
429     * efficient code.
430     */
431    bool useUnitStride = iter.isUnitStride(maxRank)
432                          && expr.isUnitStride(maxRank);
433
434    /*
435     * Do all array operands share a common stride in the innermost
436     * loop?  If so, we can generate more efficient code (but only
437     * if this optimization has been enabled).
438     */
439#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
440    int commonStride = expr.suggestStride(maxRank);
441    if (iter.suggestStride(maxRank) > commonStride)
442        commonStride = iter.suggestStride(maxRank);
443    bool useCommonStride = iter.isStride(maxRank,commonStride)
444        && expr.isStride(maxRank,commonStride);
445
446#ifdef BZ_DEBUG_TRAVERSE
447    BZ_DEBUG_MESSAGE("BZ_ARRAY_EXPR_USE_COMMON_STRIDE" << endl
448        << "commonStride = " << commonStride << " useCommonStride = "
449        << useCommonStride);
450#endif
451
452#else
453    int commonStride = 1;
454    bool useCommonStride = false;
455#endif
456
457    /*
458     * The "last" array contains a pointer to the last element
459     * encountered in each "loop".
460     */
461    const T_numtype* last[N_rank];
462
463    // Set up the initial state of the "last" array
464    for (i=1; i < N_rank; ++i)
465        last[i] = iter.data() + length(ordering(i)) * stride(ordering(i));
466
467    int lastLength = length(maxRank);
468    int firstNoncollapsedLoop = 1;
469
470#ifdef BZ_COLLAPSE_LOOPS
471
472    /*
473     * This bit of code handles collapsing loops.  When possible,
474     * the N nested loops are converted into a single loop (basically,
475     * the N-dimensional array is treated as a long vector).
476     * This is important for cases where the length of the innermost
477     * loop is very small, for example a 100x100x3 array.
478     * If this code can't collapse all the loops into a single loop,
479     * it will collapse as many loops as possible starting from the
480     * innermost and working out.
481     */
482
483    // Collapse loops when possible
484    for (i=1; i < N_rank; ++i)
485    {
486        // Figure out which pair of loops we are considering combining.
487        int outerLoopRank = ordering(i);
488        int innerLoopRank = ordering(i-1);
489
490        /*
491         * The canCollapse() routines look at the strides and extents
492         * of the loops, and determine if they can be combined into
493         * one loop.
494         */
495
496        if (canCollapse(outerLoopRank,innerLoopRank) 
497          && expr.canCollapse(outerLoopRank,innerLoopRank))
498        {
499#ifdef BZ_DEBUG_TRAVERSE
500            cout << "Collapsing " << outerLoopRank << " and " 
501                 << innerLoopRank << endl;
502#endif
503            lastLength *= length(outerLoopRank);
504            firstNoncollapsedLoop = i+1;
505        }
506        else 
507            break;
508    }
509
510#endif // BZ_COLLAPSE_LOOPS
511
512    /*
513     * Now we actually perform the loops.  This while loop contains
514     * two parts: first, the innermost loop is performed.  Then we
515     * exit the loop, and pop our way down the stack until we find
516     * a loop that isn't completed.  We then restart the inner loops
517     * and push them onto the stack.
518     */
519
520    while (true) {
521
522        /*
523         * This bit of code handles the innermost loop.  It would look
524         * a lot simpler if it weren't for unit stride and common stride
525         * optimizations; these clutter up the code with multiple versions.
526         */
527
528        if ((useUnitStride) || (useCommonStride))
529        {
530#ifdef BZ_USE_FAST_READ_ARRAY_EXPR
531
532            /*
533             * The check for BZ_USE_FAST_READ_ARRAY_EXPR can probably
534             * be taken out.  This was put in place while the unit stride/
535             * common stride optimizations were being implemented and
536             * tested.
537             */
538
539            // Calculate the end of the innermost loop
540            int ubound = lastLength * commonStride;
541
542            /*
543             * This is a real kludge.  I didn't want to have to write
544             * a const and non-const version of FastArrayIterator, so I use a
545             * const iterator and cast away const.  This could
546             * probably be avoided with some trick, but the whole routine
547             * is ugly, so why bother.
548             */
549
550            T_numtype* restrict data = const_cast<T_numtype*>(iter.data());
551
552            /*
553             * BZ_NEEDS_WORK-- need to implement optional unrolling.
554             */
555            if (commonStride == 1)
556            {
557                for (int i=0; i < ubound; ++i)
558                    T_update::update(*data++, expr.fastRead(i));
559            }
560#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
561            else {
562                for (int i=0; i != ubound; i += commonStride)
563                    T_update::update(data[i], expr.fastRead(i));
564            }
565#endif
566            /*
567             * Tidy up for the fact that we haven't actually been
568             * incrementing the iterators in the innermost loop, by
569             * faking it afterward.
570             */
571            iter.advance(lastLength * commonStride);
572            expr.advance(lastLength * commonStride);
573#else       
574            // !BZ_USE_FAST_READ_ARRAY_EXPR
575            // This bit of code not really needed; should remove at some
576            // point, along with the test for BZ_USE_FAST_READ_ARRAY_EXPR
577
578            T_numtype * restrict end = const_cast<T_numtype*>(iter.data()) 
579                + lastLength;
580
581            while (iter.data() != end) 
582            {
583                T_update::update(*const_cast<T_numtype*>(iter.data()), *expr);
584                iter.advance(commonStride);
585                expr.advance(commonStride);
586            }
587#endif
588        }
589        else {
590            /*
591             * We don't have a unit stride or common stride in the innermost
592             * loop.  This is going to hurt performance.  Luckily 95% of
593             * the time, we hit the cases above.
594             */
595            T_numtype * restrict end = const_cast<T_numtype*>(iter.data())
596                + lastLength * stride(maxRank);
597
598            while (iter.data() != end)
599            {
600                T_update::update(*const_cast<T_numtype*>(iter.data()), *expr);
601                iter.advance();
602                expr.advance();
603            }
604        }
605
606
607        /*
608         * We just finished the innermost loop.  Now we pop our way down
609         * the stack, until we hit a loop that hasn't completed yet.
610         */ 
611        int j = firstNoncollapsedLoop;
612        for (; j < N_rank; ++j)
613        {
614            // Get the next loop
615            int r = ordering(j);
616
617            // Pop-- this restores the data pointers to the first element
618            // encountered in the loop.
619            iter.pop(j);
620            expr.pop(j);
621
622            // Load the stride associated with this loop, and increment
623            // once.
624            iter.loadStride(r);
625            expr.loadStride(r);
626            iter.advance();
627            expr.advance();
628
629            // If we aren't at the end of this loop, then stop popping.
630            if (iter.data() != last[j])
631                break;
632        }
633
634        // Are we completely done?
635        if (j == N_rank)
636            break;
637
638        // No, so push all the inner loops back onto the stack.
639        for (; j >= firstNoncollapsedLoop; --j)
640        {
641            int r2 = ordering(j-1);
642            iter.push(j);
643            expr.push(j);
644            last[j-1] = iter.data() + length(r2) * stride(r2);
645        }
646
647        // Load the stride for the innermost loop again.
648        iter.loadStride(maxRank);
649        expr.loadStride(maxRank);
650    }
651
652    return *this;
653}
654
655template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
656inline Array<T_numtype, N_rank>&
657Array<T_numtype, N_rank>::evaluateWithIndexTraversal1(
658    T_expr expr, T_update)
659{
660    TinyVector<int,N_rank> index;
661
662    if (stride(firstRank) == 1)
663    {
664        T_numtype * restrict iter = data_ + lbound(firstRank);
665        int last = ubound(firstRank);
666
667        for (index[0] = lbound(firstRank); index[0] <= last;
668            ++index[0])
669        {
670            T_update::update(*iter++, expr(index));
671        }
672    }
673    else {
674        FastArrayIterator<T_numtype, N_rank> iter(*this);
675        iter.loadStride(0);
676        int last = ubound(firstRank);
677
678        for (index[0] = lbound(firstRank); index[0] <= last;
679            ++index[0])
680        {
681            T_update::update(*const_cast<T_numtype*>(iter.data()), 
682                expr(index));
683            iter.advance();
684        }
685    }
686
687    return *this;
688}
689
690template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
691inline Array<T_numtype, N_rank>&
692Array<T_numtype, N_rank>::evaluateWithIndexTraversalN(
693    T_expr expr, T_update)
694{
695    // Do a stack-type traversal for the destination array and use
696    // index traversal for the source expression
697   
698    const int maxRank = ordering(0);
699
700#ifdef BZ_DEBUG_TRAVERSE
701    const int secondLastRank = ordering(1);
702    cout << "Index traversal: N_rank = " << N_rank << endl;
703    cout << "maxRank = " << maxRank << " secondLastRank = " << secondLastRank
704         << endl;
705    cout.flush();
706#endif
707
708    FastArrayIterator<T_numtype, N_rank> iter(*this);
709    for (int i=1; i < N_rank; ++i)
710        iter.push(ordering(i));
711
712    iter.loadStride(maxRank);
713
714    TinyVector<int,N_rank> index, last;
715
716    index = storage_.base();
717
718    for (int i=0; i < N_rank; ++i)
719      last(i) = storage_.base(i) + length_(i);
720
721    // int lastLength = length(maxRank);
722
723    while (true) {
724
725        for (index[maxRank] = base(maxRank); 
726             index[maxRank] < last[maxRank]; 
727             ++index[maxRank])
728        {
729#ifdef BZ_DEBUG_TRAVERSE
730#if 0
731    cout << "(" << index[0] << "," << index[1] << ") " << endl;
732    cout.flush();
733#endif
734#endif
735
736            T_update::update(*const_cast<T_numtype*>(iter.data()), expr(index));
737            iter.advance();
738        }
739
740        int j = 1;
741        for (; j < N_rank; ++j)
742        {
743            iter.pop(ordering(j));
744            iter.loadStride(ordering(j));
745            iter.advance();
746
747            index[ordering(j-1)] = base(ordering(j-1));
748            ++index[ordering(j)];
749            if (index[ordering(j)] != last[ordering(j)])
750                break;
751        }
752
753        if (j == N_rank)
754            break;
755
756        for (; j > 0; --j)
757        {
758            iter.push(ordering(j));
759        }
760        iter.loadStride(maxRank);
761    }
762
763    return *this; 
764}
765
766// Fast traversals require <set> from the ISO/ANSI C++ standard library
767
768#ifdef BZ_HAVE_STD
769#ifdef BZ_ARRAY_SPACE_FILLING_TRAVERSAL
770
771template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
772inline Array<T_numtype, N_rank>&
773Array<T_numtype, N_rank>::evaluateWithFastTraversal(
774    const TraversalOrder<N_rank - 1>& order, 
775    T_expr expr,
776    T_update)
777{
778    const int maxRank = ordering(0);
779
780#ifdef BZ_DEBUG_TRAVERSE
781    const int secondLastRank = ordering(1);
782    cerr << "maxRank = " << maxRank << " secondLastRank = " << secondLastRank
783         << endl;
784#endif
785
786    FastArrayIterator<T_numtype, N_rank> iter(*this);
787    iter.push(0);
788    expr.push(0);
789
790    bool useUnitStride = iter.isUnitStride(maxRank) 
791                          && expr.isUnitStride(maxRank);
792
793#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
794    int commonStride = expr.suggestStride(maxRank);
795    if (iter.suggestStride(maxRank) > commonStride)
796        commonStride = iter.suggestStride(maxRank);
797    bool useCommonStride = iter.isStride(maxRank,commonStride)
798        && expr.isStride(maxRank,commonStride);
799#else
800    int commonStride = 1;
801    bool useCommonStride = false;
802#endif
803
804    int lastLength = length(maxRank);
805
806    for (int i=0; i < order.length(); ++i)
807    {
808        iter.pop(0);
809        expr.pop(0);
810
811#ifdef BZ_DEBUG_TRAVERSE
812    cerr << "Traversing: " << order[i] << endl;
813#endif
814        // Position the iterator at the start of the next column       
815        for (int j=1; j < N_rank; ++j)
816        {
817            iter.loadStride(ordering(j));
818            expr.loadStride(ordering(j));
819
820            int offset = order[i][j-1];
821            iter.advance(offset);
822            expr.advance(offset);
823        }
824
825        iter.loadStride(maxRank);
826        expr.loadStride(maxRank);
827
828        // Evaluate the expression along the column
829
830        if ((useUnitStride) || (useCommonStride))
831        {
832#ifdef BZ_USE_FAST_READ_ARRAY_EXPR
833            int ubound = lastLength * commonStride;
834            T_numtype* restrict data = const_cast<T_numtype*>(iter.data());
835
836            if (commonStride == 1)
837            {           
838 #ifndef BZ_ARRAY_FAST_TRAVERSAL_UNROLL
839                for (int i=0; i < ubound; ++i)
840                    T_update::update(*data++, expr.fastRead(i));
841 #else
842                int n1 = ubound & 3;
843                int i=0;
844                for (; i < n1; ++i)
845                    T_update::update(*data++, expr.fastRead(i));
846
847                for (; i < ubound; i += 4)
848                {
849                    T_update::update(*data++, expr.fastRead(i));
850                    T_update::update(*data++, expr.fastRead(i+1));
851                    T_update::update(*data++, expr.fastRead(i+2));
852                    T_update::update(*data++, expr.fastRead(i+3));
853                }
854 #endif  // BZ_ARRAY_FAST_TRAVERSAL_UNROLL
855            }
856 #ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
857            else {
858                for (int i=0; i < ubound; i += commonStride)
859                    T_update::update(data[i], expr.fastRead(i));
860            }
861 #endif // BZ_ARRAY_EXPR_USE_COMMON_STRIDE
862
863            iter.advance(lastLength * commonStride);
864            expr.advance(lastLength * commonStride);
865#else   // ! BZ_USE_FAST_READ_ARRAY_EXPR
866            T_numtype* restrict last = const_cast<T_numtype*>(iter.data()) 
867                + lastLength * commonStride;
868
869            while (iter.data() != last)
870            {
871                T_update::update(*const_cast<T_numtype*>(iter.data()), *expr);
872                iter.advance(commonStride);
873                expr.advance(commonStride);
874            }
875#endif  // BZ_USE_FAST_READ_ARRAY_EXPR
876
877        }
878        else {
879            // No common stride
880
881            T_numtype* restrict last = const_cast<T_numtype*>(iter.data()) 
882                + lastLength * stride(maxRank);
883
884            while (iter.data() != last)
885            {
886                T_update::update(*const_cast<T_numtype*>(iter.data()), *expr);
887                iter.advance();
888                expr.advance();
889            }
890        }
891    }
892
893    return *this;
894}
895
896#endif // BZ_ARRAY_SPACE_FILLING_TRAVERSAL
897#endif // BZ_HAVE_STD
898
899#ifdef BZ_ARRAY_2D_NEW_STENCIL_TILING
900
901#ifdef BZ_ARRAY_2D_STENCIL_TILING
902
903template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
904inline Array<T_numtype, N_rank>& 
905Array<T_numtype, N_rank>::evaluateWithTiled2DTraversal(
906    T_expr expr, T_update)
907{
908    const int minorRank = ordering(0);
909    const int majorRank = ordering(1);
910
911    FastArrayIterator<T_numtype, N_rank> iter(*this);
912    iter.push(0);
913    expr.push(0);
914
915#ifdef BZ_2D_STENCIL_DEBUG
916    int count = 0;
917#endif
918
919    bool useUnitStride = iter.isUnitStride(minorRank)
920                          && expr.isUnitStride(minorRank);
921
922#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
923    int commonStride = expr.suggestStride(minorRank);
924    if (iter.suggestStride(minorRank) > commonStride)
925        commonStride = iter.suggestStride(minorRank);
926    bool useCommonStride = iter.isStride(minorRank,commonStride)
927        && expr.isStride(minorRank,commonStride);
928#else
929    int commonStride = 1;
930    bool useCommonStride = false;
931#endif
932
933    // Determine if a common major stride exists
934    int commonMajorStride = expr.suggestStride(majorRank);
935    if (iter.suggestStride(majorRank) > commonMajorStride)
936        commonMajorStride = iter.suggestStride(majorRank);
937    bool haveCommonMajorStride = iter.isStride(majorRank,commonMajorStride)
938        && expr.isStride(majorRank,commonMajorStride);
939
940
941    int maxi = length(majorRank);
942    int maxj = length(minorRank);
943
944    const int tileHeight = 16, tileWidth = 3;
945
946    int bi, bj;
947    for (bi=0; bi < maxi; bi += tileHeight)
948    {
949        int ni = bi + tileHeight;
950        if (ni > maxi)
951            ni = maxi;
952
953        // Move back to the beginning of the array
954        iter.pop(0);
955        expr.pop(0);
956
957        // Move to the start of this tile row
958        iter.loadStride(majorRank);
959        iter.advance(bi);
960        expr.loadStride(majorRank);
961        expr.advance(bi);
962
963        // Save this position
964        iter.push(1);
965        expr.push(1);
966
967        for (bj=0; bj < maxj; bj += tileWidth)
968        {
969            // Move to the beginning of the tile row
970            iter.pop(1);
971            expr.pop(1);
972
973            // Move to the top of the current tile (bi,bj)
974            iter.loadStride(minorRank);
975            iter.advance(bj);
976            expr.loadStride(minorRank);
977            expr.advance(bj);
978
979            if (bj + tileWidth <= maxj)
980            {
981                // Strip mining
982
983                if ((useUnitStride) && (haveCommonMajorStride))
984                {
985                    int offset = 0;
986                    T_numtype* restrict data = const_cast<T_numtype*>
987                        (iter.data());
988
989                    for (int i=bi; i < ni; ++i)
990                    {
991                        _bz_typename T_expr::T_numtype tmp1, tmp2, tmp3;
992
993                        // Common subexpression elimination -- compilers
994                        // won't necessarily do this on their own.
995                        int t1 = offset+1;
996                        int t2 = offset+2;
997
998                        tmp1 = expr.fastRead(offset);
999                        tmp2 = expr.fastRead(t1);
1000                        tmp3 = expr.fastRead(t2);
1001
1002                        T_update::update(data[0], tmp1);
1003                        T_update::update(data[1], tmp2);
1004                        T_update::update(data[2], tmp3);
1005
1006                        offset += commonMajorStride;
1007                        data += commonMajorStride;
1008
1009#ifdef BZ_2D_STENCIL_DEBUG
1010    count += 3;
1011#endif
1012                    }
1013                }
1014                else {
1015
1016                    for (int i=bi; i < ni; ++i)
1017                    {
1018                        iter.loadStride(minorRank);
1019                        expr.loadStride(minorRank);
1020
1021                        // Loop through current row elements
1022                        T_update::update(*const_cast<T_numtype*>(iter.data()),
1023                            *expr);
1024                        iter.advance();
1025                        expr.advance();
1026
1027                        T_update::update(*const_cast<T_numtype*>(iter.data()),
1028                            *expr);
1029                        iter.advance();
1030                        expr.advance();
1031
1032                        T_update::update(*const_cast<T_numtype*>(iter.data()),
1033                            *expr);
1034                        iter.advance(-2);
1035                        expr.advance(-2);
1036
1037                        iter.loadStride(majorRank);
1038                        expr.loadStride(majorRank);
1039                        iter.advance();
1040                        expr.advance();
1041
1042#ifdef BZ_2D_STENCIL_DEBUG
1043    count += 3;
1044#endif
1045
1046                    }
1047                }
1048            }
1049            else {
1050
1051                // This code handles partial tiles at the bottom of the
1052                // array.
1053
1054                for (int j=bj; j < maxj; ++j)
1055                {
1056                    iter.loadStride(majorRank);
1057                    expr.loadStride(majorRank);
1058
1059                    for (int i=bi; i < ni; ++i)
1060                    {
1061                        T_update::update(*const_cast<T_numtype*>(iter.data()),
1062                            *expr);
1063                        iter.advance();
1064                        expr.advance();
1065#ifdef BZ_2D_STENCIL_DEBUG
1066    ++count;
1067#endif
1068
1069                    }
1070
1071                    // Move back to the top of this column
1072                    iter.advance(bi-ni);
1073                    expr.advance(bi-ni);
1074
1075                    // Move over to the next column
1076                    iter.loadStride(minorRank);
1077                    expr.loadStride(minorRank);
1078
1079                    iter.advance();
1080                    expr.advance();
1081                }
1082            }
1083        }
1084    }
1085
1086#ifdef BZ_2D_STENCIL_DEBUG
1087    cout << "BZ_2D_STENCIL_DEBUG: count = " << count << endl;
1088#endif
1089
1090    return *this;
1091}
1092
1093#endif // BZ_ARRAY_2D_STENCIL_TILING
1094#endif // BZ_ARRAY_2D_NEW_STENCIL_TILING
1095
1096
1097
1098#ifndef BZ_ARRAY_2D_NEW_STENCIL_TILING
1099
1100#ifdef BZ_ARRAY_2D_STENCIL_TILING
1101
1102template<typename T_numtype, int N_rank> template<typename T_expr, typename T_update>
1103inline Array<T_numtype, N_rank>& 
1104Array<T_numtype, N_rank>::evaluateWithTiled2DTraversal(
1105    T_expr expr, T_update)
1106{
1107    const int minorRank = ordering(0);
1108    const int majorRank = ordering(1);
1109
1110    const int blockSize = 16;
1111   
1112    FastArrayIterator<T_numtype, N_rank> iter(*this);
1113    iter.push(0);
1114    expr.push(0);
1115
1116    bool useUnitStride = iter.isUnitStride(minorRank)
1117                          && expr.isUnitStride(minorRank);
1118
1119#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
1120    int commonStride = expr.suggestStride(minorRank);
1121    if (iter.suggestStride(minorRank) > commonStride)
1122        commonStride = iter.suggestStride(minorRank);
1123    bool useCommonStride = iter.isStride(minorRank,commonStride)
1124        && expr.isStride(minorRank,commonStride);
1125#else
1126    int commonStride = 1;
1127    bool useCommonStride = false;
1128#endif
1129
1130    int maxi = length(majorRank);
1131    int maxj = length(minorRank);
1132
1133    int bi, bj;
1134    for (bi=0; bi < maxi; bi += blockSize)
1135    {
1136        int ni = bi + blockSize;
1137        if (ni > maxi)
1138            ni = maxi;
1139
1140        for (bj=0; bj < maxj; bj += blockSize)
1141        {
1142            int nj = bj + blockSize;
1143            if (nj > maxj)
1144                nj = maxj;
1145
1146            // Move to the beginning of the array
1147            iter.pop(0);
1148            expr.pop(0);
1149
1150            // Move to the beginning of the tile (bi,bj)
1151            iter.loadStride(majorRank);
1152            iter.advance(bi);
1153            iter.loadStride(minorRank);
1154            iter.advance(bj);
1155
1156            expr.loadStride(majorRank);
1157            expr.advance(bi);
1158            expr.loadStride(minorRank);
1159            expr.advance(bj);
1160
1161            // Loop through tile rows
1162            for (int i=bi; i < ni; ++i)
1163            {
1164                // Save the beginning of this tile row
1165                iter.push(1);
1166                expr.push(1);
1167
1168                // Load the minor stride
1169                iter.loadStride(minorRank);
1170                expr.loadStride(minorRank);
1171
1172                if (useUnitStride)
1173                {
1174                    T_numtype* restrict data = const_cast<T_numtype*>
1175                        (iter.data());
1176
1177                    int ubound = (nj-bj);
1178                    for (int j=0; j < ubound; ++j)
1179                        T_update::update(*data++, expr.fastRead(j));
1180                }
1181#ifdef BZ_ARRAY_EXPR_USE_COMMON_STRIDE
1182                else if (useCommonStride)
1183                {
1184                    int ubound = (nj-bj) * commonStride;
1185                    T_numtype* restrict data = const_cast<T_numtype*>
1186                        (iter.data());
1187
1188                    for (int j=0; j < ubound; j += commonStride)
1189                        T_update::update(data[j], expr.fastRead(j));
1190                }
1191#endif
1192                else {
1193                    for (int j=bj; j < nj; ++j)
1194                    {
1195                        // Loop through current row elements
1196                        T_update::update(*const_cast<T_numtype*>(iter.data()), 
1197                            *expr);
1198                        iter.advance();
1199                        expr.advance();
1200                    }
1201                }
1202
1203                // Move back to the beginning of the tile row, then
1204                // move to the next row
1205                iter.pop(1);
1206                iter.loadStride(majorRank);
1207                iter.advance(1);
1208
1209                expr.pop(1);
1210                expr.loadStride(majorRank);
1211                expr.advance(1);
1212            }
1213        }
1214    }
1215
1216    return *this;
1217}
1218#endif // BZ_ARRAY_2D_STENCIL_TILING
1219#endif // BZ_ARRAY_2D_NEW_STENCIL_TILING
1220
1221BZ_NAMESPACE_END
1222
1223#endif // BZ_ARRAYEVAL_CC
1224
Note: See TracBrowser for help on using the repository browser.