Skip to content

Commit

Permalink
compiler: fix alias dtype with complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 30, 2024
1 parent 014ef2c commit 05f8528
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 51 deletions.
5 changes: 3 additions & 2 deletions devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def reinit_compiler(val):
"""
Re-initialize the Compiler.
"""
configuration['compiler'].__init__(suffix=configuration['compiler'].suffix,
configuration['compiler'].__init__(name=configuration['compiler'].name,
suffix=configuration['compiler'].suffix,
mpi=configuration['mpi'])
return val

Expand All @@ -61,7 +62,7 @@ def reinit_compiler(val):
configuration.add('platform', 'cpu64', list(platform_registry),
callback=lambda i: platform_registry[i]())
configuration.add('compiler', 'custom', list(compiler_registry),
callback=lambda i: compiler_registry[i]())
callback=lambda i: compiler_registry[i](name=i))

# Setup language for shared-memory parallelism
preprocessor = lambda i: {0: 'C', 1: 'openmp'}.get(i, i) # Handles DEVITO_OPENMP deprec
Expand Down
24 changes: 22 additions & 2 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def __init__(self):
_cpp = False

def __init__(self, **kwargs):
self._name = kwargs.pop('name', self.__class__.__name__)

super().__init__(**kwargs)

self.__lookup_cmds__()
Expand Down Expand Up @@ -221,13 +223,13 @@ def __new_with__(self, **kwargs):
Create a new Compiler from an existing one, inherenting from it
the flags that are not specified via ``kwargs``.
"""
return self.__class__(suffix=kwargs.pop('suffix', self.suffix),
return self.__class__(name=self.name, suffix=kwargs.pop('suffix', self.suffix),
mpi=kwargs.pop('mpi', configuration['mpi']),
**kwargs)

@property
def name(self):
return self.__class__.__name__
return self._name

@property
def version(self):
Expand All @@ -243,6 +245,20 @@ def version(self):

return version

@property
def _complex_ctype(self):
"""
Type definition for complex numbers. THese two cases cover 99% of the cases since
- Hip is now using std::complex
https://rocm.docs.amd.com/en/docs-5.1.3/CHANGELOG.html#hip-api-deprecations-and-warnings
- Sycl supports std::complex
- C's _Complex is part of C99
"""
if self._cpp:
return lambda dtype: 'std::complex<%s>' % str(dtype)
else:
return lambda dtype: '%s _Complex' % str(dtype)

def get_version(self):
result, stdout, stderr = call_capture_output((self.cc, "--version"))
if result != 0:
Expand Down Expand Up @@ -697,6 +713,10 @@ def __lookup_cmds__(self):
self.MPICC = 'nvcc'
self.MPICXX = 'nvcc'

@property
def _complex_ctype(self):
return lambda dtype: 'thrust::complex<%s>' % str(dtype)


class HipCompiler(Compiler):

Expand Down
34 changes: 14 additions & 20 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import ctypes

import cgen as c
import numpy as np
from sympy import IndexedBase
from sympy.core.function import Application

from devito.parameters import configuration
from devito.parameters import configuration, switchconfig
from devito.exceptions import VisitorException
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
Expand Down Expand Up @@ -190,20 +189,15 @@ def __init__(self, *args, compiler=None, **kwargs):
}
_restrict_keyword = 'restrict'

def _complex_type(self, ctypestr, dtype):
# Not complex
try:
if not np.issubdtype(dtype, np.complexfloating):
return ctypestr
except TypeError:
return ctypestr
# Complex only supported for float and double
if ctypestr not in ('float', 'double'):
return ctypestr
if self._compiler._cpp:
return 'std::complex<%s>' % ctypestr
else:
return '%s _Complex' % ctypestr
@property
def compiler(self):
return self._compiler

def visit(self, o, *args, **kwargs):
# Make sure the visitor always is within the generating compiler
# in case the configuration is accessed
with switchconfig(compiler=self.compiler.name):
return super().visit(o, *args, **kwargs)

def _gen_struct_decl(self, obj, masked=()):
"""
Expand Down Expand Up @@ -260,10 +254,10 @@ def _gen_value(self, obj, mode=1, masked=()):
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = self._complex_type(obj._C_typedata, obj.dtype)
strtype = obj._C_typedata
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
else:
strtype = self._complex_type(ctypes_to_cstr(obj._C_ctype), obj.dtype)
strtype = ctypes_to_cstr(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
Expand Down Expand Up @@ -393,7 +387,7 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = self._complex_type(i._C_typedata, i.dtype)
cstr = i._C_typedata

if f.is_PointerArray:
# lvalue
Expand Down Expand Up @@ -448,7 +442,7 @@ def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = self._complex_type(i._C_typedata, i.dtype)
cstr = i._C_typedata
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
Expand Down
7 changes: 4 additions & 3 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Complex header if needed. Needs to be done specialization
# as some specific cases requires complex to be loaded first
# Complex header if needed. Needs to be done before specialization
# as some specific cases require complex to be loaded first
complex_include(graph, language=kwargs['language'], compiler=kwargs['compiler'])

# Specialize
Expand Down Expand Up @@ -1220,7 +1220,8 @@ def parse_kwargs(**kwargs):
raise InvalidOperator("Illegal `compiler=%s`" % str(compiler))
kwargs['compiler'] = compiler_registry[compiler](platform=kwargs['platform'],
language=kwargs['language'],
mpi=configuration['mpi'])
mpi=configuration['mpi'],
name=compiler)
elif any([platform, language]):
kwargs['compiler'] =\
configuration['compiler'].__new_with__(platform=kwargs['platform'],
Expand Down
37 changes: 26 additions & 11 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import ValueLimit, evalrel, has_integer_args, limits_mapper
from devito.tools import as_mapper, filter_ordered, split
from devito.tools import as_mapper, filter_ordered, split, dtype_to_cstr

__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
'generate_macros', 'minimize_symbols', 'complex_include']
Expand Down Expand Up @@ -192,22 +192,28 @@ def minimize_symbols(iet):
return iet, {}


