source: codes/icosagcm/devel/Python/test/py/partition.py @ 841

Last change on this file since 841 was 825, checked in by dubos, 5 years ago

devel/Python : moved Fortran bindings and *.pyx to dynamico/dev module + necessary changes to test/py/*.py

File size: 4.4 KB
Line 
1print 'Starting'
2
3from dynamico import meshes
4from dynamico import parallel
5from dynamico.dev import unstructured as unst
6from dynamico import maps
7from dynamico.parmetis import partition_graph
8
9from mpi4py import MPI
10comm = MPI.COMM_WORLD
11mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
12print '%d/%d starting'%(mpi_rank,mpi_size)
13
14#from dynamico import partition
15
16import math as math
17import numpy as np
18import netCDF4 as cdf
19
20import matplotlib.pyplot as plt
21from matplotlib.patches import Polygon
22from matplotlib.collections import PatchCollection
23
24print 'Done loading modules'
25
26#----------------- partition hand-written 15-cell mesh ------------------#
27
28if mpi_size<15:
29    send=np.random.randn(mpi_size)
30    recv=np.zeros(mpi_size)
31    comm.Alltoall(send,recv)
32    #time.sleep(mpi_rank)
33#    print mpi_rank, send, recv
34   
35    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
36              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
37              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
38    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
39   
40    nb_vert = len(xadj)-1
41    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
42    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
43   
44    idx_start = vtxdist[mpi_rank]
45    idx_end = vtxdist[mpi_rank+1]
46    nb_vert = idx_end - idx_start
47   
48    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
49    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
50    part = 0*xadj_loc[0:-1];
51
52#    unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part)
53    partition_graph(comm, vtxdist, xadj_loc, adjncy_loc, part, nparts=4)
54
55    for i in range(len(part)):
56        print 'vertex', i+idx_start, 'proc', part[i]
57
58#-----------------------------------------------------------------------------#
59#---------------         partition and plot MPAS mesh       ------------------#
60#-----------------------------------------------------------------------------#
61
62# Helper functions to plot unstructured graph
63
64def local_mesh(get_mycells):
65    #    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
66    mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex]
67    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
68    vertex_list = sorted(set(unst.list_stencil(mydegree,mybounds))) 
69    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
70    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
71    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
72    vertex_dict = parallel.inverse_list(vertex_list)
73    meshes.reindex(vertex_dict, mydegree, mybounds)
74    return vertex_list, mydegree, mybounds, mylon, mylat
75
76def members(struct, *names): return [struct.__dict__ [name] for name in names]
77
78#--------------- read MPAS grid file ---------------#
79
80grid = 'x1.2562'
81#grid = 'x1.10242'
82#grid = 'x4.163842'
83print 'Reading MPAS file %s ...'%grid
84
85meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid)
86pmesh = meshes.Unstructured_PMesh(comm, meshfile)
87pmesh.partition_metis()
88
89def coriolis(lon,lat): return 0.*lat
90llm, nqdyn, radius = 1,1,1.
91planet = maps.SphereMap(radius, 0.)
92lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, planet)
93
94(primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, 
95 lonVertex, latVertex, lonCell, latCell) = members(
96    pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', 
97    'lon_v', 'lat_v', 'lon_i', 'lat_i')
98
99local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal
100local_num[0]=com_cells.own_len
101comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
102if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
103
104#---------------------------- plot -----------------------------#
105
106print 'Plotting ...'
107
108halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
109buf = parallel.LocalArray1(com_cells)
110
111fig, ax = plt.subplots()
112buf.read_own(latCell) # reads only own values
113buf.data = np.cos(10.*buf.data)
114buf.update() # updates halo
115lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
116plt.xlim(-190.,190.)
117plt.ylim(-90.,90.)
118plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600)
119
120fig, ax = plt.subplots()
121buf.read_own(cell_owner)
122buf.update()
123lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
124plt.xlim(-190.,190.)
125plt.ylim(-90.,90.)
126plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600)
Note: See TracBrowser for help on using the repository browser.