source: codes/icosagcm/devel/Python/test/py/partition.py

Last change on this file was 977, checked in by dubos, 5 years ago

devel/Python : now only dynamico.dev modules require link to DYNAMICO/XIOS shared objects

File size: 4.2 KB
RevLine 
[620]1print 'Starting'
2
[825]3from dynamico import meshes
4from dynamico import parallel
5from dynamico import maps
[977]6from dynamico import partition
[825]7
[620]8from mpi4py import MPI
9comm = MPI.COMM_WORLD
10mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
11print '%d/%d starting'%(mpi_rank,mpi_size)
12
13import math as math
14import numpy as np
15import netCDF4 as cdf
16
17import matplotlib.pyplot as plt
18from matplotlib.patches import Polygon
19from matplotlib.collections import PatchCollection
20
21print 'Done loading modules'
22
23#----------------- partition hand-written 15-cell mesh ------------------#
24
25if mpi_size<15:
26    send=np.random.randn(mpi_size)
27    recv=np.zeros(mpi_size)
28    comm.Alltoall(send,recv)
29    #time.sleep(mpi_rank)
30#    print mpi_rank, send, recv
31   
32    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
33              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
34              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
35    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
36   
37    nb_vert = len(xadj)-1
38    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
39    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
40   
41    idx_start = vtxdist[mpi_rank]
42    idx_end = vtxdist[mpi_rank+1]
43    nb_vert = idx_end - idx_start
44   
45    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
46    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
47    part = 0*xadj_loc[0:-1];
48
[977]49    partition.partition_graph(comm, vtxdist, xadj_loc, adjncy_loc, part, nparts=4)
[620]50
[825]51    for i in range(len(part)):
52        print 'vertex', i+idx_start, 'proc', part[i]
[620]53
54#-----------------------------------------------------------------------------#
55#---------------         partition and plot MPAS mesh       ------------------#
56#-----------------------------------------------------------------------------#
57
58# Helper functions to plot unstructured graph
59
60def local_mesh(get_mycells):
[680]61    #    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
62    mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex]
[620]63    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
[977]64    vertex_list = sorted(set(partition.list_stencil(mydegree,mybounds))) 
[620]65    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
66    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
67    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
68    vertex_dict = parallel.inverse_list(vertex_list)
[672]69    meshes.reindex(vertex_dict, mydegree, mybounds)
[620]70    return vertex_list, mydegree, mybounds, mylon, mylat
71
[680]72def members(struct, *names): return [struct.__dict__ [name] for name in names]
73
[620]74#--------------- read MPAS grid file ---------------#
75
[680]76grid = 'x1.2562'
77#grid = 'x1.10242'
[620]78#grid = 'x4.163842'
79print 'Reading MPAS file %s ...'%grid
80
[680]81meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid)
82pmesh = meshes.Unstructured_PMesh(comm, meshfile)
[760]83pmesh.partition_metis()
[620]84
[692]85def coriolis(lon,lat): return 0.*lat
86llm, nqdyn, radius = 1,1,1.
[825]87planet = maps.SphereMap(radius, 0.)
88lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, planet)
[692]89
[680]90(primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, 
91 lonVertex, latVertex, lonCell, latCell) = members(
92    pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', 
93    'lon_v', 'lat_v', 'lon_i', 'lat_i')
[620]94
[680]95local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal
[620]96local_num[0]=com_cells.own_len
97comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
98if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
99
100#---------------------------- plot -----------------------------#
101
[680]102print 'Plotting ...'
[620]103
[680]104halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
105buf = parallel.LocalArray1(com_cells)
[620]106
[680]107fig, ax = plt.subplots()
108buf.read_own(latCell) # reads only own values
109buf.data = np.cos(10.*buf.data)
110buf.update() # updates halo
[692]111lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
[680]112plt.xlim(-190.,190.)
113plt.ylim(-90.,90.)
[977]114plt.savefig('fig_partition/A%03d.png'%mpi_rank, dpi=160)
[620]115
[680]116fig, ax = plt.subplots()
117buf.read_own(cell_owner)
118buf.update()
[692]119lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
[680]120plt.xlim(-190.,190.)
121plt.ylim(-90.,90.)
[977]122plt.savefig('fig_partition/B%03d.png'%mpi_rank, dpi=160)
Note: See TracBrowser for help on using the repository browser.