_complex_lib = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
_complex_lib = {'cuda': 'thrust/complex.h'}


@iet_pass
def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
"""
# Check if there is complex numbers that always take dtype precedence
max_dtype = np.result_type(*[f.dtype for f in FindSymbols().visit(iet)])
if not np.issubdtype(max_dtype, np.complexfloating):
return iet, {}

lib = (_complex_lib.get(language, 'complex' if compiler._cpp else 'complex.h'),)

headers = {}

# For openacc (cpp) need to define constant _Complex_I that isn't found otherwise
if compiler._cpp:
c_str = dtype_to_cstr(max_dtype.type(0).real.dtype.type)
# Constant I
headers = {('_Complex_I', ('std::complex<float>(0.0f, 1.0f)'))}
headers = {('_Complex_I', ('std::complex<%s>(0.0, 1.0)' % c_str))}
# Mix arithmetic definitions
dest = compiler.get_jit_dir()
hfile = dest.joinpath('stdcomplex_arith.h')
Expand All @@ -216,14 +222,7 @@ def complex_include(iet, language, compiler):
ff.write(str(_stdcomplex_defs))
lib += (str(hfile),)

for f in FindSymbols().visit(iet):
try:
if np.issubdtype(f.dtype, np.complexfloating):
return iet, {'includes': lib, 'headers': headers}
except TypeError:
pass

return iet, {}
return iet, {'includes': lib, 'headers': headers}


def remove_redundant_moddims(iet):
Expand Down Expand Up @@ -314,13 +313,29 @@ def _rename_subdims(target, dimensions):
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() * a, b.imag() * a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){
_Tp denom = b.real() * b.real () + b.imag() * b.imag()
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() / a, b.imag() / a);
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){
return std::complex<_Tp>(b.real() + a, b.imag());
}
template<typename _Tp, typename _Ti>
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){
return std::complex<_Tp>(b.real() + a, b.imag());
}
"""
6 changes: 4 additions & 2 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,12 @@ def sympy_dtype(expr, base=None):
dtypes.add(i.dtype)
except AttributeError:
pass

