Skip to content

Make a wrapper to hide jaxdf computations #125

Open
@astanziola

Description

One immediate feature that emerged from the chat with @jejjohnson is the ability to work with fields in a way that allows hiding them from the user, or at least not explicitly working with them.

A common pattern for achieving this is given by the following code:

def my_awesome_func(u: jax.ArrayLike):
  # Declare fields
  N = u.shape
  dx = [0.1,] * len(N)
  u_field = FourierSeries(u, Domain(N, dx))
  
  # Perform the desired operation using jaxdf
  v_field = some_operator(u_field)
  
  # Return a simple jax array
  return v_field.on_grid

To simplify the syntax and achieve a cleaner implementation, this pattern can be encapsulated in a decorator, as shown below:

@use_discretization(FourierSeries, dx)
def my_awesome_func(u: jax.ArrayLike):
  return some_operator(u_field)

Here, the use_discretization decorator takes care of packing and unpacking the fields:

def use_discretization(discr_class, dx):
  def _decorator(func):

    def wrapper(u):
      # Declare fields
      N = u.shape
      dx = [0.1,] * len(N)
      u_field = FourierSeries(u, Domain(N, dx))
      
      # Perform the desired operation using jaxdf
      v_field = func(u_field)
  
      # Return a simple jax array
      return v_field.on_grid
   
   return wrapper
return _decorator

Potential issues and things to work out

  • How to deal with multiple input fields
  • How to pass generic parameters, i.e. generalize dx in this example
  • Does this only make sense for OnGrid fields?

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions