Skip to content

Commit 0823c5f

Browse files
committed
compiler: patch multi-subdim isapce handling
1 parent 23d7b99 commit 0823c5f

File tree

4 files changed

+30
-24
lines changed

4 files changed

+30
-24
lines changed

devito/finite_differences/tools.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
from sympy import S, finite_diff_weights, cacheit, sympify, Rational, Expr
66

7-
from devito.logger import warning
87
from devito.tools import Tag, as_tuple
98
from devito.types.dimension import StencilDimension
109

@@ -269,9 +268,7 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
269268
f"stencil size ({order + 1}) for order {order} scheme")
270269
elif do > dw:
271270
order = nweights - nweights % 2
272-
warning(f"Less weights ({nweights}) provided than the stencil size"
273-
f"({order + 1}) for order {order} scheme."
274-
f" Reducing order to {order}")
271+
275272
# Evaluation point
276273
x0 = sympify(((x0 or {}).get(dim) or expr.indices_ref[dim]))
277274

devito/passes/clusters/implicit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,17 @@ def callback(self, clusters, prefix):
175175

176176
# Make sure the "implicit expressions" are scheduled in
177177
# the innermost loop such that the thicknesses can be computed
178-
edims = set(retrieve_dimensions(mapper.values(), deep=True))
179-
if dim not in edims or not edims.issubset(prefix.dimensions):
178+
def key(tkn):
179+
edims = set(retrieve_dimensions(tkn, deep=True))
180+
return dim._defines & edims and edims.issubset(prefix.dimensions)
181+
182+
mapper = {k: v for k, v in mapper.items() if key(v)}
183+
if not mapper:
180184
continue
181185

182186
found[d.functions].clusters.append(c)
183187
found[d.functions].mapper = reduce(found[d.functions].mapper,
184-
mapper, edims, prefix)
188+
mapper, {dim}, prefix)
185189

186190
# Turn the reduced mapper into a list of equations
187191
processed = []
@@ -262,7 +266,7 @@ def reduce(m0, m1, edims, prefix):
262266
def key(i):
263267
try:
264268
return i.indices[d]
265-
except AttributeError:
269+
except (KeyError, AttributeError):
266270
return i
267271

268272
mapper = {}

devito/passes/iet/misc.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from devito.finite_differences import Max, Min
88
from devito.finite_differences.differentiable import SafeInv
9-
from devito.logger import warning
109
from devito.ir import (Any, Forward, DummyExpr, Iteration, EmptyList, Prodder,
1110
FindApplications, FindNodes, FindSymbols, Transformer,
1211
Uxreplace, filter_iterations, retrieve_iteration_tree,
@@ -155,7 +154,7 @@ def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs)
155154
for define, expr in headers)
156155

157156
# Generate Macros from higher-level SymPy objects
158-
mheaders, includes = _generate_macros_math(iet, langbb=langbb)
157+
mheaders, includes = _generate_macros_math(iet, langbb=langbb, printer=printer)
159158
includes = sorted(includes, key=str)
160159
headers.extend(sorted(mheaders, key=str))
161160

@@ -199,25 +198,25 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
199198
return iet
200199

201200

202-
def _generate_macros_math(iet, langbb=None):
201+
def _generate_macros_math(iet, langbb=None, printer=CPrinter):
203202
headers = []
204203
includes = []
205204
for i in FindApplications().visit(iet):
206-
header, include = _lower_macro_math(i, langbb)
205+
header, include = _lower_macro_math(i, langbb, printer)
207206
headers.extend(header)
208207
includes.extend(include)
209208

210209
return headers, set(includes) - {None}
211210

212211

213212
@singledispatch
214-
def _lower_macro_math(expr, langbb):
213+
def _lower_macro_math(expr, langbb, printer):
215214
return (), {}
216215

217216

218217
@_lower_macro_math.register(Min)
219218
@_lower_macro_math.register(sympy.Min)
220-
def _(expr, langbb):
219+
def _(expr, langbb, printer):
221220
if has_integer_args(*expr.args):
222221
return (('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),), {}
223222
else:
@@ -226,23 +225,25 @@ def _(expr, langbb):
226225

227226
@_lower_macro_math.register(Max)
228227
@_lower_macro_math.register(sympy.Max)
229-
def _(expr, langbb):
228+
def _(expr, langbb, printer):
230229
if has_integer_args(*expr.args):
231230
return (('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))')),), {}
232231
else:
233232
return (), as_tuple(langbb.get('header-math'))
234233

235234

236235
@_lower_macro_math.register(SafeInv)
237-
def _(expr, langbb):
236+
def _(expr, langbb, printer):
238237
try:
239-
eps = np.finfo(expr.base.dtype).resolution**2
240-
except ValueError:
241-
warning(f"dtype not recognized in SafeInv for {expr.base}, assuming float32")
242-
eps = np.finfo(np.float32).resolution**2
243-
b = Cast('b', dtype=np.float32)
238+
dtype = expr.base.dtype
239+
except (AttributeError, ValueError):
240+
dtype = np.float32
241+
eps = np.finfo(dtype).resolution**2
242+
b = printer()._print(Cast('b', dtype=dtype))
243+
ext = 'F' if dtype is np.float32 else ''
244244
return (('SAFEINV(a, b)',
245-
f'(((a) < {eps}F || ({b}) < {eps}F) ? (0.0F) : ((1.0F) / (a)))'),), {}
245+
f'(((a) < {eps}{ext} || ({b}) < {eps}{ext}) ? '
246+
f'(0.0{ext}) : ((1.0{ext}) / (a)))'),), {}
246247

247248

248249
@iet_pass

devito/symbolics/inspection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from devito.symbolics.extended_sympy import (CallFromPointer, Cast,
1414
DefFunction, ReservedWord)
1515
from devito.symbolics.queries import q_routine
16-
from devito.tools import as_tuple, prod
16+
from devito.tools import as_tuple, prod, is_integer
1717
from devito.tools.dtypes_lowering import infer_dtype
1818

1919
__all__ = ['compare_ops', 'estimate_cost', 'has_integer_args', 'sympy_dtype']
@@ -287,13 +287,17 @@ def has_integer_args(*args):
287287
try:
288288
return np.issubdtype(args[0].dtype, np.integer)
289289
except AttributeError:
290-
return args[0].is_integer
290+
return is_integer(args[0])
291291

292292
res = True
293293
for a in args:
294294
try:
295295
if isinstance(a, INT):
296296
res = res and True
297+
elif is_integer(a):
298+
res = res and True
299+
elif has_integer_args(a):
300+
res = res and True
297301
elif len(a.args) > 0:
298302
res = res and has_integer_args(*a.args)
299303
else:

0 commit comments

Comments
 (0)