dtype = infer_dtype(dtypes)

# Promote if complex
if expr.has(ImaginaryUnit):
# Promote if we missed complex number, i.e f + I
is_im = np.issubdtype(dtype, np.complexfloating)
if expr.has(ImaginaryUnit) and not is_im:
dtype = np.promote_types(dtype, np.complex64).type

return dtype
2 changes: 1 addition & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _print_Float(self, expr):
elif rv.startswith('.0'):
rv = '0.' + rv[2:]

if self.dtype == np.float32 or self.dtype == np.complex64:
if self.dtype == np.float32:
rv = rv + 'F'

return rv
Expand Down
12 changes: 9 additions & 3 deletions devito/tools/dtypes_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,12 @@ def dtype_to_ctype(dtype):
# Complex data
if np.issubdtype(dtype, np.complexfloating):
rtype = dtype(0).real.__class__
return dtype_to_ctype(rtype)
from devito import configuration
make = configuration['compiler']._complex_ctype
ctname = make(dtype_to_cstr(rtype))
ctype = dtype_to_ctype(rtype)
r = type(ctname, (ctype,), {})
return r

try:
return ctypes_vector_mapper[dtype]
Expand Down Expand Up @@ -214,7 +219,7 @@ class c_restrict_void_p(ctypes.c_void_p):
# *** ctypes lowering


def ctypes_to_cstr(ctype, toarray=None, cpp=False):
def ctypes_to_cstr(ctype, toarray=None):
"""Translate ctypes types into C strings."""
if ctype in ctypes_vector_mapper.values():
retval = ctype.__name__
Expand Down Expand Up @@ -308,7 +313,8 @@ def infer_dtype(dtypes):
# Resolve the vector types, if any
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes}

fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)}
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or
np.issubdtype(i, np.complexfloating)}
if len(fdtypes) > 1:
return max(fdtypes, key=lambda i: np.dtype(i).itemsize)
elif len(fdtypes) == 1:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ def test_maxpar_option(self):
assert trees[0][0] is trees[1][0]
assert trees[0][1] is not trees[1][1]

def test_complex(self):
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
def test_complex(self, dtype):
grid = Grid((5, 5))
x, y = grid.dimensions
# Float32 complex is called complex64 in numpy
u = Function(name="u", grid=grid, dtype=np.complex64)
u = Function(name="u", grid=grid, dtype=dtype)

eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
print(op)
op()

# Check against numpy
Expand Down
8 changes: 4 additions & 4 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,22 +640,22 @@ def test_tensor(self, func1):
op2 = Operator([Eq(f, f.dx) for f in f1.values()])
assert str(op1.ccode) == str(op2.ccode)

def test_complex(self):
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
def test_complex(self, dtype):
grid = Grid((5, 5))
x, y = grid.dimensions
# Float32 complex is called complex64 in numpy
u = Function(name="u", grid=grid, dtype=np.complex64)
u = Function(name="u", grid=grid, dtype=dtype)

eq = Eq(u, x + sympy.I*y + exp(sympy.I + x.spacing))
# Currently wrong alias type
op = Operator(eq)
# print(op)
op()

# Check against numpy
dx = grid.spacing_map[x.spacing]
xx, yy = np.meshgrid(np.linspace(0, 4, 5), np.linspace(0, 4, 5))
npres = xx + 1j*yy + np.exp(1j + dx)
print(op)

assert np.allclose(u.data, npres.T, rtol=1e-7, atol=0)

Expand Down

0 comments on commit 05f8528

Please sign in to comment.