1 | print 'Starting' |
---|
2 | |
---|
3 | from mpi4py import MPI |
---|
4 | comm = MPI.COMM_WORLD |
---|
5 | mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size() |
---|
6 | print '%d/%d starting'%(mpi_rank,mpi_size) |
---|
7 | |
---|
8 | #from dynamico import partition |
---|
9 | from dynamico import parallel, meshes |
---|
10 | from dynamico import unstructured as unst |
---|
11 | |
---|
12 | import math as math |
---|
13 | import numpy as np |
---|
14 | import netCDF4 as cdf |
---|
15 | |
---|
16 | import matplotlib.pyplot as plt |
---|
17 | from matplotlib.patches import Polygon |
---|
18 | from matplotlib.collections import PatchCollection |
---|
19 | |
---|
20 | print 'Done loading modules' |
---|
21 | |
---|
22 | #----------------- partition hand-written 15-cell mesh ------------------# |
---|
23 | |
---|
24 | if mpi_size<15: |
---|
25 | send=np.random.randn(mpi_size) |
---|
26 | recv=np.zeros(mpi_size) |
---|
27 | comm.Alltoall(send,recv) |
---|
28 | #time.sleep(mpi_rank) |
---|
29 | # print mpi_rank, send, recv |
---|
30 | |
---|
31 | adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, |
---|
32 | 0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14, |
---|
33 | 5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ] |
---|
34 | xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44] |
---|
35 | |
---|
36 | nb_vert = len(xadj)-1 |
---|
37 | vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)] |
---|
38 | xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist] |
---|
39 | |
---|
40 | idx_start = vtxdist[mpi_rank] |
---|
41 | idx_end = vtxdist[mpi_rank+1] |
---|
42 | nb_vert = idx_end - idx_start |
---|
43 | |
---|
44 | xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start] |
---|
45 | adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ] |
---|
46 | part = 0*xadj_loc[0:-1]; |
---|
47 | |
---|
48 | unst.ker.dynamico_partition_graph(mpi_rank, mpi_size, vtxdist, xadj_loc, adjncy_loc, 4, part) |
---|
49 | |
---|
50 | # for i in range(len(part)): |
---|
51 | # print 'vertex', i+idx_start, 'proc', part[i] |
---|
52 | |
---|
53 | #-----------------------------------------------------------------------------# |
---|
54 | #--------------- partition and plot MPAS mesh ------------------# |
---|
55 | #-----------------------------------------------------------------------------# |
---|
56 | |
---|
57 | # Helper functions to plot unstructured graph |
---|
58 | |
---|
59 | def local_mesh(get_mycells): |
---|
60 | # mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell] |
---|
61 | mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex] |
---|
62 | print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree)) |
---|
63 | vertex_list = sorted(set(unst.list_stencil(mydegree,mybounds))) |
---|
64 | print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list)) |
---|
65 | get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list) |
---|
66 | mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex] |
---|
67 | vertex_dict = parallel.inverse_list(vertex_list) |
---|
68 | meshes.reindex(vertex_dict, mydegree, mybounds) |
---|
69 | return vertex_list, mydegree, mybounds, mylon, mylat |
---|
70 | |
---|
71 | def members(struct, *names): return [struct.__dict__ [name] for name in names] |
---|
72 | |
---|
73 | #--------------- read MPAS grid file ---------------# |
---|
74 | |
---|
75 | grid = 'x1.2562' |
---|
76 | #grid = 'x1.10242' |
---|
77 | #grid = 'x4.163842' |
---|
78 | print 'Reading MPAS file %s ...'%grid |
---|
79 | |
---|
80 | meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid) |
---|
81 | pmesh = meshes.Unstructured_PMesh(comm, meshfile) |
---|
82 | pmesh.partition_metis() |
---|
83 | |
---|
84 | def coriolis(lon,lat): return 0.*lat |
---|
85 | llm, nqdyn, radius = 1,1,1. |
---|
86 | lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, radius, coriolis) |
---|
87 | |
---|
88 | (primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, |
---|
89 | lonVertex, latVertex, lonCell, latCell) = members( |
---|
90 | pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', |
---|
91 | 'lon_v', 'lat_v', 'lon_i', 'lat_i') |
---|
92 | |
---|
93 | local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal |
---|
94 | local_num[0]=com_cells.own_len |
---|
95 | comm.Reduce(local_num, total_num, op=MPI.SUM, root=0) |
---|
96 | if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n |
---|
97 | |
---|
98 | #---------------------------- plot -----------------------------# |
---|
99 | |
---|
100 | print 'Plotting ...' |
---|
101 | |
---|
102 | halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all) |
---|
103 | buf = parallel.LocalArray1(com_cells) |
---|
104 | |
---|
105 | fig, ax = plt.subplots() |
---|
106 | buf.read_own(latCell) # reads only own values |
---|
107 | buf.data = np.cos(10.*buf.data) |
---|
108 | buf.update() # updates halo |
---|
109 | lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data) |
---|
110 | plt.xlim(-190.,190.) |
---|
111 | plt.ylim(-90.,90.) |
---|
112 | plt.savefig('fig_partition/A%03d.pdf'%mpi_rank, dpi=1600) |
---|
113 | |
---|
114 | fig, ax = plt.subplots() |
---|
115 | buf.read_own(cell_owner) |
---|
116 | buf.update() |
---|
117 | lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data) |
---|
118 | plt.xlim(-190.,190.) |
---|
119 | plt.ylim(-90.,90.) |
---|
120 | plt.savefig('fig_partition/B%03d.pdf'%mpi_rank, dpi=1600) |
---|