66
77from devito .finite_differences import Max , Min
88from devito .finite_differences .differentiable import SafeInv
9- from devito .logger import warning
109from devito .ir import (Any , Forward , DummyExpr , Iteration , EmptyList , Prodder ,
1110 FindApplications , FindNodes , FindSymbols , Transformer ,
1211 Uxreplace , filter_iterations , retrieve_iteration_tree ,
@@ -155,7 +154,7 @@ def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs)
155154 for define , expr in headers )
156155
157156 # Generate Macros from higher-level SymPy objects
158- mheaders , includes = _generate_macros_math (iet , langbb = langbb )
157+ mheaders , includes = _generate_macros_math (iet , langbb = langbb , printer = printer )
159158 includes = sorted (includes , key = str )
160159 headers .extend (sorted (mheaders , key = str ))
161160
@@ -199,25 +198,25 @@ def _generate_macros_findexeds(iet, sregistry=None, tracker=None, **kwargs):
199198 return iet
200199
201200
202- def _generate_macros_math (iet , langbb = None ):
201+ def _generate_macros_math (iet , langbb = None , printer = CPrinter ):
203202 headers = []
204203 includes = []
205204 for i in FindApplications ().visit (iet ):
206- header , include = _lower_macro_math (i , langbb )
205+ header , include = _lower_macro_math (i , langbb , printer )
207206 headers .extend (header )
208207 includes .extend (include )
209208
210209 return headers , set (includes ) - {None }
211210
212211
213212@singledispatch
214- def _lower_macro_math (expr , langbb ):
213+ def _lower_macro_math (expr , langbb , printer ):
215214 return (), {}
216215
217216
218217@_lower_macro_math .register (Min )
219218@_lower_macro_math .register (sympy .Min )
220- def _ (expr , langbb ):
219+ def _ (expr , langbb , printer ):
221220 if has_integer_args (* expr .args ):
222221 return (('MIN(a,b)' , ('(((a) < (b)) ? (a) : (b))' )),), {}
223222 else :
@@ -226,23 +225,25 @@ def _(expr, langbb):
226225
227226@_lower_macro_math .register (Max )
228227@_lower_macro_math .register (sympy .Max )
229- def _ (expr , langbb ):
228+ def _ (expr , langbb , printer ):
230229 if has_integer_args (* expr .args ):
231230 return (('MAX(a,b)' , ('(((a) > (b)) ? (a) : (b))' )),), {}
232231 else :
233232 return (), as_tuple (langbb .get ('header-math' ))
234233
235234
236235@_lower_macro_math .register (SafeInv )
237- def _ (expr , langbb ):
236+ def _ (expr , langbb , printer ):
238237 try :
239- eps = np .finfo (expr .base .dtype ).resolution ** 2
240- except ValueError :
241- warning (f"dtype not recognized in SafeInv for { expr .base } , assuming float32" )
242- eps = np .finfo (np .float32 ).resolution ** 2
243- b = Cast ('b' , dtype = np .float32 )
238+ dtype = expr .base .dtype
239+ except (AttributeError , ValueError ):
240+ dtype = np .float32
241+ eps = np .finfo (dtype ).resolution ** 2
242+ b = printer ()._print (Cast ('b' , dtype = dtype ))
243+ ext = 'F' if dtype is np .float32 else ''
244244 return (('SAFEINV(a, b)' ,
245- f'(((a) < { eps } F || ({ b } ) < { eps } F) ? (0.0F) : ((1.0F) / (a)))' ),), {}
245+ f'(((a) < { eps } { ext } || ({ b } ) < { eps } { ext } ) ? '
246+ f'(0.0{ ext } ) : ((1.0{ ext } ) / (a)))' ),), {}
246247
247248
248249@iet_pass
0 commit comments