Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
)
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min
from devito.types import (
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift
)

__all__ = [
'ClusterizedEq',
Expand Down Expand Up @@ -222,7 +224,7 @@ def __new__(cls, *args, **kwargs):
relations=ordering.relations, mode='partial')
ispace = IterationSpace(intervals, iterators)

# Construct the conditionals and replace the ConditionalDimensions in `expr`
# Construct the conditionals
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should place this whole block of code, which constructs/lowers the conditionals, into its own separate functions, and a docstring with some examples

conditionals = {}
for d in ordering:
if not d.is_Conditional:
Expand All @@ -234,13 +236,31 @@ def __new__(cls, *args, **kwargs):
if d._factor is not None:
cond = d.relation(cond, GuardFactor(d))
conditionals[d] = cond

# Replace the ConditionalDimensions in `expr`
for d, cond in conditionals.items():
# Replace dimension with index
index = d.index
if d.condition is not None and d in expr.free_symbols:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})

conditionals = frozendict(conditionals)
index = index - relational_min(cond, d.parent)
shift = relational_shift(cond, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})

# Merge conditionals when possible. E.g if we have an implicit_dim
# and there is a dimension with the same parent, we ca merged
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension

"ca merged"

"their conditions"

you could also make the example a bit more practical

# its condition
for d in input_expr.implicit_dims:
if d not in conditionals:
continue
for cd in dict(conditionals):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list(...) is fine

if cd.parent == d.parent and cd != d:
cond = conditionals.pop(d)
mode = cd.relation and d.relation
if issubclass(mode, sympy.Or):
conditionals[d] = cond
conditionals.pop(cd)
else:
conditionals[cd] = mode(cond, conditionals[cd])
break

# Lower all Differentiable operations into SymPy operations
rhs = diff2sympy(expr.rhs)
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/support/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __lt__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -164,6 +165,7 @@ def __gt__(self, other):
return True
elif q_negative(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -203,6 +205,7 @@ def __le__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

# Note: unlike `__lt__`, if we end up here, then *it is* <=. For example,
Expand Down
13 changes: 9 additions & 4 deletions devito/passes/clusters/buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from itertools import chain

import numpy as np
from sympy import S
from sympy import Mod, S, simplify

from devito.exceptions import CompilationError
from devito.ir import (
Expand Down Expand Up @@ -203,7 +203,7 @@ def callback(self, clusters, prefix):
guards = c.guards

properties = c.properties.sequentialize(d)
if not isinstance(d, BufferDimension):
if not isinstance(d, BufferDimension) and c.guards[d].has(Mod):
properties = properties.prefetchable(d)
# `c` may be a HaloTouch Cluster, so with no vision of the `bdims`
properties = properties.parallelize(v.bdims).affine(v.bdims)
Expand Down Expand Up @@ -377,7 +377,12 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
buffer, = buffers
xd = buffer.indices[dim]
else:
size = infer_buffer_size(f, dim, clusters)
if len({c.guards[dim.root] for c in clusters}) > 1:
# Multiple clusters with different guards,
# will lead to conflicts in asynchrony with multiple (modulo) slots
size = 1
else:
size = infer_buffer_size(f, dim, clusters)

if async_degree is not None:
if async_degree < size:
Expand Down Expand Up @@ -775,7 +780,7 @@ def infer_buffer_size(f, dim, clusters):
slots = [Vector(i) for i in slots]
size = int((vmax(*slots) - vmin(*slots) + 1)[0])

return size
return simplify(size)


def offset_from_centre(d, indices):
Expand Down
33 changes: 32 additions & 1 deletion devito/types/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import sympy

__all__ = ['Ge', 'Gt', 'Le', 'Lt', 'Ne', 'relational_max', 'relational_min']
__all__ = ['Ge', 'Gt', 'Le', 'Lt', 'Ne', 'relational_max', 'relational_min',
'relational_shift']


class AbstractRel:
Expand Down Expand Up @@ -291,3 +292,33 @@ def _(expr, s):
return expr.gts
else:
return sympy.S.Infinity


def relational_shift(expr, s):
"""
Infer shift incurred by the expression. Generally only
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use an example here to quickly visualise what's it trying to do

applies when a CondEq is used as it adds a single value.
"""
if not expr.has(s):
return 0

return _relational_shift(expr, s)


@singledispatch
def _relational_shift(s, expr):
return 0


@_relational_shift.register(sympy.Or)
@_relational_shift.register(sympy.And)
def _(expr, s):
return sum([_relational_shift(e, s) for e in expr.args])


@_relational_shift.register(sympy.Eq)
def _(expr, s):
if isinstance(expr.lhs, sympy.Mod):
return 0
from devito.symbolics.extended_dtypes import INT
return INT(Ge(*expr.args))
85 changes: 80 additions & 5 deletions tests/test_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def test_buffer_reuse():
assert all(np.all(vsave.data[i-1] == i + 1) for i in range(1, nt + 1))


def test_multi_cond():
def test_multi_cond_v0():
grid = Grid((3, 3))
nt = 5

Expand All @@ -774,14 +774,89 @@ def test_multi_cond():
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)

eqs = [Eq(T, grid.time_dim)]
# this to save times from 0 to nt - 2
# This saves
# - All subsampled times since ct1 is the dimension of f
# - The last time step (ntmod - 2) through ctend (since it's set as ct1 or ctend)
eqs.append(Eq(f, T, implicit_dims=ctend))

# run operator with buffering
op = Operator(eqs, opt='buffering')
op.apply(time_m=0, time_M=ntmod-2)

for i in range(nt-1):
assert np.allclose(f.data[i], i*2)
assert np.allclose(f.data[nt-1], ntmod - 2)


def test_multi_cond_v1():
grid = Grid((3, 3))
nt = 5

x, y = grid.dimensions

factor = 2
ntmod = (nt - 1) * factor + 1

ct1 = ConditionalDimension(name="ct1", parent=grid.time_dim,
factor=factor, relation=Or,
condition=CondEq(grid.time_dim, ntmod - 2))

f = TimeFunction(grid=grid, name='f', time_order=0,
space_order=0, save=nt, time_dim=ct1)
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)

