diff --git a/driver/pace/driver/driver.py b/driver/pace/driver/driver.py index 7cd22dc5c..b4ad261af 100644 --- a/driver/pace/driver/driver.py +++ b/driver/pace/driver/driver.py @@ -280,7 +280,6 @@ def exit_function(*args, **kwargs): damping_coefficients=self.state.damping_coefficients, config=self.config.dycore_config, phis=self.state.dycore_state.phis, - state=self.state.dycore_state, ) self.dycore.update_state( @@ -313,11 +312,8 @@ def exit_function(*args, **kwargs): namelist=self.config.physics_config, comm=communicator, grid_info=self.state.driver_grid_data, - state=self.state.dycore_state, - quantity_factory=self.quantity_factory, dycore_only=self.config.dycore_only, apply_tendencies=self.config.apply_tendencies, - tendency_state=self.state.tendency_state, ) else: # Make sure those are set to None to raise any issues @@ -439,6 +435,7 @@ def _step_dynamics( ): self.dycore.step_dynamics( state=state, + tracers_dict=state.tracers_as_array(), timer=timer, ) @@ -454,9 +451,9 @@ def _step_physics(self, timestep: float): self.end_of_step_update( dycore_state=self.state.dycore_state, phy_state=self.state.physics_state, - u_dt=self.state.tendency_state.u_dt.storage, - v_dt=self.state.tendency_state.v_dt.storage, - pt_dt=self.state.tendency_state.pt_dt.storage, + u_dt=self.state.tendency_state.u_dt, + v_dt=self.state.tendency_state.v_dt, + pt_dt=self.state.tendency_state.pt_dt, dt=float(timestep), ) diff --git a/dsl/pace/dsl/dace/wrapped_halo_exchange.py b/dsl/pace/dsl/dace/wrapped_halo_exchange.py index ad88fb118..f094483c3 100644 --- a/dsl/pace/dsl/dace/wrapped_halo_exchange.py +++ b/dsl/pace/dsl/dace/wrapped_halo_exchange.py @@ -1,9 +1,9 @@ -import dataclasses -from typing import List, Optional +from typing import List, Optional, Union + +import numpy as np from pace.dsl.dace.orchestration import dace_inhibitor -from pace.util.communicator import CubedSphereCommunicator -from pace.util.halo_updater import HaloUpdater +from pace.util.halo_updater import HaloUpdater, VectorInterfaceHaloUpdater class WrappedHaloUpdater: @@ -17,57 +17,30 @@ class WrappedHaloUpdater: def __init__( self, - updater: HaloUpdater, - state, - qty_x_names: List[str], - qty_y_names: List[str] = None, - comm: Optional[CubedSphereCommunicator] = None, + updater: Union[HaloUpdater, VectorInterfaceHaloUpdater], ) -> None: self._updater = updater - self._state = state - self._qtx_x_names = qty_x_names - self._qtx_y_names = qty_y_names - self._comm = comm @dace_inhibitor - def start(self): - if self._qtx_y_names is None: - if dataclasses.is_dataclass(self._state): - self._updater.start( - [self._state.__getattribute__(x) for x in self._qtx_x_names] - ) - elif isinstance(self._state, dict): - self._updater.start([self._state[x] for x in self._qtx_x_names]) - else: - raise NotImplementedError - else: - if dataclasses.is_dataclass(self._state): - self._updater.start( - [self._state.__getattribute__(x) for x in self._qtx_x_names], - [self._state.__getattribute__(y) for y in self._qtx_y_names], - ) - elif isinstance(self._state, dict): - self._updater.start( - [self._state[x] for x in self._qtx_x_names], - [self._state[y] for y in self._qtx_y_names], - ) - else: - raise NotImplementedError + def start( + self, arrays_x: List[np.ndarray], arrays_y: Optional[List[np.ndarray]] = None + ): + assert isinstance(self._updater, HaloUpdater) + self._updater.start(arrays_x, arrays_y) @dace_inhibitor def wait(self): self._updater.wait() @dace_inhibitor - def update(self): - self.start() + def update( + self, arrays_x: List[np.ndarray], arrays_y: Optional[List[np.ndarray]] = None + ): + self.start(arrays_x, arrays_y) self.wait() @dace_inhibitor - def interface(self): - assert len(self._qtx_x_names) == 1 - assert len(self._qtx_y_names) == 1 - self._comm.synchronize_vector_interfaces( - self._state.__getattribute__(self._qtx_x_names[0]), - self._state.__getattribute__(self._qtx_y_names[0]), - ) + def interface(self, arrays_x: np.ndarray, arrays_y: np.ndarray): + assert isinstance(self._updater, VectorInterfaceHaloUpdater) + request = self._updater.start_synchronize_vector_interfaces(arrays_x, arrays_y) + request.wait() diff --git a/dsl/pace/dsl/gt4py_utils.py b/dsl/pace/dsl/gt4py_utils.py index f031a6fcb..6d98edb7f 100644 --- a/dsl/pace/dsl/gt4py_utils.py +++ b/dsl/pace/dsl/gt4py_utils.py @@ -21,6 +21,12 @@ halo = 3 origin = (halo, halo, 0) +# nq is actually given by ncnst - pnats, where those are given in atmosphere.F90 by: +# ncnst = Atm(mytile)%ncnst +# pnats = Atm(mytile)%flagstruct%pnats +# here we hard-coded it because 8 is the only supported value, refactor this later! +NQ = 8 # state.nq_tot - spec.namelist.dnats + # TODO get from field_table tracer_variables = [ "qvapor", diff --git a/examples/notebooks/functions.py b/examples/notebooks/functions.py index 756cf4d31..7b5e42700 100644 --- a/examples/notebooks/functions.py +++ b/examples/notebooks/functions.py @@ -917,9 +917,7 @@ def run_finite_volume_fluxprep( return flux_prep -def build_tracer_advection( - stencil_configuration: Dict[str, Any], tracers: Dict[str, Quantity] -) -> TracerAdvection: +def build_tracer_advection(stencil_configuration: Dict[str, Any]) -> TracerAdvection: """ Use: tracer_advection = build_tracer_advection(stencil_configuration, tracers) @@ -949,7 +947,6 @@ def build_tracer_advection( fvtp_2d, stencil_configuration["grid_data"], stencil_configuration["communicator"], - tracers, ) return tracer_advection @@ -993,7 +990,7 @@ def prepare_everything_for_advection( timestep, ) - tracer_advection = build_tracer_advection(stencil_configuration, tracers) + tracer_advection = build_tracer_advection(stencil_configuration) tracer_advection_data = { "tracers": tracers, diff --git a/fv3core/examples/standalone/runfile/acoustics.py b/fv3core/examples/standalone/runfile/acoustics.py index 360cfc9b0..a82018890 100755 --- a/fv3core/examples/standalone/runfile/acoustics.py +++ b/fv3core/examples/standalone/runfile/acoustics.py @@ -180,7 +180,6 @@ def driver( dycore_config.acoustic_dynamics, input_data["pfull"], input_data["phis"], - state, ) # warm-up timestep. diff --git a/fv3core/examples/standalone/runfile/dynamics.py b/fv3core/examples/standalone/runfile/dynamics.py index 0ca8bc4bd..1d1353056 100755 --- a/fv3core/examples/standalone/runfile/dynamics.py +++ b/fv3core/examples/standalone/runfile/dynamics.py @@ -263,7 +263,6 @@ def setup_dycore( damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms), config=dycore_config, phis=state.phis, - state=state, ) dycore.update_state( conserve_total_energy=dycore_config.consv_te, @@ -310,7 +309,7 @@ def setup_dycore( # warmup/compilation from the internal timers if rank == 0: print("timestep 1") - dycore.step_dynamics(state, timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer) if profiler is not None: profiler.enable() @@ -324,7 +323,7 @@ def setup_dycore( with timestep_timer.clock("mainloop"): if rank == 0: print(f"timestep {i+2}") - dycore.step_dynamics(state, timer=timestep_timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer=timestep_timer) times_per_step.append(timestep_timer.times) hits_per_step.append(timestep_timer.hits) timestep_timer.reset() diff --git a/fv3core/fv3core/initialization/dycore_state.py b/fv3core/fv3core/initialization/dycore_state.py index 4baf2f90a..b846864ac 100644 --- a/fv3core/fv3core/initialization/dycore_state.py +++ b/fv3core/fv3core/initialization/dycore_state.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field, fields -from typing import Any, Mapping +from typing import Any, Dict, List, Mapping +import numpy as np import xarray as xr import pace.dsl.gt4py_utils as gt_utils @@ -372,5 +373,21 @@ def xr_dataset(self): ) return xr.Dataset(data_vars=data_vars) + @property + def tracers(self) -> List[pace.util.Quantity]: + return [self.__getattribute__(x) for x in DycoreState.tracer_names()] + + def tracers_as_array(self) -> Dict[str, np.ndarray]: + all_tracers = { + name: self.__getattribute__(name).data + for name in DycoreState.tracer_names() + } + all_tracers.pop("qcld") + return all_tracers + + @classmethod + def tracer_names(cls) -> List[str]: + return gt_utils.tracer_variables + def __getitem__(self, item): return getattr(self, item) diff --git a/fv3core/fv3core/stencils/dyn_core.py b/fv3core/fv3core/stencils/dyn_core.py index dd87caa6e..d33f973a7 100644 --- a/fv3core/fv3core/stencils/dyn_core.py +++ b/fv3core/fv3core/stencils/dyn_core.py @@ -245,7 +245,6 @@ def __init__( comm: pace.util.CubedSphereCommunicator, grid_indexing: GridIndexing, backend: str, - state, ): origin = grid_indexing.origin_compute() shape = grid_indexing.max_shape @@ -296,51 +295,32 @@ def __init__( # quantities at runtime paradigm self.q_con__cappa = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec] * 2), - state, - ["q_con", "cappa"], ) self.delp__pt = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec] * 2), - state, - ["delp", "pt"], ) self.u__v = WrappedHaloUpdater( comm.get_vector_halo_updater( [full_size_xyiz_halo_spec], [full_size_xiyz_halo_spec] ), - state, - ["u"], - ["v"], ) self.w = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec]), - state, - ["w"], ) self.gz = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyzi_halo_spec]), - state, - ["gz"], ) self.delp__pt__q_con = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec] * 3), - state, - ["delp", "pt", "q_con"], ) self.zh = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyzi_halo_spec]), - state, - ["zh"], ) self.divgd = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xiyiz_halo_spec]), - state, - ["divgd"], ) self.heat_source = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec]), - state, - ["heat_source"], ) if grid_indexing.domain[0] == grid_indexing.domain[1]: full_3Dfield_2pts_halo_spec = grid_indexing.get_quantity_halo_spec( @@ -352,21 +332,20 @@ def __init__( ) self.pkc = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_3Dfield_2pts_halo_spec]), - state, - ["pkc"], ) else: self.pkc = comm.get_scalar_halo_updater([full_size_xyzi_halo_spec]) self.uc__vc = WrappedHaloUpdater( comm.get_vector_halo_updater( - [full_size_xiyz_halo_spec], [full_size_xyiz_halo_spec] + [full_size_xiyz_halo_spec], + [full_size_xyiz_halo_spec], ), - state, - ["uc"], - ["vc"], ) self.interface_uc__vc = WrappedHaloUpdater( - None, state, ["u"], ["v"], comm=comm + comm.get_vector_interface_halo_updater( + full_size_xyiz_halo_spec, + full_size_xiyz_halo_spec, + ) ) def __init__( @@ -381,7 +360,6 @@ def __init__( config: AcousticDynamicsConfig, pfull: FloatFieldK, phis: FloatFieldIJ, - state, # [DaCe] hack to get around quantity as parameters for halo updates checkpointer: Optional[pace.util.Checkpointer] = None, ): """ @@ -621,7 +599,7 @@ def __init__( # Halo updaters self._halo_updaters = AcousticDynamics._HaloUpdaters( - comm, grid_indexing, stencil_factory.backend, state + comm, grid_indexing, stencil_factory.backend ) def _checkpoint_csw(self, state, tag: str): @@ -704,9 +682,9 @@ def __call__( # m_split = 1. + abs(dt_atmos)/real(k_split*n_split*abs(p_split)) # n_split = nint( real(n0split)/real(k_split*abs(p_split)) * stretch_fac + 0.5 ) # NOTE: In Fortran model the halo update starts happens in fv_dynamics, not here - self._halo_updaters.q_con__cappa.start() - self._halo_updaters.delp__pt.start() - self._halo_updaters.u__v.start() + self._halo_updaters.q_con__cappa.start([state.q_con.data, state.cappa.data]) + self._halo_updaters.delp__pt.start([state.delp.data, state.pt.data]) + self._halo_updaters.u__v.start([state.u.data], [state.v.data]) self._halo_updaters.q_con__cappa.wait() if update_temporaries: @@ -741,14 +719,14 @@ def __call__( if self.config.breed_vortex_inline or (it == n_split - 1): remap_step = True if not self.config.hydrostatic: - self._halo_updaters.w.start() + self._halo_updaters.w.start([state.w.data]) if it == 0: self._set_gz( self._zs, state.delz, state.gz, ) - self._halo_updaters.gz.start() + self._halo_updaters.gz.start([state.gz.data]) if it == 0: self._halo_updaters.delp__pt.wait() @@ -785,7 +763,7 @@ def __call__( self._checkpoint_csw(state, tag="Out") if self.config.nord > 0: - self._halo_updaters.divgd.start() + self._halo_updaters.divgd.start([state.divgd.data]) if not self.config.hydrostatic: if it == 0: self._halo_updaters.gz.wait() @@ -826,7 +804,7 @@ def __call__( state.gz, dt2, ) - self._halo_updaters.uc__vc.start() + self._halo_updaters.uc__vc.start([state.uc.data], [state.vc.data]) if self.config.nord > 0: self._halo_updaters.divgd.wait() self._halo_updaters.uc__vc.wait() @@ -863,7 +841,9 @@ def __call__( # note that uc and vc are not needed at all past this point. # they will be re-computed from scratch on the next acoustic timestep. - self._halo_updaters.delp__pt__q_con.update() + self._halo_updaters.delp__pt__q_con.update( + [state.delp.data, state.pt.data, state.q_con.data] + ) # Not used unless we implement other betas and alternatives to nh_p_grad # if self.namelist.d_ext > 0: @@ -901,8 +881,8 @@ def __call__( state.w, ) - self._halo_updaters.zh.start() - self._halo_updaters.pkc.start() + self._halo_updaters.zh.start([state.zh.data]) + self._halo_updaters.pkc.start([state.pkc.data]) if remap_step: self._edge_pe_stencil(state.pe, state.delp, self._ptop) if self.config.use_logp: @@ -949,13 +929,15 @@ def __call__( # [DaCe] this should be a reuse of # self._halo_updaters.u__v but it creates # parameter generation issues, and therefore has been duplicated - self._halo_updaters.u__v.start() + self._halo_updaters.u__v.start([state.u.data], [state.v.data]) else: if self.config.grid_type < 4: - self._halo_updaters.interface_uc__vc.interface() + self._halo_updaters.interface_uc__vc.interface( + state.u.data, state.v.data + ) if self._do_del2cubed: - self._halo_updaters.heat_source.update() + self._halo_updaters.heat_source.update([state.heat_source.data]) # TODO: move dependence on da_min into init of hyperdiffusion class cd = constants.CNST_0P20 * self._da_min self._hyperdiffusion(state.heat_source, cd) diff --git a/fv3core/fv3core/stencils/fillz.py b/fv3core/fv3core/stencils/fillz.py index c2d64ad98..851aeab42 100644 --- a/fv3core/fv3core/stencils/fillz.py +++ b/fv3core/fv3core/stencils/fillz.py @@ -1,13 +1,13 @@ import typing from typing import Dict +import numpy as np from gt4py.gtscript import BACKWARD, FORWARD, PARALLEL, computation, interval import pace.dsl.gt4py_utils as utils from pace.dsl.dace import orchestrate from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField, FloatFieldIJ, IntFieldIJ -from pace.util import Quantity @typing.no_type_check @@ -125,7 +125,6 @@ def __init__( jm: int, km: int, nq: int, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -155,21 +154,17 @@ def make_storage(*args, **kwargs): self._sum0 = make_storage(shape_ij, origin=(0, 0)) self._sum1 = make_storage(shape_ij, origin=(0, 0)) - self._filtered_tracer_dict = { - name: tracers[name] for name in utils.tracer_variables[0 : self._nq] - } - def __call__( self, dp2: FloatField, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], ): """ Args: dp2 (in): pressure thickness of atmospheric layer tracers (inout): tracers to fix negative masses in """ - for tracer_name in self._filtered_tracer_dict.keys(): + for tracer_name in utils.tracer_variables[0 : self._nq]: self._fix_tracer_stencil( tracers[tracer_name], dp2, diff --git a/fv3core/fv3core/stencils/fv_dynamics.py b/fv3core/fv3core/stencils/fv_dynamics.py index 7fc89aa5f..45a5ea9a6 100644 --- a/fv3core/fv3core/stencils/fv_dynamics.py +++ b/fv3core/fv3core/stencils/fv_dynamics.py @@ -1,5 +1,6 @@ from typing import Dict, Optional +import numpy as np from dace.frontend.python.interface import nounroll as dace_no_unroll from gt4py.gtscript import PARALLEL, computation, interval, log @@ -23,14 +24,6 @@ from pace.util import Timer from pace.util.grid import DampingCoefficients, GridData from pace.util.mpi import MPI -from pace.util.quantity import Quantity - - -# nq is actually given by ncnst - pnats, where those are given in atmosphere.F90 by: -# ncnst = Atm(mytile)%ncnst -# pnats = Atm(mytile)%flagstruct%pnats -# here we hard-coded it because 8 is the only supported value, refactor this later! -NQ = 8 # state.nq_tot - spec.namelist.dnats def pt_adjust(pkz: FloatField, dp1: FloatField, q_con: FloatField, pt: FloatField): @@ -105,7 +98,6 @@ def __init__( damping_coefficients: DampingCoefficients, config: DynamicalCoreConfig, phis: pace.util.Quantity, - state: DycoreState, checkpointer: Optional[pace.util.Checkpointer] = None, ): """ @@ -210,19 +202,14 @@ def __init__( hord=config.hord_tr, ) - self.tracers = {} - for name in utils.tracer_variables[0:NQ]: - self.tracers[name] = state.__dict__[name] - self.tracer_storages = { - name: quantity.storage for name, quantity in self.tracers.items() - } - self._temporaries = fvdyn_temporaries(quantity_factory) - state.__dict__.update(self._temporaries) # Build advection stencils self.tracer_advection = tracer_2d_1l.TracerAdvection( - stencil_factory, tracer_transport, self.grid_data, comm, self.tracers + stencil_factory, + tracer_transport, + self.grid_data, + comm, ) self._ak = grid_data.ak self._bk = grid_data.bk @@ -274,7 +261,6 @@ def __init__( self.config.acoustic_dynamics, self._pfull, self._phis, - state, checkpointer=checkpointer, ) self._hyperdiffusion = HyperdiffusionDamping( @@ -284,7 +270,7 @@ def __init__( self.config.nf_omega, ) self._cubed_to_latlon = CubedToLatLon( - state, stencil_factory, grid_data, config.c2l_ord, comm + stencil_factory, grid_data, config.c2l_ord, comm ) self._temporaries = fvdyn_temporaries(quantity_factory) @@ -293,7 +279,7 @@ def __init__( # if self._temporaries were a dataclass we can remove this for name, value in self._temporaries.items(): setattr(self, f"_tmp_{name}", value) - if not (not self.config.inline_q and NQ != 0): + if not (not self.config.inline_q and utils.NQ != 0): raise NotImplementedError("tracer_2d not implemented, turn on z_tracer") self._adjust_tracer_mixing_ratio = AdjustNegativeTracerMixingRatio( stencil_factory, @@ -305,9 +291,8 @@ def __init__( stencil_factory, config.remapping, grid_data.area_64, - NQ, + utils.NQ, self._pfull, - tracers=self.tracers, ) full_xyz_spec = grid_indexing.get_quantity_halo_spec( @@ -318,7 +303,7 @@ def __init__( backend=stencil_factory.backend, ) self._omega_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([full_xyz_spec]), state, ["omga"], comm=comm + comm.get_scalar_halo_updater([full_xyz_spec]) ) def _checkpoint_fvdynamics(self, state: DycoreState, tag: str): @@ -369,7 +354,12 @@ def update_state( state.__dict__.update(self._temporaries) state.__dict__.update(self.acoustic_dynamics._temporaries) - def step_dynamics(self, state: DycoreState, timer: Timer = pace.util.NullTimer()): + def step_dynamics( + self, + state: DycoreState, + tracers_dict: Dict[str, np.ndarray], + timer: Timer = pace.util.NullTimer(), + ): """ Step the model state forward by one timestep. @@ -378,7 +368,7 @@ def step_dynamics(self, state: DycoreState, timer: Timer = pace.util.NullTimer() state: model prognostic state and inputs """ self._checkpoint_fvdynamics(state=state, tag="In") - self._compute(state, timer) + self._compute(state, tracers_dict, timer) self._checkpoint_fvdynamics(state=state, tag="Out") def compute_preamble(self, state, is_root_rank: bool): @@ -430,17 +420,22 @@ def compute_preamble(self, state, is_root_rank: bool): def __call__(self, *args, **kwargs): return self.step_dynamics(*args, **kwargs) - def _compute(self, state, timer: pace.util.Timer): + def _compute( + self, + state: DycoreState, + tracers_dict: Dict[str, np.ndarray], + timer: pace.util.Timer, + ): last_step = False self.compute_preamble( state, is_root_rank=self.comm_rank == 0, ) - for k_split in dace_no_unroll(range(state.k_split)): + for k_split in dace_no_unroll(range(state.k_split)): # type: ignore n_map = k_split + 1 - last_step = k_split == state.k_split - 1 - self._dyn(state=state, tracers=self.tracers, n_map=n_map, timer=timer) + last_step = k_split == state.k_split - 1 # type: ignore + self._dyn(state=state, tracers=tracers_dict, n_map=n_map, timer=timer) if self.grid_indexing.domain[2] > 4: # nq is actually given by ncnst - pnats, @@ -458,7 +453,7 @@ def _compute(self, state, timer: pace.util.Timer): log_on_rank_0("Remapping") with timer.clock("Remapping"): self._lagrangian_to_eulerian_obj( - self.tracer_storages, + tracers_dict, state.pt, state.delp, state.delz, @@ -468,27 +463,27 @@ def _compute(self, state, timer: pace.util.Timer): state.w, state.ua, state.va, - state.cappa, + state.cappa, # type: ignore state.q_con, state.qcld, state.pkz, state.pk, state.pe, state.phis, - state.te0_2d, + state.te0_2d, # type: ignore state.ps, - state.wsd, + state.wsd, # type: ignore state.omga, self._ak, self._bk, self._pfull, - state.dp1, + state.dp1, # type: ignore self._ptop, constants.KAPPA, constants.ZVIR, last_step, - state.consv_te, - state.bdt / state.k_split, + state.consv_te, # type: ignore + state.bdt / state.k_split, # type: ignore state.bdt, state.do_adiabatic_init, ) @@ -506,7 +501,7 @@ def _compute(self, state, timer: pace.util.Timer): def _dyn( self, state, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], n_map, timer: pace.util.Timer, ): @@ -552,9 +547,8 @@ def post_remap( state.omga, ) if self.config.nf_omega > 0: - if __debug__: - log_on_rank_0("Del2Cubed") - self._omega_halo_updater.update() + log_on_rank_0("Del2Cubed") + self._omega_halo_updater.update([state.omga.data]) self._hyperdiffusion(state.omga, 0.18 * da_min) def wrapup( diff --git a/fv3core/fv3core/stencils/mapn_tracer.py b/fv3core/fv3core/stencils/mapn_tracer.py index 880633c7b..fa86a1d4a 100644 --- a/fv3core/fv3core/stencils/mapn_tracer.py +++ b/fv3core/fv3core/stencils/mapn_tracer.py @@ -1,12 +1,13 @@ from typing import Dict +import numpy as np + import pace.dsl.gt4py_utils as utils from fv3core.stencils.fillz import FillNegativeTracerValues from fv3core.stencils.map_single import MapSingle from pace.dsl.dace.orchestration import orchestrate from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField -from pace.util import Quantity class MapNTracer: @@ -24,7 +25,6 @@ def __init__( j1: int, j2: int, fill: bool, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -62,7 +62,6 @@ def __init__( self._list_of_remap_objects[0].j_extent, self._nk, self._nq, - tracers, ) else: self._fill_negative_tracers = False @@ -72,7 +71,7 @@ def __call__( pe1: FloatField, pe2: FloatField, dp2: FloatField, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], ): """ Remaps the tracer species onto the Eulerian grid diff --git a/fv3core/fv3core/stencils/remapping.py b/fv3core/fv3core/stencils/remapping.py index c081f1a56..1ab349a9b 100644 --- a/fv3core/fv3core/stencils/remapping.py +++ b/fv3core/fv3core/stencils/remapping.py @@ -1,5 +1,6 @@ from typing import Dict +import numpy as np from gt4py.gtscript import ( __INLINED, BACKWARD, @@ -24,7 +25,6 @@ from pace.dsl.dace.orchestration import orchestrate from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField, FloatFieldIJ, FloatFieldK -from pace.util import Quantity # TODO: Should this be set here or in global_constants? @@ -285,7 +285,6 @@ def __init__( area_64, nq, pfull, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -372,7 +371,6 @@ def __init__( grid_indexing.jsc, grid_indexing.jec, fill=config.fill, - tracers=tracers, ) self._map_single_w = MapSingle( @@ -486,7 +484,7 @@ def __init__( def __call__( self, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], pt: FloatField, delp: FloatField, delz: FloatField, diff --git a/fv3core/fv3core/stencils/tracer_2d_1l.py b/fv3core/fv3core/stencils/tracer_2d_1l.py index b86ce1db8..a3fde159a 100644 --- a/fv3core/fv3core/stencils/tracer_2d_1l.py +++ b/fv3core/fv3core/stencils/tracer_2d_1l.py @@ -2,6 +2,7 @@ from typing import Dict import gt4py.gtscript as gtscript +import numpy as np from gt4py.gtscript import PARALLEL, computation, horizontal, interval, region import pace.dsl.gt4py_utils as utils @@ -11,7 +12,6 @@ from pace.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField, FloatFieldIJ -from pace.util import Quantity @gtscript.function @@ -182,7 +182,6 @@ def __init__( transport: FiniteVolumeTransport, grid_data, comm: pace.util.CubedSphereCommunicator, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -191,7 +190,6 @@ def __init__( ) grid_indexing = stencil_factory.grid_indexing self.grid_indexing = grid_indexing # needed for selective validation - self._tracer_count = len(tracers) self.grid_data = grid_data shape = grid_indexing.domain_full(add=(1, 1, 1)) origin = grid_indexing.origin_compute() @@ -271,12 +269,10 @@ def make_storage(): backend=stencil_factory.backend, ) self._tracers_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([tracer_halo_spec] * self._tracer_count), - tracers, - [t for t in tracers.keys()], + comm.get_scalar_halo_updater([tracer_halo_spec] * utils.NQ), ) - def __call__(self, tracers: Dict[str, Quantity], dp1, mfxd, mfyd, cxd, cyd, mdt): + def __call__(self, tracers: Dict[str, np.ndarray], dp1, mfxd, mfyd, cxd, cyd, mdt): """ Args: tracers (inout): @@ -352,7 +348,7 @@ def __call__(self, tracers: Dict[str, Quantity], dp1, mfxd, mfyd, cxd, cyd, mdt) n_split, ) - self._tracers_halo_updater.update() + self._tracers_halo_updater.update(tracers.values()) dp2 = self._tmp_dp @@ -386,6 +382,6 @@ def __call__(self, tracers: Dict[str, Quantity], dp1, mfxd, mfyd, cxd, cyd, mdt) dp2, ) if not last_call: - self._tracers_halo_updater.update() + self._tracers_halo_updater.update(tracers.values()) # use variable assignment to avoid a data copy self._swap_dp(dp1, dp2) diff --git a/fv3core/fv3core/testing/translate_dyncore.py b/fv3core/fv3core/testing/translate_dyncore.py index 27f18b18c..374dc4389 100644 --- a/fv3core/fv3core/testing/translate_dyncore.py +++ b/fv3core/fv3core/testing/translate_dyncore.py @@ -171,7 +171,6 @@ def compute_parallel(self, inputs, communicator): config=DynamicalCoreConfig.from_namelist(self.namelist).acoustic_dynamics, pfull=inputs["pfull"], phis=inputs["phis"], - state=state, ) state.__dict__.update(acoustic_dynamics._temporaries) acoustic_dynamics(state, n_map=state.n_map, update_temporaries=False) diff --git a/fv3core/fv3core/testing/translate_fvdynamics.py b/fv3core/fv3core/testing/translate_fvdynamics.py index 97fa1ece0..42e1facc9 100644 --- a/fv3core/fv3core/testing/translate_fvdynamics.py +++ b/fv3core/fv3core/testing/translate_fvdynamics.py @@ -14,7 +14,7 @@ from pace.util.grid import GridData -ADVECTED_TRACER_NAMES = utils.tracer_variables[: fv_dynamics.NQ] +ADVECTED_TRACER_NAMES = utils.tracer_variables[: utils.NQ] class TranslateDycoreFortranData2Py(TranslateFortranData2Py): @@ -335,7 +335,6 @@ def compute_parallel(self, inputs, communicator): damping_coefficients=self.grid.damping_coefficients, config=DynamicalCoreConfig.from_namelist(self.namelist), phis=state.phis, - state=state, ) self.dycore.update_state( self.namelist.consv_te, @@ -344,7 +343,9 @@ def compute_parallel(self, inputs, communicator): self.namelist.n_split, state, ) - self.dycore.step_dynamics(state, pace.util.NullTimer()) + self.dycore.step_dynamics( + state, state.tracers_as_array(), pace.util.NullTimer() + ) outputs = self.outputs_from_state(state) for name, value in outputs.items(): outputs[name] = self.subset_output(name, value) diff --git a/fv3core/tests/mpi/test_doubly_periodic.py b/fv3core/tests/mpi/test_doubly_periodic.py index 3949e3611..98a1a0a0b 100644 --- a/fv3core/tests/mpi/test_doubly_periodic.py +++ b/fv3core/tests/mpi/test_doubly_periodic.py @@ -109,18 +109,22 @@ def setup_dycore() -> Tuple[fv3core.DynamicalCore, List[Any]]: damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms), config=config, phis=state.phis, - state=state, ) do_adiabatic_init = False # TODO compute from namelist bdt = config.dt_atmos - args = [ - state, + dycore.update_state( config.consv_te, do_adiabatic_init, bdt, config.n_split, + state, + ) + + args = [ + state, + state.tracers_as_array(), ] return dycore, args diff --git a/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py b/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py index c25167255..baf960e8a 100644 --- a/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py +++ b/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py @@ -47,7 +47,6 @@ def compute_parallel(self, inputs, communicator): state_dict = {"u": u_quantity, "v": v_quantity} self._cubed_to_latlon = CubedToLatLon( - state=state_dict, stencil_factory=self.stencil_factory, grid_data=self.grid.grid_data, order=self.namelist.c2l_ord, diff --git a/fv3core/tests/savepoint/translate/translate_fillz.py b/fv3core/tests/savepoint/translate/translate_fillz.py index 3093ea6a4..590f0d263 100644 --- a/fv3core/tests/savepoint/translate/translate_fillz.py +++ b/fv3core/tests/savepoint/translate/translate_fillz.py @@ -74,7 +74,6 @@ def compute(self, inputs): inputs.pop("jm"), inputs.pop("km"), inputs.pop("nq"), - inputs["tracers"], ) run_fillz(**inputs) ds = self.grid.default_domain_dict() diff --git a/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py b/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py index 27d885d4e..5f51257a2 100644 --- a/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py +++ b/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py @@ -66,7 +66,6 @@ def compute(self, inputs): inputs.pop("j1"), inputs.pop("j2"), fill=self.namelist.fill, - tracers=inputs["tracers"], ) self.compute_func(**inputs) return self.slice_output(inputs) diff --git a/fv3core/tests/savepoint/translate/translate_remapping.py b/fv3core/tests/savepoint/translate/translate_remapping.py index 01d3bce26..7a70655f7 100644 --- a/fv3core/tests/savepoint/translate/translate_remapping.py +++ b/fv3core/tests/savepoint/translate/translate_remapping.py @@ -130,7 +130,6 @@ def compute_from_storage(self, inputs): self.grid.area_64, inputs["nq"], inputs["pfull"], - inputs["tracers"], ) inputs.pop("nq") l_to_e_obj(**inputs) diff --git a/fv3core/tests/savepoint/translate/translate_tracer2d1l.py b/fv3core/tests/savepoint/translate/translate_tracer2d1l.py index b7856974f..9fc9d4bb5 100644 --- a/fv3core/tests/savepoint/translate/translate_tracer2d1l.py +++ b/fv3core/tests/savepoint/translate/translate_tracer2d1l.py @@ -51,8 +51,9 @@ def compute_parallel(self, inputs, communicator): self._base.make_storage_data_input_vars(inputs) all_tracers = inputs["tracers"] + tracer_count = int(inputs.pop("nq")) inputs["tracers"] = self.get_advected_tracer_dict( - inputs["tracers"], int(inputs.pop("nq")) + inputs["tracers"], tracer_count ) transport = fv3core.stencils.fvtp2d.FiniteVolumeTransport( stencil_factory=self.stencil_factory, @@ -67,7 +68,6 @@ def compute_parallel(self, inputs, communicator): transport, self.grid.grid_data, communicator, - inputs["tracers"], ) self.tracer_advection(**inputs) inputs[ @@ -89,7 +89,7 @@ def get_advected_tracer_dict(self, all_tracers, nq): units=properties["units"], ) tracer_names = utils.tracer_variables[:nq] - return {name: all_tracers[name + "_quantity"] for name in tracer_names} + return {name: all_tracers[name + "_quantity"].data for name in tracer_names} def compute_sequential(self, a, b): pytest.skip( diff --git a/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py b/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py index bc3b95f7d..e9e48a9f2 100644 --- a/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py +++ b/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py @@ -174,9 +174,6 @@ def compute_parallel(self, inputs, communicator): self.namelist, communicator, self.grid.driver_grid_data, - state, - tendencies["u_dt"], - tendencies["v_dt"], ) dims_u = [pace.util.X_DIM, pace.util.Y_INTERFACE_DIM, pace.util.Z_DIM] u_quantity = self.grid.make_quantity( diff --git a/pace-util/pace/util/communicator.py b/pace-util/pace/util/communicator.py index 709468a4e..3a8a96c8c 100644 --- a/pace-util/pace/util/communicator.py +++ b/pace-util/pace/util/communicator.py @@ -8,7 +8,7 @@ from ._timing import NullTimer, Timer from .boundary import Boundary from .buffer import array_buffer, recv_buffer, send_buffer -from .halo_data_transformer import QuantityHaloSpec +from .halo_quantity_specification import QuantityHaloSpec from .halo_updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner from .quantity import Quantity, QuantityMetadata @@ -325,8 +325,10 @@ def start_halo_update( """ if isinstance(quantity, Quantity): quantities = [quantity] + arrays = [quantity.data] else: quantities = quantity + arrays = [qty.data for qty in quantity] specifications = [] for quantity in quantities: @@ -345,7 +347,7 @@ def start_halo_update( halo_updater = self.get_scalar_halo_updater(specifications) halo_updater.force_finalize_on_wait() - halo_updater.start(quantities) + halo_updater.start(arrays) return halo_updater def vector_halo_update( @@ -397,12 +399,16 @@ def start_vector_halo_update( """ if isinstance(x_quantity, Quantity): x_quantities = [x_quantity] + x_arrays = [x_quantity.data] else: x_quantities = x_quantity + x_arrays = [qty.data for qty in x_quantity] if isinstance(y_quantity, Quantity): y_quantities = [y_quantity] + y_arrays = [y_quantity.data] else: y_quantities = y_quantity + y_arrays = [qty.data for qty in y_quantity] x_specifications = [] y_specifications = [] @@ -434,7 +440,7 @@ def start_vector_halo_update( halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) halo_updater.force_finalize_on_wait() - halo_updater.start(x_quantities, y_quantities) + halo_updater.start(x_arrays, y_arrays) return halo_updater def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): @@ -479,14 +485,20 @@ def start_synchronize_vector_interfaces( """ halo_updater = VectorInterfaceHaloUpdater( comm=self.comm, + qty_x_spec=QuantityHaloSpec.from_quantity(x_quantity, -1), + qty_y_spec=QuantityHaloSpec.from_quantity(y_quantity, -1), boundaries=self.boundaries, force_cpu=self._force_cpu, timer=self.timer, ) - req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) + req = halo_updater.start_synchronize_vector_interfaces( + x_quantity.data, y_quantity.data + ) return req - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + def get_scalar_halo_updater( + self, specifications: List[QuantityHaloSpec] + ) -> HaloUpdater: if len(specifications) == 0: raise RuntimeError("Cannot create updater with specifications list") if specifications[0].n_points == 0: @@ -504,7 +516,7 @@ def get_vector_halo_updater( self, specifications_x: List[QuantityHaloSpec], specifications_y: List[QuantityHaloSpec], - ): + ) -> HaloUpdater: if len(specifications_x) == 0 and len(specifications_y) == 0: raise RuntimeError("Cannot create updater with empty specifications list") if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: @@ -519,6 +531,20 @@ def get_vector_halo_updater( self.timer, ) + def get_vector_interface_halo_updater( + self, + specification_x: QuantityHaloSpec, + specification_y: QuantityHaloSpec, + ) -> VectorInterfaceHaloUpdater: + return VectorInterfaceHaloUpdater( + comm=self.comm, + qty_x_spec=specification_x, + qty_y_spec=specification_y, + boundaries=self.boundaries, + force_cpu=self._force_cpu, + timer=self.timer, + ) + def _get_halo_tag(self) -> int: self._last_halo_tag += 1 return self._last_halo_tag diff --git a/pace-util/pace/util/halo_data_transformer.py b/pace-util/pace/util/halo_data_transformer.py index 794714c86..c55d6e766 100644 --- a/pace-util/pace/util/halo_data_transformer.py +++ b/pace-util/pace/util/halo_data_transformer.py @@ -1,7 +1,7 @@ import abc from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple from uuid import UUID, uuid1 import numpy as np @@ -14,27 +14,12 @@ unpack_scalar_f64_kernel, unpack_vector_f64_kernel, ) -from .quantity import Quantity +from .halo_quantity_specification import QuantityHaloSpec from .rotate import rotate_scalar_data, rotate_vector_data from .types import NumpyModule from .utils import device_synchronize -@dataclass -class QuantityHaloSpec: - """Describe the memory to be exchanged, including size of the halo.""" - - n_points: int - strides: Tuple[int] - itemsize: int - shape: Tuple[int] - origin: Tuple[int, ...] - extent: Tuple[int, ...] - dims: Tuple[str, ...] - numpy_module: NumpyModule - dtype: Any - - # ------------------------------------------------------------------------ # Simple pool of streams to lower the driver pressure # Use _pop/_push_stream to manipulate the pool @@ -274,7 +259,7 @@ def get( ) raise NotImplementedError( - f"Quantity module {np_module} has no HaloDataTransformer implemented" + f"Numpy-like module {np_module} has no HaloDataTransformer implemented" ) def get_unpack_buffer(self) -> Buffer: @@ -325,41 +310,41 @@ def ready(self) -> bool: @abc.abstractmethod def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): - """Pack all given quantities into a single send Buffer. + """Pack all given arrays into a single send Buffer. Does not guarantee the buffer returned by `get_unpack_buffer` has received data, doing so requires calling `synchronize`. Reaching for the buffer via get_pack_buffer() will call synchronize(). Args: - quantities_x: scalar or vector x-component quantities to pack, + arrays_x: scalar or vector x-component data to pack, if one is vector they must all be vector - quantities_y: if quantities are vector, the y-component - quantities. + arrays_y: if data to exchange are vectors, the y-component + data. """ pass @abc.abstractmethod def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): - """Unpack the buffer into destination quantities. + """Unpack the buffer into destination arrays. Does not guarantee the buffer returned by `get_unpack_buffer` has received data, doing so requires calling `synchronize`. Reaching for the buffer via get_unpack_buffer() will call synchronize(). Args: - quantities_x: scalar or vector x-component quantities to be unpacked into, + arrays_x: scalar or vector x-component data to pack, if one is vector they must all be vector - quantities_y: if quantities are vector, the y-component - quantities. + arrays_y: if data to exchange are vectors, the y-component + data. """ pass @@ -386,41 +371,41 @@ def synchronize(self): def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._pack_scalar(quantities_x) + self._pack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._pack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._pack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} pack") assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened - def _pack_scalar(self, quantities: List[Quantity]): + def _pack_scalar(self, arrays: List[np.ndarray]): if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)}" + f"Arrays count ({len(arrays)}" f" is different that edges count {len(self._infos_x)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity, info_x in zip(quantities, self._infos_x): + for array, info_x in zip(arrays, self._infos_x): data_size = _slices_size(info_x.pack_slices) # sending data across the boundary will rotate the data # n_clockwise_rotations times, due to the difference in axis orientation.\ # Thus we rotate that number of times counterclockwise before sending, # to get the right final orientation source_view = rotate_scalar_data( - quantity.data[info_x.pack_slices], - quantity.dims, - quantity.np, + array[info_x.pack_slices], + info_x.specification.dims, + info_x.specification.numpy_module, -info_x.pack_clockwise_rotation, ) self._pack_buffer.assign_from( @@ -429,38 +414,38 @@ def _pack_scalar(self, quantities: List[Quantity]): ) offset += data_size - def _pack_vector(self, quantities_x: List[Quantity], quantities_y: List[Quantity]): + def _pack_vector(self, arrays_x: List[np.ndarray], arrays_y: List[np.ndarray]): if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)})" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)})" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened - assert len(quantities_y) == len(quantities_x) + assert len(arrays_y) == len(arrays_x) assert len(self._infos_x) == len(self._infos_y) offset = 0 for ( - quantity_x, - quantity_y, + array_x, + array_y, info_x, info_y, - ) in zip(quantities_x, quantities_y, self._infos_x, self._infos_y): + ) in zip(arrays_x, arrays_y, self._infos_x, self._infos_y): # sending data across the boundary will rotate the data # n_clockwise_rotations times, due to the difference in axis orientation # Thus we rotate that number of times counterclockwise before sending, # to get the right final orientation x_view, y_view = rotate_vector_data( - quantity_x.data[info_x.pack_slices], - quantity_y.data[info_y.pack_slices], + array_x[info_x.pack_slices], + array_y[info_y.pack_slices], -info_x.pack_clockwise_rotation, - quantity_x.dims, - quantity_x.np, + info_x.specification.dims, + info_x.specification.numpy_module, ) # Pack X/Y data slices in the buffer @@ -477,74 +462,72 @@ def _pack_vector(self, quantities_x: List[Quantity], quantities_y: List[Quantity def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._unpack_scalar(quantities_x) + self._unpack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._unpack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._unpack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} unpack") assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened - def _unpack_scalar(self, quantities: List[Quantity]): + def _unpack_scalar(self, arrays: List[np.ndarray]): if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)}" + f"Arrays count ({len(arrays)}" f" is different that specifications count {len(self._infos_x)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity, info_x in zip(quantities, self._infos_x): - quantity_view = quantity.data[info_x.unpack_slices] + for array, info_x in zip(arrays, self._infos_x): + array_view = array[info_x.unpack_slices] data_size = _slices_size(info_x.unpack_slices) self._unpack_buffer.assign_to( - quantity_view, + array_view, buffer_slice=np.index_exp[offset : offset + data_size], - buffer_reshape=quantity_view.shape, + buffer_reshape=array_view.shape, ) offset += data_size - def _unpack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] - ): + def _unpack_vector(self, arrays_x: List[np.ndarray], arrays_y: List[np.ndarray]): if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)})" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)})" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)})" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity_x, quantity_y, info_x, info_y in zip( - quantities_x, quantities_y, self._infos_x, self._infos_y + for array_x, array_y, info_x, info_y in zip( + arrays_x, arrays_y, self._infos_x, self._infos_y ): - quantity_view = quantity_x.data[info_x.unpack_slices] + array_view = array_x[info_x.unpack_slices] data_size = _slices_size(info_x.unpack_slices) self._unpack_buffer.assign_to( - quantity_view, + array_view, buffer_slice=np.index_exp[offset : offset + data_size], - buffer_reshape=quantity_view.shape, + buffer_reshape=array_view.shape, ) offset += data_size - quantity_view = quantity_y.data[info_y.unpack_slices] + array_view = array_y[info_y.unpack_slices] data_size = _slices_size(info_y.unpack_slices) self._unpack_buffer.assign_to( - quantity_view, + array_view, buffer_slice=np.index_exp[offset : offset + data_size], - buffer_reshape=quantity_view.shape, + buffer_reshape=array_view.shape, ) offset += data_size @@ -553,7 +536,7 @@ class HaloDataTransformerGPU(HaloDataTransformer): """Pack/unpack data in a single buffer using CUDA Kernels. In order to efficiently pack/unpack on the GPU to a single GPU buffer - we use streamed (e.g. async) kernels per quantity per edge to send. The + we use streamed (e.g. async) kernels per array per edge to send. The kernels are store in `cuda_kernels.py`, they both follow the same simple pattern by reading the indices to the device memory of the data to pack/unpack. `_flatten_indices` is the routine that take the layout of the memory and @@ -683,47 +666,47 @@ def _get_stream(self, stream) -> "cp.cuda.stream": def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List["cp.ndarray"], + arrays_y: Optional[List["cp.ndarray"]] = None, ): - """Pack the quantities into a single buffer via streamed cuda kernels + """Pack the arrays into a single buffer via streamed cuda kernels Writes into self._pack_buffer using self._x_infos and self._y_infos - to read the offsets and sizes per quantity. + to read the offsets and sizes per array. Args: - quantities_x: list of quantities to pack. Must fit the specifications given + arrays_x: list of arrays to pack. Must fit the specifications given at init time. - quantities_y: Same as above but optional, used only for vector transfer. + arrays_y: Same as above but optional, used only for vector transfer. """ # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._opt_pack_scalar(quantities_x) + self._opt_pack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._opt_pack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._opt_pack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} pack") - def _opt_pack_scalar(self, quantities: List[Quantity]): + def _opt_pack_scalar(self, arrays: List["cp.ndarray"]): """Specialized packing for scalar. See async_pack docs for usage.""" if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)}" + f"Quantities count ({len(arrays)}" f" is different that specifications count {len(self._infos_x)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened offset = 0 - for info_x, quantity in zip(self._infos_x, quantities): + for info_x, array in zip(self._infos_x, arrays): cu_kernel_args = self._cu_kernel_args[info_x._id] # Use private stream with self._get_stream(cu_kernel_args.stream): - if quantity.metadata.dtype != np.float64: + if info_x.specification.dtype != np.float64: raise RuntimeError(f"Kernel requires f64 given {np.float64}") # Launch kernel @@ -736,7 +719,7 @@ def _opt_pack_scalar(self, quantities: List[Quantity]): (grid_x,), (blocks,), ( - quantity.data[:], # source_array + array[:], # source_array cu_kernel_args.x_send_indices, # indices info_x.pack_buffer_size, # nIndex offset, @@ -748,29 +731,29 @@ def _opt_pack_scalar(self, quantities: List[Quantity]): offset += info_x.pack_buffer_size def _opt_pack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] + self, arrays_x: List["cp.ndarray"], arrays_y: List["cp.ndarray"] ): """Specialized packing for vectors. See async_pack docs for usage.""" if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)}" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)}" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened assert len(self._infos_x) == len(self._infos_y) - assert len(quantities_x) == len(quantities_y) + assert len(arrays_x) == len(arrays_y) offset = 0 for ( - quantity_x, - quantity_y, + array_x, + array_y, info_x, info_y, - ) in zip(quantities_x, quantities_y, self._infos_x, self._infos_y): + ) in zip(arrays_x, arrays_y, self._infos_x, self._infos_y): cu_kernel_args = self._cu_kernel_args[info_x._id] # Use private stream @@ -779,7 +762,7 @@ def _opt_pack_vector( # Buffer sizes transformer_size = info_x.pack_buffer_size + info_y.pack_buffer_size - if quantity_x.metadata.dtype != np.float64: + if info_x.specification.dtype != np.float64: raise RuntimeError(f"Kernel requires f64 given {np.float64}") # Launch kernel @@ -792,8 +775,8 @@ def _opt_pack_vector( (grid_x,), (blocks,), ( - quantity_x.data[:], # source_array_x - quantity_y.data[:], # source_array_y + array_x[:], # source_array_x + array_y[:], # source_array_y cu_kernel_args.x_send_indices, # indices_x cu_kernel_args.y_send_indices, # indices_y info_x.pack_buffer_size, # nIndex_x @@ -809,40 +792,40 @@ def _opt_pack_vector( def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List["cp.ndarray"], + arrays_y: Optional[List["cp.ndarray"]] = None, ): - """Unpack the quantities from a single buffer via streamed cuda kernels + """Unpack the arrays from a single buffer via streamed cuda kernels Reads from self._unpack_buffer using self._x_infos and self._y_infos - to read the offsets and sizes per quantity. + to read the offsets and sizes per array. Args: - quantities_x: list of quantities to unpack. Must fit + arrays_x: list of arrays to unpack. Must fit the specifications given at init time. - quantities_y: Same as above but optional, used only for vector transfer. + arrays_y: Same as above but optional, used only for vector transfer. """ # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._opt_unpack_scalar(quantities_x) + self._opt_unpack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._opt_unpack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._opt_unpack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} unpack") - def _opt_unpack_scalar(self, quantities: List[Quantity]): + def _opt_unpack_scalar(self, arrays: List["cp.ndarray"]): """Specialized unpacking for scalars. See async_unpack docs for usage.""" if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)})" + f"Arrays count ({len(arrays)})" f" is different that specifications count ({len(self._infos_x)})" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity, info_x in zip(quantities, self._infos_x): + for array, info_x in zip(arrays, self._infos_x): cu_kernel_args = self._cu_kernel_args[info_x._id] # Use private stream @@ -862,7 +845,7 @@ def _opt_unpack_scalar(self, quantities: List[Quantity]): cu_kernel_args.x_recv_indices, # indices info_x._unpack_buffer_size, # nIndex offset, - quantity.data[:], # destination_array + array[:], # destination_array ), ) @@ -870,31 +853,31 @@ def _opt_unpack_scalar(self, quantities: List[Quantity]): offset += info_x._unpack_buffer_size def _opt_unpack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] + self, arrays_x: List["cp.ndarray"], arrays_y: List["cp.ndarray"] ): """Specialized unpacking for vectors. See async_unpack docs for usage.""" if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)}" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)}" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened assert len(self._infos_x) == len(self._infos_y) - assert len(quantities_x) == len(quantities_y) + assert len(arrays_x) == len(arrays_y) offset = 0 for ( - quantity_x, - quantity_y, + array_x, + array_y, info_x, info_y, - ) in zip(quantities_x, quantities_y, self._infos_x, self._infos_y): + ) in zip(arrays_x, arrays_y, self._infos_x, self._infos_y): # We only have writte a f64 kernel - if quantity_x.metadata.dtype != np.float64: + if info_x.specification.dtype != np.float64: raise RuntimeError(f"Kernel requires f64 given {np.float64}") cu_kernel_args = self._cu_kernel_args[info_x._id] @@ -921,8 +904,8 @@ def _opt_unpack_vector( info_x._unpack_buffer_size, # nIndex_x info_y._unpack_buffer_size, # nIndex_y offset, - quantity_x.data[:], # destination_array_x - quantity_y.data[:], # destination_array_y + array_x[:], # destination_array_x + array_y[:], # destination_array_y ), ) diff --git a/pace-util/pace/util/halo_quantity_specification.py b/pace-util/pace/util/halo_quantity_specification.py new file mode 100644 index 000000000..a93ac2a35 --- /dev/null +++ b/pace-util/pace/util/halo_quantity_specification.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import Any, Tuple + +from .quantity import Quantity +from .types import NumpyModule + + +@dataclass +class QuantityHaloSpec: + """Describe the memory to be exchanged. + + Specification needs to cover all aspect of the memory layout for + borth scalar, vector and interface fields, for their halo to be exchanged. + `numpy_module` carries a numpy-like (numpy, cupy) module that will be + used to direct the exchange on the right device. + """ + + n_points: int + strides: Tuple[int] + itemsize: int + shape: Tuple[int] + origin: Tuple[int, ...] + extent: Tuple[int, ...] + dims: Tuple[str, ...] + numpy_module: NumpyModule + dtype: Any + + @classmethod + def from_quantity(cls, quantity: Quantity, n_points: int) -> "QuantityHaloSpec": + return QuantityHaloSpec( + n_points=n_points, + strides=quantity.data.strides, + itemsize=quantity.data.itemsize, + shape=quantity.data.shape, + origin=quantity.origin, + extent=quantity.extent, + dims=quantity.dims, + numpy_module=quantity.np, + dtype=quantity.data.dtype, + ) diff --git a/pace-util/pace/util/halo_updater.py b/pace-util/pace/util/halo_updater.py index 3a98c5c74..4d3cee411 100644 --- a/pace-util/pace/util/halo_updater.py +++ b/pace-util/pace/util/halo_updater.py @@ -7,12 +7,9 @@ from ._timing import NullTimer, Timer from .boundary import Boundary from .buffer import Buffer -from .halo_data_transformer import ( - HaloDataTransformer, - HaloExchangeSpec, - QuantityHaloSpec, -) -from .quantity import Quantity +from .halo_data_transformer import HaloDataTransformer, HaloExchangeSpec +from .halo_quantity_specification import QuantityHaloSpec +from .quantity import BoundaryArrayView from .rotate import rotate_scalar_data from .types import AsyncRequest, NumpyModule from .utils import device_synchronize @@ -64,8 +61,8 @@ def __init__( self._timer = timer self._recv_requests: List[AsyncRequest] = [] self._send_requests: List[AsyncRequest] = [] - self._inflight_x_quantities: Optional[Tuple[Quantity, ...]] = None - self._inflight_y_quantities: Optional[Tuple[Quantity, ...]] = None + self._inflight_x_arrays: Optional[Tuple[np.ndarray, ...]] = None + self._inflight_y_arrays: Optional[Tuple[np.ndarray, ...]] = None self._finalize_on_wait = False def force_finalize_on_wait(self): @@ -77,10 +74,7 @@ def force_finalize_on_wait(self): def __del__(self): """Clean up all buffers on garbage collection""" - if ( - self._inflight_x_quantities is not None - or self._inflight_y_quantities is not None - ): + if self._inflight_x_arrays is not None or self._inflight_y_arrays is not None: raise RuntimeError( "An halo exchange wasn't completed and a wait() call was expected" ) @@ -208,25 +202,22 @@ def from_vector_specifications( def update( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): """Exhange the data and blocks until finished.""" - self.start(quantities_x, quantities_y) + self.start(arrays_x, arrays_y) self.wait() def start( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): """Start data exchange.""" self._comm._device_synchronize() - if ( - self._inflight_x_quantities is not None - or self._inflight_y_quantities is not None - ): + if self._inflight_x_arrays is not None or self._inflight_y_arrays is not None: raise RuntimeError( "Previous exchange hasn't been properly finished." "E.g. previous start() call didn't have a wait() call." @@ -244,15 +235,13 @@ def start( ) ) - # Pack quantities halo points data into buffers + # Pack arrays halo points data into buffers with self._timer.clock("pack"): for transformer in self._transformers.values(): - transformer.async_pack(quantities_x, quantities_y) + transformer.async_pack(arrays_x, arrays_y) - self._inflight_x_quantities = tuple(quantities_x) - self._inflight_y_quantities = ( - tuple(quantities_y) if quantities_y is not None else None - ) + self._inflight_x_arrays = tuple(arrays_x) + self._inflight_y_arrays = tuple(arrays_y) if arrays_y is not None else None # Post send MPI order with self._timer.clock("Isend"): @@ -268,7 +257,7 @@ def start( def wait(self): """Finalize data exchange.""" - if __debug__ and self._inflight_x_quantities is None: + if __debug__ and self._inflight_x_arrays is None: raise RuntimeError('Halo update "wait" call before "start"') # Wait message to be exchange with self._timer.clock("wait"): @@ -278,12 +267,10 @@ def wait(self): recv_req.wait() # Unpack buffers (updated by MPI with neighbouring halos) - # to proper quantities + # to proper arrays with self._timer.clock("unpack"): for buffer in self._transformers.values(): - buffer.async_unpack( - self._inflight_x_quantities, self._inflight_y_quantities - ) + buffer.async_unpack(self._inflight_x_arrays, self._inflight_y_arrays) if self._finalize_on_wait: for transformer in self._transformers.values(): transformer.finalize() @@ -291,8 +278,8 @@ def wait(self): for transformer in self._transformers.values(): transformer.synchronize() - self._inflight_x_quantities = None - self._inflight_y_quantities = None + self._inflight_x_arrays = None + self._inflight_y_arrays = None class HaloUpdateRequest: @@ -333,15 +320,15 @@ def wait(self): Buffer.push_to_cache(transfer_buffer) -def on_c_grid(x_quantity, y_quantity): +def on_c_grid(x_spec: QuantityHaloSpec, y_spec: QuantityHaloSpec): if ( - constants.X_DIM not in x_quantity.dims - or constants.Y_INTERFACE_DIM not in x_quantity.dims + constants.X_DIM not in x_spec.dims + or constants.Y_INTERFACE_DIM not in x_spec.dims ): return False if ( - constants.Y_DIM not in y_quantity.dims - or constants.X_INTERFACE_DIM not in y_quantity.dims + constants.Y_DIM not in y_spec.dims + or constants.X_INTERFACE_DIM not in y_spec.dims ): return False else: @@ -349,9 +336,19 @@ def on_c_grid(x_quantity, y_quantity): class VectorInterfaceHaloUpdater: + """Exchange halo on information between ranks for data living on the interface. + + This class reasons on QuantityHaloSpec for initialization and assumes + the arrays given to the start_synchronize_vector_interfaces adhere to those specs. + + See start_synchronize_vector_interfaces for details on interface exchange. + """ + def __init__( self, comm, + qty_x_spec: QuantityHaloSpec, + qty_y_spec: QuantityHaloSpec, boundaries: Mapping[int, Boundary], force_cpu: bool = False, timer: Optional[Timer] = None, @@ -360,6 +357,8 @@ def __init__( Args: comm: mpi4py.Comm object + qty_x_spec: halo specification for data to exchange on the X-axis + qty_y_spec: halo specification for data to exchange on the Y-axis partitioner: cubed sphere partitioner force_cpu: Force all communication to go through central memory. Optional. timer: Time communication operations. Optional. @@ -369,13 +368,15 @@ def __init__( self._force_cpu = force_cpu self.comm = comm self.boundaries = boundaries + self._qty_x_spec = qty_x_spec + self._qty_y_spec = qty_y_spec def _get_halo_tag(self) -> int: self._last_halo_tag += 1 return self._last_halo_tag def start_synchronize_vector_interfaces( - self, x_quantity: Quantity, y_quantity: Quantity + self, x_array: np.ndarray, y_array: np.ndarray ) -> HaloUpdateRequest: """ Synchronize shared points at the edges of a vector interface variable. @@ -386,73 +387,94 @@ def start_synchronize_vector_interfaces( For interface variables, the edges of the tile are computed on both ranks bordering that edge. This routine copies values across those shared edges so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. + rotation of vector data needed to move data across the edge. Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized + x_array: the x-component data to be synchronized + y_array: the y-component data to be synchronized Returns: request: an asynchronous request object with a .wait() method """ - if not on_c_grid(x_quantity, y_quantity): + if not on_c_grid(self._qty_x_spec, self._qty_y_spec): raise ValueError("vector must be defined on Arakawa C-grid") device_synchronize() tag = self._get_halo_tag() - send_requests = self._Isend_vector_shared_boundary( - x_quantity, y_quantity, tag=tag - ) - recv_requests = self._Irecv_vector_shared_boundary( - x_quantity, y_quantity, tag=tag - ) + send_requests = self._Isend_vector_shared_boundary(x_array, y_array, tag=tag) + recv_requests = self._Irecv_vector_shared_boundary(x_array, y_array, tag=tag) return HaloUpdateRequest(send_requests, recv_requests, self.timer) def _Isend_vector_shared_boundary( - self, x_quantity, y_quantity, tag=0 + self, x_array: np.ndarray, y_array: np.ndarray, tag=0 ) -> _HaloRequestSendList: + # South boundary south_boundary = self.boundaries[constants.SOUTH] - west_boundary = self.boundaries[constants.WEST] - south_data = x_quantity.view.southwest.sel( - **{ + southwest_x_view = BoundaryArrayView( + x_array, + constants.SOUTHWEST, + self._qty_x_spec.dims, + self._qty_x_spec.origin, + self._qty_x_spec.extent, + ) + south_data = southwest_x_view.sel( + **{ # type: ignore constants.Y_INTERFACE_DIM: 0, constants.X_DIM: slice( - 0, x_quantity.extent[x_quantity.dims.index(constants.X_DIM)] + 0, + self._qty_x_spec.extent[ + self._qty_x_spec.dims.index(constants.X_DIM) + ], ), } ) south_data = rotate_scalar_data( south_data, [constants.X_DIM], - x_quantity.np, + self._qty_x_spec.numpy_module, -south_boundary.n_clockwise_rotations, ) if south_boundary.n_clockwise_rotations in (3, 2): south_data = -south_data - west_data = y_quantity.view.southwest.sel( - **{ + + # West boundary + west_boundary = self.boundaries[constants.WEST] + southwest_y_view = BoundaryArrayView( + y_array, + constants.SOUTHWEST, + self._qty_y_spec.dims, + self._qty_y_spec.origin, + self._qty_y_spec.extent, + ) + west_data = southwest_y_view.sel( + **{ # type: ignore constants.X_INTERFACE_DIM: 0, constants.Y_DIM: slice( - 0, y_quantity.extent[y_quantity.dims.index(constants.Y_DIM)] + 0, + self._qty_y_spec.extent[ + self._qty_y_spec.dims.index(constants.Y_DIM) + ], ), } ) west_data = rotate_scalar_data( west_data, [constants.Y_DIM], - y_quantity.np, + self._qty_y_spec.numpy_module, -west_boundary.n_clockwise_rotations, ) if west_boundary.n_clockwise_rotations in (1, 2): west_data = -west_data + + # Send requests send_requests = [ self._Isend( - self._maybe_force_cpu(x_quantity.np), + self._maybe_force_cpu(self._qty_x_spec.numpy_module), south_data, dest=south_boundary.to_rank, tag=tag, ), self._Isend( - self._maybe_force_cpu(y_quantity.np), + self._maybe_force_cpu(self._qty_y_spec.numpy_module), west_data, dest=west_boundary.to_rank, tag=tag, @@ -470,35 +492,61 @@ def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: return module def _Irecv_vector_shared_boundary( - self, x_quantity, y_quantity, tag=0 + self, x_array: np.ndarray, y_array: np.ndarray, tag=0 ) -> _HaloRequestRecvList: + # North boundary north_rank = self.boundaries[constants.NORTH].to_rank - east_rank = self.boundaries[constants.EAST].to_rank - north_data = x_quantity.view.northwest.sel( - **{ + northwest_x_view = BoundaryArrayView( + x_array, + constants.NORTHWEST, + self._qty_x_spec.dims, + self._qty_x_spec.origin, + self._qty_x_spec.extent, + ) + + north_data = northwest_x_view.sel( + **{ # type: ignore constants.Y_INTERFACE_DIM: -1, constants.X_DIM: slice( - 0, x_quantity.extent[x_quantity.dims.index(constants.X_DIM)] + 0, + self._qty_x_spec.extent[ + self._qty_x_spec.dims.index(constants.X_DIM) + ], ), } ) - east_data = y_quantity.view.southeast.sel( - **{ + + # East boundary + east_rank = self.boundaries[constants.EAST].to_rank + southeast_y_view = BoundaryArrayView( + y_array, + constants.SOUTHEAST, + self._qty_y_spec.dims, + self._qty_y_spec.origin, + self._qty_y_spec.extent, + ) + east_data = southeast_y_view.sel( + **{ # type: ignore constants.X_INTERFACE_DIM: -1, constants.Y_DIM: slice( - 0, y_quantity.extent[y_quantity.dims.index(constants.Y_DIM)] + 0, + self._qty_y_spec.extent[ + self._qty_y_spec.dims.index(constants.Y_DIM) + ], ), } ) + + # Receive requests recv_requests = [ self._Irecv( - self._maybe_force_cpu(x_quantity.np), + self._maybe_force_cpu(self._qty_x_spec.numpy_module), north_data, source=north_rank, tag=tag, ), self._Irecv( - self._maybe_force_cpu(y_quantity.np), + self._maybe_force_cpu(self._qty_y_spec.numpy_module), east_data, source=east_rank, tag=tag, diff --git a/pace-util/pace/util/rotate.py b/pace-util/pace/util/rotate.py index 27ab1f252..dcbbd71f5 100644 --- a/pace-util/pace/util/rotate.py +++ b/pace-util/pace/util/rotate.py @@ -1,7 +1,11 @@ +from typing import List + +import numpy as np + from . import constants -def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): +def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations) -> np.ndarray: n_clockwise_rotations = n_clockwise_rotations % 4 if n_clockwise_rotations == 0: pass @@ -34,7 +38,9 @@ def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): return data -def rotate_vector_data(x_data, y_data, n_clockwise_rotations, dims, numpy): +def rotate_vector_data( + x_data, y_data, n_clockwise_rotations, dims, numpy +) -> List[np.ndarray]: x_data = rotate_scalar_data(x_data, dims, numpy, n_clockwise_rotations) y_data = rotate_scalar_data(y_data, dims, numpy, n_clockwise_rotations) data = [x_data, y_data] diff --git a/pace-util/tests/test_halo_data_transformer.py b/pace-util/tests/test_halo_data_transformer.py index 8f512d08b..d6defc35b 100644 --- a/pace-util/tests/test_halo_data_transformer.py +++ b/pace-util/tests/test_halo_data_transformer.py @@ -326,12 +326,12 @@ def test_data_transformer_scalar_pack_unpack(quantity, rotation, n_halos): data_transformer = HaloDataTransformer.get(quantity.np, exchange_descriptors) - data_transformer.async_pack([quantity, quantity]) + data_transformer.async_pack([quantity.data, quantity.data]) # Simulate data transfer data_transformer.get_unpack_buffer().assign_from( data_transformer.get_pack_buffer().array ) - data_transformer.async_unpack([quantity, quantity]) + data_transformer.async_unpack([quantity.data, quantity.data]) data_transformer.synchronize() # From the copy of the original quantity we rotate data @@ -433,12 +433,16 @@ def test_data_transformer_vector_pack_unpack(quantity, rotation, n_halos): x_quantity.np, exchange_descriptors_x, exchange_descriptors_y ) - data_transformer.async_pack([x_quantity, x_quantity], [y_quantity, y_quantity]) + data_transformer.async_pack( + [x_quantity.data, x_quantity.data], [y_quantity.data, y_quantity.data] + ) # Simulate data transfer data_transformer.get_unpack_buffer().assign_from( data_transformer.get_pack_buffer().array ) - data_transformer.async_unpack([x_quantity, x_quantity], [y_quantity, y_quantity]) + data_transformer.async_unpack( + [x_quantity.data, x_quantity.data], [y_quantity.data, y_quantity.data] + ) data_transformer.synchronize() # From the copy of the original quantity we rotate data diff --git a/pace-util/tests/test_halo_update.py b/pace-util/tests/test_halo_update.py index 834c1beb1..8cdfefca7 100644 --- a/pace-util/tests/test_halo_update.py +++ b/pace-util/tests/test_halo_update.py @@ -866,7 +866,7 @@ def test_halo_updater_stability( # First run for halo_updater in halo_updaters: - halo_updater.start([quantity]) + halo_updater.start([quantity.data]) for halo_updater in halo_updaters: halo_updater.wait() @@ -874,11 +874,11 @@ def test_halo_updater_stability( # The buffer should stay stable since we are exchanging the same information exchanged_once_quantity = copy.deepcopy(quantity) for halo_updater in halo_updaters: - halo_updater.start([quantity]) + halo_updater.start([quantity.data]) for halo_updater in halo_updaters: halo_updater.wait() for halo_updater in halo_updaters: - halo_updater.start([quantity]) + halo_updater.start([quantity.data]) for halo_updater in halo_updaters: halo_updater.wait() assert (quantity.data == exchanged_once_quantity.data).all() diff --git a/stencils/pace/stencils/c2l_ord.py b/stencils/pace/stencils/c2l_ord.py index f5962ae20..0325bb82f 100644 --- a/stencils/pace/stencils/c2l_ord.py +++ b/stencils/pace/stencils/c2l_ord.py @@ -1,6 +1,5 @@ from gt4py.gtscript import PARALLEL, computation, horizontal, interval, region -import fv3core import pace.dsl.gt4py_utils as utils from pace.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from pace.dsl.stencil import StencilFactory @@ -104,7 +103,6 @@ class CubedToLatLon: def __init__( self, - state: fv3core.DycoreState, stencil_factory: StencilFactory, grid_data: GridData, order: int, @@ -162,11 +160,7 @@ def __init__( self.u__v = WrappedHaloUpdater( comm.get_vector_halo_updater( [full_size_xyiz_halo_spec], [full_size_xiyz_halo_spec] - ), - state, - ["u"], - ["v"], - comm=comm, + ) ) def __call__( @@ -186,7 +180,7 @@ def __call__( comm: Cubed-sphere communicator """ if self._do_ord4: - self.u__v.update() + self.u__v.update([u.data], [v.data]) self._compute_cubed_to_latlon( u, v, diff --git a/stencils/pace/stencils/fv_update_phys.py b/stencils/pace/stencils/fv_update_phys.py index 9b7196ba3..e67a750ca 100644 --- a/stencils/pace/stencils/fv_update_phys.py +++ b/stencils/pace/stencils/fv_update_phys.py @@ -1,7 +1,6 @@ import gt4py.gtscript as gtscript from gt4py.gtscript import FORWARD, PARALLEL, computation, exp, interval, log -import fv3core import pace.dsl.gt4py_utils as utils import pace.util import pace.util.constants as constants @@ -88,9 +87,6 @@ def __init__( namelist, comm: pace.util.CubedSphereCommunicator, grid_info: DriverGridData, - state: fv3core.DycoreState, - u_dt: pace.util.Quantity, - v_dt: pace.util.Quantity, ): orchestrate( obj=self, @@ -113,7 +109,7 @@ def __init__( stencil_factory, comm.partitioner, comm.rank, namelist, grid_info ) self._do_cubed_to_latlon = CubedToLatLon( - state, stencil_factory, grid_data, order=namelist.c2l_ord, comm=comm + stencil_factory, grid_data, order=namelist.c2l_ord, comm=comm ) self.origin = grid_indexing.origin_compute() self.extent = grid_indexing.domain_compute() @@ -127,13 +123,9 @@ def __init__( ) self._udt_halo_updater = WrappedHaloUpdater( self.comm.get_scalar_halo_updater([full_3Dfield_1pts_halo_spec]), - {"u_dt": u_dt}, - ["u_dt"], ) self._vdt_halo_updater = WrappedHaloUpdater( self.comm.get_scalar_halo_updater([full_3Dfield_1pts_halo_spec]), - {"v_dt": v_dt}, - ["v_dt"], ) # TODO: check if we actually need surface winds self._u_srf = utils.make_storage_from_shape( @@ -164,8 +156,8 @@ def __call__( dt, ) - self._udt_halo_updater.start() - self._vdt_halo_updater.start() + self._udt_halo_updater.start([u_dt.data]) + self._vdt_halo_updater.start([v_dt.data]) self._update_pressure_and_surface_winds( state.pe, state.delp, diff --git a/stencils/pace/stencils/update_atmos_state.py b/stencils/pace/stencils/update_atmos_state.py index 30f815be8..3e53825c6 100644 --- a/stencils/pace/stencils/update_atmos_state.py +++ b/stencils/pace/stencils/update_atmos_state.py @@ -242,11 +242,8 @@ def __init__( namelist, comm: pace.util.CubedSphereCommunicator, grid_info: DriverGridData, - state: fv3core.DycoreState, - quantity_factory: pace.util.QuantityFactory, dycore_only: bool, apply_tendencies: bool, - tendency_state, ): orchestrate( obj=self, @@ -259,8 +256,6 @@ def __init__( grid_indexing = stencil_factory.grid_indexing self.namelist = namelist - origin = grid_indexing.origin_compute() - shape = grid_indexing.domain_full(add=(1, 1, 1)) self._rdt = 1.0 / Float(self.namelist.dt_atmos) self._prepare_tendencies_and_update_tracers = ( @@ -271,7 +266,6 @@ def __init__( ) ) - dims = [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM] self._fill_GFS_delp = stencil_factory.from_origin_domain( fill_gfs_delp, origin=grid_indexing.origin_full(), @@ -284,9 +278,6 @@ def __init__( self.namelist, comm, grid_info, - state, - tendency_state.u_dt, - tendency_state.v_dt, ) self._dycore_only = dycore_only # apply_tendencies when we have run physics or fv_subgridz diff --git a/tests/main/fv3core/test_dycore_call.py b/tests/main/fv3core/test_dycore_call.py index 92bc020e2..e3d9f4178 100644 --- a/tests/main/fv3core/test_dycore_call.py +++ b/tests/main/fv3core/test_dycore_call.py @@ -123,7 +123,6 @@ def setup_dycore() -> Tuple[ damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms), config=config, phis=state.phis, - state=state, ) do_adiabatic_init = False @@ -160,10 +159,10 @@ def test_temporaries_are_deterministic(): dycore1, state1, timer1 = setup_dycore() dycore2, state2, timer2 = setup_dycore() - dycore1.step_dynamics(state1, timer1) + dycore1.step_dynamics(state1, state1.tracers_as_array(), timer1) first_temporaries = copy_temporaries(dycore1, max_depth=10) assert len(first_temporaries) > 0 - dycore2.step_dynamics(state2, timer2) + dycore2.step_dynamics(state2, state2.tracers_as_array(), timer2) second_temporaries = copy_temporaries(dycore2, max_depth=10) assert_same_temporaries(second_temporaries, first_temporaries) @@ -180,14 +179,14 @@ def test_call_on_same_state_same_dycore_produces_same_temporaries(): # state_1 and state_2 are identical, if the dycore is stateless then they # should produce identical dycore final states when used to call - dycore.step_dynamics(state_1, timer_1) + dycore.step_dynamics(state_1, state_1.tracers_as_array(), timer_1) first_temporaries = copy_temporaries(dycore, max_depth=10) assert len(first_temporaries) > 0 # TODO: The orchestrated code pushed us to make the dycore stateful for halo # exchange, so we must copy into state_1 instead of using state_2. # We should call with state_2 directly when this is fixed. copy_state(state_2, state_1) - dycore.step_dynamics(state_1, timer_2) + dycore.step_dynamics(state_1, state_1.tracers_as_array(), timer_2) second_temporaries = copy_temporaries(dycore, max_depth=10) assert_same_temporaries(second_temporaries, first_temporaries) @@ -200,7 +199,7 @@ def error_func(*args, **kwargs): with unittest.mock.patch("gt4py.storage.storage.zeros", new=error_func): with unittest.mock.patch("gt4py.storage.storage.empty", new=error_func): - dycore.step_dynamics(state, timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer) def test_call_does_not_define_stencils(): @@ -210,4 +209,4 @@ def error_func(*args, **kwargs): raise AssertionError("call not allowed") with unittest.mock.patch("gt4py.gtscript.stencil", new=error_func): - dycore.step_dynamics(state, timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer) diff --git a/tests/mpi/test_checkpoints.py b/tests/mpi/test_checkpoints.py index af306b1d6..d9675f6a5 100644 --- a/tests/mpi/test_checkpoints.py +++ b/tests/mpi/test_checkpoints.py @@ -173,7 +173,7 @@ def test_fv_dynamics( checkpointer=validation, ) with validation.trial(): - dycore.step_dynamics(state) + dycore.step_dynamics(state, state.tracers_as_array()) def _calibrate_thresholds( @@ -202,7 +202,7 @@ def _calibrate_thresholds( trial_state, _ = initializer.new_state() perturb(dycore_state_to_dict(trial_state)) with calibration.trial(): - dycore.step_dynamics(trial_state) + dycore.step_dynamics(trial_state, trial_state.tracers_as_array()) all_thresholds = communicator.comm.allgather(calibration.thresholds) thresholds = merge_thresholds(all_thresholds) set_manual_thresholds(thresholds)