source: codes/icosagcm/devel/Python/test/py/partition.py @ 692

Last change on this file since 692 was 692, checked in by dubos, 6 years ago

devel/Python : fix test/partition.py

File size: 4.2 KB
Line 
1print 'Starting'
2
3from mpi4py import MPI
4comm = MPI.COMM_WORLD
5mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
6print '%d/%d starting'%(mpi_rank,mpi_size)
7
8#from dynamico import partition
9from dynamico import parallel, meshes
10from dynamico import unstructured as unst
11
12import math as math
13import numpy as np
14import netCDF4 as cdf
15
16import matplotlib.pyplot as plt
17from matplotlib.patches import Polygon
18from matplotlib.collections import PatchCollection
19
20print 'Done loading modules'
21
22#----------------- partition hand-written 15-cell mesh ------------------#
23
24if mpi_size<15:
25    send=np.random.randn(mpi_size)
26    recv=np.zeros(mpi_size)
27    comm.Alltoall(send,recv)
28    #time.sleep(mpi_rank)
29#    print mpi_rank, send, recv
30   
31    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
32              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
33              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
34    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
35   
36    nb_vert = len(xadj)-1
37    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
38    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
39   
40    idx_start = vtxdist[mpi_rank]
41    idx_end = vtxdist[mpi_rank+1]
42    nb_vert = idx_end - idx_start
43   
44    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
45    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
46    part = 0*xadj_loc[0:-1];
47
48    unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part)
49
50#    for i in range(len(part)):
51#        print 'vertex', i+idx_start, 'proc', part[i]
52
53#-----------------------------------------------------------------------------#
54#---------------         partition and plot MPAS mesh       ------------------#
55#-----------------------------------------------------------------------------#
56
57# Helper functions to plot unstructured graph
58
59def local_mesh(get_mycells):
60    #    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
61    mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex]
62    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
63    vertex_list = sorted(set(unst.list_stencil(mydegree,mybounds))) 
64    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
65    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
66    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
67    vertex_dict = parallel.inverse_list(vertex_list)
68    meshes.reindex(vertex_dict, mydegree, mybounds)
69    return vertex_list, mydegree, mybounds, mylon, mylat
70
71def members(struct, *names): return [struct.__dict__ [name] for name in names]
72
73#--------------- read MPAS grid file ---------------#
74
75grid = 'x1.2562'
76#grid = 'x1.10242'
77#grid = 'x4.163842'
78print 'Reading MPAS file %s ...'%grid
79
80meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid)
81pmesh = meshes.Unstructured_PMesh(comm, meshfile)
82
83def coriolis(lon,lat): return 0.*lat
84llm, nqdyn, radius = 1,1,1.
85lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, radius, coriolis)
86
87(primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, 
88 lonVertex, latVertex, lonCell, latCell) = members(
89    pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', 
90    'lon_v', 'lat_v', 'lon_i', 'lat_i')
91
92local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal
93local_num[0]=com_cells.own_len
94comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
95if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
96
97#---------------------------- plot -----------------------------#
98
99print 'Plotting ...'
100
101halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
102buf = parallel.LocalArray1(com_cells)
103
104fig, ax = plt.subplots()
105buf.read_own(latCell) # reads only own values
106buf.data = np.cos(10.*buf.data)
107buf.update() # updates halo
108lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
109plt.xlim(-190.,190.)
110plt.ylim(-90.,90.)
111plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600)
112
113fig, ax = plt.subplots()
114buf.read_own(cell_owner)
115buf.update()
116lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
117plt.xlim(-190.,190.)
118plt.ylim(-90.,90.)
119plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600)
Note: See TracBrowser for help on using the repository browser.