source: codes/icosagcm/devel/Python/test/py/partition.py @ 798

Last change on this file since 798 was 760, checked in by dubos, 6 years ago

devel/Python : block-wise partitioning

File size: 4.2 KB
Line 
1print 'Starting'
2
3from mpi4py import MPI
4comm = MPI.COMM_WORLD
5mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
6print '%d/%d starting'%(mpi_rank,mpi_size)
7
8#from dynamico import partition
9from dynamico import parallel, meshes
10from dynamico import unstructured as unst
11
12import math as math
13import numpy as np
14import netCDF4 as cdf
15
16import matplotlib.pyplot as plt
17from matplotlib.patches import Polygon
18from matplotlib.collections import PatchCollection
19
20print 'Done loading modules'
21
22#----------------- partition hand-written 15-cell mesh ------------------#
23
24if mpi_size<15:
25    send=np.random.randn(mpi_size)
26    recv=np.zeros(mpi_size)
27    comm.Alltoall(send,recv)
28    #time.sleep(mpi_rank)
29#    print mpi_rank, send, recv
30   
31    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
32              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
33              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
34    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
35   
36    nb_vert = len(xadj)-1
37    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
38    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
39   
40    idx_start = vtxdist[mpi_rank]
41    idx_end = vtxdist[mpi_rank+1]
42    nb_vert = idx_end - idx_start
43   
44    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
45    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
46    part = 0*xadj_loc[0:-1];
47
48    unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part)
49
50#    for i in range(len(part)):
51#        print 'vertex', i+idx_start, 'proc', part[i]
52
53#-----------------------------------------------------------------------------#
54#---------------         partition and plot MPAS mesh       ------------------#
55#-----------------------------------------------------------------------------#
56
57# Helper functions to plot unstructured graph
58
59def local_mesh(get_mycells):
60    #    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
61    mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex]
62    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
63    vertex_list = sorted(set(unst.list_stencil(mydegree,mybounds))) 
64    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
65    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
66    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
67    vertex_dict = parallel.inverse_list(vertex_list)
68    meshes.reindex(vertex_dict, mydegree, mybounds)
69    return vertex_list, mydegree, mybounds, mylon, mylat
70
71def members(struct, *names): return [struct.__dict__ [name] for name in names]
72
73#--------------- read MPAS grid file ---------------#
74
75grid = 'x1.2562'
76#grid = 'x1.10242'
77#grid = 'x4.163842'
78print 'Reading MPAS file %s ...'%grid
79
80meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid)
81pmesh = meshes.Unstructured_PMesh(comm, meshfile)
82pmesh.partition_metis()
83
84def coriolis(lon,lat): return 0.*lat
85llm, nqdyn, radius = 1,1,1.
86lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, radius, coriolis)
87
88(primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, 
89 lonVertex, latVertex, lonCell, latCell) = members(
90    pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', 
91    'lon_v', 'lat_v', 'lon_i', 'lat_i')
92
93local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal
94local_num[0]=com_cells.own_len
95comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
96if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
97
98#---------------------------- plot -----------------------------#
99
100print 'Plotting ...'
101
102halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
103buf = parallel.LocalArray1(com_cells)
104
105fig, ax = plt.subplots()
106buf.read_own(latCell) # reads only own values
107buf.data = np.cos(10.*buf.data)
108buf.update() # updates halo
109lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
110plt.xlim(-190.,190.)
111plt.ylim(-90.,90.)
112plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600)
113
114fig, ax = plt.subplots()
115buf.read_own(cell_owner)
116buf.update()
117lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
118plt.xlim(-190.,190.)
119plt.ylim(-90.,90.)
120plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600)
Note: See TracBrowser for help on using the repository browser.