Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from graph_optimizer.core import GraphOptimizer, OptimizationContext
from graph_optimizer.utils import create_node, make_output_shapes_attr, create_const_node
from graph_optimizer.transforms.vectorize import PackVectorizePass
from graph_optimizer.transforms.scalar import AlgebraicSimplifyPass, ConstantFoldPass
from graph_optimizer.transforms.scalar import (
ConstantFoldPass,
SimplifyAddPass,
SimplifySubPass,
SimplifyMulPass,
SimplifyDivPass,
)
from tensorflow.core.framework import attr_value_pb2
from graph_optimizer.utils.logger import set_log_level
import logging
Expand Down
14 changes: 6 additions & 8 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import tensorflow.compat.v1 as tf
import numpy as np
from graph_optimizer.core import GraphOptimizer
from graph_optimizer.transforms.scalar.algebraic_simplify import AlgebraicSimplifyPass
from graph_optimizer.core import GraphOptimizer, PassRegistry
from graph_optimizer.utils.graph_utils import create_node, create_const_node

from graph_optimizer.transforms.scalar.constant_fold import ConstantFoldPass

def test_shape_breakage():
# Construct a graph:
Expand All @@ -27,7 +26,7 @@ def test_shape_breakage():
graph_def.node.extend([x, eq, sink])

optimizer = GraphOptimizer(graph_def)
pass_instance = AlgebraicSimplifyPass()
pass_instance = PassRegistry.get_pass("simplify_redundant_comparison")

print("Before optimization:")
print(f"eq_node: {optimizer.nodes['eq_node'].op}")
Expand All @@ -36,8 +35,8 @@ def test_shape_breakage():
pass_instance.transform(optimizer, protected_nodes=["sink"])

print("\nAfter optimization:")
if "eq_node_bool" in optimizer.nodes:
new_node = optimizer.nodes["eq_node_bool"]
if "eq_node" in optimizer.nodes and optimizer.nodes["eq_node"].op == "Const":
new_node = optimizer.nodes["eq_node"]
print(f"Replacement node op: {new_node.op}")

# Check shape in top-level or in tensor
Expand Down Expand Up @@ -68,7 +67,6 @@ def test_div_zero():
graph_def.node.extend([c1, c0, div])

optimizer = GraphOptimizer(graph_def)
from graph_optimizer.transforms.scalar.constant_fold import ConstantFoldPass

pass_instance = ConstantFoldPass()

Expand Down Expand Up @@ -97,7 +95,7 @@ def test_broadcasting_safety():
graph_def.node.extend([x, zero, add, sink])

optimizer = GraphOptimizer(graph_def)
pass_instance = AlgebraicSimplifyPass()
pass_instance = PassRegistry.get_pass("simplify_add")

print("\n--- Broadcasting Safety Test ---")
pass_instance.transform(optimizer, protected_nodes=["sink"])
Expand Down
Loading