Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions pyfv3/stencils/fillz.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import typing

import ndsl.dsl.gt4py_utils as utils
from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate
import dace

from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
from ndsl.constants import I_DIM, J_DIM, K_DIM
from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval, max, min
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, IntFieldIJ
from ndsl.dsl.typing import FloatField, FloatFieldIJ, Int, IntFieldIJ


@typing.no_type_check
Expand Down Expand Up @@ -97,7 +98,7 @@ def fix_tracer(
q = max(fac * dm / dp, 0.0)


class FillNegativeTracerValues:
class FillNegativeTracerValues(NDSLRuntime):
"""
Fix tracer values to prevent negative masses.

Expand All @@ -109,13 +110,9 @@ def __init__(
stencil_factory: StencilFactory,
quantity_factory: QuantityFactory,
nq: int,
tracers: dict[str, Quantity],
):
orchestrate(
obj=self,
config=stencil_factory.config.dace_config,
dace_compiletime_args=["tracers"],
)
super().__init__(stencil_factory)

self._nq = int(nq)
self._fix_tracer_stencil = stencil_factory.from_dims_halo(
fix_tracer,
Expand All @@ -124,33 +121,24 @@ def __init__(

# Setting initial value of upper_fix to zero is only needed for validation.
# The values in the compute domain are set to zero in the stencil.
self._zfix = quantity_factory.zeros([I_DIM, J_DIM], units="unknown", dtype=int)
self._sum0 = quantity_factory.zeros(
[I_DIM, J_DIM],
units="unknown",
dtype=Float,
)
self._sum1 = quantity_factory.zeros(
[I_DIM, J_DIM],
units="unknown",
dtype=Float,
)

self._filtered_tracer_dict = {
name: tracers[name] for name in utils.tracer_variables[0 : self._nq]
}
self._zfix = self.make_local(quantity_factory, [I_DIM, J_DIM], dtype=Int)
self._zfix.data[:] = 0
self._sum0 = self.make_local(quantity_factory, [I_DIM, J_DIM])
self._sum0.data[:] = 0
self._sum1 = self.make_local(quantity_factory, [I_DIM, J_DIM])
self._sum1.data[:] = 0

def __call__(
self,
dp2: FloatField,
tracers: dict[str, Quantity],
tracers: dace.compiletime, # dict[str, Quantity],
):
"""
Args:
dp2 (in): pressure thickness of atmospheric layer
tracers (inout): tracers to fix negative masses in
"""
for tracer_name in self._filtered_tracer_dict.keys():
for tracer_name in tracers.keys():
self._fix_tracer_stencil(
tracers[tracer_name],
dp2,
Expand Down
1 change: 0 additions & 1 deletion pyfv3/stencils/fv_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def __init__(
area_64=grid_data.area_64,
nq=NQ,
pfull=self._pfull,
tracers=self.tracers,
)

full_xyz_spec = quantity_factory.get_quantity_halo_spec(
Expand Down
37 changes: 17 additions & 20 deletions pyfv3/stencils/map_single.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Sequence
from typing import Optional

from ndsl import QuantityFactory, StencilFactory, orchestrate
from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
from ndsl.constants import I_DIM, J_DIM, K_DIM
from ndsl.dsl.gt4py import FORWARD, PARALLEL, computation, interval
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, IntFieldIJ
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, Int, IntFieldIJ
from ndsl.stencils.basic_operations import copy
from pyfv3.stencils.remap_profile import RemapProfile

Expand Down Expand Up @@ -79,7 +79,7 @@ def lagrangian_contributions(
lev = lev - 1


class MapSingle:
class MapSingle(NDSLRuntime):
"""
Fortran name is map_single, test classes are Map1_PPM_2d, Map_Scalar_2d
"""
Expand All @@ -92,10 +92,7 @@ def __init__(
mode: int,
dims: Sequence[str],
) -> None:
orchestrate(
obj=self,
config=stencil_factory.config.dace_config,
)
super().__init__(stencil_factory)

def make_quantity():
return quantity_factory.zeros(
Expand All @@ -104,17 +101,17 @@ def make_quantity():
dtype=Float,
)

self._dp1 = make_quantity()
self._q4_1 = make_quantity()
self._q4_2 = make_quantity()
self._q4_3 = make_quantity()
self._q4_4 = make_quantity()
self._tmp_qs = quantity_factory.zeros(
[I_DIM, J_DIM],
units="unknown",
dtype=Float,
)
self._lev = quantity_factory.zeros([I_DIM, J_DIM], units="", dtype=int)
# All locals will be initialized in code before being read
self._dp1 = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._q4_1 = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._q4_2 = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._q4_3 = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._q4_4 = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._lev = self.make_local(quantity_factory, [I_DIM, J_DIM], dtype=Int)

# If the boundary condition is not given as an input, we use use a zero-reference
self._zero_qs = self.make_local(quantity_factory, [I_DIM, J_DIM])
self._zero_qs.data[:] = 0

self._copy_stencil = stencil_factory.from_dims_halo(
copy,
Expand Down Expand Up @@ -144,7 +141,7 @@ def __call__(
q1: FloatField,
pe1: FloatField,
pe2: FloatField,
qs: Optional["FloatFieldIJ"] = None,
qs: Optional[FloatFieldIJ] = None,
qmin: Float = 0.0,
) -> None:
"""
Expand All @@ -163,7 +160,7 @@ def __call__(

if qs is None:
self._remap_profile(
self._tmp_qs,
self._zero_qs,
self._q4_1,
self._q4_2,
self._q4_3,
Expand Down
60 changes: 30 additions & 30 deletions pyfv3/stencils/mapn_tracer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import dace

import ndsl.dsl.gt4py_utils as utils
from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate
from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
from ndsl.constants import I_DIM, J_DIM, K_DIM
from ndsl.dsl.typing import Float, FloatField
from ndsl.dsl.typing import FloatField
from pyfv3.stencils.fillz import FillNegativeTracerValues
from pyfv3.stencils.map_single import MapSingle


class MapNTracer:
class MapNTracer(NDSLRuntime):
"""
Fortran code is mapn_tracer, test class is MapN_Tracer_2d
"""
Expand All @@ -18,51 +20,46 @@ def __init__(
kord: int,
nq: int,
fill: bool,
tracers: dict[str, Quantity],
):
orchestrate(
obj=self,
config=stencil_factory.config.dace_config,
dace_compiletime_args=["tracers"],
)
super().__init__(stencil_factory)
self._nq = int(nq)
self._qs = quantity_factory.zeros(
[I_DIM, J_DIM, K_DIM],
units="unknown",
dtype=Float,
)
self._qs = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._qs.data[:] = 0 # low boundary condition for RemapProfile

kord_tracer = [kord] * self._nq
kord_tracer[5] = 9 # qcld
self._map_single_parametrized_kord = MapSingle(
stencil_factory,
quantity_factory,
kord,
0,
dims=[I_DIM, J_DIM, K_DIM],
)

self._list_of_remap_objects = [
MapSingle(
stencil_factory,
quantity_factory,
kord_tracer[i],
0,
dims=[I_DIM, J_DIM, K_DIM],
)
for i in range(len(kord_tracer))
]
self._map_single_kord9 = MapSingle(
stencil_factory,
quantity_factory,
9,
0,
dims=[I_DIM, J_DIM, K_DIM],
)

if fill:
self._fill_negative_tracers = True
self._fillz = FillNegativeTracerValues(
stencil_factory,
quantity_factory,
self._nq,
tracers,
)
else:
self._fill_negative_tracers = False

self._index_graupel = utils.tracer_variables.index("qgraupel")

def __call__(
self,
pe1: FloatField,
pe2: FloatField,
dp2: FloatField,
tracers: dict[str, Quantity],
tracers: dace.compiletime, # dict[str, Quantity]
):
"""
Remaps the tracer species onto the Eulerian grid
Expand All @@ -75,8 +72,11 @@ def __call__(
dp2 (in): Difference in pressure between Eulerian levels
tracers (inout): tracers to be remapped
"""
for i, q in enumerate(utils.tracer_variables[0 : self._nq]):
self._list_of_remap_objects[i](tracers[q], pe1, pe2, self._qs)
for i, q in enumerate(tracers.keys()):
if i != self._index_graupel:
self._map_single_parametrized_kord(tracers[q], pe1, pe2, self._qs)

self._map_single_kord9(tracers["qgraupel"], pe1, pe2, self._qs)

if self._fill_negative_tracers:
self._fillz(dp2, tracers)
34 changes: 13 additions & 21 deletions pyfv3/stencils/remap_profile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Sequence

from ndsl import QuantityFactory, StencilFactory, orchestrate
from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
from ndsl.constants import I_DIM, J_DIM, K_DIM, K_INTERFACE_DIM
from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation
from ndsl.dsl.gt4py import function as gtfunction
Expand Down Expand Up @@ -535,7 +535,7 @@ def set_interpolation_coefficients(
a4_1, a4_2, a4_3, a4_4 = posdef_constraint_iv1(a4_1, a4_2, a4_3, a4_4)


class RemapProfile:
class RemapProfile(NDSLRuntime):
"""
This corresponds to the cs_profile routine in FV3
"""
Expand All @@ -558,10 +558,7 @@ def __init__(
iv: ???
dims: dimensions on which to operate on inputs
"""
orchestrate(
obj=self,
config=stencil_factory.config.dace_config,
)
super().__init__(stencil_factory)

if kord > 10:
raise NotImplementedError(
Expand All @@ -570,24 +567,19 @@ def __init__(

self._kord = kord

self._gam = quantity_factory.zeros(
[I_DIM, J_DIM, K_DIM],
units="unknown",
dtype=Float,
self._gam = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._q = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
self._q_bot = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])

self._extm = self.make_local(
quantity_factory, [I_DIM, J_DIM, K_DIM], dtype=bool
)
self._q = quantity_factory.zeros(
[I_DIM, J_DIM, K_DIM],
units="unknown",
dtype=Float,
self._ext5 = self.make_local(
quantity_factory, [I_DIM, J_DIM, K_DIM], dtype=bool
)
self._q_bot = quantity_factory.zeros(
[I_DIM, J_DIM, K_DIM],
units="unknown",
dtype=Float,
self._ext6 = self.make_local(
quantity_factory, [I_DIM, J_DIM, K_DIM], dtype=bool
)
self._extm = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], units="", dtype=bool)
self._ext5 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], units="", dtype=bool)
self._ext6 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], units="", dtype=bool)

self._set_initial_values = stencil_factory.from_dims_halo(
func=set_initial_vals,
Expand Down
Loading
Loading