source: codes/icosagcm/devel/Python/test/py/partition.py @ 672

Last change on this file since 672 was 672, checked in by dubos, 6 years ago

devel/unstructured : Fix partitioning

File size: 7.4 KB
Line 
1print 'Starting'
2
3from mpi4py import MPI
4comm = MPI.COMM_WORLD
5mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
6print '%d/%d starting'%(mpi_rank,mpi_size)
7
8#from dynamico import partition
9from dynamico import parallel, meshes
10from dynamico import unstructured as unst
11
12import math as math
13import numpy as np
14import netCDF4 as cdf
15
16import matplotlib.pyplot as plt
17from matplotlib.patches import Polygon
18from matplotlib.collections import PatchCollection
19
20print 'Done loading modules'
21
22#----------------- partition hand-written 15-cell mesh ------------------#
23
24if mpi_size<15:
25    send=np.random.randn(mpi_size)
26    recv=np.zeros(mpi_size)
27    comm.Alltoall(send,recv)
28    #time.sleep(mpi_rank)
29#    print mpi_rank, send, recv
30   
31    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
32              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
33              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
34    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
35   
36    nb_vert = len(xadj)-1
37    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
38    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
39   
40    idx_start = vtxdist[mpi_rank]
41    idx_end = vtxdist[mpi_rank+1]
42    nb_vert = idx_end - idx_start
43   
44    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
45    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
46    part = 0*xadj_loc[0:-1];
47
48    unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part)
49
50#    for i in range(len(part)):
51#        print 'vertex', i+idx_start, 'proc', part[i]
52
53#-----------------------------------------------------------------------------#
54#---------------         partition and plot MPAS mesh       ------------------#
55#-----------------------------------------------------------------------------#
56
57# Helper functions to plot unstructured graph
58
59def patches(degree, bounds, lon, lat):
60    for i in range(degree.size):
61        nb_edge=degree[i]
62        bounds_cell = bounds[i,0:nb_edge]
63        lat_cell    = lat[bounds_cell]
64        lon_cell    = lon[bounds_cell]
65        orig=lon_cell[0]
66        lon_cell    = lon_cell-orig+180.
67        lon_cell    = np.mod(lon_cell,360.)
68        lon_cell    = lon_cell+orig-180.
69#        if np.abs(lon_cell-orig).max()>10. :
70#            print '%d patches :'%mpi_rank, lon_cell
71        lonlat_cell = np.zeros((nb_edge,2))
72        lonlat_cell[:,0],lonlat_cell[:,1] = lon_cell,lat_cell
73        polygon = Polygon(lonlat_cell, True)
74        yield polygon
75
76def plot_mesh(ax, clim, degree, bounds, lon, lat, data):
77    nb_vertex = lon.size # global
78    p = list(patches(degree, bounds, lon, lat))
79    print '%d : plot_mesh %d %d %d'%( mpi_rank, degree.size, len(p), len(data) ) 
80    p = PatchCollection(p, linewidth=0.01)
81    p.set_array(data) # set values at each polygon (cell)
82    p.set_clim(clim)
83    ax.add_collection(p)
84
85def local_mesh(get_mycells):
86    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
87    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
88    vertex_list = sorted(set(unst.list_stencil(mydegree,mybounds))) 
89    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
90    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
91    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
92    vertex_dict = parallel.inverse_list(vertex_list)
93    meshes.reindex(vertex_dict, mydegree, mybounds)
94    return vertex_list, mydegree, mybounds, mylon, mylat
95
96#--------------- read MPAS grid file ---------------#
97
98#grid = 'x1.2562'
99grid = 'x1.10242'
100#grid = 'x4.163842'
101print 'Reading MPAS file %s ...'%grid
102
103nc = cdf.Dataset('grids/%s.grid.nc'%grid, "r")
104dim_cell, dim_edge, dim_vertex = [
105    parallel.PDim(nc.dimensions[name], comm) 
106    for name in 'nCells','nEdges','nVertices']
107edge_degree   = parallel.CstPArray1D(dim_edge, np.int32, 2)
108vertex_degree = parallel.CstPArray1D(dim_vertex, np.int32, 3)
109nEdgesOnCell, verticesOnCell, edgesOnCell, cellsOnCell, latCell = [
110    parallel.PArray(dim_cell, nc.variables[var])
111    for var in 'nEdgesOnCell', 'verticesOnCell', 'edgesOnCell', 'cellsOnCell', 'latCell' ]
112cellsOnVertex, edgesOnVertex, kiteAreasOnVertex, lonVertex, latVertex = [
113    parallel.PArray(dim_vertex, nc.variables[var])
114    for var in 'cellsOnVertex', 'edgesOnVertex', 'kiteAreasOnVertex', 'lonVertex', 'latVertex']
115nEdgesOnEdge, cellsOnEdge, edgesOnEdge, verticesOnEdge, weightsOnEdge = [
116    parallel.PArray(dim_edge, nc.variables[var])
117    for var in 'nEdgesOnEdge', 'cellsOnEdge', 'edgesOnEdge', 'verticesOnEdge', 'weightsOnEdge']
118
119# Indices start at 0 on the C/Python side and at 1 on the Fortran/MPAS side
120# hence an offset of 1 is added/substracted where needed.
121for x in (verticesOnCell, edgesOnCell, cellsOnCell, cellsOnVertex, edgesOnVertex,
122          cellsOnEdge, edgesOnEdge, verticesOnEdge) : x.data = x.data-1
123edge2cell, cell2edge, edge2vertex, vertex2edge, cell2cell, edge2edge = [
124    meshes.Stencil_glob(a,b) for a,b in 
125    (edge_degree, cellsOnEdge), (nEdgesOnCell, edgesOnCell),
126    (edge_degree, verticesOnEdge), (vertex_degree, edgesOnVertex),
127    (nEdgesOnCell, cellsOnCell), (nEdgesOnEdge, edgesOnEdge) ]
128
129#---------------- partition edges and cells ------------------#
130
131print 'Partitioning ...'
132
133edge_owner = unst.partition_mesh(nEdgesOnEdge, edgesOnEdge, mpi_size)
134edge_owner = parallel.LocPArray1D(dim_edge, edge_owner)
135cell_owner = meshes.partition_from_stencil(edge_owner, nEdgesOnCell, edgesOnCell)
136cell_owner = parallel.LocPArray1D(dim_cell, cell_owner)
137
138#--------------------- construct halos  -----------------------#
139
140print 'Constructing halos ...'
141
142def chain(start, links):
143    for link in links:
144        start = link(start).neigh_set
145        yield start
146
147edges_E0 = meshes.find_my_cells(edge_owner)
148cells_C0, edges_E1, vertices_V1, edges_E2, cells_C1 = chain(
149    edges_E0, ( edge2cell, cell2edge, edge2vertex, vertex2edge, edge2cell) )
150
151edges_E0, edges_E1, edges_E2 = meshes.progressive_list(edges_E0, edges_E1, edges_E2)
152cells_C0, cells_C1 = meshes.progressive_list(cells_C0, cells_C1)
153
154print 'E2,E1,E0 ; C1,C0 : ', map(len, (edges_E2, edges_E1, edges_E0, cells_C1, cells_C0))
155
156#com_edges = parallel.Halo_Xchange(24, dim_edge, edges_E2, dim_edge.get(edges_E2, edge_owner))
157
158mycells, halo_cells = cells_C0, cells_C1
159get_mycells, get_halo_cells = dim_cell.getter(mycells), dim_cell.getter(halo_cells)
160com_cells = parallel.Halo_Xchange(42, dim_cell, halo_cells, get_halo_cells(cell_owner))
161
162local_num, total_num = np.zeros(1), np.zeros(1)
163local_num[0]=com_cells.own_len
164comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
165if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
166
167#---------------------------- plot -----------------------------#
168
169if True:
170    print 'Plotting ...'
171
172    halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
173    buf = parallel.LocalArray1(com_cells)
174
175    fig, ax = plt.subplots()
176    buf.read_own(latCell) # reads only own values
177    buf.data = np.cos(10.*buf.data)
178    buf.update() # updates halo
179    plot_mesh(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
180    plt.xlim(-190.,190.)
181    plt.ylim(-90.,90.)
182    plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600)
183
184    fig, ax = plt.subplots()
185    buf.read_own(cell_owner)
186    buf.update()
187    plot_mesh(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
188    plt.xlim(-190.,190.)
189    plt.ylim(-90.,90.)
190    plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600)
Note: See TracBrowser for help on using the repository browser.