source: codes/icosagcm/devel/Python/test/py/partition.py @ 692

Last change on this file since 692 was 692, checked in by dubos, 6 years ago

devel/Python : fix test/partition.py

File size: 4.2 KB
RevLine 
[620]1print 'Starting'
2
3from mpi4py import MPI
4comm = MPI.COMM_WORLD
5mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size()
6print '%d/%d starting'%(mpi_rank,mpi_size)
7
[672]8#from dynamico import partition
9from dynamico import parallel, meshes
10from dynamico import unstructured as unst
11
[620]12import math as math
13import numpy as np
14import netCDF4 as cdf
15
16import matplotlib.pyplot as plt
17from matplotlib.patches import Polygon
18from matplotlib.collections import PatchCollection
19
20print 'Done loading modules'
21
22#----------------- partition hand-written 15-cell mesh ------------------#
23
24if mpi_size<15:
25    send=np.random.randn(mpi_size)
26    recv=np.zeros(mpi_size)
27    comm.Alltoall(send,recv)
28    #time.sleep(mpi_rank)
29#    print mpi_rank, send, recv
30   
31    adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, 
32              0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14,
33              5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ]
34    xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44]
35   
36    nb_vert = len(xadj)-1
37    vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)]
38    xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist]
39   
40    idx_start = vtxdist[mpi_rank]
41    idx_end = vtxdist[mpi_rank+1]
42    nb_vert = idx_end - idx_start
43   
44    xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start]
45    adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ]
46    part = 0*xadj_loc[0:-1];
47
48    unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part)
49
50#    for i in range(len(part)):
51#        print 'vertex', i+idx_start, 'proc', part[i]
52
53#-----------------------------------------------------------------------------#
54#---------------         partition and plot MPAS mesh       ------------------#
55#-----------------------------------------------------------------------------#
56
57# Helper functions to plot unstructured graph
58
59def local_mesh(get_mycells):
[680]60    #    mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell]
61    mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex]
[620]62    print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree))
[672]63    vertex_list = sorted(set(unst.list_stencil(mydegree,mybounds))) 
[620]64    print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list))
65    get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list)
66    mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex]
67    vertex_dict = parallel.inverse_list(vertex_list)
[672]68    meshes.reindex(vertex_dict, mydegree, mybounds)
[620]69    return vertex_list, mydegree, mybounds, mylon, mylat
70
[680]71def members(struct, *names): return [struct.__dict__ [name] for name in names]
72
[620]73#--------------- read MPAS grid file ---------------#
74
[680]75grid = 'x1.2562'
76#grid = 'x1.10242'
[620]77#grid = 'x4.163842'
78print 'Reading MPAS file %s ...'%grid
79
[680]80meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid)
81pmesh = meshes.Unstructured_PMesh(comm, meshfile)
[620]82
[692]83def coriolis(lon,lat): return 0.*lat
84llm, nqdyn, radius = 1,1,1.
85lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, radius, coriolis)
86
[680]87(primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, 
88 lonVertex, latVertex, lonCell, latCell) = members(
89    pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', 
90    'lon_v', 'lat_v', 'lon_i', 'lat_i')
[620]91
[680]92local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal
[620]93local_num[0]=com_cells.own_len
94comm.Reduce(local_num, total_num, op=MPI.SUM, root=0)
95if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n
96
97#---------------------------- plot -----------------------------#
98
[680]99print 'Plotting ...'
[620]100
[680]101halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all)
102buf = parallel.LocalArray1(com_cells)
[620]103
[680]104fig, ax = plt.subplots()
105buf.read_own(latCell) # reads only own values
106buf.data = np.cos(10.*buf.data)
107buf.update() # updates halo
[692]108lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data)
[680]109plt.xlim(-190.,190.)
110plt.ylim(-90.,90.)
111plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600)
[620]112
[680]113fig, ax = plt.subplots()
114buf.read_own(cell_owner)
115buf.update()
[692]116lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data)
[680]117plt.xlim(-190.,190.)
118plt.ylim(-90.,90.)
119plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600)
Note: See TracBrowser for help on using the repository browser.