diff --git a/pyfms/py_horiz_interp/interp.py b/pyfms/py_horiz_interp/interp.py index 1aa96f8..2145ef8 100644 --- a/pyfms/py_horiz_interp/interp.py +++ b/pyfms/py_horiz_interp/interp.py @@ -22,7 +22,7 @@ def __init__(self, interp_id: int = None, save_xgrid_area: bool = False): self.nlon_dst = horiz_interp.get_nlon_dst(interp_id) self.nlat_dst = horiz_interp.get_nlat_dst(interp_id) self.interp_method = horiz_interp.get_interp_method(interp_id) - self.get_area_frac_dst = horiz_interp.get_area_frac_dst(interp_id) + self.area_frac_dst = horiz_interp.get_area_frac_dst(interp_id) if save_xgrid_area: self.xgrid_area = horiz_interp.get_xgrid_area(interp_id) else: @@ -36,20 +36,23 @@ def __init__(self, interp_id: int = None, save_xgrid_area: bool = False): self.nlon_dst = None self.nlat_dst = None self.interp_method = None - self.get_area_frac_dst = None + self.area_frac_dst = None def __repr__(self): - description = "\n\nConserveInterp object\n\n" - description += "src_nx = {:>5} src_ny={:>5}\n".format( - self.nlon_src, self.nlat_src - ) - description += "tgt_nx = {:>5} tgt_ny={:>5}\n".format( - self.nlon_dst, self.nlat_dst - ) - description += f"nxgrid = {self.nxgrid}\n" - description += f"i_src = {self.i_src}\n" - description += f"j_src = {self.j_src}\n" - description += f"i_dst = {self.i_dst}\n" - description += f"j_dst = {self.j_dst}\n" - description += f"xgrid_area = {self.xgrid_area}\n" - return description + + repr_str = f""" + interp_id: {self.interp_id} + nxgrid: {self.nxgrid} + nlon_src: {self.nlon_src} + nlat_src: {self.nlat_src} + nlon_dst: {self.nlon_dst} + nlat_dst: {self.nlat_dst} + interp_method: {self.interp_method} + i_src_minmax: [{self.i_src.min()}, {self.i_src.max()}] + j_src_minmax [{self.j_src.min()}, {self.j_src.max()}] + i_dst_minmax: [{self.i_dst.min()}, {self.i_dst.max()}] + j_dst_minmax: [{self.j_dst.min()}, {self.j_dst.max()}] + area_frac_dst_minmax: [{self.area_frac_dst.min()}, {self.area_frac_dst.max()}] + """ + + return repr_str diff --git a/pyfms/py_mpp/_mpp_functions.py b/pyfms/py_mpp/_mpp_functions.py index a5468fd..5f9af9d 100644 --- a/pyfms/py_mpp/_mpp_functions.py +++ b/pyfms/py_mpp/_mpp_functions.py @@ -46,9 +46,9 @@ def define(lib): POINTER(c_int), # npes ndpointer(dtype=np.int32, ndim=1, flags=C), # pelist ndpointer(dtype=nptype, ndim=2, flags=C), # array_seg - NDPOINTER(dtype=np.int32, shape=(2,), flags=C), # gather_data_c_shape NDPOINTER(dtype=nptype, ndim=2, flags=C), # gather_data POINTER(c_bool), # is_root_pe + NDPOINTER(dtype=np.int32, shape=(2,), flags=C), # gather_data_c_shape POINTER(c_int), # ishift POINTER(c_int), # jshift POINTER(c_bool), # convert_cf_order @@ -64,10 +64,10 @@ def define(lib): cFMS_gather.restype = None cFMS_gather.argtypes = [ POINTER(c_int), # sbufsize - POINTER(c_int), # rbufsize ndpointer(dtype=nptype, ndim=1, flags=C), # sbuf - ndpointer(dtype=nptype, ndim=1, flags=C), # rbuf + NDPOINTER(dtype=nptype, ndim=1, flags=C), # rbuf NDPOINTER(dtype=np.int32, ndim=1, flags=C), # pelist + POINTER(c_int), # rbufsize POINTER(c_int), # npes ] @@ -79,14 +79,13 @@ def define(lib): for nptype, cFMS_gather in gatherdict.items(): cFMS_gather.restype = None cFMS_gather.argtypes = [ - POINTER(c_int), # npes POINTER(c_int), # sbuf_size - POINTER(c_int), # rbuf_size ndpointer(dtype=nptype, ndim=1, flags=C), # sbuf POINTER(c_int), # ssize - ndpointer(dtype=nptype, ndim=1, flags=C), # rbuf - ndpointer(dtype=np.int32, ndim=1, flags=C), # rsize + NDPOINTER(dtype=nptype, ndim=1, flags=C), # rbuf + NDPOINTER(dtype=np.int32, ndim=1, flags=C), # rsize NDPOINTER(dtype=np.int32, ndim=1, flags=C), # pelist + POINTER(c_int), # npes ] # cFMS_get_current_pelist diff --git a/pyfms/py_mpp/domain.py b/pyfms/py_mpp/domain.py index ec093f9..819dcc1 100644 --- a/pyfms/py_mpp/domain.py +++ b/pyfms/py_mpp/domain.py @@ -71,3 +71,28 @@ def update(self, domain_dict: dict): for key in domain_dict: setattr(self, key, domain_dict[key]) return self + + def __repr__(self): + + repr_str = f""" + domain_id: {self.domain_id}\n + ** compute domain ** + (isc, jsc): ({self.isc}, {self.jsc}) + (iec, jec): ({self.iec}, {self.jec}) + (xsize_c, ysize_c): ({self.xsize_c}, {self.ysize_c}) + (xmax_size_c, ymax_size_c): ({self.xmax_size_c}, {self.ymax_size_c}) + (x_is_global_c, y_is_global_c): ({self.x_is_global_c}, {self.y_is_global_c})\n + ** data domain ** + (isd, jsd) = ({self.isd}, {self.jsd}) + (ied, jed) = ({self.ied}, {self.jed}) + (xsize_d, ysize_d): ({self.xsize_d}, {self.ysize_d}) + (xmax_size_d, ymax_size_d): ({self.xmax_size_d}, {self.ymax_size_d}) + (x_is_global_d, y_is_global_d): ({self.x_is_global_d}, {self.y_is_global_d})\n + ** global domain ** + (isg, jsg) = ({self.isg}, {self.jsg}) + (ieg, jeg) = ({self.ieg}, {self.jeg}) + (xsize_g, ysize_g) = ({self.xsize_g}, {self.ysize_g}) + (x_is_global_g, y_is_global_g): ({self.x_is_global_g}, {self.y_is_global_g}) + """ + + return repr_str diff --git a/pyfms/py_mpp/mpp.py b/pyfms/py_mpp/mpp.py index 97ac130..564c999 100644 --- a/pyfms/py_mpp/mpp.py +++ b/pyfms/py_mpp/mpp.py @@ -41,18 +41,20 @@ def gather( sbuf: npt.NDArray, - ssize: int = None, # mpp_gatherv_1d argument - rsize: list[int] = None, # mpp_gatherv_1d argument + rbuf_size: int = None, # for 1d + rbuf_shape: list[int, int] = None, # for 2d domain: dict = None, # mpp_gather_2d argument pelist: list = None, + is_root_pe: bool = None, ishift: int = None, # mpp_gather_pelist_2d argument jshift: int = None, # mpp_gather_pelist_2d argument convert_cf_order: bool = True, -): +) -> npt.NDArray: datatype = sbuf.dtype - is_root_pe = pe() == root_pe() - (dim, do_vector) = (sbuf.ndim, False) if rsize is None else ("v", True) + if is_root_pe is None: + is_root_pe = pe() == root_pe() + dim = sbuf.ndim try: cFMS_gather = _cFMS_gathers[dim][datatype.name] @@ -61,56 +63,33 @@ def gather( arglist = [] - if do_vector: - - rsize = rsize if is_root_pe else [1] - rbuf_size = sum(rsize) - npes_here = len(rsize) - rbuf = np.zeros((rbuf_size), dtype=datatype) - - # The pelist does not matter for non-root pe's - # However, pelist is declared to be the size of rsize in cFMS - # for non root-pelist, len(rsize) = 1 so pelist has to be the len of [1] - if pelist is not None: - pelist = pelist[:npes_here] - - set_c_int(npes_here, arglist) - set_c_int(sbuf.shape[0], arglist) - set_c_int(rbuf_size, arglist) - set_array(sbuf, arglist) - set_c_int(ssize, arglist) - set_array(rbuf, arglist) - set_list(rsize, np.int32, arglist) - set_list(pelist, np.int32, arglist) + if dim == 1: - cFMS_gather(*arglist) if is_root_pe: - return rbuf - return None - - if dim == 1: + if rbuf_size is None: + raise RuntimeError("Must specify size of receiving array") + rbuf = np.zeros(rbuf_size, dtype=datatype) + else: + rbuf_size, rbuf = None, None sbuf_size = sbuf.shape[0] n_pes = None if pelist is None else len(pelist) - rbuf_size = sbuf_size * npes() - rbuf = np.zeros(rbuf_size, dtype=datatype) + set_c_int(sbuf_size, arglist) - set_c_int(rbuf_size, arglist) set_array(sbuf, arglist) set_array(rbuf, arglist) set_list(pelist, np.int32, arglist) + set_c_int(rbuf_size, arglist) set_c_int(n_pes, arglist) elif dim == 2: - nx = domain.xsize_g if is_root_pe else 1 - ny = domain.ysize_g if is_root_pe else 1 if is_root_pe: - rbuf_shape = (nx, ny) if convert_cf_order else (ny, nx) + if rbuf_shape is None: + raise RuntimeError("Must specify shape of receiving array") rbuf = np.zeros(rbuf_shape, dtype=datatype) else: - rbuf_shape = None - rbuf = None + rbuf_shape, rbuf = None, None pelist = get_current_pelist(npes()) if pelist is None else pelist @@ -121,9 +100,9 @@ def gather( set_c_int(len(pelist), arglist) set_list(pelist, np.int32, arglist) set_array(sbuf, arglist) - set_list(rbuf_shape, np.int32, arglist) set_array(rbuf, arglist) set_c_bool(is_root_pe, arglist) + set_list(rbuf_shape, np.int32, arglist) set_c_int(ishift, arglist) set_c_int(jshift, arglist) set_c_bool(convert_cf_order, arglist) @@ -135,6 +114,45 @@ def gather( return None +def gatherv( + sbuf: npt.NDArray, ssize: int, rsize: int = None, pelist: list[int] = None +) -> npt.NDArray: + + datatype = sbuf.dtype + + try: + cFMS_gather = _cFMS_gathers["v"][datatype.name] + except Exception: + error(FATAL, f"mpp.gather {datatype.name} not supported for gatherv") + + is_root_pe = pe() == root_pe() + + sbuf_size = sbuf.shape[0] + + if is_root_pe: + if rsize is None: + raise RuntimeError("must specify receiving sizes for root pe") + rbuf = np.zeros(np.sum(rsize), dtype=datatype) + npes = len(rsize) + else: + rbuf, rsize = None, None + npes = None if pelist is None else len(pelist) + + arglist = [] + set_c_int(sbuf_size, arglist) + set_array(sbuf, arglist) + set_c_int(ssize, arglist) + set_array(rbuf, arglist) + set_list(rsize, np.int32, arglist) + set_list(pelist, np.int32, arglist) + set_c_int(npes, arglist) + + cFMS_gather(*arglist) + if is_root_pe: + return rbuf + return None + + def declare_pelist( pelist: list[int], name: str = None, diff --git a/tests/py_mpp/test_gather.py b/tests/py_mpp/test_gather.py index 6382446..b3ff8cb 100644 --- a/tests/py_mpp/test_gather.py +++ b/tests/py_mpp/test_gather.py @@ -16,6 +16,8 @@ def test_gather_2d(): layout = pyfms.mpp_domains.define_layout(global_indices, pyfms.mpp.npes()) domain = pyfms.mpp_domains.define_domains(global_indices, layout) + is_root_pe = pyfms.mpp.pe() == pyfms.mpp.root_pe() + # data to send global_data = np.array( [[i * 100 + j for j in range(ny)] for i in range(nx)], dtype=np.float64 @@ -26,9 +28,20 @@ def test_gather_2d(): global_data = global_data.T send = send.T + rbuf_shape = None + if is_root_pe: + if convert: + rbuf_shape = [nx, ny] + else: + rbuf_shape = [ny, nx] + pelist = pyfms.mpp.get_current_pelist(pyfms.mpp.npes()) gathered = pyfms.mpp.gather( - send, domain=domain, pelist=pelist, convert_cf_order=convert + send, + rbuf_shape=rbuf_shape, + domain=domain, + pelist=pelist, + convert_cf_order=convert, ) if pyfms.mpp.pe() == pyfms.mpp.root_pe(): @@ -50,11 +63,18 @@ def buffer(ipe): pe = pyfms.mpp.pe() npes = pyfms.mpp.npes() + is_root_pe = pyfms.mpp.pe() == pyfms.mpp.root_pe() send = np.array(buffer(pe), dtype=np.float64) - receive = pyfms.mpp.gather(np.array(send)) - if pe == pyfms.mpp.root_pe(): + if is_root_pe: + rbuf_size = sbuf_size * npes + else: + rbuf_size = None + + receive = pyfms.mpp.gather(np.array(send), rbuf_size=rbuf_size) + + if is_root_pe: answers = [] for ipe in range(npes): answers += buffer(ipe) @@ -71,13 +91,18 @@ def buffer(ipe): pyfms.fms.init() pe = pyfms.mpp.pe() + is_root_pe = pe == pyfms.mpp.root_pe() sbuf = np.array(buffer(pe), dtype=np.float64) - rsize = [ipe + 2 for ipe in range(pyfms.mpp.npes())] - receive = pyfms.mpp.gather(sbuf, ssize=pe + 2, rsize=rsize) + if is_root_pe: + rsize = [ipe + 2 for ipe in range(pyfms.mpp.npes())] + else: + rsize = None + + receive = pyfms.mpp.gatherv(sbuf, ssize=pe + 2, rsize=rsize) - if pe == pyfms.mpp.root_pe(): + if is_root_pe: answers = [] for ipe in range(pyfms.mpp.npes()): answers += buffer(ipe)