source: XMLIO_V2/external/include/blitz/array/cgsolve.h @ 80

Last change on this file since 80 was 80, checked in by ymipsl, 14 years ago

ajout lib externe

  • Property svn:eol-style set to native
File size: 4.0 KB
Line 
1/***************************************************************************
2 * blitz/array/cgsolve.h  Basic conjugate gradient solver for linear systems
3 *
4 * Copyright (C) 1997-2001 Todd Veldhuizen <tveldhui@oonumerics.org>
5 *
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License
8 * as published by the Free Software Foundation; either version 2
9 * of the License, or (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * Suggestions:          blitz-dev@oonumerics.org
17 * Bugs:                 blitz-bugs@oonumerics.org
18 *
19 * For more information, please see the Blitz++ Home Page:
20 *    http://oonumerics.org/blitz/
21 *
22 ****************************************************************************/
23#ifndef BZ_CGSOLVE_H
24#define BZ_CGSOLVE_H
25
26BZ_NAMESPACE(blitz)
27
28template<typename T_numtype>
29void dump(const char* name, Array<T_numtype,3>& A)
30{
31    T_numtype normA = 0;
32
33    for (int i=A.lbound(0); i <= A.ubound(0); ++i)
34    {
35      for (int j=A.lbound(1); j <= A.ubound(1); ++j)
36      {
37        for (int k=A.lbound(2); k <= A.ubound(2); ++k)
38        {
39            T_numtype tmp = A(i,j,k);
40            normA += ::fabs(tmp);
41        }
42      }
43    }
44
45    normA /= A.numElements();
46    cout << "Average magnitude of " << name << " is " << normA << endl;
47}
48
49template<typename T_stencil, typename T_numtype, int N_rank, typename T_BCs>
50int conjugateGradientSolver(T_stencil stencil,
51    Array<T_numtype,N_rank>& x,
52    Array<T_numtype,N_rank>& rhs, double haltrho, 
53    const T_BCs& boundaryConditions)
54{
55    // NEEDS_WORK: only apply CG updates over interior; need to handle
56    // BCs separately.
57
58    // x = unknowns being solved for (initial guess assumed)
59    // r = residual
60    // p = descent direction for x
61    // q = descent direction for r
62
63    RectDomain<N_rank> interior = interiorDomain(stencil, x, rhs);
64
65cout << "Interior: " << interior.lbound() << ", " << interior.ubound()
66     << endl;
67
68    // Calculate initial residual
69    Array<T_numtype,N_rank> r = rhs.copy();
70    r *= -1.0;
71
72    boundaryConditions.applyBCs(x);
73
74    applyStencil(stencil, r, x);
75
76 dump("r after stencil", r);
77 cout << "Slice through r: " << endl << r(23,17,Range::all()) << endl;
78 cout << "Slice through x: " << endl << x(23,17,Range::all()) << endl;
79 cout << "Slice through rhs: " << endl << rhs(23,17,Range::all()) << endl;
80
81    r *= -1.0;
82
83 dump("r", r);
84
85    // Allocate the descent direction arrays
86    Array<T_numtype,N_rank> p, q;
87    allocateArrays(x.shape(), p, q);
88
89    int iteration = 0;
90    int converged = 0;
91    T_numtype rho = 0.;
92    T_numtype oldrho = 0.;
93
94    const int maxIterations = 1000;
95
96    // Get views of interior of arrays (without boundaries)
97    Array<T_numtype,N_rank> rint = r(interior);
98    Array<T_numtype,N_rank> pint = p(interior);
99    Array<T_numtype,N_rank> qint = q(interior);
100    Array<T_numtype,N_rank> xint = x(interior);
101
102    while (iteration < maxIterations)
103    {
104        rho = sum(r * r);
105
106        if ((iteration % 20) == 0)
107            cout << "CG: Iter " << iteration << "\t rho = " << rho << endl;
108
109        // Check halting condition
110        if (rho < haltrho)
111        {
112            converged = 1;
113            break;
114        }
115
116        if (iteration == 0)
117        {
118            p = r;
119        }
120        else {
121            T_numtype beta = rho / oldrho;
122            p = beta * p + r;
123        }
124
125        q = 0.;
126//        boundaryConditions.applyBCs(p);
127        applyStencil(stencil, q, p);
128
129        T_numtype pq = sum(p*q);
130
131        T_numtype alpha = rho / pq;
132
133        x += alpha * p;
134        r -= alpha * q;
135
136        oldrho = rho;
137        ++iteration;
138    }
139
140    if (!converged)
141        cout << "Warning: CG solver did not converge" << endl;
142
143    return iteration;
144}
145
146BZ_NAMESPACE_END
147
148#endif // BZ_CGSOLVE_H
Note: See TracBrowser for help on using the repository browser.