diff --git a/core.py b/core.py index 25d8232..362f051 100644 --- a/core.py +++ b/core.py @@ -321,6 +321,11 @@ def get_node_shape(self, node_or_name): if "shape" in node.attr: return [dim.size for dim in node.attr["shape"].shape.dim] + if node.op == "Const" and "value" in node.attr: + tensor = node.attr["value"].tensor + if tensor.HasField("tensor_shape"): + return [d.size for d in tensor.tensor_shape.dim] + return None def get_node_rank(self, node_or_name): @@ -1136,9 +1141,11 @@ def CommutativeOp( def ConstValue(value, alias=None): """Matches a Const node with a specific value.""" + import numpy as np def check_value(unwrapped_value): - return unwrapped_value == value + # Use np.all() for element-wise comparison on arrays + return np.all(np.equal(unwrapped_value, value)) return Op("Const", attrs={"value": check_value}, alias=alias) diff --git a/transforms/scalar/algebraic_simplify.py b/transforms/scalar/algebraic_simplify.py index 76c94af..154958d 100644 --- a/transforms/scalar/algebraic_simplify.py +++ b/transforms/scalar/algebraic_simplify.py @@ -67,357 +67,280 @@ """ from __future__ import annotations - +import numpy as np from graph_optimizer.core import ( + Any, + BasePass, + CommutativeOp, + ConstValue, Op, PassRegistry, - PatternRewritePass, - Any, RewriteResult, ) -from graph_optimizer.utils.graph_utils import create_node, create_const_node +from graph_optimizer.utils.graph_utils import create_const_node, create_node from graph_optimizer.utils.logger import logger as logging -import numpy as np - @PassRegistry.register("algebraic_simplify", opt_level=1, priority=7) -class AlgebraicSimplifyPass(PatternRewritePass): - """ - Applies algebraic identities to simplify expressions. - """ - +class AlgebraicSimplifyPass(BasePass): def __init__(self): - # We'll handle multiple patterns manually in _rewrite - pattern = Any(alias="op") # fallback, we check inside - super().__init__(pattern, self._rewrite, name="AlgebraicSimplify") - - def _rewrite(self, match, optimizer): - node = match.matched_nodes["op"] - op_type = node.op - inputs = list(node.input) - name = node.name - - def _mapped_result(target_name): - return RewriteResult(new_nodes=[], node_mapping={name: target_name}) - - def _new_node_result(new_node): - return RewriteResult( - new_nodes=[new_node], node_mapping={name: new_node.name} - ) - - # Helper to create True/False const - def _bool_const(val): - return _new_node_result( - create_const_node(name + "_bool", value=val, dtype="bool", shape=[]) - ) - - # Helper to get node object ignoring output index - def _get_node(name): - real_name = name.split(":")[0] - return optimizer.nodes.get(real_name) - - # Helper to check if a node is Const with given value (broadcast-safe) - def _is_const(node_name, value): - node = _get_node(node_name) - if node is None: - return False - if node.op != "Const": - return False - val = optimizer.get_node_attr(node, "value") - # Check if all elements are equal to the target value - return np.all(np.equal(val, value)) - - # Helper to get shape of a node - def _get_shape(node_name): - node = _get_node(node_name) - if node is None: - return None - # Check for shape attribute (Placeholder, etc.) - if "shape" in node.attr: - return [d.size for d in node.attr["shape"].shape.dim] - # Check for Const value shape - if node.op == "Const" and "value" in node.attr: - tensor = node.attr["value"].tensor - if tensor.HasField("tensor_shape"): - return [d.size for d in tensor.tensor_shape.dim] + super().__init__(name="AlgebraicSimplify", iterative=True) + self.rewrite_rules = [ + (CommutativeOp("Add", Any(alias="x"), ConstValue(0, alias="const"), alias="op"), self._rewrite_add_zero), + (CommutativeOp("Add", Any(alias="x"), Op("Neg", Any(alias="y")), alias="op"), self._rewrite_add_neg), + (CommutativeOp("Mul", Any(alias="x"), ConstValue(1, alias="const"), alias="op"), self._rewrite_mul_one), + (CommutativeOp("Mul", Any(alias="x"), ConstValue(0, alias="const"), alias="op"), self._rewrite_mul_zero), + (Op("Mul", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_mul_self_to_square), + (CommutativeOp("LogicalAnd", Any(alias="x"), ConstValue(True, alias="const"), alias="op"), self._rewrite_logicaland_true), + (CommutativeOp("LogicalAnd", Any(alias="x"), ConstValue(False, alias="const"), alias="op"), self._rewrite_logicaland_false), + (Op("LogicalAnd", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_logicaland_self), + (CommutativeOp("LogicalOr", Any(alias="x"), ConstValue(False, alias="const"), alias="op"), self._rewrite_logicalor_false), + (CommutativeOp("LogicalOr", Any(alias="x"), ConstValue(True, alias="const"), alias="op"), self._rewrite_logicalor_true), + (Op("LogicalOr", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_logicalor_self), + (Op("Sub", Any(alias="x"), ConstValue(0, alias="const"), alias="op"), self._rewrite_sub_zero), + (Op("Sub", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_sub_self), + (Op("Div", Any(alias="x"), ConstValue(1, alias="const"), alias="op"), self._rewrite_div_one), + (Op("Div", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_div_self), + (Op("Neg", Op("Neg", Any(alias="x")), alias="op"), self._rewrite_double_neg), + (Op("LogicalNot", Op("LogicalNot", Any(alias="x")), alias="op"), self._rewrite_double_logical_not), + (Op("Abs", Op("Abs", Any(alias="x")), alias="op"), self._rewrite_double_abs), + (Op("Square", Op("Sqrt", Any(alias="x")), alias="op"), self._rewrite_square_sqrt), + (Op("Sqrt", Op("Square", Any(alias="x")), alias="op"), self._rewrite_sqrt_square), + (Op("Pow", Any(alias="x"), ConstValue(1, alias="const"), alias="op"), self._rewrite_pow_one), + (Op("Pow", Any(alias="x"), ConstValue(2, alias="const"), alias="op"), self._rewrite_pow_two), + (Op("Equal", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_equal_self), + (Op("NotEqual", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_not_equal_self), + (Op("Less", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_less_self), + (Op("Greater", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_greater_self), + (Op("LessEqual", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_less_equal_self), + (Op("GreaterEqual", Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_greater_equal_self), + (Op("Select", Any(), Any(alias="x"), Any(alias="y"), alias="op"), self._rewrite_select_self), + (Op("Identity", Any(alias="x"), alias="op"), self._rewrite_identity), + ] + + def _get_shape_safe(self, optimizer, node_or_name): + return optimizer.get_node_shape(node_or_name) + + def _get_broadcast_shape(self, s1, s2): + if s1 is None or s2 is None: return None - - # Helper to check if a node is definitely scalar - def _is_scalar(node_name): - shape = _get_shape(node_name) - return shape == [] - - # Helper to compute broadcast shape of two shapes - def _get_broadcast_shape(s1, s2): - if s1 is None or s2 is None: - return None - if s1 == s2: - return s1 - if not s1: - return s2 - if not s2: - return s1 - - # Simple broadcasting logic - len1, len2 = len(s1), len(s2) - max_len = max(len1, len2) - result = [] - for i in range(max_len): - d1 = s1[len1 - 1 - i] if i < len1 else 1 - d2 = s2[len2 - 1 - i] if i < len2 else 1 - if d1 == d2: - result.append(d1) - elif d1 == 1: - result.append(d2) - elif d2 == 1: - result.append(d1) - else: - return None # Incompatible - return result[::-1] - - # Helper to check if simplification is shape-preserving - def _is_shape_preserving(source_shape, target_shape): - # If both are unknown, assume it's safe (common in simple tests) - if source_shape is None and target_shape is None: - return True - if source_shape is None or target_shape is None: - return False - return source_shape == target_shape - - # Rule: Add(x, 0) or Add(0, x) - if op_type == "Add": - left, right = inputs[0], inputs[1] - s_left, s_right = _get_shape(left), _get_shape(right) - s_res = _get_broadcast_shape(s_left, s_right) - - if _is_const(left, 0) and _is_shape_preserving(s_res, s_right): - return _mapped_result(right) - if _is_const(right, 0) and _is_shape_preserving(s_res, s_left): - return _mapped_result(left) - # Add(x, Neg(x)) -> 0 or Add(Neg(x), x) -> 0 - # Note: This is a simplified check for Neg(x) - for l, r in [(left, right), (right, left)]: - rn = _get_node(r) - if rn and rn.op == "Neg" and rn.input[0] == l: - s = _get_shape(l) - if s is not None: - source = _get_node(l) - dtype = source.attr.get("dtype", "float32") if source else "float32" - return _new_node_result( - create_const_node(name + "_zero", value=0, dtype=dtype, shape=s) - ) - - # Rule: Sub(x, 0) → x - if op_type == "Sub": - left, right = inputs[0], inputs[1] - if _is_const(right, 0) and ( - _is_scalar(right) or _get_shape(right) == _get_shape(left) - ): - return _mapped_result(left) - # Sub(x, x) → 0 - if left == right: - s = _get_shape(left) - if s is not None: - source = _get_node(left) - dtype = source.attr.get("dtype", "float32") if source else "float32" - return _new_node_result( - create_const_node(name + "_zero", value=0, dtype=dtype, shape=s) - ) - - # Rule: Mul(x, 1) or Mul(1, x) - if op_type == "Mul": - left, right = inputs[0], inputs[1] - s_left, s_right = _get_shape(left), _get_shape(right) - s_res = _get_broadcast_shape(s_left, s_right) - - if _is_const(left, 1) and _is_shape_preserving(s_res, s_right): - return _mapped_result(right) - if _is_const(right, 1) and _is_shape_preserving(s_res, s_left): - return _mapped_result(left) - # Mul(x, 0) → 0 - if _is_const(left, 0) or _is_const(right, 0): - if s_res is not None: - source_name = right if _is_const(left, 0) else left - source = _get_node(source_name) - dtype = source.attr.get("dtype", "float32") if source else "float32" - return _new_node_result( - create_const_node( - name + "_zero", value=0, dtype=dtype, shape=s_res - ) - ) - # Mul(x, x) -> Square(x) - if left == right: - return _new_node_result( - create_node("Square", name + "_sq", inputs=[left]) - ) - - # Rule: Div(x, 1) → x - if op_type == "Div": - left, right = inputs[0], inputs[1] - s_left, s_right = _get_shape(left), _get_shape(right) - s_res = _get_broadcast_shape(s_left, s_right) - if _is_const(right, 1) and _is_shape_preserving(s_res, s_left): - return _mapped_result(left) - # Div(x, x) -> 1 - if left == right: - s = _get_shape(left) - if s is not None: - source = _get_node(left) - dtype = source.attr.get("dtype", "float32") if source else "float32" - return _new_node_result( - create_const_node(name + "_one", value=1, dtype=dtype, shape=s) - ) - - # Rule: Neg(Neg(x)) → x - if op_type == "Neg": - inp = _get_node(inputs[0]) - if inp and inp.op == "Neg": - return _mapped_result(inp.input[0]) - - # Rule: LogicalNot(LogicalNot(x)) → x - if op_type == "LogicalNot": - inp = _get_node(inputs[0]) - if inp and inp.op == "LogicalNot": - return _mapped_result(inp.input[0]) - - # Rule: Abs(Abs(x)) → Abs(x) - if op_type == "Abs": - inp = _get_node(inputs[0]) - if inp and inp.op == "Abs": - orig = _get_node(inp.input[0]) - if orig: - return _new_node_result( - create_node("Abs", name + "_abs", inputs=[orig.name]) - ) - - # Rule: Square(Sqrt(x)) → x (domain assumed ok) - if op_type == "Square": - inp = _get_node(inputs[0]) - if inp and inp.op == "Sqrt": - return _mapped_result(inp.input[0]) - - # Rule: Sqrt(Square(x)) → Abs(x) - if op_type == "Sqrt": - inp = _get_node(inputs[0]) - if inp and inp.op == "Square": - orig = _get_node(inp.input[0]) - if orig: - return _new_node_result( - create_node("Abs", name + "_abs", inputs=[orig.name]) - ) - - # Rule: Pow(x, 1) -> x - if op_type == "Pow": - left, right = inputs[0], inputs[1] - s_left, s_right = _get_shape(left), _get_shape(right) - s_res = _get_broadcast_shape(s_left, s_right) - if _is_const(right, 1) and _is_shape_preserving(s_res, s_left): - return _mapped_result(left) - # Pow(x, 2) -> Square(x) - if _is_const(right, 2) and _is_shape_preserving(s_res, s_left): - return _new_node_result( - create_node("Square", name + "_sq", inputs=[left]) - ) - - # Helper for comparison results - def _comparison_const(val): - # Equal(x, x) -> True should have same shape as x (or broadcasted shape) - # If x is [2, 2], result is [2, 2] of True - s = _get_shape(inputs[0]) - if s is None: - return None # Safer to skip if shape unknown - return _new_node_result( - create_const_node(name + "_bool", value=val, dtype="bool", shape=s) - ) - - # Rule: Equal(x, x) → True - if op_type == "Equal": - left, right = inputs[0], inputs[1] - if left == right: - return _comparison_const(True) - - # Rule: NotEqual(x, x) → False - if op_type == "NotEqual": - left, right = inputs[0], inputs[1] - if left == right: - return _comparison_const(False) - - # Rule: Less(x, x) → False ; Greater(x, x) → False - if op_type in ("Less", "Greater") and inputs[0] == inputs[1]: - return _comparison_const(False) - - # Rule: LessEqual(x, x) → True ; GreaterEqual(x, x) → True - if op_type in ("LessEqual", "GreaterEqual") and inputs[0] == inputs[1]: - return _comparison_const(True) - - # Rule: And(x, True) → x ; And(True, x) → x - if op_type == "LogicalAnd": - left, right = inputs[0], inputs[1] - s_left, s_right = _get_shape(left), _get_shape(right) - s_res = _get_broadcast_shape(s_left, s_right) - - if _is_const(left, True) and _is_shape_preserving(s_res, s_right): - return _mapped_result(right) - if _is_const(right, True) and _is_shape_preserving(s_res, s_left): - return _mapped_result(left) - # LogicalAnd(x, x) -> x - if left == right: - return _mapped_result(left) - # LogicalAnd(x, False) -> False - if _is_const(left, False) or _is_const(right, False): - if s_res is not None: - return _new_node_result( - create_const_node(name + "_bool", value=False, dtype="bool", shape=s_res) - ) - - # Rule: Or(x, False) → x ; Or(False, x) → x - if op_type == "LogicalOr": - left, right = inputs[0], inputs[1] - s_left, s_right = _get_shape(left), _get_shape(right) - s_res = _get_broadcast_shape(s_left, s_right) - - if _is_const(left, False) and _is_shape_preserving(s_res, s_right): - return _mapped_result(right) - if _is_const(right, False) and _is_shape_preserving(s_res, s_left): - return _mapped_result(left) - # LogicalOr(x, x) -> x - if left == right: - return _mapped_result(left) - # LogicalOr(x, True) -> True - if _is_const(left, True) or _is_const(right, True): - if s_res is not None: - return _new_node_result( - create_const_node(name + "_bool", value=True, dtype="bool", shape=s_res) - ) - - # Rule: Select(cond, x, x) → x - if op_type == "Select": - if len(inputs) >= 3 and inputs[1] == inputs[2]: - return _mapped_result(inputs[1]) - - # Rule: Identity(x) -> x (bypass or collapse nested Identity) - if op_type == "Identity": - # Skip if protected/output node - if ( - hasattr(optimizer, "protected_nodes") - and name in optimizer.protected_nodes - ): - return None - # Skip ReadVariableOp - if "ReadVariableOp" in name: - return None - # Skip colocation constraint - if "_class" in node.attr: + if s1 == s2: + return s1 + if not s1: + return s2 + if not s2: + return s1 + len1, len2 = len(s1), len(s2) + max_len = max(len1, len2) + result = [] + for i in range(max_len): + d1 = s1[len1 - 1 - i] if i < len1 else 1 + d2 = s2[len2 - 1 - i] if i < len2 else 1 + if d1 == d2: + result.append(d1) + elif d1 == 1: + result.append(d2) + elif d2 == 1: + result.append(d1) + else: return None - # Collapse nested Identity - inp_node = _get_node(inputs[0]) - if inp_node and inp_node.op == "Identity": - inner_input = inp_node.input[0] - new_node = create_node( - "Identity", name + "_collapsed", inputs=[inner_input] - ) - return _new_node_result(new_node) - # Bypass single Identity - return _mapped_result(inputs[0]) - - return None + return result[::-1] + + def _rewrite_add_zero(self, match, optimizer): + op, x, const = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["const"] + s_x = self._get_shape_safe(optimizer, x) + s_const = self._get_shape_safe(optimizer, const) + s_res = self._get_broadcast_shape(s_x, s_const) + if s_res == s_x: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_add_neg(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + s = self._get_shape_safe(optimizer, x) + dtype = x.attr.get("dtype", "float32") + new_node = create_const_node(op.name + "_zero", value=0, dtype=dtype, shape=s) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_mul_one(self, match, optimizer): + op, x, const = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["const"] + s_x = self._get_shape_safe(optimizer, x) + s_const = self._get_shape_safe(optimizer, const) + s_res = self._get_broadcast_shape(s_x, s_const) + if s_res == s_x: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_mul_zero(self, match, optimizer): + op, x, const = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["const"] + s_x = self._get_shape_safe(optimizer, x) + s_const = self._get_shape_safe(optimizer, const) + s_res = self._get_broadcast_shape(s_x, s_const) + if s_res is not None: + dtype = x.attr.get("dtype", "float32") + new_node = create_const_node(op.name + "_zero", value=0, dtype=dtype, shape=s_res) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_mul_self_to_square(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + new_node = create_node("Square", op.name + "_sq", inputs=[x.name]) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_logicaland_true(self, match, optimizer): + op, x, const = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["const"] + s_x = self._get_shape_safe(optimizer, x) + s_const = self._get_shape_safe(optimizer, const) + s_res = self._get_broadcast_shape(s_x, s_const) + if s_res == s_x: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_logicaland_false(self, match, optimizer): + op, x, const = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["const"] + s_x = self._get_shape_safe(optimizer, x) + s_const = self._get_shape_safe(optimizer, const) + s_res = self._get_broadcast_shape(s_x, s_const) + if s_res is not None: + new_node = create_const_node(op.name + "_bool", value=False, dtype="bool", shape=s_res) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_logicaland_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_logicalor_false(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + s_op = self._get_shape_safe(optimizer, op) + s_x = self._get_shape_safe(optimizer, x) + if s_op == s_x: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_logicalor_true(self, match, optimizer): + op = match.matched_nodes["op"] + s_op = self._get_shape_safe(optimizer, op) + new_node = create_const_node(op.name + "_bool", value=True, dtype="bool", shape=s_op) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_logicalor_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_sub_zero(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_sub_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + s = self._get_shape_safe(optimizer, x) + dtype = x.attr.get("dtype", "float32") + new_node = create_const_node(op.name + "_zero", value=0, dtype=dtype, shape=s) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_div_one(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_div_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + s = self._get_shape_safe(optimizer, x) + dtype = x.attr.get("dtype", "float32") + new_node = create_const_node(op.name + "_one", value=1, dtype=dtype, shape=s) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_double_neg(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_double_logical_not(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_double_abs(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + new_node = create_node("Abs", op.name + "_abs", inputs=[x.name]) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_square_sqrt(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_sqrt_square(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + new_node = create_node("Abs", op.name + "_abs", inputs=[x.name]) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_pow_one(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_pow_two(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + new_node = create_node("Square", op.name + "_sq", inputs=[x.name]) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _comparison_const(self, op, x, value, optimizer): + s = self._get_shape_safe(optimizer, x) + new_node = create_const_node(op.name + "_bool", value=value, dtype="bool", shape=s) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + + def _rewrite_equal_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return self._comparison_const(op, x, True, optimizer) + + def _rewrite_not_equal_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return self._comparison_const(op, x, False, optimizer) + + def _rewrite_less_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return self._comparison_const(op, x, False, optimizer) + + def _rewrite_greater_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return self._comparison_const(op, x, False, optimizer) + + def _rewrite_less_equal_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return self._comparison_const(op, x, True, optimizer) + + def _rewrite_greater_equal_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return self._comparison_const(op, x, True, optimizer) + + def _rewrite_select_self(self, match, optimizer): + op, x, y = match.matched_nodes["op"], match.matched_nodes["x"], match.matched_nodes["y"] + if x.name == y.name: + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def _rewrite_identity(self, match, optimizer): + op, x = match.matched_nodes["op"], match.matched_nodes["x"] + if op.name in optimizer.protected_nodes or "ReadVariableOp" in op.name or "_class" in op.attr: + return + if x.op == "Identity": + new_node = create_node("Identity", op.name + "_collapsed", inputs=[x.input[0]]) + return RewriteResult(new_nodes=[new_node], node_mapping={op.name: new_node.name}) + return RewriteResult(new_nodes=[], node_mapping={op.name: x.name}) + + def transform_once(self, optimizer, auto_cleanup=True, protected_nodes=None): + optimizer.clear_transformations() + for pattern, rewriter in self.rewrite_rules: + optimizer.add_transformation(pattern, rewriter) + + new_graph_def, changes = optimizer.match_patterns_once( + pass_name=self.name, + auto_cleanup=auto_cleanup, + protected_nodes=protected_nodes, + ) + + if changes > 0: + optimizer.load_state(new_graph_def) + + return changes