eqs = [Eq(T, grid.time_dim)]
# This saves
# - All subsampled times since ct1 is the dimension of f with factor 2
# - The last time step (ntmod - 2) since ct1 also has the condition for ntmod - 2
eqs.append(Eq(f, T))
# this to save the last time sample nt - 1
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))

# run operator with buffering
op = Operator(eqs, opt='buffering')
op.apply(time_m=0, time_M=ntmod-2)

for i in range(nt):
for i in range(nt-1):
assert np.allclose(f.data[i], i*2)
assert np.allclose(f.data[nt-1], ntmod - 2)


@pytest.mark.parametrize("factor", [1, 2, 3])
def test_buffering_multi_cond(factor):
grid = Grid((16, 16))

nt = 5
ntmod = (nt - 1) * factor + 1

ct0 = ConditionalDimension(name="ct0", parent=grid.time_dim, factor=factor,
relation=Or)
f = TimeFunction(grid=grid, name='f', time_order=0, space_order=0,
time_dim=ct0, save=nt)
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)

eqs = []
eqs.append(Eq(T, grid.time_dim))

# conditional dimension for the last sample in the operator
ctend = ConditionalDimension(name="ctend", parent=grid.time_dim,
condition=CondEq(grid.time_dim, ntmod - 2),
relation=Or)

eqs.append(Eq(f, T)) # this to save times from 0 to nt - 2
# this to save the last time sample nt - 1
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))

# run operator with serialization
op = Operator(eqs, opt='buffering')
op.apply(time_m=0, time_M=ntmod-2)

# Now run backward as well with buffering

f_all = TimeFunction(grid=grid, name='f_all', time_order=0,
space_order=0, time_dim=ct0, save=nt)

eq_all = [Eq(f_all, f)]
eq_all.append(Eq(f_all.forward, f.forward, implicit_dims=ctend))
op_all = Operator(eq_all, opt='buffering')
op_all.apply(time_m=0, time_M=ntmod-2)

assert np.allclose(f_all.data[:, 11, 11], factor * np.arange(nt))
Loading