[620] | 1 | print 'Starting' |
---|
| 2 | |
---|
[825] | 3 | from dynamico import meshes |
---|
| 4 | from dynamico import parallel |
---|
| 5 | from dynamico import maps |
---|
[977] | 6 | from dynamico import partition |
---|
[825] | 7 | |
---|
[620] | 8 | from mpi4py import MPI |
---|
| 9 | comm = MPI.COMM_WORLD |
---|
| 10 | mpi_rank, mpi_size = comm.Get_rank(), comm.Get_size() |
---|
| 11 | print '%d/%d starting'%(mpi_rank,mpi_size) |
---|
| 12 | |
---|
| 13 | import math as math |
---|
| 14 | import numpy as np |
---|
| 15 | import netCDF4 as cdf |
---|
| 16 | |
---|
| 17 | import matplotlib.pyplot as plt |
---|
| 18 | from matplotlib.patches import Polygon |
---|
| 19 | from matplotlib.collections import PatchCollection |
---|
| 20 | |
---|
| 21 | print 'Done loading modules' |
---|
| 22 | |
---|
| 23 | #----------------- partition hand-written 15-cell mesh ------------------# |
---|
| 24 | |
---|
| 25 | if mpi_size<15: |
---|
| 26 | send=np.random.randn(mpi_size) |
---|
| 27 | recv=np.zeros(mpi_size) |
---|
| 28 | comm.Alltoall(send,recv) |
---|
| 29 | #time.sleep(mpi_rank) |
---|
| 30 | # print mpi_rank, send, recv |
---|
| 31 | |
---|
| 32 | adjncy=[1, 5, 0, 2, 6, 1, 3, 7, 2, 4, 8, 3, 9, |
---|
| 33 | 0, 6, 10, 1, 5, 7, 11, 2, 6, 8, 12, 3, 7, 9, 13, 4, 8, 14, |
---|
| 34 | 5, 11, 6, 10, 12, 7, 11, 13, 8, 12, 14, 9, 13 ] |
---|
| 35 | xadj=[0, 2, 5, 8, 11, 13, 16, 20, 24, 28, 31, 33, 36, 39, 42, 44] |
---|
| 36 | |
---|
| 37 | nb_vert = len(xadj)-1 |
---|
| 38 | vtxdist = [i*nb_vert/mpi_size for i in range(mpi_size+1)] |
---|
| 39 | xadj, adjncy, vtxdist = [np.asarray(x,np.int32) for x in xadj,adjncy,vtxdist] |
---|
| 40 | |
---|
| 41 | idx_start = vtxdist[mpi_rank] |
---|
| 42 | idx_end = vtxdist[mpi_rank+1] |
---|
| 43 | nb_vert = idx_end - idx_start |
---|
| 44 | |
---|
| 45 | xadj_loc = xadj[idx_start:idx_end+1]-xadj[idx_start] |
---|
| 46 | adjncy_loc = adjncy[ xadj[idx_start]:xadj[idx_end] ] |
---|
| 47 | part = 0*xadj_loc[0:-1]; |
---|
| 48 | |
---|
[977] | 49 | partition.partition_graph(comm, vtxdist, xadj_loc, adjncy_loc, part, nparts=4) |
---|
[620] | 50 | |
---|
[825] | 51 | for i in range(len(part)): |
---|
| 52 | print 'vertex', i+idx_start, 'proc', part[i] |
---|
[620] | 53 | |
---|
| 54 | #-----------------------------------------------------------------------------# |
---|
| 55 | #--------------- partition and plot MPAS mesh ------------------# |
---|
| 56 | #-----------------------------------------------------------------------------# |
---|
| 57 | |
---|
| 58 | # Helper functions to plot unstructured graph |
---|
| 59 | |
---|
| 60 | def local_mesh(get_mycells): |
---|
[680] | 61 | # mydegree, mybounds = [get_mycells(x) for x in nEdgesOnCell, verticesOnCell] |
---|
| 62 | mydegree, mybounds = [get_mycells(x) for x in primal_deg, primal_vertex] |
---|
[620] | 63 | print '%d : len(mydegree)=%d'%(mpi_rank, len(mydegree)) |
---|
[977] | 64 | vertex_list = sorted(set(partition.list_stencil(mydegree,mybounds))) |
---|
[620] | 65 | print '%d : len(vertex_list))=%d'%(mpi_rank, len(vertex_list)) |
---|
| 66 | get_myvertices = parallel.Get_Indices(dim_vertex, vertex_list) |
---|
| 67 | mylon, mylat = [get_myvertices(x)*180./math.pi for x in lonVertex, latVertex] |
---|
| 68 | vertex_dict = parallel.inverse_list(vertex_list) |
---|
[672] | 69 | meshes.reindex(vertex_dict, mydegree, mybounds) |
---|
[620] | 70 | return vertex_list, mydegree, mybounds, mylon, mylat |
---|
| 71 | |
---|
[680] | 72 | def members(struct, *names): return [struct.__dict__ [name] for name in names] |
---|
| 73 | |
---|
[620] | 74 | #--------------- read MPAS grid file ---------------# |
---|
| 75 | |
---|
[680] | 76 | grid = 'x1.2562' |
---|
| 77 | #grid = 'x1.10242' |
---|
[620] | 78 | #grid = 'x4.163842' |
---|
| 79 | print 'Reading MPAS file %s ...'%grid |
---|
| 80 | |
---|
[680] | 81 | meshfile = meshes.MPAS_Format('grids/%s.grid.nc'%grid) |
---|
| 82 | pmesh = meshes.Unstructured_PMesh(comm, meshfile) |
---|
[760] | 83 | pmesh.partition_metis() |
---|
[620] | 84 | |
---|
[692] | 85 | def coriolis(lon,lat): return 0.*lat |
---|
| 86 | llm, nqdyn, radius = 1,1,1. |
---|
[825] | 87 | planet = maps.SphereMap(radius, 0.) |
---|
| 88 | lmesh = meshes.Local_Mesh(pmesh, llm, nqdyn, planet) |
---|
[692] | 89 | |
---|
[680] | 90 | (primal_deg, primal_vertex, dim_vertex, dim_cell, cell_owner, |
---|
| 91 | lonVertex, latVertex, lonCell, latCell) = members( |
---|
| 92 | pmesh, 'primal_deg', 'primal_vertex', 'dim_dual', 'dim_primal', 'primal_owner', |
---|
| 93 | 'lon_v', 'lat_v', 'lon_i', 'lat_i') |
---|
[620] | 94 | |
---|
[680] | 95 | local_num, total_num, com_cells = np.zeros(1), np.zeros(1), lmesh.com_primal |
---|
[620] | 96 | local_num[0]=com_cells.own_len |
---|
| 97 | comm.Reduce(local_num, total_num, op=MPI.SUM, root=0) |
---|
| 98 | if(mpi_rank==0): print 'total num :', total_num[0], dim_cell.n |
---|
| 99 | |
---|
| 100 | #---------------------------- plot -----------------------------# |
---|
| 101 | |
---|
[680] | 102 | print 'Plotting ...' |
---|
[620] | 103 | |
---|
[680] | 104 | halo_vertex_list, mydegree, mybounds, mylon, mylat = local_mesh(com_cells.get_all) |
---|
| 105 | buf = parallel.LocalArray1(com_cells) |
---|
[620] | 106 | |
---|
[680] | 107 | fig, ax = plt.subplots() |
---|
| 108 | buf.read_own(latCell) # reads only own values |
---|
| 109 | buf.data = np.cos(10.*buf.data) |
---|
| 110 | buf.update() # updates halo |
---|
[692] | 111 | lmesh.plot_patches(ax,[-math.pi/2,math.pi/2], mydegree, mybounds, mylon, mylat, buf.data) |
---|
[680] | 112 | plt.xlim(-190.,190.) |
---|
| 113 | plt.ylim(-90.,90.) |
---|
[977] | 114 | plt.savefig('fig_partition/A%03d.png'%mpi_rank, dpi=160) |
---|
[620] | 115 | |
---|
[680] | 116 | fig, ax = plt.subplots() |
---|
| 117 | buf.read_own(cell_owner) |
---|
| 118 | buf.update() |
---|
[692] | 119 | lmesh.plot_patches(ax,[0,mpi_rank+1], mydegree, mybounds, mylon, mylat, buf.data) |
---|
[680] | 120 | plt.xlim(-190.,190.) |
---|
| 121 | plt.ylim(-90.,90.) |
---|
[977] | 122 | plt.savefig('fig_partition/B%03d.png'%mpi_rank, dpi=160) |
---|