From c9c64324ac7655a72ee0034a41f4c380bef94f5c Mon Sep 17 00:00:00 2001 From: Victor Li Date: Wed, 12 Nov 2025 18:18:46 -0800 Subject: [PATCH 1/3] porting unity substitutions from old branch --- .../operator_attribute_constraint.h | 4 +- .../tensor_pattern/tensor_attribute_pattern.h | 4 +- .../tensor_attribute_value.variant.toml | 8 + .../substitutions/unity_substitution_set.h | 37 +- .../operator_pattern/get_attribute.cc | 22 +- .../operator_attribute_constraint.cc | 5 +- .../materialize_operator_from_attrs_map.cc | 63 +- .../tensor_attribute_pattern.cc | 2 +- .../substitutions/unity_substitution_set.cc | 628 ++++++- .../substitutions/unity_substitution_set.cc | 1530 ++++++++++++++++- 10 files changed, 2205 insertions(+), 98 deletions(-) diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h index c2c11fac51..1985e5c03c 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h @@ -9,8 +9,8 @@ OperatorAttributeConstraint op_type_equals_constraint(OperatorType); OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey, OperatorAttributeValue const &); -OperatorAttributeConstraint - op_attr_key_divisible_by(OperatorAttributeKey, nonnegative_int denominator); +OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey, + positive_int denominator); OperatorAttributeConstraint make_equals_constraint(OperatorAttributeExpr const &, OperatorAttributeValue const &); diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h index c1e28f8d8f..99e80eaa7f 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H #include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" -#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/positive_int/positive_int.h" namespace FlexFlow { TensorAttributePattern tensor_attribute_pattern_match_all(); TensorAttributePattern - tensor_attr_pattern_require_num_dims(nonnegative_int num_dims); + tensor_attr_pattern_require_num_dims(positive_int num_dims); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml index d2b931fb2d..9d0a2d4bee 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml @@ -13,10 +13,18 @@ includes = [ "utils/hash/vector.h", "utils/fmt/vector.h", "utils/nonnegative_int/nonnegative_int.h", + "utils/positive_int/positive_int.h", ] [[values]] type = "::FlexFlow::nonnegative_int" +[[values]] +type = "::FlexFlow::positive_int" + [[values]] type = "std::vector<::FlexFlow::nonnegative_int>" + +[[values]] +type = "std::vector<::FlexFlow::positive_int>" + diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index 183f76ac8a..f8edbc6179 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -10,36 +10,25 @@ namespace FlexFlow { std::vector get_substitution_set(MachineSpecification const &resources); -Substitution create_combine_inception(nonnegative_int num_convs, - nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_combine_concat(nonnegative_int num_inputs, - nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_replicate_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, +Substitution create_replicate_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias); -Substitution create_partition_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, - Activation activation, +Substitution create_partition_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias); -Substitution create_partition_conv2d_combine(nonnegative_int num_dims, - nonnegative_int degree); -Substitution create_partition_attention_combine(nonnegative_int num_heads, - nonnegative_int degree); -Substitution create_replicate_attention_reduce(nonnegative_int num_heads, - nonnegative_int degree); +Substitution create_partition_conv2d_combine(positive_int num_dims, + positive_int degree); +Substitution create_partition_attention_combine(positive_int num_heads, + positive_int degree); +Substitution create_replicate_attention_reduce(positive_int num_heads, + positive_int degree); Substitution create_partition_add_combine(ff_dim_t parallel_dim, - nonnegative_int degree); + positive_int degree); Substitution create_partition_relu_combine(ff_dim_t parallel_dim, - nonnegative_int degree); -Substitution create_partition_concat_combine(nonnegative_int num_inputs, - ff_dim_t concat_dim, - ff_dim_t parallel_dim, - nonnegative_int degree); + positive_int degree); Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, ff_dim_t partition_dim, - nonnegative_int degree); + positive_int degree); Substitution create_fuse_linear_activation(Activation activation); } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index cb733e16ff..a7575ae837 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -83,6 +83,8 @@ std::optional get_attribute(ConcatAttrs const &p, std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OUT_CHANNELS: + return OperatorAttributeValue{p.out_channels}; case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::KERNEL_H: @@ -113,6 +115,12 @@ std::optional get_attribute(ElementBinaryAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::DATA_TYPE: + return OperatorAttributeValue{p.compute_type}; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + return OperatorAttributeValue{p.should_broadcast_lhs}; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + return OperatorAttributeValue{p.should_broadcast_rhs}; default: return std::nullopt; } @@ -123,6 +131,8 @@ std::optional get_attribute(ElementUnaryAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::SCALAR: + return OperatorAttributeValue{p.scalar}; default: return std::nullopt; } @@ -227,10 +237,20 @@ std::optional switch (key) { case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; + case OperatorAttributeKey::EMBED_DIM: + return OperatorAttributeValue{p.embed_dim}; + case OperatorAttributeKey::KDIM: + return OperatorAttributeValue{p.kdim}; + case OperatorAttributeKey::VDIM: + return OperatorAttributeValue{p.vdim}; case OperatorAttributeKey::NUM_HEADS: return OperatorAttributeValue{p.num_heads}; - case OperatorAttributeKey::USE_BIAS: + case OperatorAttributeKey::BIAS: return OperatorAttributeValue{p.bias}; + case OperatorAttributeKey::ADD_BIAS_KV: + return OperatorAttributeValue{p.add_bias_kv}; + case OperatorAttributeKey::ADD_ZERO_ATTN: + return OperatorAttributeValue{p.add_bias_kv}; case OperatorAttributeKey::DROPOUT: return OperatorAttributeValue{p.dropout}; default: diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc index 29aef07e3a..a45af1e7d4 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc @@ -20,9 +20,8 @@ OperatorAttributeConstraint }; } -OperatorAttributeConstraint - op_attr_key_divisible_by(OperatorAttributeKey key, - nonnegative_int denominator) { +OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey key, + positive_int denominator) { return OperatorAttributeConstraint{ ConstraintType::DIVISIBLE_BY, OperatorAttributeExpr{key}, diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc index cf5a1e17f9..fcc20b92a0 100644 --- a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -61,7 +61,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::NOOP: case OperatorType::INPUT: case OperatorType::WEIGHT: - case OperatorType::CONV2D: case OperatorType::DROPOUT: case OperatorType::LINEAR: return PCGOperatorAttrs{LinearAttrs{ @@ -75,19 +74,72 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( acc.get>( OperatorAttributeKey::REGULARIZER), }}; + case OperatorType::CONV2D: + return PCGOperatorAttrs{Conv2DAttrs{ + /*out_channels=*/acc.get( + OperatorAttributeKey::OUT_CHANNELS), + /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*padding_h=*/ + acc.get(OperatorAttributeKey::PADDING_H), + /*padding_w=*/ + acc.get(OperatorAttributeKey::PADDING_W), + /*groups=*/acc.get(OperatorAttributeKey::GROUPS), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION), + /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), + }}; + case OperatorType::RELU: + return PCGOperatorAttrs{ElementUnaryAttrs{ + acc.get(OperatorAttributeKey::OP_TYPE), + acc.get>(OperatorAttributeKey::SCALAR), + }}; + case OperatorType::SOFTMAX: + return PCGOperatorAttrs{SoftmaxAttrs{ + acc.get(OperatorAttributeKey::AXIS), + }}; + case OperatorType::EW_ADD: + return PCGOperatorAttrs{ElementBinaryAttrs{ + acc.get(OperatorAttributeKey::OP_TYPE), + acc.get(OperatorAttributeKey::DATA_TYPE), + acc.get(OperatorAttributeKey::SHOULD_BROADCAST_LHS), + acc.get(OperatorAttributeKey::SHOULD_BROADCAST_LHS), + }}; + case OperatorType::REPLICATE: + return PCGOperatorAttrs{ReplicateAttrs{ + /*replicate_degree=*/acc.get( + OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::REPARTITION: + return PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/acc.get( + OperatorAttributeKey::PARALLEL_DIM), + /*repartition_Degree=*/ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::COMBINE: + return PCGOperatorAttrs{CombineAttrs{ + /*combine_dim=*/acc.get(OperatorAttributeKey::PARALLEL_DIM), + /*combine_degree=*/ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; + case OperatorType::REDUCTION: + return PCGOperatorAttrs{ReductionAttrs{ + acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + }}; case OperatorType::BATCHMATMUL: case OperatorType::SCALAR_MULTIPLY: case OperatorType::SCALAR_ADD: case OperatorType::SCALAR_FLOOR_DIV: case OperatorType::SCALAR_TRUE_DIV: case OperatorType::SCALAR_SUB: - case OperatorType::RELU: case OperatorType::IDENTITY: case OperatorType::SIGMOID: case OperatorType::TANH: case OperatorType::ELU: case OperatorType::FLAT: - case OperatorType::SOFTMAX: case OperatorType::BATCHNORM: case OperatorType::CONCAT: case OperatorType::SPLIT: @@ -96,7 +148,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::RESHAPE: case OperatorType::REVERSE: case OperatorType::TRANSPOSE: - case OperatorType::EW_ADD: case OperatorType::EW_MUL: case OperatorType::MATMUL: case OperatorType::MUL: @@ -143,10 +194,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::LAYERNORM: case OperatorType::GATHER: case OperatorType::BROADCAST: - case OperatorType::REPARTITION: - case OperatorType::COMBINE: - case OperatorType::REPLICATE: - case OperatorType::REDUCTION: case OperatorType::BATCH: case OperatorType::PIPELINE: case OperatorType::FUSED_PARALLEL: diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc index e1c1fe7cf6..f224c6883d 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc @@ -8,7 +8,7 @@ TensorAttributePattern tensor_attribute_pattern_match_all() { } TensorAttributePattern - tensor_attr_pattern_require_num_dims(nonnegative_int num_dims) { + tensor_attr_pattern_require_num_dims(positive_int num_dims) { return TensorAttributePattern{{ TensorAttributeConstraint{ ConstraintType::EQUAL, diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 4b00cdd95f..99e528cc9e 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -13,14 +13,45 @@ namespace FlexFlow { std::vector get_substitution_set(MachineSpecification const &resources) { std::vector substitutions; - for (nonnegative_int num_dims : - nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) { - for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources); - degree *= 2_n) { + for (positive_int dim = 1_p; dim <= positive_int{MAX_TENSOR_DIM}; dim++) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { substitutions.push_back( - create_replicate_linear_combine(num_dims, degree, true)); + create_replicate_linear_combine(dim, degree, true)); substitutions.push_back( - create_replicate_linear_combine(num_dims, degree, false)); + create_replicate_linear_combine(dim, degree, false)); + substitutions.push_back( + create_partition_linear_combine(dim, degree, true)); + substitutions.push_back( + create_partition_linear_combine(dim, degree, false)); + substitutions.push_back(create_partition_relu_combine( + ff_dim_t{dim.nonnegative_int_from_positive_int()}, degree)); + substitutions.push_back(create_partition_add_combine( + ff_dim_t{dim.nonnegative_int_from_positive_int()}, degree)); + substitutions.push_back(create_partition_attention_combine(dim, degree)); + substitutions.push_back(create_replicate_attention_reduce(dim, degree)); + } + } + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { + substitutions.push_back(create_partition_conv2d_combine(4_p, degree)); + } + + for (positive_int partition_dim = 1_p; + partition_dim <= positive_int{MAX_TENSOR_DIM}; + partition_dim++) { + for (positive_int softmax_dim = 1_p; + softmax_dim <= positive_int{MAX_TENSOR_DIM}; + softmax_dim++) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { + if (partition_dim != softmax_dim) { + substitutions.push_back(create_partition_softmax_combine( + ff_dim_t{partition_dim.nonnegative_int_from_positive_int()}, + ff_dim_t{softmax_dim.nonnegative_int_from_positive_int()}, + degree)); + } + } } } substitutions.push_back(create_fuse_linear_activation(Activation::RELU)); @@ -30,20 +61,8 @@ std::vector return substitutions; } -Substitution create_combine_inception(nonnegative_int num_convs, - nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} - -Substitution create_combine_concat(nonnegative_int num_inputs, - nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} - -Substitution create_replicate_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, +Substitution create_replicate_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias) { SubstitutionBuilder b; @@ -63,15 +82,14 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, op_type_equals_constraint(OperatorType::LINEAR), op_attr_key_equals(OperatorAttributeKey::BIAS, OperatorAttributeValue{use_bias}), - op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, - nonnegative_int{degree}), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_linear_output = get_only(b.add_pattern_node( - linear_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(nonnegative_int{num_dims})}, - "linear")); + PatternValue p_linear_output = get_only( + b.add_pattern_node(linear_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(num_dims)}, + "linear")); OutputOperatorAttrsAssignment replicate_input_expr = OutputOperatorAttrsAssignment{ @@ -132,7 +150,7 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, set_attr_to_constant( OperatorAttributeKey::PARALLEL_DIM, OperatorAttributeValue{ff_dim_t{ - nonnegative_int{num_dims.unwrap_nonnegative() - 1}, + nonnegative_int{num_dims.int_from_positive_int() - 1}, }}), }, }; @@ -144,49 +162,547 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims, return b.get_substitution(); } -Substitution create_partition_linear_combine(nonnegative_int num_dims, - nonnegative_int degree, - Activation activation, +Substitution create_partition_linear_combine(positive_int num_dims, + positive_int degree, bool use_bias) { - NOT_IMPLEMENTED(); + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input, p_weight}; + + std::optional o_bias = std::nullopt; + if (use_bias) { + std::pair bias = + b.add_input(tensor_attribute_pattern_match_all()); + p_inputs.push_back(bias.first); + o_bias = bias.second; + } + + OperatorAttributePattern linear_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals(OperatorAttributeKey::BIAS, + OperatorAttributeValue{use_bias}), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_linear_output = get_only( + b.add_pattern_node(linear_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(num_dims)}, + "linear")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{0_n}}), + }}; + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + OutputOperatorAttrsAssignment replicate_weights_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + OutputGraphExprValue o_replicate_weights_output = get_only( + b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n)); + + std::vector o_linear_inputs = { + o_partition_input_output, o_replicate_weights_output}; + + if (use_bias) { + OutputOperatorAttrsAssignment replicate_bias_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + OutputGraphExprValue o_replicate_bias_output = get_only( + b.add_output_graph_node(replicate_bias_expr, {o_bias.value()}, 1_n)); + o_linear_inputs.push_back(o_replicate_bias_output); + } + + OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("linear"), + {}, + }; + OutputGraphExprValue o_linear_output = + get_only(b.add_output_graph_node(linear_expr, o_linear_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant( + OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}, + }}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_linear_output}, 1_n)); + + b.equate_outputs(p_linear_output, o_combine_output); + + return b.get_substitution(); } -Substitution create_partition_conv2d_combine(nonnegative_int num_dims, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +Substitution create_partition_conv2d_combine(positive_int num_dims, + positive_int degree) { + if (num_dims != 4_p) { + throw mk_runtime_error(fmt::format("num_dims must be 4, not {}", num_dims)); + } + + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input, p_weight}; + + OperatorAttributePattern conv2d_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::CONV2D), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_conv2d_output = get_only( + b.add_pattern_node(conv2d_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(num_dims)}, + "conv2d")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{0_n}}), + }}; + + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + OutputOperatorAttrsAssignment replicate_weights_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + OutputGraphExprValue o_replicate_weights_output = get_only( + b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n)); + + std::vector o_conv2d_inputs = { + o_partition_input_output, o_replicate_weights_output}; + + OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("conv2d"), + {}, + }; + OutputGraphExprValue o_conv2d_output = + get_only(b.add_output_graph_node(conv2d_expr, o_conv2d_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant( + OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}, + }}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_conv2d_output}, 1_n)); + + b.equate_outputs(p_conv2d_output, o_combine_output); + + return b.get_substitution(); } -Substitution create_partition_attention_combine(nonnegative_int num_heads, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +Substitution create_partition_attention_combine(positive_int num_heads, + positive_int degree) { + + SubstitutionBuilder b; + + auto [p_query_input, o_query_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_key_input, o_key_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_value_input, o_value_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weights, o_weights] = + b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = { + p_query_input, p_key_input, p_value_input, p_weights}; + + OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), + }}; + + PatternValue p_attention_output = + get_only(b.add_pattern_node(attention_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(3_p)}, + "attention")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{0_n}}), + }}; + + OutputGraphExprValue o_partition_query_input_output = get_only( + b.add_output_graph_node(partition_input_expr, {o_query_input}, 1_n)); + + OutputGraphExprValue o_partition_key_input_output = get_only( + b.add_output_graph_node(partition_input_expr, {o_key_input}, 1_n)); + + OutputGraphExprValue o_partition_value_input_output = get_only( + b.add_output_graph_node(partition_input_expr, {o_value_input}, 1_n)); + + OutputOperatorAttrsAssignment replicate_weight_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + + OutputGraphExprValue o_replicate_weight_output = get_only( + b.add_output_graph_node(replicate_weight_expr, {o_weights}, 1_n)); + + std::vector o_attention_inputs = { + o_partition_query_input_output, + o_partition_key_input_output, + o_partition_value_input_output, + o_replicate_weight_output}; + + OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("attention"), + {}, + }; + OutputGraphExprValue o_attention_output = get_only( + b.add_output_graph_node(attention_expr, o_attention_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{ + 2_n, + }}), + }, + }; + OutputGraphExprValue o_combine_output = get_only( + b.add_output_graph_node(combine_expr, {o_attention_output}, 1_n)); + + b.equate_outputs(p_attention_output, o_combine_output); + + return b.get_substitution(); } -Substitution create_replicate_attention_reduce(nonnegative_int num_heads, - nonnegative_int degree) { - NOT_IMPLEMENTED(); +Substitution create_replicate_attention_reduce(positive_int num_heads, + positive_int degree) { + + SubstitutionBuilder b; + + auto [p_query_input, o_query_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_key_input, o_key_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_value_input, o_value_input] = + b.add_input(tensor_attribute_pattern_match_all()); + auto [p_weights, o_weights] = + b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = { + p_query_input, p_key_input, p_value_input, p_weights}; + + OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), + }}; + + PatternValue p_attention_output = + get_only(b.add_pattern_node(attention_pattern, + p_inputs, + {tensor_attr_pattern_require_num_dims(3_p)}, + "attention")); + + OutputOperatorAttrsAssignment replicate_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPLICATE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + + OutputGraphExprValue o_replicate_query_input_output = get_only( + b.add_output_graph_node(replicate_input_expr, {o_query_input}, 1_n)); + + OutputGraphExprValue o_replicate_key_input_output = get_only( + b.add_output_graph_node(replicate_input_expr, {o_key_input}, 1_n)); + + OutputGraphExprValue o_replicate_value_input_output = get_only( + b.add_output_graph_node(replicate_input_expr, {o_value_input}, 1_n)); + + OutputOperatorAttrsAssignment partition_weight_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{ff_dim_t{1_n}}), + }}; + + OutputGraphExprValue o_partition_weight_output = get_only( + b.add_output_graph_node(partition_weight_expr, {o_weights}, 1_n)); + + std::vector o_attention_inputs = { + o_replicate_query_input_output, + o_replicate_key_input_output, + o_replicate_value_input_output, + o_partition_weight_output}; + + OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("attention"), + {}, + }; + OutputGraphExprValue o_attention_output = get_only( + b.add_output_graph_node(attention_expr, o_attention_inputs, 1_n)); + + OutputOperatorAttrsAssignment reduce_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REDUCTION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }, + }; + OutputGraphExprValue o_reduce_output = + get_only(b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n)); + + b.equate_outputs(p_attention_output, o_reduce_output); + + return b.get_substitution(); +} + +Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, + ff_dim_t partition_dim, + positive_int degree) { + if (partition_dim == softmax_dim) { + throw mk_runtime_error( + fmt::format("partition dim {} must not be equal to softmax dim {}", + partition_dim, + softmax_dim)); + } + SubstitutionBuilder b; + + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input}; + + OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::SOFTMAX), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + op_attr_key_divisible_by(OperatorAttributeKey::SOFTMAX_DIM, + positive_int{softmax_dim.value}), + }}; + + PatternValue p_softmax_output = + get_only(b.add_pattern_node(softmax_pattern, + p_inputs, + {tensor_attribute_pattern_match_all()}, + "softmax")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{partition_dim}), + }}; + + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + std::vector o_softmax_inputs = { + o_partition_input_output}; + + OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("softmax"), + {}, + }; + OutputGraphExprValue o_softmax_output = + get_only(b.add_output_graph_node(softmax_expr, o_softmax_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{partition_dim}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_softmax_output}, 1_n)); + + b.equate_outputs(p_softmax_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_add_combine(ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); + positive_int degree) { + SubstitutionBuilder b; + + auto [p_input1, o_input1] = b.add_input(tensor_attribute_pattern_match_all()); + auto [p_input2, o_input2] = b.add_input(tensor_attribute_pattern_match_all()); + std::vector p_inputs = {p_input1, p_input2}; + + OperatorAttributePattern add_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::EW_ADD), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_add_output = get_only(b.add_pattern_node( + add_pattern, p_inputs, {tensor_attribute_pattern_match_all()}, "add")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }}; + + OutputGraphExprValue o_partition_input1_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input1}, 1_n)); + + OutputGraphExprValue o_partition_input2_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input2}, 1_n)); + + std::vector o_add_inputs = {o_partition_input1_output, + o_partition_input2_output}; + + OutputOperatorAttrsAssignment add_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("add"), + {}, + }; + OutputGraphExprValue o_add_output = + get_only(b.add_output_graph_node(add_expr, o_add_inputs, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_add_output}, 1_n)); + + b.equate_outputs(p_add_output, o_combine_output); + + return b.get_substitution(); } Substitution create_partition_relu_combine(ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} + positive_int degree) { + SubstitutionBuilder b; -Substitution create_partition_concat_combine(nonnegative_int num_inputs, - ff_dim_t concat_dim, - ff_dim_t parallel_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); -} + auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); -Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, - ff_dim_t partition_dim, - nonnegative_int degree) { - NOT_IMPLEMENTED(); + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), + }}; + + PatternValue p_relu_output = get_only(b.add_pattern_node( + relu_pattern, {p_input}, {tensor_attribute_pattern_match_all()}, "relu")); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::REPARTITION), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }}; + + OutputGraphExprValue o_partition_input_output = + get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + + OutputOperatorAttrsAssignment relu_expr = OutputOperatorAttrsAssignment{ + b.pattern_node_named("relu"), + {}, + }; + OutputGraphExprValue o_relu_output = get_only( + b.add_output_graph_node(relu_expr, {o_partition_input_output}, 1_n)); + + OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(OperatorType::COMBINE), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{parallel_dim}), + }, + }; + OutputGraphExprValue o_combine_output = + get_only(b.add_output_graph_node(combine_expr, {o_relu_output}, 1_n)); + + b.equate_outputs(p_relu_output, o_combine_output); + + return b.get_substitution(); } Substitution create_fuse_linear_activation(Activation activation) { diff --git a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc index c86cb7e51f..61e3c9b833 100644 --- a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc @@ -1,8 +1,37 @@ #include "substitutions/unity_substitution_set.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/operator_type.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution_builder.h" +#include "utils/containers/get_only.h" #include using namespace ::FlexFlow; +template +static ParallelLayerAttrs make_layer_attrs( + T const &op_attrs, + std::optional const &maybe_name = std::nullopt) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/maybe_name, + }; +}; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_substitution_set") { MachineSpecification machine_spec = MachineSpecification{ @@ -15,6 +44,1505 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector result = get_substitution_set(machine_spec); - CHECK(result.size() == 36); + CHECK(result.size() == 248); + } + + TEST_CASE("create_replicate_linear_combine, use_bias = false") { + positive_int num_dims = 1_p; + positive_int degree = 1_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_replicate_linear_combine(num_dims, degree, false); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult replicate_input_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_replicated_input = + get_only(replicate_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult partition_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(partition_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_partitioned_projection_weight = + get_only(partition_projection_added.outputs); + + ParallelLayerAddedResult replicate_linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs), + {t_replicated_input}, + {t_partitioned_projection_weight}); + parallel_tensor_guid_t t_replicated_linear = + get_only(replicate_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_replicated_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_replicate_linear_combine, use_bias = true") { + positive_int num_dims = 1_p; + positive_int degree = 1_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_replicate_linear_combine(num_dims, degree, true); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + WeightAttrs bias_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_bias_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight, t_bias}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_input_bias = + get_layer_inputs(original_pcg, match_layer).at(2); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_input_bias, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult replicate_input_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_replicated_input = + get_only(replicate_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult partition_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(partition_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_partitioned_projection_weight = + get_only(partition_projection_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult partition_bias_added = add_parallel_layer( + pcg, make_layer_attrs(partition_projection_attrs), {t_bias}, {}); + parallel_tensor_guid_t t_partitioned_bias = + get_only(partition_bias_added.outputs); + + ParallelLayerAddedResult replicate_linear_added = add_parallel_layer( + pcg, + make_layer_attrs(linear_attrs), + {t_replicated_input}, + {t_partitioned_projection_weight, t_partitioned_bias}); + parallel_tensor_guid_t t_replicated_linear = + get_only(replicate_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_replicated_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_linear_combine, use_bias = false") { + positive_int num_dims = 1_p; + positive_int degree = 2_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_partition_linear_combine(num_dims, degree, false); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult replicate_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(replicate_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_replicated_projection_weight = + get_only(replicate_projection_added.outputs); + + ParallelLayerAddedResult partition_linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs), + {t_partitioned_input}, + {t_replicated_projection_weight}); + parallel_tensor_guid_t t_partitioned_linear = + get_only(partition_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_partitioned_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_linear_combine, use_bias = true") { + positive_int num_dims = 1_p; + positive_int degree = 2_p; + std::string linear_match = "linear_match"; + + Substitution sub = create_partition_linear_combine(num_dims, degree, true); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 12_p, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + WeightAttrs bias_attrs = WeightAttrs{ + /*tensor_shape=*/throw_if_unexpected( + get_bias_shape(linear_attrs, input_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_op_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult linear_added = + add_parallel_layer(pcg, + make_layer_attrs(linear_attrs, linear_match), + {t_input}, + {t_projection_weight, t_bias}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, linear_match); + std::cout << get_layer_inputs(original_pcg, match_layer) << std::endl; + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_input_bias = + get_layer_inputs(original_pcg, match_layer).at(2); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_input_bias, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult replicate_projection_added = + add_parallel_layer(pcg, + make_layer_attrs(replicate_projection_attrs), + {t_projection_weight}, + {}); + parallel_tensor_guid_t t_replicated_projection_weight = + get_only(replicate_projection_added.outputs); + + ParallelLayerAddedResult bias_added = + add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); + parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + + ParallelLayerAddedResult replicate_bias_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_projection_attrs), {t_bias}, {}); + parallel_tensor_guid_t t_replicated_bias = + get_only(replicate_bias_added.outputs); + + ParallelLayerAddedResult partition_linear_added = add_parallel_layer( + pcg, + make_layer_attrs(linear_attrs), + {t_partitioned_input}, + {t_replicated_projection_weight, t_replicated_bias}); + parallel_tensor_guid_t t_partitioned_linear = + get_only(partition_linear_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_op_attrs), {t_partitioned_linear}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_conv2d_combine") { + positive_int outChannels = 6_p; + positive_int kernelH = 5_p; + positive_int kernelW = 4_p; + positive_int strideH = 3_p; + positive_int strideW = 2_p; + nonnegative_int paddingH = 1_n; + nonnegative_int paddingW = 0_n; + positive_int num_dims = 4_p; + positive_int degree = 1_p; + std::string conv2d_match = "conv2d_match"; + + Substitution sub = create_partition_conv2d_combine(num_dims, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 3_p, + 10_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + Conv2DAttrs conv2d_attrs = Conv2DAttrs{/*outChannels=*/outChannels, + /*kernelH=*/kernelH, + /*kernelW=*/kernelW, + /*strideH=*/strideH, + /*strideW=*/strideW, + /*paddingH=*/paddingH, + /*paddingW=*/paddingW, + /*groups=*/1_p, + /*activation=*/std::nullopt, + /*use_bias=*/false}; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + ReplicateAttrs replicate_weight_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + TensorShape casted_input_shape = + get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); + + WeightAttrs projection_weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + get_weight_shapes(conv2d_attrs, casted_input_shape).at(0), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ParallelLayerAddedResult projection_weight_added = add_parallel_layer( + pcg, make_layer_attrs(projection_weight_attrs), {}, {}); + parallel_tensor_guid_t t_projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult conv_2d_added = + add_parallel_layer(pcg, + make_layer_attrs(conv2d_attrs, conv2d_match), + {t_input}, + {t_projection_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, conv2d_match); + open_parallel_tensor_guid_t match_layer_input_activations = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + TensorShape casted_input_shape = + get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + get_weight_shapes(conv2d_attrs, casted_input_shape).at(0), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult replicate_weight_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_weight_attrs), {t_weight}, {}); + parallel_tensor_guid_t t_replicated_weight = + get_only(replicate_weight_added.outputs); + + ParallelLayerAddedResult partition_conv2d_added = + add_parallel_layer(pcg, + make_layer_attrs(conv2d_attrs), + {t_partitioned_input}, + {t_replicated_weight}); + parallel_tensor_guid_t t_partitioned_conv2d = + get_only(partition_conv2d_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_conv2d}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_attention_combine") { + positive_int embed_dim = 8_p; + positive_int num_heads = 6_p; + positive_int degree = 1_p; + std::string attention_match = "attention_match"; + + Substitution sub = create_partition_attention_combine(num_heads, degree); + + TensorShape query_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + TensorShape key_shape = query_shape; + TensorShape value_shape = query_shape; + + MultiHeadAttentionAttrs attention_attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0, + /*bias=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{0_n}, + /*repartition_degree=*/degree, + }; + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + ReplicateAttrs replicate_weight_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{2_n}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult attention_added = + add_parallel_layer(pcg, + make_layer_attrs(attention_attrs, attention_match), + {t_query, t_key, t_value}, + {t_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, attention_match); + open_parallel_tensor_guid_t match_layer_query = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_key = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_value = + get_layer_inputs(original_pcg, match_layer).at(2); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(3); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_query, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_key, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_value, + }, + { + PatternInput{DataflowGraphInput{6}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult partition_query_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_query}, {}); + parallel_tensor_guid_t t_partitioned_query = + get_only(partition_query_added.outputs); + + ParallelLayerAddedResult partition_key_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_key}, {}); + parallel_tensor_guid_t t_partitioned_key = + get_only(partition_key_added.outputs); + + ParallelLayerAddedResult partition_value_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_value}, {}); + parallel_tensor_guid_t t_partitioned_value = + get_only(partition_value_added.outputs); + + ParallelLayerAddedResult replicate_weight_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_weight_attrs), {t_weight}, {}); + parallel_tensor_guid_t t_replicated_weight = + get_only(replicate_weight_added.outputs); + + ParallelLayerAddedResult partition_attention_added = add_parallel_layer( + pcg, + make_layer_attrs(attention_attrs), + {t_partitioned_query, t_partitioned_key, t_partitioned_value}, + {t_replicated_weight}); + parallel_tensor_guid_t t_partitioned_attention = + get_only(partition_attention_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_attention}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_replicate_attention_reduce") { + positive_int embed_dim = 8_p; + positive_int num_heads = 6_p; + positive_int degree = 1_p; + std::string attention_match = "attention_match"; + + Substitution sub = create_replicate_attention_reduce(num_heads, degree); + + TensorShape query_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 12_p, + 16_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + TensorShape key_shape = query_shape; + TensorShape value_shape = query_shape; + + MultiHeadAttentionAttrs attention_attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0, + /*bias=*/false, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/ + throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)), + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + RepartitionAttrs partition_weight_attrs = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1_n}, + /*repartition_degree=*/degree, + }; + + ReductionAttrs reduction_attrs = ReductionAttrs{ + /*reduction_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult attention_added = + add_parallel_layer(pcg, + make_layer_attrs(attention_attrs, attention_match), + {t_query, t_key, t_value}, + {t_weight}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, attention_match); + open_parallel_tensor_guid_t match_layer_query = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t match_layer_key = + get_layer_inputs(original_pcg, match_layer).at(1); + open_parallel_tensor_guid_t match_layer_value = + get_layer_inputs(original_pcg, match_layer).at(2); + open_parallel_tensor_guid_t match_layer_input_weights = + get_layer_inputs(original_pcg, match_layer).at(3); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + match_layer_query, + }, + { + PatternInput{DataflowGraphInput{2}}, + match_layer_key, + }, + { + PatternInput{DataflowGraphInput{4}}, + match_layer_value, + }, + { + PatternInput{DataflowGraphInput{6}}, + match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult query_added = + pcg_add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_query = get_only(query_added.outputs); + + ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_key = get_only(key_added.outputs); + + ParallelLayerAddedResult value_added = + pcg_add_input_layer(pcg, value_shape); + parallel_tensor_guid_t t_value = get_only(value_added.outputs); + + ParallelLayerAddedResult weight_added = + add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); + parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + + ParallelLayerAddedResult replicate_query_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_query}, {}); + parallel_tensor_guid_t t_replicated_query = + get_only(replicate_query_added.outputs); + + ParallelLayerAddedResult replicate_key_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_key}, {}); + parallel_tensor_guid_t t_replicated_key = + get_only(replicate_key_added.outputs); + + ParallelLayerAddedResult replicate_value_added = add_parallel_layer( + pcg, make_layer_attrs(replicate_input_attrs), {t_value}, {}); + parallel_tensor_guid_t t_replicated_value = + get_only(replicate_value_added.outputs); + + ParallelLayerAddedResult partition_weight_added = add_parallel_layer( + pcg, make_layer_attrs(partition_weight_attrs), {t_weight}, {}); + parallel_tensor_guid_t t_partitioned_weight = + get_only(partition_weight_added.outputs); + + ParallelLayerAddedResult replicate_attention_added = add_parallel_layer( + pcg, + make_layer_attrs(attention_attrs), + {t_replicated_query, t_replicated_key, t_replicated_value}, + {t_partitioned_weight}); + parallel_tensor_guid_t t_replicated_attention = + get_only(replicate_attention_added.outputs); + + ParallelLayerAddedResult reduce_added = add_parallel_layer( + pcg, make_layer_attrs(reduction_attrs), {t_replicated_attention}, {}); + parallel_tensor_guid_t t_reduction = get_only(reduce_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_softmax_combine") { + positive_int degree = 1_p; + ff_dim_t softmax_dim = ff_dim_t{1_n}; + ff_dim_t partition_dim = ff_dim_t{0_n}; + std::string softmax_match = "softmax_match"; + + Substitution sub = + create_partition_softmax_combine(softmax_dim, partition_dim, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + SoftmaxAttrs softmax_attrs = SoftmaxAttrs{ + /*softmax_dim=*/softmax_dim, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/partition_dim, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{partition_dim}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult softmax_added = add_parallel_layer( + pcg, make_layer_attrs(softmax_attrs, softmax_match), {t_input}, {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, softmax_match); + open_parallel_tensor_guid_t match_layer_input = + get_layer_inputs(original_pcg, match_layer).at(0); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{{ + PatternInput{DataflowGraphInput{0}}, + match_layer_input, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult partition_softmax_added = add_parallel_layer( + pcg, make_layer_attrs(softmax_attrs), {t_partitioned_input}, {}); + parallel_tensor_guid_t t_partitioned_softmax = + get_only(partition_softmax_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_softmax}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_add_combine") { + positive_int degree = 1_p; + ff_dim_t parallel_dim = ff_dim_t{1_n}; + std::string add_match = "add_match"; + + Substitution sub = create_partition_add_combine(parallel_dim, degree); + + TensorShape lhs_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 15_p, + }, + }, + DataType::FLOAT, + }; + + TensorShape rhs_shape = lhs_shape; + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/parallel_dim, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/parallel_dim, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult lhs_added = pcg_add_input_layer(pcg, lhs_shape); + parallel_tensor_guid_t t_lhs = get_only(lhs_added.outputs); + + ParallelLayerAddedResult rhs_added = pcg_add_input_layer(pcg, rhs_shape); + parallel_tensor_guid_t t_rhs = get_only(rhs_added.outputs); + + ParallelLayerAddedResult output_added = add_parallel_layer( + pcg, make_layer_attrs(add_attrs, add_match), {t_lhs, t_rhs}, {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, add_match); + open_parallel_tensor_guid_t add_match_layer_lhs = + get_layer_inputs(original_pcg, match_layer).at(0); + open_parallel_tensor_guid_t add_match_layer_rhs = + get_layer_inputs(original_pcg, match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + add_match_layer_lhs, + }, + { + PatternInput{DataflowGraphInput{2}}, + add_match_layer_rhs, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult lhs_added = pcg_add_input_layer(pcg, lhs_shape); + parallel_tensor_guid_t t_lhs = get_only(lhs_added.outputs); + + ParallelLayerAddedResult rhs_added = pcg_add_input_layer(pcg, rhs_shape); + parallel_tensor_guid_t t_rhs = get_only(rhs_added.outputs); + + ParallelLayerAddedResult partition_lhs_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_lhs}, {}); + parallel_tensor_guid_t t_partitioned_lhs = + get_only(partition_lhs_added.outputs); + + ParallelLayerAddedResult partition_rhs_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_rhs}, {}); + parallel_tensor_guid_t t_partitioned_rhs = + get_only(partition_rhs_added.outputs); + + ParallelLayerAddedResult partition_add_added = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs, add_match), + {t_partitioned_lhs, t_partitioned_rhs}, + {}); + parallel_tensor_guid_t t_partitioned_add = + get_only(partition_add_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_add}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_partition_relu_combine") { + positive_int degree = 1_p; + ff_dim_t parallel_dim = ff_dim_t{1_n}; + std::string relu_match = "relu_match"; + + Substitution sub = create_partition_relu_combine(parallel_dim, degree); + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 10_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + ElementUnaryAttrs relu_attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + + RepartitionAttrs partition_input_attrs = RepartitionAttrs{ + /*repartition_dim=*/parallel_dim, + /*repartition_degree=*/degree, + }; + + CombineAttrs combine_attrs = CombineAttrs{ + /*combine_dim=*/ff_dim_t{parallel_dim}, + /*combine_degree=*/degree, + }; + + SubParallelComputationGraph original_pcg = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult relu_added = add_parallel_layer( + pcg, make_layer_attrs(relu_attrs, relu_match), {t_input}, {}); + + return sub_pcg_from_full_pcg(pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t match_layer = + get_parallel_layer_by_name(original_pcg, relu_match); + open_parallel_tensor_guid_t match_layer_input = + get_layer_inputs(original_pcg, match_layer).at(0); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, match_layer}, + }, + std::unordered_map{{ + PatternInput{DataflowGraphInput{0}}, + match_layer_input, + }}, + }; + }(); + + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAddedResult input_added = + pcg_add_input_layer(pcg, input_shape); + + parallel_tensor_guid_t t_input = get_only(input_added.outputs); + + ParallelLayerAddedResult partition_input_added = add_parallel_layer( + pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); + parallel_tensor_guid_t t_partitioned_input = + get_only(partition_input_added.outputs); + + ParallelLayerAddedResult partition_relu_added = add_parallel_layer( + pcg, make_layer_attrs(relu_attrs), {t_partitioned_input}, {}); + parallel_tensor_guid_t t_partitioned_relu = + get_only(partition_relu_added.outputs); + + ParallelLayerAddedResult combine_added = add_parallel_layer( + pcg, make_layer_attrs(combine_attrs), {t_partitioned_relu}, {}); + parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + + return sub_pcg_from_full_pcg(pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } + + TEST_CASE("create_fuse_linear_activation") { + Substitution sub = create_fuse_linear_activation(Activation::SIGMOID); + + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + TensorShape input_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 10_p, + }, + }, + DataType::FLOAT, + }; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(input_shape); + t = b.dense(t, + /*outDim=*/4_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(0); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {PatternNode{Node{0}}, mm_match_layer}, + {PatternNode{Node{1}}, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{DataflowGraphInput{0}}, + mm_match_layer_input_activations, + }, + { + PatternInput{DataflowGraphInput{2}}, + mm_match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = apply_substitution(pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(input_shape); + t = b.dense(t, + /*outDim=*/4_p, + /*activation=*/Activation::SIGMOID, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/std::nullopt); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + CHECK(sub_pcgs_are_isomorphic(result, correct)); } } From 599c70376c690c2fc0b2ace73177ab4d49a7f1a2 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 10 Jan 2026 01:04:57 -0800 Subject: [PATCH 2/3] Fix and shorten unity_substitution_set tests --- .../perform_shape_inference.cc | 6 + .../substitutions/unity_substitution_set.cc | 963 +++++++-------- .../substitutions/unity_substitution_set.cc | 1059 ++++++----------- .../utils/positive_int/positive_range.h | 13 + .../src/utils/positive_int/positive_range.cc | 14 + 5 files changed, 847 insertions(+), 1208 deletions(-) create mode 100644 lib/utils/include/utils/positive_int/positive_range.h create mode 100644 lib/utils/src/utils/positive_int/positive_range.cc diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index e7dc926682..94c3bee5d2 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -18,6 +18,8 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" #include "utils/nonnegative_int/num_elements.h" +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/is_subseteq_of.h" namespace FlexFlow { @@ -54,6 +56,8 @@ LabelledOpenKwargDataflowGraphView incoming_tensor_roles = get_incoming_tensor_roles(n_attrs.op_attrs); + ASSERT(is_subseteq_of(keys(incoming_shapes), keys(incoming_tensor_roles))); + auto incoming_shapes_with_role = [&](IncomingTensorRole role) -> std::unordered_map { std::unordered_set slots_with_desired_role = @@ -68,6 +72,8 @@ LabelledOpenKwargDataflowGraphView weight_shapes = incoming_shapes_with_role(IncomingTensorRole::WEIGHT); + ASSERT(binary_merge_disjoint_maps(input_shapes, weight_shapes) == incoming_shapes); + std::unordered_map inferred_weight_shapes = get_weight_shapes(n_attrs.op_attrs, input_shapes); diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index f4034c9cf1..6940e86162 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -4,17 +4,20 @@ #include "substitutions/output_graph/output_operator_attrs_assignment.h" #include "substitutions/substitution_builder.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" -#include "utils/containers/get_only.h" #include "utils/containers/require_only_key.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/positive_int/positive_range.h" namespace FlexFlow { std::vector get_substitution_set(MachineComputeSpecification const &resources) { std::vector substitutions; - for (positive_int dim = 1_p; dim <= positive_int{MAX_TENSOR_DIM}; dim++) { + + positive_int max_tensor_dim = positive_int{MAX_TENSOR_DIM}; + + for (positive_int dim : positive_range(1_p, max_tensor_dim + 1_p)) { for (positive_int degree = 1_p; degree <= get_num_gpus(resources); degree *= 2_p) { substitutions.push_back( @@ -33,19 +36,14 @@ std::vector substitutions.push_back(create_replicate_attention_reduce(dim, degree)); } } - for (positive_int degree = 1_p; degree <= get_num_gpus(resources); - degree *= 2_p) { + + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); degree *= 2_p) { substitutions.push_back(create_partition_conv2d_combine(4_p, degree)); } - for (positive_int partition_dim = 1_p; - partition_dim <= positive_int{MAX_TENSOR_DIM}; - partition_dim++) { - for (positive_int softmax_dim = 1_p; - softmax_dim <= positive_int{MAX_TENSOR_DIM}; - softmax_dim++) { - for (positive_int degree = 1_p; degree <= get_num_gpus(resources); - degree *= 2_p) { + for (positive_int partition_dim : positive_range(1_p, max_tensor_dim + 1_p)) { + for (positive_int softmax_dim : positive_range(1_p, max_tensor_dim + 1_p)) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); degree *= 2_p) { if (partition_dim != softmax_dim) { substitutions.push_back(create_partition_softmax_combine( ff_dim_t{partition_dim.nonnegative_int_from_positive_int()}, @@ -62,6 +60,114 @@ std::vector return substitutions; } +static PatternValue insert_single_output_pattern( + SubstitutionBuilder &b, + OperatorAttributePattern const &attribute_pattern, + std::unordered_map const &inputs, + TensorAttributePattern const &output_pattern, + std::string const &name) +{ + return require_only_key( + b.add_pattern_node(attribute_pattern, + inputs, + /*output_patterns=*/{ + { + TensorSlotName::OUTPUT, + output_pattern, + }, + }, + name), + TensorSlotName::OUTPUT); +} + + +static OutputGraphExprValue insert_single_output_op( + SubstitutionBuilder &b, + OutputOperatorAttrsAssignment const &expr, + std::unordered_map const &inputs) +{ + return require_only_key( + b.add_output_graph_node(expr, inputs, {TensorSlotName::OUTPUT}), + TensorSlotName::OUTPUT); +} + + +static OutputGraphExprValue insert_replicate_or_reduce(OperatorType op_type, + SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + + ASSERT(op_type == OperatorType::REPLICATE || op_type == OperatorType::REDUCTION); + + OutputOperatorAttrsAssignment replicate_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(op_type), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; + + return insert_single_output_op(b, replicate_expr, {{TensorSlotName::INPUT, input}}); +} + +static OutputGraphExprValue insert_replicate(SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + return insert_replicate_or_reduce(OperatorType::REPLICATE, b, degree, input); +} + + +static OutputGraphExprValue insert_reduce(SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + return insert_replicate_or_reduce(OperatorType::REDUCTION, b, degree, input); +} + +static OutputGraphExprValue insert_partition_or_combine( + OperatorType op_type, + SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { + + ASSERT(op_type == OperatorType::REPARTITION || op_type == OperatorType::COMBINE); + + OutputOperatorAttrsAssignment partition_input_expr = + OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(op_type), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, + OperatorAttributeValue{dim}), + }}; + + OutputGraphExprValue o_partition_output = + insert_single_output_op(b, partition_input_expr, {{TensorSlotName::INPUT, input}}); + + return o_partition_output; +} + +static OutputGraphExprValue insert_partition(SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { + + return insert_partition_or_combine(OperatorType::REPARTITION, b, degree, dim, input); +} + +static OutputGraphExprValue insert_combine(SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { + + return insert_partition_or_combine(OperatorType::COMBINE, b, degree, dim, input); +} + + + Substitution create_replicate_linear_combine(positive_int num_dims, positive_int degree, bool use_bias) { @@ -92,68 +198,18 @@ Substitution create_replicate_linear_combine(positive_int num_dims, op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_linear_output = require_only_key( - b.add_pattern_node(linear_pattern, - p_inputs, - { - { - TensorSlotName::OUTPUT, - tensor_attr_pattern_require_num_dims( - nonnegative_int{num_dims}), - }, - }, - "linear"), - TensorSlotName::OUTPUT); - - OutputOperatorAttrsAssignment replicate_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; + std::string linear_name = "linear"; + PatternValue p_linear_output = insert_single_output_pattern( + b, + linear_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + linear_name); + OutputGraphExprValue o_replicate_input_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/replicate_input_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_input, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); - - OutputOperatorAttrsAssignment partition_weights_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{1_n}}), - }}; - OutputGraphExprValue o_partition_weights_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/partition_weights_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_weight, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + insert_replicate(b, degree, o_input); + + OutputGraphExprValue o_partition_weights_output = insert_partition(b, degree, ff_dim_t{1_n}, o_weight); std::unordered_map o_linear_inputs = { { @@ -167,31 +223,8 @@ Substitution create_replicate_linear_combine(positive_int num_dims, }; if (use_bias) { - OutputOperatorAttrsAssignment partition_bias_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{1_n}}), - }}; - OutputGraphExprValue o_partition_bias_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/partition_bias_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_bias.value(), - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + OutputGraphExprValue o_partition_bias_output = insert_partition(b, degree, ff_dim_t{1_n}, o_bias.value()); + o_linear_inputs.insert({ TensorSlotName::BIAS, o_partition_bias_output, @@ -199,48 +232,15 @@ Substitution create_replicate_linear_combine(positive_int num_dims, } OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("linear"), + b.pattern_node_named(linear_name), {}, }; - OutputGraphExprValue o_linear_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/linear_expr, - /*inputs=*/o_linear_inputs, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); - - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant( - OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}, - }}), - }, - }; + OutputGraphExprValue o_linear_output = insert_single_output_op(b, linear_expr, o_linear_inputs); - OutputGraphExprValue o_combine_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/combine_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_linear_output, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + ff_dim_t combine_output_dim = ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}, + }; + OutputGraphExprValue o_combine_output = insert_combine(b, degree, combine_output_dim, o_linear_output); b.equate_outputs(p_linear_output, o_combine_output); @@ -254,13 +254,25 @@ Substitution create_partition_linear_combine(positive_int num_dims, auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = {p_input, p_weight}; + std::unordered_map p_inputs = { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }; std::optional o_bias = std::nullopt; if (use_bias) { std::pair bias = b.add_input(tensor_attribute_pattern_match_all()); - p_inputs.push_back(bias.first); + p_inputs.insert({ + TensorSlotName::BIAS, + bias.first, + }); o_bias = bias.second; } @@ -271,75 +283,48 @@ Substitution create_partition_linear_combine(positive_int num_dims, op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_linear_output = get_only( - b.add_pattern_node(linear_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(num_dims)}, - "linear")); + std::string linear_name = "linear"; + PatternValue p_linear_output = insert_single_output_pattern( + b, + linear_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + linear_name); - OutputOperatorAttrsAssignment partition_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{0_n}}), - }}; - OutputGraphExprValue o_partition_input_output = - get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, ff_dim_t{0_n}, o_input); - OutputOperatorAttrsAssignment replicate_weights_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; - OutputGraphExprValue o_replicate_weights_output = get_only( - b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n)); + OutputGraphExprValue o_replicate_weights_output = insert_replicate(b, degree, o_weight); - std::vector o_linear_inputs = { - o_partition_input_output, o_replicate_weights_output}; + std::unordered_map o_linear_inputs = { + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + { + TensorSlotName::WEIGHT, + o_replicate_weights_output, + }, + }; if (use_bias) { - OutputOperatorAttrsAssignment replicate_bias_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; - OutputGraphExprValue o_replicate_bias_output = get_only( - b.add_output_graph_node(replicate_bias_expr, {o_bias.value()}, 1_n)); - o_linear_inputs.push_back(o_replicate_bias_output); + OutputGraphExprValue o_replicate_bias_output = insert_replicate(b, degree, o_bias.value()); + + o_linear_inputs.insert({ + TensorSlotName::BIAS, + o_replicate_bias_output, + }); } OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("linear"), + b.pattern_node_named(linear_name), {}, }; - OutputGraphExprValue o_linear_output = - get_only(b.add_output_graph_node(linear_expr, o_linear_inputs, 1_n)); + OutputGraphExprValue o_linear_output = insert_single_output_op(b, linear_expr, o_linear_inputs); - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant( - OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}, - }}), - }, + ff_dim_t combine_output_dim = ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}, }; - OutputGraphExprValue o_combine_output = - get_only(b.add_output_graph_node(combine_expr, {o_linear_output}, 1_n)); + OutputGraphExprValue o_combine_output = insert_combine(b, degree, combine_output_dim, o_linear_output); b.equate_outputs(p_linear_output, o_combine_output); @@ -348,77 +333,60 @@ Substitution create_partition_linear_combine(positive_int num_dims, Substitution create_partition_conv2d_combine(positive_int num_dims, positive_int degree) { - if (num_dims != 4_p) { - throw mk_runtime_error(fmt::format("num_dims must be 4, not {}", num_dims)); - } + ASSERT(num_dims == 4_p); SubstitutionBuilder b; auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = {p_input, p_weight}; + + std::unordered_map p_inputs = { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::FILTER, + p_weight, + }, + }; OperatorAttributePattern conv2d_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::CONV2D), op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_conv2d_output = get_only( - b.add_pattern_node(conv2d_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(num_dims)}, - "conv2d")); + std::string conv2d_name = "conv2d"; + PatternValue p_conv2d_output = insert_single_output_pattern( + b, + conv2d_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + conv2d_name); - OutputOperatorAttrsAssignment partition_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{0_n}}), - }}; - OutputGraphExprValue o_partition_input_output = - get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, ff_dim_t{0_n}, o_input); - OutputOperatorAttrsAssignment replicate_weights_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; - OutputGraphExprValue o_replicate_weights_output = get_only( - b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n)); + OutputGraphExprValue o_replicate_weights_output = insert_replicate(b, degree, o_weight); - std::vector o_conv2d_inputs = { - o_partition_input_output, o_replicate_weights_output}; + std::unordered_map o_conv2d_inputs = { + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + { + TensorSlotName::FILTER, + o_replicate_weights_output + }, + }; OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("conv2d"), + b.pattern_node_named(conv2d_name), {}, }; - OutputGraphExprValue o_conv2d_output = - get_only(b.add_output_graph_node(conv2d_expr, o_conv2d_inputs, 1_n)); + OutputGraphExprValue o_conv2d_output = insert_single_output_op(b, conv2d_expr, o_conv2d_inputs); - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant( - OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}, - }}), - }, - }; - OutputGraphExprValue o_combine_output = - get_only(b.add_output_graph_node(combine_expr, {o_conv2d_output}, 1_n)); + OutputGraphExprValue o_combine_output = insert_combine(b, degree, ff_dim_t{0_n}, o_conv2d_output); b.equate_outputs(p_conv2d_output, o_combine_output); @@ -438,8 +406,24 @@ Substitution create_partition_attention_combine(positive_int num_heads, b.add_input(tensor_attribute_pattern_match_all()); auto [p_weights, o_weights] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = { - p_query_input, p_key_input, p_value_input, p_weights}; + std::unordered_map p_inputs = { + { + TensorSlotName::QUERY, + p_query_input, + }, + { + TensorSlotName::KEY, + p_key_input, + }, + { + TensorSlotName::VALUE, + p_value_input, + }, + { + TensorSlotName::WEIGHT, + p_weights, + }, + }; OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), @@ -447,71 +431,51 @@ Substitution create_partition_attention_combine(positive_int num_heads, op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), }}; - PatternValue p_attention_output = - get_only(b.add_pattern_node(attention_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(3_p)}, - "attention")); - - OutputOperatorAttrsAssignment partition_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{0_n}}), - }}; - - OutputGraphExprValue o_partition_query_input_output = get_only( - b.add_output_graph_node(partition_input_expr, {o_query_input}, 1_n)); - - OutputGraphExprValue o_partition_key_input_output = get_only( - b.add_output_graph_node(partition_input_expr, {o_key_input}, 1_n)); + std::string attention_name = "attention"; + PatternValue p_attention_output = insert_single_output_pattern( + b, + attention_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), + attention_name); - OutputGraphExprValue o_partition_value_input_output = get_only( - b.add_output_graph_node(partition_input_expr, {o_value_input}, 1_n)); + OutputGraphExprValue o_partition_query_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_query_input); + + OutputGraphExprValue o_partition_key_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_key_input); - OutputOperatorAttrsAssignment replicate_weight_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; + OutputGraphExprValue o_partition_value_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_value_input); - OutputGraphExprValue o_replicate_weight_output = get_only( - b.add_output_graph_node(replicate_weight_expr, {o_weights}, 1_n)); + OutputGraphExprValue o_replicate_weight_output = insert_replicate(b, degree, o_weights); - std::vector o_attention_inputs = { + std::unordered_map o_attention_inputs = { + { + TensorSlotName::QUERY, o_partition_query_input_output, + }, + { + TensorSlotName::KEY, o_partition_key_input_output, + }, + { + TensorSlotName::VALUE, o_partition_value_input_output, - o_replicate_weight_output}; + }, + { + TensorSlotName::WEIGHT, + o_replicate_weight_output, + }, + }; OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("attention"), + b.pattern_node_named(attention_name), {}, }; - OutputGraphExprValue o_attention_output = get_only( - b.add_output_graph_node(attention_expr, o_attention_inputs, 1_n)); + OutputGraphExprValue o_attention_output = insert_single_output_op(b, attention_expr, o_attention_inputs); - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{ - 2_n, - }}), - }, - }; - OutputGraphExprValue o_combine_output = get_only( - b.add_output_graph_node(combine_expr, {o_attention_output}, 1_n)); + OutputGraphExprValue o_combine_output = insert_combine(b, degree, ff_dim_t{0_n}, o_attention_output); b.equate_outputs(p_attention_output, o_combine_output); @@ -531,8 +495,25 @@ Substitution create_replicate_attention_reduce(positive_int num_heads, b.add_input(tensor_attribute_pattern_match_all()); auto [p_weights, o_weights] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = { - p_query_input, p_key_input, p_value_input, p_weights}; + + std::unordered_map p_inputs = { + { + TensorSlotName::QUERY, + p_query_input, + }, + { + TensorSlotName::KEY, + p_key_input, + }, + { + TensorSlotName::VALUE, + p_value_input, + }, + { + TensorSlotName::WEIGHT, + p_weights, + }, + }; OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION), @@ -540,67 +521,51 @@ Substitution create_replicate_attention_reduce(positive_int num_heads, op_attr_key_divisible_by(OperatorAttributeKey::NUM_HEADS, num_heads), }}; - PatternValue p_attention_output = - get_only(b.add_pattern_node(attention_pattern, - p_inputs, - {tensor_attr_pattern_require_num_dims(3_p)}, - "attention")); - - OutputOperatorAttrsAssignment replicate_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPLICATE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; - - OutputGraphExprValue o_replicate_query_input_output = get_only( - b.add_output_graph_node(replicate_input_expr, {o_query_input}, 1_n)); + std::string attention_name = "attention"; + PatternValue p_attention_output = insert_single_output_pattern( + b, + attention_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), + attention_name); - OutputGraphExprValue o_replicate_key_input_output = get_only( - b.add_output_graph_node(replicate_input_expr, {o_key_input}, 1_n)); + OutputGraphExprValue o_replicate_query_input_output = + insert_replicate(b, degree, o_query_input); + + OutputGraphExprValue o_replicate_key_input_output = + insert_replicate(b, degree, o_key_input); - OutputGraphExprValue o_replicate_value_input_output = get_only( - b.add_output_graph_node(replicate_input_expr, {o_value_input}, 1_n)); + OutputGraphExprValue o_replicate_value_input_output = + insert_replicate(b, degree, o_value_input); - OutputOperatorAttrsAssignment partition_weight_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{ff_dim_t{1_n}}), - }}; + OutputGraphExprValue o_partition_weight_output = insert_partition(b, degree, ff_dim_t{1_n}, o_weights); - OutputGraphExprValue o_partition_weight_output = get_only( - b.add_output_graph_node(partition_weight_expr, {o_weights}, 1_n)); - - std::vector o_attention_inputs = { + std::unordered_map o_attention_inputs = { + { + TensorSlotName::QUERY, o_replicate_query_input_output, + }, + { + TensorSlotName::KEY, o_replicate_key_input_output, + }, + { + TensorSlotName::VALUE, o_replicate_value_input_output, - o_partition_weight_output}; + }, + { + TensorSlotName::WEIGHT, + o_partition_weight_output, + }, + }; OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("attention"), + b.pattern_node_named(attention_name), {}, }; - OutputGraphExprValue o_attention_output = get_only( - b.add_output_graph_node(attention_expr, o_attention_inputs, 1_n)); + OutputGraphExprValue o_attention_output = insert_single_output_op(b, attention_expr, o_attention_inputs); - OutputOperatorAttrsAssignment reduce_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REDUCTION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }, - }; - OutputGraphExprValue o_reduce_output = - get_only(b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n)); + OutputGraphExprValue o_reduce_output = insert_reduce(b, degree, o_attention_output); b.equate_outputs(p_attention_output, o_reduce_output); @@ -610,16 +575,17 @@ Substitution create_replicate_attention_reduce(positive_int num_heads, Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, ff_dim_t partition_dim, positive_int degree) { - if (partition_dim == softmax_dim) { - throw mk_runtime_error( - fmt::format("partition dim {} must not be equal to softmax dim {}", - partition_dim, - softmax_dim)); - } + ASSERT(partition_dim != softmax_dim); + SubstitutionBuilder b; auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = {p_input}; + std::unordered_map p_inputs = { + { + TensorSlotName::INPUT, + p_input, + }, + }; OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::SOFTMAX), @@ -628,48 +594,30 @@ Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, positive_int{softmax_dim.value}), }}; - PatternValue p_softmax_output = - get_only(b.add_pattern_node(softmax_pattern, - p_inputs, - {tensor_attribute_pattern_match_all()}, - "softmax")); - - OutputOperatorAttrsAssignment partition_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{partition_dim}), - }}; - - OutputGraphExprValue o_partition_input_output = - get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); - - std::vector o_softmax_inputs = { - o_partition_input_output}; + std::string softmax_name = "softmax"; + PatternValue p_softmax_output = insert_single_output_pattern( + b, + softmax_pattern, + p_inputs, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + softmax_name); + + OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, partition_dim, o_input); + + std::unordered_map o_softmax_inputs = { + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + }; OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("softmax"), + b.pattern_node_named(softmax_name), {}, }; - OutputGraphExprValue o_softmax_output = - get_only(b.add_output_graph_node(softmax_expr, o_softmax_inputs, 1_n)); + OutputGraphExprValue o_softmax_output = insert_single_output_op(b, softmax_expr, o_softmax_inputs); - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{partition_dim}), - }, - }; - OutputGraphExprValue o_combine_output = - get_only(b.add_output_graph_node(combine_expr, {o_softmax_output}, 1_n)); + OutputGraphExprValue o_combine_output = insert_combine(b, degree, partition_dim, o_softmax_output); b.equate_outputs(p_softmax_output, o_combine_output); @@ -682,55 +630,52 @@ Substitution create_partition_add_combine(ff_dim_t parallel_dim, auto [p_input1, o_input1] = b.add_input(tensor_attribute_pattern_match_all()); auto [p_input2, o_input2] = b.add_input(tensor_attribute_pattern_match_all()); - std::vector p_inputs = {p_input1, p_input2}; + + std::unordered_map p_inputs = { + { + TensorSlotName::LHS_INPUT, + p_input1, + }, + { + TensorSlotName::RHS_INPUT, + p_input2, + }, + }; OperatorAttributePattern add_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::EW_ADD), op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_add_output = get_only(b.add_pattern_node( - add_pattern, p_inputs, {tensor_attribute_pattern_match_all()}, "add")); - - OutputOperatorAttrsAssignment partition_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{parallel_dim}), - }}; - - OutputGraphExprValue o_partition_input1_output = - get_only(b.add_output_graph_node(partition_input_expr, {o_input1}, 1_n)); - - OutputGraphExprValue o_partition_input2_output = - get_only(b.add_output_graph_node(partition_input_expr, {o_input2}, 1_n)); - - std::vector o_add_inputs = {o_partition_input1_output, - o_partition_input2_output}; + std::string add_name = "add"; + PatternValue p_add_output = insert_single_output_pattern( + b, + add_pattern, + p_inputs, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + add_name); + + OutputGraphExprValue o_partition_input1_output = insert_partition(b, degree, parallel_dim, o_input1); + OutputGraphExprValue o_partition_input2_output = insert_partition(b, degree, parallel_dim, o_input2); + + std::unordered_map o_add_inputs = { + { + TensorSlotName::LHS_INPUT, + o_partition_input1_output, + }, + { + TensorSlotName::RHS_INPUT, + o_partition_input2_output, + }, + }; OutputOperatorAttrsAssignment add_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("add"), + b.pattern_node_named(add_name), {}, }; - OutputGraphExprValue o_add_output = - get_only(b.add_output_graph_node(add_expr, o_add_inputs, 1_n)); + OutputGraphExprValue o_add_output = insert_single_output_op(b, add_expr, o_add_inputs); - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{parallel_dim}), - }, - }; - OutputGraphExprValue o_combine_output = - get_only(b.add_output_graph_node(combine_expr, {o_add_output}, 1_n)); + OutputGraphExprValue o_combine_output = insert_combine(b, degree, parallel_dim, o_add_output); b.equate_outputs(p_add_output, o_combine_output); @@ -748,42 +693,24 @@ Substitution create_partition_relu_combine(ff_dim_t parallel_dim, op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree), }}; - PatternValue p_relu_output = get_only(b.add_pattern_node( - relu_pattern, {p_input}, {tensor_attribute_pattern_match_all()}, "relu")); + std::string relu_name = "relu"; + PatternValue p_relu_output = insert_single_output_pattern( + b, + relu_pattern, + {{TensorSlotName::INPUT, p_input}}, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + relu_name); - OutputOperatorAttrsAssignment partition_input_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::REPARTITION), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{parallel_dim}), - }}; - - OutputGraphExprValue o_partition_input_output = - get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n)); + OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, parallel_dim, o_input); OutputOperatorAttrsAssignment relu_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("relu"), + b.pattern_node_named(relu_name), {}, }; - OutputGraphExprValue o_relu_output = get_only( - b.add_output_graph_node(relu_expr, {o_partition_input_output}, 1_n)); + OutputGraphExprValue o_relu_output + = insert_single_output_op(b, relu_expr, {{TensorSlotName::INPUT, o_partition_input_output}}); - OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(OperatorType::COMBINE), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, - OperatorAttributeValue{parallel_dim}), - }, - }; - OutputGraphExprValue o_combine_output = - get_only(b.add_output_graph_node(combine_expr, {o_relu_output}, 1_n)); + OutputGraphExprValue o_combine_output = insert_combine(b, degree, parallel_dim, o_relu_output); b.equate_outputs(p_relu_output, o_combine_output); @@ -804,78 +731,62 @@ Substitution create_fuse_linear_activation(Activation activation) { OperatorAttributeKey::ACTIVATION, OperatorAttributeValue{std::optional{std::nullopt}}), }}; - PatternValue p_mm_output = - require_only_key(b.add_pattern_node( - /*node_expr=*/mm_pattern, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - p_input, - }, - { - TensorSlotName::WEIGHT, - p_weight, - }, - }, - /*output_patterns=*/ - { - { - TensorSlotName::OUTPUT, - tensor_attribute_pattern_match_all(), - }, - }, - /*name=*/"mm"), - TensorSlotName::OUTPUT); + + std::string mm_name = "mm"; + PatternValue p_mm_output = insert_single_output_pattern( + b, + mm_pattern, + /*inputs=*/{ + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + mm_name); OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::RELU), }}; - PatternValue p_relu_output = - require_only_key(b.add_pattern_node( - /*node_expr=*/relu_pattern, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - p_mm_output, - }, - }, - /*output_patterns=*/ - { - { - TensorSlotName::OUTPUT, - tensor_attribute_pattern_match_all(), - }, - }, - /*name=*/"relu"), - TensorSlotName::OUTPUT); + + std::string relu_name = "relu"; + PatternValue p_relu_output = insert_single_output_pattern( + b, + relu_pattern, + /*inputs=*/{ + { + TensorSlotName::INPUT, + p_mm_output, + }, + }, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + relu_name); OutputOperatorAttrsAssignment fused_node_expr = OutputOperatorAttrsAssignment{ - b.pattern_node_named("mm"), + b.pattern_node_named(mm_name), { set_attr_to_constant(OperatorAttributeKey::ACTIVATION, OperatorAttributeValue{activation}), }}; + OutputGraphExprValue o_fused_node_output = - require_only_key(b.add_output_graph_node( - /*node_expr=*/fused_node_expr, - /*inputs=*/ - { - { - TensorSlotName::INPUT, - o_input, - }, - { - TensorSlotName::WEIGHT, - o_weight, - }, - }, - /*output_slots=*/ - { - TensorSlotName::OUTPUT, - }), - TensorSlotName::OUTPUT); + insert_single_output_op( + b, + fused_node_expr, + /*inputs=*/{ + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }); b.equate_outputs(p_relu_output, o_fused_node_output); diff --git a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc index 8806bdb60e..022b69b850 100644 --- a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc @@ -18,6 +18,7 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/substitution_builder.h" #include "utils/containers/get_only.h" +#include "utils/containers/require_only_key.h" #include using namespace ::FlexFlow; @@ -32,6 +33,176 @@ static ParallelLayerAttrs make_layer_attrs( }; }; +parallel_tensor_guid_t get_single_output(ParallelLayerAddedResult const &added) { + return require_only_key(added.outputs, TensorSlotName::OUTPUT); +} + +parallel_tensor_guid_t add_single_output_layer( + ParallelComputationGraph &pcg, + ParallelLayerAttrs const &layer_attrs, + std::unordered_map const &inputs, + std::unordered_map const &weights, + std::optional> const + &outputs = std::nullopt) { + + return get_single_output(add_parallel_layer(pcg, layer_attrs, inputs, weights, outputs)); +} + +parallel_tensor_guid_t add_input_layer( + ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { + + return get_single_output(pcg_add_input_layer(pcg, tensor_shape)); +} + +parallel_tensor_guid_t add_weight_layer( + ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { + + WeightAttrs weight_attrs = WeightAttrs{ + /*tensor_shape=*/tensor_shape, + /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, + }; + + return add_single_output_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); +} + +parallel_tensor_guid_t add_replicate_layer( + ParallelComputationGraph &pcg, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + ReplicateAttrs replicate_attrs = ReplicateAttrs{ + /*replicate_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(replicate_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t add_reduction_layer( + ParallelComputationGraph &pcg, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + ReductionAttrs reduction_attrs = ReductionAttrs{ + /*reduction_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(reduction_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + + +parallel_tensor_guid_t add_partition_layer( + ParallelComputationGraph &pcg, + ff_dim_t dim, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + RepartitionAttrs partition_attrs = RepartitionAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(partition_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t add_combine_layer( + ParallelComputationGraph &pcg, + ff_dim_t dim, + positive_int degree, + parallel_tensor_guid_t const &t_input) { + + CombineAttrs partition_attrs = CombineAttrs{ + /*combine_dim=*/dim, + /*combine_degree=*/degree, + }; + + return add_single_output_layer(pcg, + make_layer_attrs(partition_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); +} + +parallel_tensor_guid_t add_linear_layer( + ParallelComputationGraph &pcg, + LinearAttrs const &linear_attrs, + parallel_tensor_guid_t const &t_input, + parallel_tensor_guid_t const &t_weight, + std::optional const &t_bias = std::nullopt, + std::optional const &name = std::nullopt) { + + ASSERT(t_bias.has_value() == linear_attrs.use_bias); + + std::unordered_map weights = { + {TensorSlotName::WEIGHT, t_weight}, + }; + + if (t_bias.has_value()) { + weights.insert({TensorSlotName::BIAS, t_bias.value()}); + } + + return add_single_output_layer(pcg, + make_layer_attrs(linear_attrs, name), + {{TensorSlotName::INPUT, t_input}}, + weights); +} + +parallel_tensor_guid_t add_attention_layer( + ParallelComputationGraph &pcg, + MultiHeadAttentionAttrs const &attn_attrs, + parallel_tensor_guid_t const &t_query, + parallel_tensor_guid_t const &t_key, + parallel_tensor_guid_t const &t_value, + parallel_tensor_guid_t const &t_weights, + std::optional const &name = std::nullopt) { + + return add_single_output_layer(pcg, + make_layer_attrs(attn_attrs, name), + { + {TensorSlotName::QUERY, t_query}, + {TensorSlotName::KEY, t_key}, + {TensorSlotName::VALUE, t_value}, + }, + {{TensorSlotName::WEIGHT, t_weights}}); +} + + + +parallel_tensor_guid_t add_conv2d_layer( + ParallelComputationGraph &pcg, + Conv2DAttrs const &conv2d_attrs, + parallel_tensor_guid_t const &t_input, + parallel_tensor_guid_t const &t_filter, + std::optional const &bias = std::nullopt, + std::optional const &name = std::nullopt) { + + ASSERT(bias.has_value() == conv2d_attrs.use_bias); + + std::unordered_map weights = { + {TensorSlotName::FILTER, t_filter}, + }; + + if (bias.has_value()) { + weights.insert({TensorSlotName::BIAS, bias.value()}); + } + + return add_single_output_layer(pcg, + make_layer_attrs(conv2d_attrs, name), + {{TensorSlotName::INPUT, t_input}}, + weights); +} + + + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_substitution_set") { MachineComputeSpecification machine_spec = MachineComputeSpecification{ @@ -47,7 +218,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("create_replicate_linear_combine, use_bias = false") { positive_int num_dims = 1_p; - positive_int degree = 1_p; + positive_int degree = 2_p; std::string linear_match = "linear_match"; Substitution sub = create_replicate_linear_combine(num_dims, degree, false); @@ -74,41 +245,26 @@ TEST_SUITE(FF_TEST_SUITE) { /*replicate_degree=*/degree, }; - WeightAttrs projection_weight_attrs = WeightAttrs{ - /*tensor_shape=*/throw_if_unexpected( - get_projection_shape(linear_attrs, input_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; + TensorShape projection_weight_shape = throw_if_unexpected( + get_projection_shape(linear_attrs, input_shape)); RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ /*repartition_dim=*/ff_dim_t{1_n}, /*repartition_degree=*/degree, }; - CombineAttrs combine_op_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}}, - /*combine_degree=*/degree, - }; + ff_dim_t combine_dim = ff_dim_t{ + nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + add_input_layer(pcg, input_shape); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); - ParallelLayerAddedResult linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs, linear_match), - {t_input}, - {t_projection_weight}); + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, /*bias=*/std::nullopt, linear_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -117,9 +273,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, linear_match); open_parallel_tensor_guid_t match_layer_input_activations = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -127,11 +283,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input_activations, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_input_weights, }}, }; @@ -143,40 +299,16 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_replicated_input = + add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); - ParallelLayerAddedResult replicate_input_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_input_attrs), {t_input}, {}); - parallel_tensor_guid_t t_replicated_input = - get_only(replicate_input_added.outputs); - - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_partitioned_projection_weight = + add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, projection_weight_shape)); - ParallelLayerAddedResult partition_projection_added = - add_parallel_layer(pcg, - make_layer_attrs(partition_projection_attrs), - {t_projection_weight}, - {}); - parallel_tensor_guid_t t_partitioned_projection_weight = - get_only(partition_projection_added.outputs); - - ParallelLayerAddedResult replicate_linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs), - {t_replicated_input}, - {t_partitioned_projection_weight}); parallel_tensor_guid_t t_replicated_linear = - get_only(replicate_linear_added.outputs); + add_linear_layer(pcg, linear_attrs, t_replicated_input, t_partitioned_projection_weight); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_op_attrs), {t_replicated_linear}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, combine_dim, degree, t_replicated_input); return sub_pcg_from_full_pcg(pcg); }(); @@ -186,7 +318,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("create_replicate_linear_combine, use_bias = true") { positive_int num_dims = 1_p; - positive_int degree = 1_p; + positive_int degree = 2_p; std::string linear_match = "linear_match"; Substitution sub = create_replicate_linear_combine(num_dims, degree, true); @@ -209,55 +341,26 @@ TEST_SUITE(FF_TEST_SUITE) { /*regularizer=*/std::nullopt, }; - ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ - /*replicate_degree=*/degree, - }; - - WeightAttrs projection_weight_attrs = WeightAttrs{ - /*tensor_shape=*/throw_if_unexpected( - get_projection_shape(linear_attrs, input_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; + TensorShape projection_weight_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - WeightAttrs bias_attrs = WeightAttrs{ - /*tensor_shape=*/throw_if_unexpected( - get_bias_shape(linear_attrs, input_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; - - RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1_n}, - /*repartition_degree=*/degree, - }; + TensorShape bias_shape = throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); - CombineAttrs combine_op_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}}, - /*combine_degree=*/degree, - }; + ff_dim_t combine_dim = ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + add_input_layer(pcg, input_shape); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + add_weight_layer(pcg, projection_weight_shape); - ParallelLayerAddedResult bias_added = - add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); - parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); + parallel_tensor_guid_t t_bias = + add_weight_layer(pcg, bias_shape); - ParallelLayerAddedResult linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs, linear_match), - {t_input}, - {t_projection_weight, t_bias}); + parallel_tensor_guid_t t_linear = + add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, t_bias); return sub_pcg_from_full_pcg(pcg); }(); @@ -266,11 +369,11 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, linear_match); open_parallel_tensor_guid_t match_layer_input_activations = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); open_parallel_tensor_guid_t match_layer_input_bias = - get_layer_inputs(original_pcg, match_layer).at(2); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::OUTPUT); return PCGPatternMatch{ bidict{ @@ -278,15 +381,15 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input_activations, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_input_weights, }, { - PatternInput{DataflowGraphInput{4}}, + PatternInput{KwargDataflowGraphInput{4}}, match_layer_input_bias, }}, }; @@ -298,49 +401,20 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); - - ParallelLayerAddedResult replicate_input_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_input_attrs), {t_input}, {}); parallel_tensor_guid_t t_replicated_input = - get_only(replicate_input_added.outputs); - - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); - ParallelLayerAddedResult partition_projection_added = - add_parallel_layer(pcg, - make_layer_attrs(partition_projection_attrs), - {t_projection_weight}, - {}); parallel_tensor_guid_t t_partitioned_projection_weight = - get_only(partition_projection_added.outputs); - - ParallelLayerAddedResult bias_added = - add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); - parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); - - ParallelLayerAddedResult partition_bias_added = add_parallel_layer( - pcg, make_layer_attrs(partition_projection_attrs), {t_bias}, {}); - parallel_tensor_guid_t t_partitioned_bias = - get_only(partition_bias_added.outputs); - - ParallelLayerAddedResult replicate_linear_added = add_parallel_layer( - pcg, - make_layer_attrs(linear_attrs), - {t_replicated_input}, - {t_partitioned_projection_weight, t_partitioned_bias}); - parallel_tensor_guid_t t_replicated_linear = - get_only(replicate_linear_added.outputs); + add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, projection_weight_shape)); + + parallel_tensor_guid_t t_partitioned_bias = + add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, bias_shape)); + + parallel_tensor_guid_t t_replicated_linear = + add_linear_layer(pcg, linear_attrs, t_replicated_linear, t_partitioned_projection_weight, t_partitioned_bias); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_op_attrs), {t_replicated_linear}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_replicated_linear); return sub_pcg_from_full_pcg(pcg); }(); @@ -373,45 +447,21 @@ TEST_SUITE(FF_TEST_SUITE) { /*regularizer=*/std::nullopt, }; - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{0_n}, - /*repartition_degree=*/degree, - }; - - WeightAttrs projection_weight_attrs = WeightAttrs{ - /*tensor_shape=*/throw_if_unexpected( - get_projection_shape(linear_attrs, input_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; + TensorShape projection_weight_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ - /*replicate_degree=*/degree, - }; - - CombineAttrs combine_op_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}}, - /*combine_degree=*/degree, - }; + ff_dim_t combine_dim = ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + add_input_layer(pcg, input_shape); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); - ParallelLayerAddedResult linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs, linear_match), - {t_input}, - {t_projection_weight}); + parallel_tensor_guid_t t_linear = + add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, /*bias=*/std::nullopt, linear_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -420,9 +470,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, linear_match); open_parallel_tensor_guid_t match_layer_input_activations = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -430,11 +480,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input_activations, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_input_weights, }}, }; @@ -446,40 +496,18 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + add_input_layer(pcg, input_shape); - ParallelLayerAddedResult partition_input_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); - parallel_tensor_guid_t t_partitioned_input = - get_only(partition_input_added.outputs); + parallel_tensor_guid_t t_partitioned_input = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_replicated_projection_weight = + add_replicate_layer(pcg, degree, add_weight_layer(pcg, projection_weight_shape)); - ParallelLayerAddedResult replicate_projection_added = - add_parallel_layer(pcg, - make_layer_attrs(replicate_projection_attrs), - {t_projection_weight}, - {}); - parallel_tensor_guid_t t_replicated_projection_weight = - get_only(replicate_projection_added.outputs); - - ParallelLayerAddedResult partition_linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs), - {t_partitioned_input}, - {t_replicated_projection_weight}); parallel_tensor_guid_t t_partitioned_linear = - get_only(partition_linear_added.outputs); + add_linear_layer(pcg, linear_attrs, t_partitioned_input, t_replicated_projection_weight); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_op_attrs), {t_partitioned_linear}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, combine_dim, degree, t_partitioned_input); return sub_pcg_from_full_pcg(pcg); }(); @@ -512,55 +540,20 @@ TEST_SUITE(FF_TEST_SUITE) { /*regularizer=*/std::nullopt, }; - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{0_n}, - /*repartition_degree=*/degree, - }; - - WeightAttrs projection_weight_attrs = WeightAttrs{ - /*tensor_shape=*/throw_if_unexpected( - get_projection_shape(linear_attrs, input_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; - - WeightAttrs bias_attrs = WeightAttrs{ - /*tensor_shape=*/throw_if_unexpected( - get_bias_shape(linear_attrs, input_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; + TensorShape projection_weight_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - ReplicateAttrs replicate_projection_attrs = ReplicateAttrs{ - /*replicate_degree=*/degree, - }; + TensorShape bias_shape = throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); - CombineAttrs combine_op_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}}, - /*combine_degree=*/degree, - }; + ff_dim_t combine_dim = ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); + parallel_tensor_guid_t t_bias = add_weight_layer(pcg, bias_shape); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); - - ParallelLayerAddedResult bias_added = - add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); - parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); - - ParallelLayerAddedResult linear_added = - add_parallel_layer(pcg, - make_layer_attrs(linear_attrs, linear_match), - {t_input}, - {t_projection_weight, t_bias}); + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, t_bias, linear_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -568,13 +561,13 @@ TEST_SUITE(FF_TEST_SUITE) { PCGPatternMatch match = [&] { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, linear_match); - std::cout << get_layer_inputs(original_pcg, match_layer) << std::endl; + open_parallel_tensor_guid_t match_layer_input_activations = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); open_parallel_tensor_guid_t match_layer_input_bias = - get_layer_inputs(original_pcg, match_layer).at(2); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::BIAS); return PCGPatternMatch{ bidict{ @@ -582,69 +575,35 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input_activations, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_input_weights, }, { - PatternInput{DataflowGraphInput{4}}, + PatternInput{KwargDataflowGraphInput{4}}, match_layer_input_bias, }}, }; }(); - SubParallelComputationGraph result = - apply_substitution(original_pcg, sub, match); + SubParallelComputationGraph result = apply_substitution(original_pcg, sub, match); SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); - - ParallelLayerAddedResult partition_input_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); parallel_tensor_guid_t t_partitioned_input = - get_only(partition_input_added.outputs); + add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_replicated_projection_weight = add_replicate_layer(pcg, degree, add_weight_layer(pcg, projection_weight_shape)); - ParallelLayerAddedResult replicate_projection_added = - add_parallel_layer(pcg, - make_layer_attrs(replicate_projection_attrs), - {t_projection_weight}, - {}); - parallel_tensor_guid_t t_replicated_projection_weight = - get_only(replicate_projection_added.outputs); - - ParallelLayerAddedResult bias_added = - add_parallel_layer(pcg, make_layer_attrs(bias_attrs), {}, {}); - parallel_tensor_guid_t t_bias = get_only(bias_added.outputs); - - ParallelLayerAddedResult replicate_bias_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_projection_attrs), {t_bias}, {}); - parallel_tensor_guid_t t_replicated_bias = - get_only(replicate_bias_added.outputs); - - ParallelLayerAddedResult partition_linear_added = add_parallel_layer( - pcg, - make_layer_attrs(linear_attrs), - {t_partitioned_input}, - {t_replicated_projection_weight, t_replicated_bias}); - parallel_tensor_guid_t t_partitioned_linear = - get_only(partition_linear_added.outputs); + parallel_tensor_guid_t t_replicated_bias = add_replicate_layer(pcg, degree, add_weight_layer(pcg, bias_shape)); + + parallel_tensor_guid_t t_partitioned_linear = add_linear_layer(pcg, linear_attrs, t_partitioned_input, t_replicated_projection_weight, t_replicated_bias); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_op_attrs), {t_partitioned_linear}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, combine_dim, degree, t_partitioned_linear); return sub_pcg_from_full_pcg(pcg); }(); @@ -661,7 +620,7 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int paddingH = 1_n; nonnegative_int paddingW = 0_n; positive_int num_dims = 4_p; - positive_int degree = 1_p; + positive_int degree = 2_p; std::string conv2d_match = "conv2d_match"; Substitution sub = create_partition_conv2d_combine(num_dims, degree); @@ -671,66 +630,40 @@ TEST_SUITE(FF_TEST_SUITE) { FFOrdered{ 12_p, 3_p, - 10_p, + 12_p, 10_p, }, }, DataType::FLOAT, }; - Conv2DAttrs conv2d_attrs = Conv2DAttrs{/*outChannels=*/outChannels, - /*kernelH=*/kernelH, - /*kernelW=*/kernelW, - /*strideH=*/strideH, - /*strideW=*/strideW, - /*paddingH=*/paddingH, - /*paddingW=*/paddingW, - /*groups=*/1_p, - /*activation=*/std::nullopt, - /*use_bias=*/false}; - - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{0_n}, - /*repartition_degree=*/degree, - }; - - ReplicateAttrs replicate_weight_attrs = ReplicateAttrs{ - /*replicate_degree=*/degree, - }; - - CombineAttrs combine_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}}, - /*combine_degree=*/degree, + Conv2DAttrs conv2d_attrs = Conv2DAttrs{ + /*outChannels=*/outChannels, + /*kernelH=*/kernelH, + /*kernelW=*/kernelW, + /*strideH=*/strideH, + /*strideW=*/strideW, + /*paddingH=*/paddingH, + /*paddingW=*/paddingW, + /*groups=*/1_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, }; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = + add_input_layer(pcg, input_shape); TensorShape casted_input_shape = get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); - WeightAttrs projection_weight_attrs = WeightAttrs{ - /*tensor_shape=*/ - get_weight_shapes(conv2d_attrs, casted_input_shape).at(0), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; + TensorShape projection_weight_shape = get_weight_shapes(conv2d_attrs, casted_input_shape).at(TensorSlotName::FILTER); - ParallelLayerAddedResult projection_weight_added = add_parallel_layer( - pcg, make_layer_attrs(projection_weight_attrs), {}, {}); - parallel_tensor_guid_t t_projection_weight = - get_only(projection_weight_added.outputs); + parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); - ParallelLayerAddedResult conv_2d_added = - add_parallel_layer(pcg, - make_layer_attrs(conv2d_attrs, conv2d_match), - {t_input}, - {t_projection_weight}); + parallel_tensor_guid_t t_conv = add_conv2d_layer(pcg, conv2d_attrs, t_input, t_projection_weight, /*bias=*/std::nullopt, conv2d_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -739,9 +672,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, conv2d_match); open_parallel_tensor_guid_t match_layer_input_activations = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::FILTER); return PCGPatternMatch{ bidict{ @@ -749,11 +682,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input_activations, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_input_weights, }}, }; @@ -765,45 +698,19 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); - - ParallelLayerAddedResult partition_input_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); - parallel_tensor_guid_t t_partitioned_input = - get_only(partition_input_added.outputs); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_partitioned_input = add_partition_layer(pcg, ff_dim_t{0_n}, degree, t_input); TensorShape casted_input_shape = get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); - WeightAttrs weight_attrs = WeightAttrs{ - /*tensor_shape=*/ - get_weight_shapes(conv2d_attrs, casted_input_shape).at(0), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; - - ParallelLayerAddedResult weight_added = - add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); - parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + TensorShape weight_shape = get_weight_shapes(conv2d_attrs, casted_input_shape).at(TensorSlotName::FILTER); - ParallelLayerAddedResult replicate_weight_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_weight_attrs), {t_weight}, {}); - parallel_tensor_guid_t t_replicated_weight = - get_only(replicate_weight_added.outputs); + parallel_tensor_guid_t t_replicated_weight = add_replicate_layer(pcg, degree, add_weight_layer(pcg, weight_shape)); - ParallelLayerAddedResult partition_conv2d_added = - add_parallel_layer(pcg, - make_layer_attrs(conv2d_attrs), - {t_partitioned_input}, - {t_replicated_weight}); - parallel_tensor_guid_t t_partitioned_conv2d = - get_only(partition_conv2d_added.outputs); + parallel_tensor_guid_t t_partitioned_conv2d = add_conv2d_layer(pcg, conv2d_attrs, t_partitioned_input, t_replicated_weight); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_attrs), {t_partitioned_conv2d}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, ff_dim_t{0_n}, degree, t_partitioned_conv2d); return sub_pcg_from_full_pcg(pcg); }(); @@ -814,7 +721,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("create_partition_attention_combine") { positive_int embed_dim = 8_p; positive_int num_heads = 6_p; - positive_int degree = 1_p; + positive_int degree = 2_p; std::string attention_match = "attention_match"; Substitution sub = create_partition_attention_combine(num_heads, degree); @@ -843,50 +750,18 @@ TEST_SUITE(FF_TEST_SUITE) { /*add_zero_attn=*/false, }; - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{0_n}, - /*repartition_degree=*/degree, - }; - - WeightAttrs weight_attrs = WeightAttrs{ - /*tensor_shape=*/ - throw_if_unexpected(get_weights_shape( - attention_attrs, query_shape, key_shape, value_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; - - ReplicateAttrs replicate_weight_attrs = ReplicateAttrs{ - /*replicate_degree=*/degree, - }; - - CombineAttrs combine_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{2_n}, - /*combine_degree=*/degree, - }; + TensorShape weights_shape = throw_if_unexpected(get_weights_shape(attention_attrs, query_shape, key_shape, value_shape)); SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult query_added = - pcg_add_input_layer(pcg, query_shape); - parallel_tensor_guid_t t_query = get_only(query_added.outputs); - - ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); - parallel_tensor_guid_t t_key = get_only(key_added.outputs); - - ParallelLayerAddedResult value_added = - pcg_add_input_layer(pcg, value_shape); - parallel_tensor_guid_t t_value = get_only(value_added.outputs); + parallel_tensor_guid_t t_query = add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_key = add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_value = add_input_layer(pcg, value_shape); - ParallelLayerAddedResult weight_added = - add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); - parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); + parallel_tensor_guid_t t_weights = add_weight_layer(pcg, weights_shape); - ParallelLayerAddedResult attention_added = - add_parallel_layer(pcg, - make_layer_attrs(attention_attrs, attention_match), - {t_query, t_key, t_value}, - {t_weight}); + parallel_tensor_guid_t t_attention = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weights, attention_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -895,13 +770,13 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, attention_match); open_parallel_tensor_guid_t match_layer_query = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::QUERY); open_parallel_tensor_guid_t match_layer_key = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::KEY); open_parallel_tensor_guid_t match_layer_value = - get_layer_inputs(original_pcg, match_layer).at(2); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::VALUE); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(3); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -909,19 +784,19 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_query, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_key, }, { - PatternInput{DataflowGraphInput{4}}, + PatternInput{KwargDataflowGraphInput{4}}, match_layer_value, }, { - PatternInput{DataflowGraphInput{6}}, + PatternInput{KwargDataflowGraphInput{6}}, match_layer_input_weights, }}, }; @@ -933,52 +808,17 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult query_added = - pcg_add_input_layer(pcg, query_shape); - parallel_tensor_guid_t t_query = get_only(query_added.outputs); - - ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); - parallel_tensor_guid_t t_key = get_only(key_added.outputs); - - ParallelLayerAddedResult value_added = - pcg_add_input_layer(pcg, value_shape); - parallel_tensor_guid_t t_value = get_only(value_added.outputs); - - ParallelLayerAddedResult weight_added = - add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); - parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); - - ParallelLayerAddedResult partition_query_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_query}, {}); - parallel_tensor_guid_t t_partitioned_query = - get_only(partition_query_added.outputs); - - ParallelLayerAddedResult partition_key_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_key}, {}); - parallel_tensor_guid_t t_partitioned_key = - get_only(partition_key_added.outputs); - - ParallelLayerAddedResult partition_value_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_value}, {}); - parallel_tensor_guid_t t_partitioned_value = - get_only(partition_value_added.outputs); - - ParallelLayerAddedResult replicate_weight_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_weight_attrs), {t_weight}, {}); - parallel_tensor_guid_t t_replicated_weight = - get_only(replicate_weight_added.outputs); - - ParallelLayerAddedResult partition_attention_added = add_parallel_layer( - pcg, - make_layer_attrs(attention_attrs), - {t_partitioned_query, t_partitioned_key, t_partitioned_value}, - {t_replicated_weight}); - parallel_tensor_guid_t t_partitioned_attention = - get_only(partition_attention_added.outputs); - - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_attrs), {t_partitioned_attention}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_query = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, query_shape)); + parallel_tensor_guid_t t_key = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, key_shape)); + parallel_tensor_guid_t t_value = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, value_shape)); + + parallel_tensor_guid_t t_weight = add_replicate_layer(pcg, degree, add_weight_layer(pcg, weights_shape)); + + + + parallel_tensor_guid_t t_partitioned_attention = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weight); + + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, ff_dim_t{0_n}, degree, t_partitioned_attention); return sub_pcg_from_full_pcg(pcg); }(); @@ -989,7 +829,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("create_replicate_attention_reduce") { positive_int embed_dim = 8_p; positive_int num_heads = 6_p; - positive_int degree = 1_p; + positive_int degree = 2_p; std::string attention_match = "attention_match"; Substitution sub = create_replicate_attention_reduce(num_heads, degree); @@ -1018,49 +858,18 @@ TEST_SUITE(FF_TEST_SUITE) { /*add_zero_attn=*/false, }; - ReplicateAttrs replicate_input_attrs = ReplicateAttrs{ - /*replicate_degree=*/degree, - }; - - WeightAttrs weight_attrs = WeightAttrs{ - /*tensor_shape=*/ - throw_if_unexpected(get_weights_shape( - attention_attrs, query_shape, key_shape, value_shape)), - /*initializer=*/InitializerAttrs{ZeroInitializerAttrs{}}, - }; - - RepartitionAttrs partition_weight_attrs = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1_n}, - /*repartition_degree=*/degree, - }; - - ReductionAttrs reduction_attrs = ReductionAttrs{ - /*reduction_degree=*/degree, - }; + TensorShape weight_shape = throw_if_unexpected(get_weights_shape(attention_attrs, query_shape, key_shape, value_shape)); SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult query_added = - pcg_add_input_layer(pcg, query_shape); - parallel_tensor_guid_t t_query = get_only(query_added.outputs); + parallel_tensor_guid_t t_query = add_input_layer(pcg, query_shape); + parallel_tensor_guid_t t_key = add_input_layer(pcg, key_shape); + parallel_tensor_guid_t t_value = add_input_layer(pcg, value_shape); - ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); - parallel_tensor_guid_t t_key = get_only(key_added.outputs); + parallel_tensor_guid_t t_weight = add_weight_layer(pcg, weight_shape); - ParallelLayerAddedResult value_added = - pcg_add_input_layer(pcg, value_shape); - parallel_tensor_guid_t t_value = get_only(value_added.outputs); - - ParallelLayerAddedResult weight_added = - add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); - parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); - - ParallelLayerAddedResult attention_added = - add_parallel_layer(pcg, - make_layer_attrs(attention_attrs, attention_match), - {t_query, t_key, t_value}, - {t_weight}); + parallel_tensor_guid_t attention_added = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weight, attention_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -1069,13 +878,13 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, attention_match); open_parallel_tensor_guid_t match_layer_query = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::QUERY); open_parallel_tensor_guid_t match_layer_key = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::KEY); open_parallel_tensor_guid_t match_layer_value = - get_layer_inputs(original_pcg, match_layer).at(2); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::VALUE); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(3); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -1083,19 +892,19 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_query, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, match_layer_key, }, { - PatternInput{DataflowGraphInput{4}}, + PatternInput{KwargDataflowGraphInput{4}}, match_layer_value, }, { - PatternInput{DataflowGraphInput{6}}, + PatternInput{KwargDataflowGraphInput{6}}, match_layer_input_weights, }}, }; @@ -1107,52 +916,15 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult query_added = - pcg_add_input_layer(pcg, query_shape); - parallel_tensor_guid_t t_query = get_only(query_added.outputs); - - ParallelLayerAddedResult key_added = pcg_add_input_layer(pcg, key_shape); - parallel_tensor_guid_t t_key = get_only(key_added.outputs); - - ParallelLayerAddedResult value_added = - pcg_add_input_layer(pcg, value_shape); - parallel_tensor_guid_t t_value = get_only(value_added.outputs); - - ParallelLayerAddedResult weight_added = - add_parallel_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); - parallel_tensor_guid_t t_weight = get_only(weight_added.outputs); - - ParallelLayerAddedResult replicate_query_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_input_attrs), {t_query}, {}); - parallel_tensor_guid_t t_replicated_query = - get_only(replicate_query_added.outputs); - - ParallelLayerAddedResult replicate_key_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_input_attrs), {t_key}, {}); - parallel_tensor_guid_t t_replicated_key = - get_only(replicate_key_added.outputs); - - ParallelLayerAddedResult replicate_value_added = add_parallel_layer( - pcg, make_layer_attrs(replicate_input_attrs), {t_value}, {}); - parallel_tensor_guid_t t_replicated_value = - get_only(replicate_value_added.outputs); - - ParallelLayerAddedResult partition_weight_added = add_parallel_layer( - pcg, make_layer_attrs(partition_weight_attrs), {t_weight}, {}); - parallel_tensor_guid_t t_partitioned_weight = - get_only(partition_weight_added.outputs); - - ParallelLayerAddedResult replicate_attention_added = add_parallel_layer( - pcg, - make_layer_attrs(attention_attrs), - {t_replicated_query, t_replicated_key, t_replicated_value}, - {t_partitioned_weight}); - parallel_tensor_guid_t t_replicated_attention = - get_only(replicate_attention_added.outputs); - - ParallelLayerAddedResult reduce_added = add_parallel_layer( - pcg, make_layer_attrs(reduction_attrs), {t_replicated_attention}, {}); - parallel_tensor_guid_t t_reduction = get_only(reduce_added.outputs); + parallel_tensor_guid_t t_query = add_replicate_layer(pcg, degree, add_input_layer(pcg, query_shape)); + parallel_tensor_guid_t t_key = add_replicate_layer(pcg, degree, add_input_layer(pcg, key_shape)); + parallel_tensor_guid_t t_value = add_replicate_layer(pcg, degree, add_input_layer(pcg, value_shape)); + + parallel_tensor_guid_t t_weight = add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, weight_shape)); + + parallel_tensor_guid_t t_replicated_attention = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weight); + + parallel_tensor_guid_t t_reduction = add_reduction_layer(pcg, degree, t_replicated_attention); return sub_pcg_from_full_pcg(pcg); }(); @@ -1161,7 +933,7 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("create_partition_softmax_combine") { - positive_int degree = 1_p; + positive_int degree = 2_p; ff_dim_t softmax_dim = ff_dim_t{1_n}; ff_dim_t partition_dim = ff_dim_t{0_n}; std::string softmax_match = "softmax_match"; @@ -1183,26 +955,13 @@ TEST_SUITE(FF_TEST_SUITE) { /*softmax_dim=*/softmax_dim, }; - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/partition_dim, - /*repartition_degree=*/degree, - }; - - CombineAttrs combine_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{partition_dim}, - /*combine_degree=*/degree, - }; - SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - ParallelLayerAddedResult softmax_added = add_parallel_layer( - pcg, make_layer_attrs(softmax_attrs, softmax_match), {t_input}, {}); + parallel_tensor_guid_t t_softmax = add_single_output_layer( + pcg, make_layer_attrs(softmax_attrs, softmax_match), {{TensorSlotName::INPUT, t_input}}, {}); return sub_pcg_from_full_pcg(pcg); }(); @@ -1211,14 +970,14 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, softmax_match); open_parallel_tensor_guid_t match_layer_input = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); return PCGPatternMatch{ bidict{ {PatternNode{Node{0}}, match_layer}, }, std::unordered_map{{ - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input, }}, }; @@ -1230,24 +989,12 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_partitioned_input = add_partition_layer(pcg, partition_dim, degree, add_input_layer(pcg, input_shape)); - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_partitioned_softmax = add_single_output_layer( + pcg, make_layer_attrs(softmax_attrs), {{TensorSlotName::INPUT, t_partitioned_input}}, {}); - ParallelLayerAddedResult partition_input_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); - parallel_tensor_guid_t t_partitioned_input = - get_only(partition_input_added.outputs); - - ParallelLayerAddedResult partition_softmax_added = add_parallel_layer( - pcg, make_layer_attrs(softmax_attrs), {t_partitioned_input}, {}); - parallel_tensor_guid_t t_partitioned_softmax = - get_only(partition_softmax_added.outputs); - - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_attrs), {t_partitioned_softmax}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, partition_dim, degree, t_partitioned_softmax); return sub_pcg_from_full_pcg(pcg); }(); @@ -1256,7 +1003,7 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("create_partition_add_combine") { - positive_int degree = 1_p; + positive_int degree = 2_p; ff_dim_t parallel_dim = ff_dim_t{1_n}; std::string add_match = "add_match"; @@ -1281,27 +1028,14 @@ TEST_SUITE(FF_TEST_SUITE) { false, }; - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/parallel_dim, - /*repartition_degree=*/degree, - }; - - CombineAttrs combine_attrs = CombineAttrs{ - /*combine_dim=*/parallel_dim, - /*combine_degree=*/degree, - }; - SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult lhs_added = pcg_add_input_layer(pcg, lhs_shape); - parallel_tensor_guid_t t_lhs = get_only(lhs_added.outputs); - - ParallelLayerAddedResult rhs_added = pcg_add_input_layer(pcg, rhs_shape); - parallel_tensor_guid_t t_rhs = get_only(rhs_added.outputs); + parallel_tensor_guid_t t_lhs = add_input_layer(pcg, lhs_shape); + parallel_tensor_guid_t t_rhs = add_input_layer(pcg, rhs_shape); - ParallelLayerAddedResult output_added = add_parallel_layer( - pcg, make_layer_attrs(add_attrs, add_match), {t_lhs, t_rhs}, {}); + parallel_tensor_guid_t t_add = add_single_output_layer( + pcg, make_layer_attrs(add_attrs, add_match), {{TensorSlotName::LHS_INPUT, t_lhs}, {TensorSlotName::RHS_INPUT, t_rhs},}, {}); return sub_pcg_from_full_pcg(pcg); }(); @@ -1310,9 +1044,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, add_match); open_parallel_tensor_guid_t add_match_layer_lhs = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::LHS_INPUT); open_parallel_tensor_guid_t add_match_layer_rhs = - get_layer_inputs(original_pcg, match_layer).at(1); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::RHS_INPUT); return PCGPatternMatch{ bidict{ @@ -1320,11 +1054,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, add_match_layer_lhs, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, add_match_layer_rhs, }}, }; @@ -1336,33 +1070,19 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult lhs_added = pcg_add_input_layer(pcg, lhs_shape); - parallel_tensor_guid_t t_lhs = get_only(lhs_added.outputs); - - ParallelLayerAddedResult rhs_added = pcg_add_input_layer(pcg, rhs_shape); - parallel_tensor_guid_t t_rhs = get_only(rhs_added.outputs); - - ParallelLayerAddedResult partition_lhs_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_lhs}, {}); - parallel_tensor_guid_t t_partitioned_lhs = - get_only(partition_lhs_added.outputs); - - ParallelLayerAddedResult partition_rhs_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_rhs}, {}); - parallel_tensor_guid_t t_partitioned_rhs = - get_only(partition_rhs_added.outputs); + parallel_tensor_guid_t t_lhs = add_partition_layer(pcg, parallel_dim, degree, add_input_layer(pcg, lhs_shape)); + parallel_tensor_guid_t t_rhs = add_partition_layer(pcg, parallel_dim, degree, add_input_layer(pcg, rhs_shape)); - ParallelLayerAddedResult partition_add_added = - add_parallel_layer(pcg, + parallel_tensor_guid_t t_partitioned_add = + add_single_output_layer(pcg, make_layer_attrs(add_attrs, add_match), - {t_partitioned_lhs, t_partitioned_rhs}, + { + {TensorSlotName::LHS_INPUT, t_lhs}, + {TensorSlotName::RHS_INPUT, t_rhs}, + }, {}); - parallel_tensor_guid_t t_partitioned_add = - get_only(partition_add_added.outputs); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_attrs), {t_partitioned_add}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, parallel_dim, degree, t_partitioned_add); return sub_pcg_from_full_pcg(pcg); }(); @@ -1371,7 +1091,7 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("create_partition_relu_combine") { - positive_int degree = 1_p; + positive_int degree = 2_p; ff_dim_t parallel_dim = ff_dim_t{1_n}; std::string relu_match = "relu_match"; @@ -1392,26 +1112,13 @@ TEST_SUITE(FF_TEST_SUITE) { std::nullopt, }; - RepartitionAttrs partition_input_attrs = RepartitionAttrs{ - /*repartition_dim=*/parallel_dim, - /*repartition_degree=*/degree, - }; - - CombineAttrs combine_attrs = CombineAttrs{ - /*combine_dim=*/ff_dim_t{parallel_dim}, - /*combine_degree=*/degree, - }; - SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - ParallelLayerAddedResult relu_added = add_parallel_layer( - pcg, make_layer_attrs(relu_attrs, relu_match), {t_input}, {}); + parallel_tensor_guid_t t_relu = add_single_output_layer( + pcg, make_layer_attrs(relu_attrs, relu_match), {{TensorSlotName::INPUT, t_input}}, {}); return sub_pcg_from_full_pcg(pcg); }(); @@ -1420,14 +1127,14 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, relu_match); open_parallel_tensor_guid_t match_layer_input = - get_layer_inputs(original_pcg, match_layer).at(0); + get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); return PCGPatternMatch{ bidict{ {PatternNode{Node{0}}, match_layer}, }, std::unordered_map{{ - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, match_layer_input, }}, }; @@ -1439,24 +1146,12 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input_added = - pcg_add_input_layer(pcg, input_shape); - - parallel_tensor_guid_t t_input = get_only(input_added.outputs); - - ParallelLayerAddedResult partition_input_added = add_parallel_layer( - pcg, make_layer_attrs(partition_input_attrs), {t_input}, {}); - parallel_tensor_guid_t t_partitioned_input = - get_only(partition_input_added.outputs); + parallel_tensor_guid_t t_input = add_partition_layer(pcg, parallel_dim, degree, add_input_layer(pcg, input_shape)); - ParallelLayerAddedResult partition_relu_added = add_parallel_layer( - pcg, make_layer_attrs(relu_attrs), {t_partitioned_input}, {}); - parallel_tensor_guid_t t_partitioned_relu = - get_only(partition_relu_added.outputs); + parallel_tensor_guid_t t_relu = add_single_output_layer( + pcg, make_layer_attrs(relu_attrs), {{TensorSlotName::INPUT, t_input}}, {}); - ParallelLayerAddedResult combine_added = add_parallel_layer( - pcg, make_layer_attrs(combine_attrs), {t_partitioned_relu}, {}); - parallel_tensor_guid_t t_combine = get_only(combine_added.outputs); + parallel_tensor_guid_t t_combine = add_combine_layer(pcg, parallel_dim, degree, t_relu); return sub_pcg_from_full_pcg(pcg); }(); @@ -1503,9 +1198,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t relu_match_layer = get_parallel_layer_by_name(pcg, relu_match); open_parallel_tensor_guid_t mm_match_layer_input_activations = - get_layer_inputs(pcg, mm_match_layer).at(0); + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t mm_match_layer_input_weights = - get_layer_inputs(pcg, mm_match_layer).at(1); + get_layer_inputs(pcg, mm_match_layer).at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -1514,11 +1209,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, std::unordered_map{ { - PatternInput{DataflowGraphInput{0}}, + PatternInput{KwargDataflowGraphInput{0}}, mm_match_layer_input_activations, }, { - PatternInput{DataflowGraphInput{2}}, + PatternInput{KwargDataflowGraphInput{2}}, mm_match_layer_input_weights, }}, }; diff --git a/lib/utils/include/utils/positive_int/positive_range.h b/lib/utils/include/utils/positive_int/positive_range.h new file mode 100644 index 0000000000..f064f766c8 --- /dev/null +++ b/lib/utils/include/utils/positive_int/positive_range.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_POSITIVE_INT_POSITIVE_RANGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_POSITIVE_INT_POSITIVE_RANGE_H + +#include "utils/positive_int/positive_int.h" + +namespace FlexFlow { + +std::vector + positive_range(positive_int start, positive_int end, int step = 1); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/positive_int/positive_range.cc b/lib/utils/src/utils/positive_int/positive_range.cc new file mode 100644 index 0000000000..8a31ea0505 --- /dev/null +++ b/lib/utils/src/utils/positive_int/positive_range.cc @@ -0,0 +1,14 @@ +#include "utils/positive_int/positive_range.h" +#include "utils/containers/transform.h" +#include "utils/containers/range.h" + +namespace FlexFlow { + +std::vector + positive_range(positive_int start, positive_int end, int step) { + return transform( + range(start.int_from_positive_int(), end.int_from_positive_int(), step), + [](int x) { return positive_int{x}; }); +} + +} // namespace FlexFlow From 611c9625140bf224ce0859993ca1c74f636e3e91 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 10 Jan 2026 01:05:36 -0800 Subject: [PATCH 3/3] Format --- .../perform_shape_inference.cc | 7 +- .../substitutions/unity_substitution_set.cc | 666 +++++++++--------- .../substitutions/unity_substitution_set.cc | 517 ++++++++------ .../src/utils/positive_int/positive_range.cc | 2 +- 4 files changed, 670 insertions(+), 522 deletions(-) diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index 94c3bee5d2..8e1c06b9b5 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -1,8 +1,10 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" +#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" +#include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" @@ -18,8 +20,6 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" #include "utils/nonnegative_int/num_elements.h" -#include "utils/containers/binary_merge_disjoint_maps.h" -#include "utils/containers/is_subseteq_of.h" namespace FlexFlow { @@ -72,7 +72,8 @@ LabelledOpenKwargDataflowGraphView weight_shapes = incoming_shapes_with_role(IncomingTensorRole::WEIGHT); - ASSERT(binary_merge_disjoint_maps(input_shapes, weight_shapes) == incoming_shapes); + ASSERT(binary_merge_disjoint_maps(input_shapes, weight_shapes) == + incoming_shapes); std::unordered_map inferred_weight_shapes = diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 6940e86162..f1d808c5fd 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -37,13 +37,15 @@ std::vector } } - for (positive_int degree = 1_p; degree <= get_num_gpus(resources); degree *= 2_p) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { substitutions.push_back(create_partition_conv2d_combine(4_p, degree)); } for (positive_int partition_dim : positive_range(1_p, max_tensor_dim + 1_p)) { for (positive_int softmax_dim : positive_range(1_p, max_tensor_dim + 1_p)) { - for (positive_int degree = 1_p; degree <= get_num_gpus(resources); degree *= 2_p) { + for (positive_int degree = 1_p; degree <= get_num_gpus(resources); + degree *= 2_p) { if (partition_dim != softmax_dim) { substitutions.push_back(create_partition_softmax_combine( ff_dim_t{partition_dim.nonnegative_int_from_positive_int()}, @@ -61,77 +63,76 @@ std::vector } static PatternValue insert_single_output_pattern( - SubstitutionBuilder &b, - OperatorAttributePattern const &attribute_pattern, - std::unordered_map const &inputs, - TensorAttributePattern const &output_pattern, - std::string const &name) -{ - return require_only_key( - b.add_pattern_node(attribute_pattern, - inputs, - /*output_patterns=*/{ - { - TensorSlotName::OUTPUT, - output_pattern, - }, - }, - name), - TensorSlotName::OUTPUT); + SubstitutionBuilder &b, + OperatorAttributePattern const &attribute_pattern, + std::unordered_map const &inputs, + TensorAttributePattern const &output_pattern, + std::string const &name) { + return require_only_key(b.add_pattern_node(attribute_pattern, + inputs, + /*output_patterns=*/ + { + { + TensorSlotName::OUTPUT, + output_pattern, + }, + }, + name), + TensorSlotName::OUTPUT); } - static OutputGraphExprValue insert_single_output_op( - SubstitutionBuilder &b, - OutputOperatorAttrsAssignment const &expr, - std::unordered_map const &inputs) -{ + SubstitutionBuilder &b, + OutputOperatorAttrsAssignment const &expr, + std::unordered_map const &inputs) { return require_only_key( b.add_output_graph_node(expr, inputs, {TensorSlotName::OUTPUT}), - TensorSlotName::OUTPUT); + TensorSlotName::OUTPUT); } +static OutputGraphExprValue + insert_replicate_or_reduce(OperatorType op_type, + SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { -static OutputGraphExprValue insert_replicate_or_reduce(OperatorType op_type, - SubstitutionBuilder &b, - positive_int degree, - OutputGraphExprValue const &input) { - - ASSERT(op_type == OperatorType::REPLICATE || op_type == OperatorType::REDUCTION); + ASSERT(op_type == OperatorType::REPLICATE || + op_type == OperatorType::REDUCTION); - OutputOperatorAttrsAssignment replicate_expr = - OutputOperatorAttrsAssignment{ - std::nullopt, - { - set_op_type_attr(op_type), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), - }}; + OutputOperatorAttrsAssignment replicate_expr = OutputOperatorAttrsAssignment{ + std::nullopt, + { + set_op_type_attr(op_type), + set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{degree}), + }}; - return insert_single_output_op(b, replicate_expr, {{TensorSlotName::INPUT, input}}); + return insert_single_output_op( + b, replicate_expr, {{TensorSlotName::INPUT, input}}); } -static OutputGraphExprValue insert_replicate(SubstitutionBuilder &b, - positive_int degree, - OutputGraphExprValue const &input) { - return insert_replicate_or_reduce(OperatorType::REPLICATE, b, degree, input); +static OutputGraphExprValue + insert_replicate(SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + return insert_replicate_or_reduce(OperatorType::REPLICATE, b, degree, input); } - -static OutputGraphExprValue insert_reduce(SubstitutionBuilder &b, - positive_int degree, - OutputGraphExprValue const &input) { - return insert_replicate_or_reduce(OperatorType::REDUCTION, b, degree, input); +static OutputGraphExprValue insert_reduce(SubstitutionBuilder &b, + positive_int degree, + OutputGraphExprValue const &input) { + return insert_replicate_or_reduce(OperatorType::REDUCTION, b, degree, input); } -static OutputGraphExprValue insert_partition_or_combine( - OperatorType op_type, - SubstitutionBuilder &b, - positive_int degree, - ff_dim_t dim, - OutputGraphExprValue const &input) { +static OutputGraphExprValue + insert_partition_or_combine(OperatorType op_type, + SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { - ASSERT(op_type == OperatorType::REPARTITION || op_type == OperatorType::COMBINE); + ASSERT(op_type == OperatorType::REPARTITION || + op_type == OperatorType::COMBINE); OutputOperatorAttrsAssignment partition_input_expr = OutputOperatorAttrsAssignment{ @@ -144,30 +145,31 @@ static OutputGraphExprValue insert_partition_or_combine( OperatorAttributeValue{dim}), }}; - OutputGraphExprValue o_partition_output = - insert_single_output_op(b, partition_input_expr, {{TensorSlotName::INPUT, input}}); + OutputGraphExprValue o_partition_output = insert_single_output_op( + b, partition_input_expr, {{TensorSlotName::INPUT, input}}); return o_partition_output; } -static OutputGraphExprValue insert_partition(SubstitutionBuilder &b, - positive_int degree, - ff_dim_t dim, - OutputGraphExprValue const &input) { +static OutputGraphExprValue + insert_partition(SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { - return insert_partition_or_combine(OperatorType::REPARTITION, b, degree, dim, input); + return insert_partition_or_combine( + OperatorType::REPARTITION, b, degree, dim, input); } -static OutputGraphExprValue insert_combine(SubstitutionBuilder &b, - positive_int degree, - ff_dim_t dim, - OutputGraphExprValue const &input) { +static OutputGraphExprValue insert_combine(SubstitutionBuilder &b, + positive_int degree, + ff_dim_t dim, + OutputGraphExprValue const &input) { - return insert_partition_or_combine(OperatorType::COMBINE, b, degree, dim, input); + return insert_partition_or_combine( + OperatorType::COMBINE, b, degree, dim, input); } - - Substitution create_replicate_linear_combine(positive_int num_dims, positive_int degree, bool use_bias) { @@ -200,16 +202,17 @@ Substitution create_replicate_linear_combine(positive_int num_dims, std::string linear_name = "linear"; PatternValue p_linear_output = insert_single_output_pattern( - b, - linear_pattern, - p_inputs, - /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), - linear_name); + b, + linear_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + linear_name); OutputGraphExprValue o_replicate_input_output = - insert_replicate(b, degree, o_input); + insert_replicate(b, degree, o_input); - OutputGraphExprValue o_partition_weights_output = insert_partition(b, degree, ff_dim_t{1_n}, o_weight); + OutputGraphExprValue o_partition_weights_output = + insert_partition(b, degree, ff_dim_t{1_n}, o_weight); std::unordered_map o_linear_inputs = { { @@ -223,7 +226,8 @@ Substitution create_replicate_linear_combine(positive_int num_dims, }; if (use_bias) { - OutputGraphExprValue o_partition_bias_output = insert_partition(b, degree, ff_dim_t{1_n}, o_bias.value()); + OutputGraphExprValue o_partition_bias_output = + insert_partition(b, degree, ff_dim_t{1_n}, o_bias.value()); o_linear_inputs.insert({ TensorSlotName::BIAS, @@ -235,12 +239,14 @@ Substitution create_replicate_linear_combine(positive_int num_dims, b.pattern_node_named(linear_name), {}, }; - OutputGraphExprValue o_linear_output = insert_single_output_op(b, linear_expr, o_linear_inputs); + OutputGraphExprValue o_linear_output = + insert_single_output_op(b, linear_expr, o_linear_inputs); ff_dim_t combine_output_dim = ff_dim_t{ nonnegative_int{num_dims.int_from_positive_int() - 1}, }; - OutputGraphExprValue o_combine_output = insert_combine(b, degree, combine_output_dim, o_linear_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, combine_output_dim, o_linear_output); b.equate_outputs(p_linear_output, o_combine_output); @@ -255,14 +261,14 @@ Substitution create_partition_linear_combine(positive_int num_dims, auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); std::unordered_map p_inputs = { - { - TensorSlotName::INPUT, - p_input, - }, - { - TensorSlotName::WEIGHT, - p_weight, - }, + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, }; std::optional o_bias = std::nullopt; @@ -270,8 +276,8 @@ Substitution create_partition_linear_combine(positive_int num_dims, std::pair bias = b.add_input(tensor_attribute_pattern_match_all()); p_inputs.insert({ - TensorSlotName::BIAS, - bias.first, + TensorSlotName::BIAS, + bias.first, }); o_bias = bias.second; } @@ -285,33 +291,36 @@ Substitution create_partition_linear_combine(positive_int num_dims, std::string linear_name = "linear"; PatternValue p_linear_output = insert_single_output_pattern( - b, - linear_pattern, - p_inputs, - /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), - linear_name); + b, + linear_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + linear_name); - OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, ff_dim_t{0_n}, o_input); + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_input); - OutputGraphExprValue o_replicate_weights_output = insert_replicate(b, degree, o_weight); + OutputGraphExprValue o_replicate_weights_output = + insert_replicate(b, degree, o_weight); std::unordered_map o_linear_inputs = { - { - TensorSlotName::INPUT, - o_partition_input_output, - }, - { - TensorSlotName::WEIGHT, - o_replicate_weights_output, - }, + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + { + TensorSlotName::WEIGHT, + o_replicate_weights_output, + }, }; if (use_bias) { - OutputGraphExprValue o_replicate_bias_output = insert_replicate(b, degree, o_bias.value()); + OutputGraphExprValue o_replicate_bias_output = + insert_replicate(b, degree, o_bias.value()); o_linear_inputs.insert({ - TensorSlotName::BIAS, - o_replicate_bias_output, + TensorSlotName::BIAS, + o_replicate_bias_output, }); } @@ -319,12 +328,14 @@ Substitution create_partition_linear_combine(positive_int num_dims, b.pattern_node_named(linear_name), {}, }; - OutputGraphExprValue o_linear_output = insert_single_output_op(b, linear_expr, o_linear_inputs); + OutputGraphExprValue o_linear_output = + insert_single_output_op(b, linear_expr, o_linear_inputs); ff_dim_t combine_output_dim = ff_dim_t{ nonnegative_int{num_dims.int_from_positive_int() - 1}, }; - OutputGraphExprValue o_combine_output = insert_combine(b, degree, combine_output_dim, o_linear_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, combine_output_dim, o_linear_output); b.equate_outputs(p_linear_output, o_combine_output); @@ -341,14 +352,14 @@ Substitution create_partition_conv2d_combine(positive_int num_dims, auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all()); std::unordered_map p_inputs = { - { - TensorSlotName::INPUT, - p_input, - }, - { - TensorSlotName::FILTER, - p_weight, - }, + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::FILTER, + p_weight, + }, }; OperatorAttributePattern conv2d_pattern = OperatorAttributePattern{{ @@ -358,35 +369,35 @@ Substitution create_partition_conv2d_combine(positive_int num_dims, std::string conv2d_name = "conv2d"; PatternValue p_conv2d_output = insert_single_output_pattern( - b, - conv2d_pattern, - p_inputs, - /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), - conv2d_name); - + b, + conv2d_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(num_dims), + conv2d_name); - OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, ff_dim_t{0_n}, o_input); + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_input); - OutputGraphExprValue o_replicate_weights_output = insert_replicate(b, degree, o_weight); + OutputGraphExprValue o_replicate_weights_output = + insert_replicate(b, degree, o_weight); std::unordered_map o_conv2d_inputs = { - { - TensorSlotName::INPUT, - o_partition_input_output, - }, - { - TensorSlotName::FILTER, - o_replicate_weights_output - }, + { + TensorSlotName::INPUT, + o_partition_input_output, + }, + {TensorSlotName::FILTER, o_replicate_weights_output}, }; OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(conv2d_name), {}, }; - OutputGraphExprValue o_conv2d_output = insert_single_output_op(b, conv2d_expr, o_conv2d_inputs); + OutputGraphExprValue o_conv2d_output = + insert_single_output_op(b, conv2d_expr, o_conv2d_inputs); - OutputGraphExprValue o_combine_output = insert_combine(b, degree, ff_dim_t{0_n}, o_conv2d_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, ff_dim_t{0_n}, o_conv2d_output); b.equate_outputs(p_conv2d_output, o_combine_output); @@ -407,22 +418,22 @@ Substitution create_partition_attention_combine(positive_int num_heads, auto [p_weights, o_weights] = b.add_input(tensor_attribute_pattern_match_all()); std::unordered_map p_inputs = { - { - TensorSlotName::QUERY, - p_query_input, - }, - { - TensorSlotName::KEY, - p_key_input, - }, - { - TensorSlotName::VALUE, - p_value_input, - }, - { - TensorSlotName::WEIGHT, - p_weights, - }, + { + TensorSlotName::QUERY, + p_query_input, + }, + { + TensorSlotName::KEY, + p_key_input, + }, + { + TensorSlotName::VALUE, + p_value_input, + }, + { + TensorSlotName::WEIGHT, + p_weights, + }, }; OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ @@ -433,49 +444,53 @@ Substitution create_partition_attention_combine(positive_int num_heads, std::string attention_name = "attention"; PatternValue p_attention_output = insert_single_output_pattern( - b, - attention_pattern, - p_inputs, - /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), + b, + attention_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), attention_name); - OutputGraphExprValue o_partition_query_input_output = - insert_partition(b, degree, ff_dim_t{0_n}, o_query_input); - - OutputGraphExprValue o_partition_key_input_output = - insert_partition(b, degree, ff_dim_t{0_n}, o_key_input); - - OutputGraphExprValue o_partition_value_input_output = - insert_partition(b, degree, ff_dim_t{0_n}, o_value_input); - - OutputGraphExprValue o_replicate_weight_output = insert_replicate(b, degree, o_weights); - - std::unordered_map o_attention_inputs = { - { - TensorSlotName::QUERY, - o_partition_query_input_output, - }, - { - TensorSlotName::KEY, - o_partition_key_input_output, - }, - { - TensorSlotName::VALUE, - o_partition_value_input_output, - }, - { - TensorSlotName::WEIGHT, - o_replicate_weight_output, - }, - }; + OutputGraphExprValue o_partition_query_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_query_input); + + OutputGraphExprValue o_partition_key_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_key_input); + + OutputGraphExprValue o_partition_value_input_output = + insert_partition(b, degree, ff_dim_t{0_n}, o_value_input); + + OutputGraphExprValue o_replicate_weight_output = + insert_replicate(b, degree, o_weights); + + std::unordered_map o_attention_inputs = + { + { + TensorSlotName::QUERY, + o_partition_query_input_output, + }, + { + TensorSlotName::KEY, + o_partition_key_input_output, + }, + { + TensorSlotName::VALUE, + o_partition_value_input_output, + }, + { + TensorSlotName::WEIGHT, + o_replicate_weight_output, + }, + }; OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(attention_name), {}, }; - OutputGraphExprValue o_attention_output = insert_single_output_op(b, attention_expr, o_attention_inputs); + OutputGraphExprValue o_attention_output = + insert_single_output_op(b, attention_expr, o_attention_inputs); - OutputGraphExprValue o_combine_output = insert_combine(b, degree, ff_dim_t{0_n}, o_attention_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, ff_dim_t{0_n}, o_attention_output); b.equate_outputs(p_attention_output, o_combine_output); @@ -497,22 +512,22 @@ Substitution create_replicate_attention_reduce(positive_int num_heads, b.add_input(tensor_attribute_pattern_match_all()); std::unordered_map p_inputs = { - { - TensorSlotName::QUERY, - p_query_input, - }, - { - TensorSlotName::KEY, - p_key_input, - }, - { - TensorSlotName::VALUE, - p_value_input, - }, - { - TensorSlotName::WEIGHT, - p_weights, - }, + { + TensorSlotName::QUERY, + p_query_input, + }, + { + TensorSlotName::KEY, + p_key_input, + }, + { + TensorSlotName::VALUE, + p_value_input, + }, + { + TensorSlotName::WEIGHT, + p_weights, + }, }; OperatorAttributePattern attention_pattern = OperatorAttributePattern{{ @@ -524,48 +539,52 @@ Substitution create_replicate_attention_reduce(positive_int num_heads, std::string attention_name = "attention"; PatternValue p_attention_output = insert_single_output_pattern( b, - attention_pattern, - p_inputs, - /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), + attention_pattern, + p_inputs, + /*output_pattern=*/tensor_attr_pattern_require_num_dims(3_p), attention_name); - OutputGraphExprValue o_replicate_query_input_output = - insert_replicate(b, degree, o_query_input); - - OutputGraphExprValue o_replicate_key_input_output = - insert_replicate(b, degree, o_key_input); - - OutputGraphExprValue o_replicate_value_input_output = - insert_replicate(b, degree, o_value_input); - - OutputGraphExprValue o_partition_weight_output = insert_partition(b, degree, ff_dim_t{1_n}, o_weights); - - std::unordered_map o_attention_inputs = { - { - TensorSlotName::QUERY, - o_replicate_query_input_output, - }, - { - TensorSlotName::KEY, - o_replicate_key_input_output, - }, - { - TensorSlotName::VALUE, - o_replicate_value_input_output, - }, - { - TensorSlotName::WEIGHT, - o_partition_weight_output, - }, - }; + OutputGraphExprValue o_replicate_query_input_output = + insert_replicate(b, degree, o_query_input); + + OutputGraphExprValue o_replicate_key_input_output = + insert_replicate(b, degree, o_key_input); + + OutputGraphExprValue o_replicate_value_input_output = + insert_replicate(b, degree, o_value_input); + + OutputGraphExprValue o_partition_weight_output = + insert_partition(b, degree, ff_dim_t{1_n}, o_weights); + + std::unordered_map o_attention_inputs = + { + { + TensorSlotName::QUERY, + o_replicate_query_input_output, + }, + { + TensorSlotName::KEY, + o_replicate_key_input_output, + }, + { + TensorSlotName::VALUE, + o_replicate_value_input_output, + }, + { + TensorSlotName::WEIGHT, + o_partition_weight_output, + }, + }; OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(attention_name), {}, }; - OutputGraphExprValue o_attention_output = insert_single_output_op(b, attention_expr, o_attention_inputs); + OutputGraphExprValue o_attention_output = + insert_single_output_op(b, attention_expr, o_attention_inputs); - OutputGraphExprValue o_reduce_output = insert_reduce(b, degree, o_attention_output); + OutputGraphExprValue o_reduce_output = + insert_reduce(b, degree, o_attention_output); b.equate_outputs(p_attention_output, o_reduce_output); @@ -581,10 +600,10 @@ Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all()); std::unordered_map p_inputs = { - { - TensorSlotName::INPUT, - p_input, - }, + { + TensorSlotName::INPUT, + p_input, + }, }; OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{ @@ -601,23 +620,26 @@ Substitution create_partition_softmax_combine(ff_dim_t softmax_dim, p_inputs, /*output_pattern=*/tensor_attribute_pattern_match_all(), softmax_name); - - OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, partition_dim, o_input); + + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, partition_dim, o_input); std::unordered_map o_softmax_inputs = { - { - TensorSlotName::INPUT, - o_partition_input_output, - }, + { + TensorSlotName::INPUT, + o_partition_input_output, + }, }; OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(softmax_name), {}, }; - OutputGraphExprValue o_softmax_output = insert_single_output_op(b, softmax_expr, o_softmax_inputs); + OutputGraphExprValue o_softmax_output = + insert_single_output_op(b, softmax_expr, o_softmax_inputs); - OutputGraphExprValue o_combine_output = insert_combine(b, degree, partition_dim, o_softmax_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, partition_dim, o_softmax_output); b.equate_outputs(p_softmax_output, o_combine_output); @@ -632,14 +654,14 @@ Substitution create_partition_add_combine(ff_dim_t parallel_dim, auto [p_input2, o_input2] = b.add_input(tensor_attribute_pattern_match_all()); std::unordered_map p_inputs = { - { - TensorSlotName::LHS_INPUT, - p_input1, - }, - { - TensorSlotName::RHS_INPUT, - p_input2, - }, + { + TensorSlotName::LHS_INPUT, + p_input1, + }, + { + TensorSlotName::RHS_INPUT, + p_input2, + }, }; OperatorAttributePattern add_pattern = OperatorAttributePattern{{ @@ -649,33 +671,37 @@ Substitution create_partition_add_combine(ff_dim_t parallel_dim, std::string add_name = "add"; PatternValue p_add_output = insert_single_output_pattern( - b, - add_pattern, - p_inputs, - /*output_pattern=*/tensor_attribute_pattern_match_all(), - add_name); + b, + add_pattern, + p_inputs, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + add_name); - OutputGraphExprValue o_partition_input1_output = insert_partition(b, degree, parallel_dim, o_input1); - OutputGraphExprValue o_partition_input2_output = insert_partition(b, degree, parallel_dim, o_input2); + OutputGraphExprValue o_partition_input1_output = + insert_partition(b, degree, parallel_dim, o_input1); + OutputGraphExprValue o_partition_input2_output = + insert_partition(b, degree, parallel_dim, o_input2); std::unordered_map o_add_inputs = { - { - TensorSlotName::LHS_INPUT, - o_partition_input1_output, - }, - { - TensorSlotName::RHS_INPUT, - o_partition_input2_output, - }, + { + TensorSlotName::LHS_INPUT, + o_partition_input1_output, + }, + { + TensorSlotName::RHS_INPUT, + o_partition_input2_output, + }, }; OutputOperatorAttrsAssignment add_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(add_name), {}, }; - OutputGraphExprValue o_add_output = insert_single_output_op(b, add_expr, o_add_inputs); + OutputGraphExprValue o_add_output = + insert_single_output_op(b, add_expr, o_add_inputs); - OutputGraphExprValue o_combine_output = insert_combine(b, degree, parallel_dim, o_add_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, parallel_dim, o_add_output); b.equate_outputs(p_add_output, o_combine_output); @@ -695,22 +721,24 @@ Substitution create_partition_relu_combine(ff_dim_t parallel_dim, std::string relu_name = "relu"; PatternValue p_relu_output = insert_single_output_pattern( - b, - relu_pattern, - {{TensorSlotName::INPUT, p_input}}, - /*output_pattern=*/tensor_attribute_pattern_match_all(), - relu_name); + b, + relu_pattern, + {{TensorSlotName::INPUT, p_input}}, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + relu_name); - OutputGraphExprValue o_partition_input_output = insert_partition(b, degree, parallel_dim, o_input); + OutputGraphExprValue o_partition_input_output = + insert_partition(b, degree, parallel_dim, o_input); OutputOperatorAttrsAssignment relu_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(relu_name), {}, }; - OutputGraphExprValue o_relu_output - = insert_single_output_op(b, relu_expr, {{TensorSlotName::INPUT, o_partition_input_output}}); + OutputGraphExprValue o_relu_output = insert_single_output_op( + b, relu_expr, {{TensorSlotName::INPUT, o_partition_input_output}}); - OutputGraphExprValue o_combine_output = insert_combine(b, degree, parallel_dim, o_relu_output); + OutputGraphExprValue o_combine_output = + insert_combine(b, degree, parallel_dim, o_relu_output); b.equate_outputs(p_relu_output, o_combine_output); @@ -734,20 +762,21 @@ Substitution create_fuse_linear_activation(Activation activation) { std::string mm_name = "mm"; PatternValue p_mm_output = insert_single_output_pattern( - b, - mm_pattern, - /*inputs=*/{ - { - TensorSlotName::INPUT, - p_input, - }, - { - TensorSlotName::WEIGHT, - p_weight, - }, - }, - /*output_pattern=*/tensor_attribute_pattern_match_all(), - mm_name); + b, + mm_pattern, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + p_input, + }, + { + TensorSlotName::WEIGHT, + p_weight, + }, + }, + /*output_pattern=*/tensor_attribute_pattern_match_all(), + mm_name); OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ op_type_equals_constraint(OperatorType::RELU), @@ -755,16 +784,17 @@ Substitution create_fuse_linear_activation(Activation activation) { std::string relu_name = "relu"; PatternValue p_relu_output = insert_single_output_pattern( - b, - relu_pattern, - /*inputs=*/{ + b, + relu_pattern, + /*inputs=*/ { - TensorSlotName::INPUT, - p_mm_output, + { + TensorSlotName::INPUT, + p_mm_output, + }, }, - }, - /*output_pattern=*/tensor_attribute_pattern_match_all(), - relu_name); + /*output_pattern=*/tensor_attribute_pattern_match_all(), + relu_name); OutputOperatorAttrsAssignment fused_node_expr = OutputOperatorAttrsAssignment{ b.pattern_node_named(mm_name), @@ -774,19 +804,19 @@ Substitution create_fuse_linear_activation(Activation activation) { }}; OutputGraphExprValue o_fused_node_output = - insert_single_output_op( - b, - fused_node_expr, - /*inputs=*/{ - { - TensorSlotName::INPUT, - o_input, - }, - { - TensorSlotName::WEIGHT, - o_weight, - }, - }); + insert_single_output_op(b, + fused_node_expr, + /*inputs=*/ + { + { + TensorSlotName::INPUT, + o_input, + }, + { + TensorSlotName::WEIGHT, + o_weight, + }, + }); b.equate_outputs(p_relu_output, o_fused_node_output); diff --git a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc index 022b69b850..df7f28538e 100644 --- a/lib/substitutions/test/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/test/src/substitutions/unity_substitution_set.cc @@ -33,7 +33,8 @@ static ParallelLayerAttrs make_layer_attrs( }; }; -parallel_tensor_guid_t get_single_output(ParallelLayerAddedResult const &added) { +parallel_tensor_guid_t + get_single_output(ParallelLayerAddedResult const &added) { return require_only_key(added.outputs, TensorSlotName::OUTPUT); } @@ -43,21 +44,20 @@ parallel_tensor_guid_t add_single_output_layer( std::unordered_map const &inputs, std::unordered_map const &weights, std::optional> const - &outputs = std::nullopt) { - - return get_single_output(add_parallel_layer(pcg, layer_attrs, inputs, weights, outputs)); + &outputs = std::nullopt) { + + return get_single_output( + add_parallel_layer(pcg, layer_attrs, inputs, weights, outputs)); } -parallel_tensor_guid_t add_input_layer( - ParallelComputationGraph &pcg, - TensorShape const &tensor_shape) { +parallel_tensor_guid_t add_input_layer(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { return get_single_output(pcg_add_input_layer(pcg, tensor_shape)); } -parallel_tensor_guid_t add_weight_layer( - ParallelComputationGraph &pcg, - TensorShape const &tensor_shape) { +parallel_tensor_guid_t add_weight_layer(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { WeightAttrs weight_attrs = WeightAttrs{ /*tensor_shape=*/tensor_shape, @@ -67,42 +67,41 @@ parallel_tensor_guid_t add_weight_layer( return add_single_output_layer(pcg, make_layer_attrs(weight_attrs), {}, {}); } -parallel_tensor_guid_t add_replicate_layer( - ParallelComputationGraph &pcg, - positive_int degree, - parallel_tensor_guid_t const &t_input) { +parallel_tensor_guid_t + add_replicate_layer(ParallelComputationGraph &pcg, + positive_int degree, + parallel_tensor_guid_t const &t_input) { ReplicateAttrs replicate_attrs = ReplicateAttrs{ /*replicate_degree=*/degree, }; return add_single_output_layer(pcg, - make_layer_attrs(replicate_attrs), - {{TensorSlotName::INPUT, t_input}}, - {}); + make_layer_attrs(replicate_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); } -parallel_tensor_guid_t add_reduction_layer( - ParallelComputationGraph &pcg, - positive_int degree, - parallel_tensor_guid_t const &t_input) { +parallel_tensor_guid_t + add_reduction_layer(ParallelComputationGraph &pcg, + positive_int degree, + parallel_tensor_guid_t const &t_input) { ReductionAttrs reduction_attrs = ReductionAttrs{ /*reduction_degree=*/degree, }; return add_single_output_layer(pcg, - make_layer_attrs(reduction_attrs), - {{TensorSlotName::INPUT, t_input}}, - {}); + make_layer_attrs(reduction_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); } - - -parallel_tensor_guid_t add_partition_layer( - ParallelComputationGraph &pcg, - ff_dim_t dim, - positive_int degree, - parallel_tensor_guid_t const &t_input) { + +parallel_tensor_guid_t + add_partition_layer(ParallelComputationGraph &pcg, + ff_dim_t dim, + positive_int degree, + parallel_tensor_guid_t const &t_input) { RepartitionAttrs partition_attrs = RepartitionAttrs{ /*repartition_dim=*/dim, @@ -110,16 +109,16 @@ parallel_tensor_guid_t add_partition_layer( }; return add_single_output_layer(pcg, - make_layer_attrs(partition_attrs), - {{TensorSlotName::INPUT, t_input}}, - {}); + make_layer_attrs(partition_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); } - -parallel_tensor_guid_t add_combine_layer( - ParallelComputationGraph &pcg, - ff_dim_t dim, - positive_int degree, - parallel_tensor_guid_t const &t_input) { + +parallel_tensor_guid_t + add_combine_layer(ParallelComputationGraph &pcg, + ff_dim_t dim, + positive_int degree, + parallel_tensor_guid_t const &t_input) { CombineAttrs partition_attrs = CombineAttrs{ /*combine_dim=*/dim, @@ -127,13 +126,13 @@ parallel_tensor_guid_t add_combine_layer( }; return add_single_output_layer(pcg, - make_layer_attrs(partition_attrs), - {{TensorSlotName::INPUT, t_input}}, - {}); + make_layer_attrs(partition_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); } - + parallel_tensor_guid_t add_linear_layer( - ParallelComputationGraph &pcg, + ParallelComputationGraph &pcg, LinearAttrs const &linear_attrs, parallel_tensor_guid_t const &t_input, parallel_tensor_guid_t const &t_weight, @@ -143,7 +142,7 @@ parallel_tensor_guid_t add_linear_layer( ASSERT(t_bias.has_value() == linear_attrs.use_bias); std::unordered_map weights = { - {TensorSlotName::WEIGHT, t_weight}, + {TensorSlotName::WEIGHT, t_weight}, }; if (t_bias.has_value()) { @@ -151,34 +150,32 @@ parallel_tensor_guid_t add_linear_layer( } return add_single_output_layer(pcg, - make_layer_attrs(linear_attrs, name), - {{TensorSlotName::INPUT, t_input}}, - weights); + make_layer_attrs(linear_attrs, name), + {{TensorSlotName::INPUT, t_input}}, + weights); } -parallel_tensor_guid_t add_attention_layer( - ParallelComputationGraph &pcg, - MultiHeadAttentionAttrs const &attn_attrs, - parallel_tensor_guid_t const &t_query, - parallel_tensor_guid_t const &t_key, - parallel_tensor_guid_t const &t_value, - parallel_tensor_guid_t const &t_weights, - std::optional const &name = std::nullopt) { +parallel_tensor_guid_t + add_attention_layer(ParallelComputationGraph &pcg, + MultiHeadAttentionAttrs const &attn_attrs, + parallel_tensor_guid_t const &t_query, + parallel_tensor_guid_t const &t_key, + parallel_tensor_guid_t const &t_value, + parallel_tensor_guid_t const &t_weights, + std::optional const &name = std::nullopt) { return add_single_output_layer(pcg, - make_layer_attrs(attn_attrs, name), - { - {TensorSlotName::QUERY, t_query}, - {TensorSlotName::KEY, t_key}, - {TensorSlotName::VALUE, t_value}, - }, - {{TensorSlotName::WEIGHT, t_weights}}); + make_layer_attrs(attn_attrs, name), + { + {TensorSlotName::QUERY, t_query}, + {TensorSlotName::KEY, t_key}, + {TensorSlotName::VALUE, t_value}, + }, + {{TensorSlotName::WEIGHT, t_weights}}); } - - parallel_tensor_guid_t add_conv2d_layer( - ParallelComputationGraph &pcg, + ParallelComputationGraph &pcg, Conv2DAttrs const &conv2d_attrs, parallel_tensor_guid_t const &t_input, parallel_tensor_guid_t const &t_filter, @@ -188,7 +185,7 @@ parallel_tensor_guid_t add_conv2d_layer( ASSERT(bias.has_value() == conv2d_attrs.use_bias); std::unordered_map weights = { - {TensorSlotName::FILTER, t_filter}, + {TensorSlotName::FILTER, t_filter}, }; if (bias.has_value()) { @@ -196,12 +193,10 @@ parallel_tensor_guid_t add_conv2d_layer( } return add_single_output_layer(pcg, - make_layer_attrs(conv2d_attrs, name), - {{TensorSlotName::INPUT, t_input}}, - weights); + make_layer_attrs(conv2d_attrs, name), + {{TensorSlotName::INPUT, t_input}}, + weights); } - - TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_substitution_set") { @@ -245,26 +240,31 @@ TEST_SUITE(FF_TEST_SUITE) { /*replicate_degree=*/degree, }; - TensorShape projection_weight_shape = throw_if_unexpected( - get_projection_shape(linear_attrs, input_shape)); + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); RepartitionAttrs partition_projection_attrs = RepartitionAttrs{ /*repartition_dim=*/ff_dim_t{1_n}, /*repartition_degree=*/degree, }; - ff_dim_t combine_dim = ff_dim_t{ - nonnegative_int{num_dims.int_from_positive_int() - 1}}; + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_input = - add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); - parallel_tensor_guid_t t_linear = add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, /*bias=*/std::nullopt, linear_match); + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, + linear_attrs, + t_input, + t_projection_weight, + /*bias=*/std::nullopt, + linear_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -275,7 +275,8 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_input_activations = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -299,16 +300,23 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_replicated_input = - add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); + parallel_tensor_guid_t t_replicated_input = + add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); - parallel_tensor_guid_t t_partitioned_projection_weight = - add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, projection_weight_shape)); + parallel_tensor_guid_t t_partitioned_projection_weight = + add_partition_layer(pcg, + ff_dim_t{1_n}, + degree, + add_weight_layer(pcg, projection_weight_shape)); parallel_tensor_guid_t t_replicated_linear = - add_linear_layer(pcg, linear_attrs, t_replicated_input, t_partitioned_projection_weight); + add_linear_layer(pcg, + linear_attrs, + t_replicated_input, + t_partitioned_projection_weight); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, combine_dim, degree, t_replicated_input); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_replicated_input); return sub_pcg_from_full_pcg(pcg); }(); @@ -341,26 +349,27 @@ TEST_SUITE(FF_TEST_SUITE) { /*regularizer=*/std::nullopt, }; - TensorShape projection_weight_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - TensorShape bias_shape = throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); + TensorShape bias_shape = + throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); - ff_dim_t combine_dim = ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_input = - add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); parallel_tensor_guid_t t_projection_weight = - add_weight_layer(pcg, projection_weight_shape); + add_weight_layer(pcg, projection_weight_shape); - parallel_tensor_guid_t t_bias = - add_weight_layer(pcg, bias_shape); + parallel_tensor_guid_t t_bias = add_weight_layer(pcg, bias_shape); - parallel_tensor_guid_t t_linear = - add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, t_bias); + parallel_tensor_guid_t t_linear = add_linear_layer( + pcg, linear_attrs, t_input, t_projection_weight, t_bias); return sub_pcg_from_full_pcg(pcg); }(); @@ -371,9 +380,11 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_input_activations = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); open_parallel_tensor_guid_t match_layer_input_bias = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::OUTPUT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::OUTPUT); return PCGPatternMatch{ bidict{ @@ -402,19 +413,26 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg = empty_parallel_computation_graph(); parallel_tensor_guid_t t_replicated_input = - add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); + add_replicate_layer(pcg, degree, add_input_layer(pcg, input_shape)); parallel_tensor_guid_t t_partitioned_projection_weight = - add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, projection_weight_shape)); + add_partition_layer(pcg, + ff_dim_t{1_n}, + degree, + add_weight_layer(pcg, projection_weight_shape)); - parallel_tensor_guid_t t_partitioned_bias = - add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, bias_shape)); + parallel_tensor_guid_t t_partitioned_bias = add_partition_layer( + pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, bias_shape)); - parallel_tensor_guid_t t_replicated_linear = - add_linear_layer(pcg, linear_attrs, t_replicated_linear, t_partitioned_projection_weight, t_partitioned_bias); + parallel_tensor_guid_t t_replicated_linear = + add_linear_layer(pcg, + linear_attrs, + t_replicated_linear, + t_partitioned_projection_weight, + t_partitioned_bias); - parallel_tensor_guid_t t_combine = - add_combine_layer(pcg, combine_dim, degree, t_replicated_linear); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_replicated_linear); return sub_pcg_from_full_pcg(pcg); }(); @@ -447,21 +465,26 @@ TEST_SUITE(FF_TEST_SUITE) { /*regularizer=*/std::nullopt, }; - TensorShape projection_weight_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - ff_dim_t combine_dim = ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_input = - add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_projection_weight = + parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); - parallel_tensor_guid_t t_linear = - add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, /*bias=*/std::nullopt, linear_match); + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, + linear_attrs, + t_input, + t_projection_weight, + /*bias=*/std::nullopt, + linear_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -472,7 +495,8 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_input_activations = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -496,18 +520,23 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_input = - add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_partitioned_input = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); + parallel_tensor_guid_t t_partitioned_input = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); - parallel_tensor_guid_t t_replicated_projection_weight = - add_replicate_layer(pcg, degree, add_weight_layer(pcg, projection_weight_shape)); + parallel_tensor_guid_t t_replicated_projection_weight = + add_replicate_layer( + pcg, degree, add_weight_layer(pcg, projection_weight_shape)); parallel_tensor_guid_t t_partitioned_linear = - add_linear_layer(pcg, linear_attrs, t_partitioned_input, t_replicated_projection_weight); + add_linear_layer(pcg, + linear_attrs, + t_partitioned_input, + t_replicated_projection_weight); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, combine_dim, degree, t_partitioned_input); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_partitioned_input); return sub_pcg_from_full_pcg(pcg); }(); @@ -540,20 +569,29 @@ TEST_SUITE(FF_TEST_SUITE) { /*regularizer=*/std::nullopt, }; - TensorShape projection_weight_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + TensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - TensorShape bias_shape = throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); + TensorShape bias_shape = + throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); - ff_dim_t combine_dim = ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; + ff_dim_t combine_dim = + ff_dim_t{nonnegative_int{num_dims.int_from_positive_int() - 1}}; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); parallel_tensor_guid_t t_bias = add_weight_layer(pcg, bias_shape); - parallel_tensor_guid_t t_linear = add_linear_layer(pcg, linear_attrs, t_input, t_projection_weight, t_bias, linear_match); + parallel_tensor_guid_t t_linear = add_linear_layer(pcg, + linear_attrs, + t_input, + t_projection_weight, + t_bias, + linear_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -565,7 +603,8 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_input_activations = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); open_parallel_tensor_guid_t match_layer_input_bias = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::BIAS); @@ -589,21 +628,31 @@ TEST_SUITE(FF_TEST_SUITE) { }; }(); - SubParallelComputationGraph result = apply_substitution(original_pcg, sub, match); + SubParallelComputationGraph result = + apply_substitution(original_pcg, sub, match); SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_partitioned_input = - add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); + parallel_tensor_guid_t t_partitioned_input = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, input_shape)); - parallel_tensor_guid_t t_replicated_projection_weight = add_replicate_layer(pcg, degree, add_weight_layer(pcg, projection_weight_shape)); + parallel_tensor_guid_t t_replicated_projection_weight = + add_replicate_layer( + pcg, degree, add_weight_layer(pcg, projection_weight_shape)); - parallel_tensor_guid_t t_replicated_bias = add_replicate_layer(pcg, degree, add_weight_layer(pcg, bias_shape)); + parallel_tensor_guid_t t_replicated_bias = + add_replicate_layer(pcg, degree, add_weight_layer(pcg, bias_shape)); - parallel_tensor_guid_t t_partitioned_linear = add_linear_layer(pcg, linear_attrs, t_partitioned_input, t_replicated_projection_weight, t_replicated_bias); + parallel_tensor_guid_t t_partitioned_linear = + add_linear_layer(pcg, + linear_attrs, + t_partitioned_input, + t_replicated_projection_weight, + t_replicated_bias); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, combine_dim, degree, t_partitioned_linear); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, combine_dim, degree, t_partitioned_linear); return sub_pcg_from_full_pcg(pcg); }(); @@ -638,32 +687,39 @@ TEST_SUITE(FF_TEST_SUITE) { }; Conv2DAttrs conv2d_attrs = Conv2DAttrs{ - /*outChannels=*/outChannels, - /*kernelH=*/kernelH, - /*kernelW=*/kernelW, - /*strideH=*/strideH, - /*strideW=*/strideW, - /*paddingH=*/paddingH, - /*paddingW=*/paddingW, - /*groups=*/1_p, - /*activation=*/std::nullopt, - /*use_bias=*/false, + /*outChannels=*/outChannels, + /*kernelH=*/kernelH, + /*kernelW=*/kernelW, + /*strideH=*/strideH, + /*strideW=*/strideW, + /*paddingH=*/paddingH, + /*paddingW=*/paddingW, + /*groups=*/1_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, }; SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_input = - add_input_layer(pcg, input_shape); + parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); TensorShape casted_input_shape = get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); - TensorShape projection_weight_shape = get_weight_shapes(conv2d_attrs, casted_input_shape).at(TensorSlotName::FILTER); + TensorShape projection_weight_shape = + get_weight_shapes(conv2d_attrs, casted_input_shape) + .at(TensorSlotName::FILTER); - parallel_tensor_guid_t t_projection_weight = add_weight_layer(pcg, projection_weight_shape); + parallel_tensor_guid_t t_projection_weight = + add_weight_layer(pcg, projection_weight_shape); - parallel_tensor_guid_t t_conv = add_conv2d_layer(pcg, conv2d_attrs, t_input, t_projection_weight, /*bias=*/std::nullopt, conv2d_match); + parallel_tensor_guid_t t_conv = add_conv2d_layer(pcg, + conv2d_attrs, + t_input, + t_projection_weight, + /*bias=*/std::nullopt, + conv2d_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -674,7 +730,8 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_input_activations = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::INPUT); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::FILTER); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::FILTER); return PCGPatternMatch{ bidict{ @@ -699,18 +756,24 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg = empty_parallel_computation_graph(); parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_partitioned_input = add_partition_layer(pcg, ff_dim_t{0_n}, degree, t_input); + parallel_tensor_guid_t t_partitioned_input = + add_partition_layer(pcg, ff_dim_t{0_n}, degree, t_input); TensorShape casted_input_shape = get_reduced_shape(get_parallel_tensor_shape(pcg, t_input)); - TensorShape weight_shape = get_weight_shapes(conv2d_attrs, casted_input_shape).at(TensorSlotName::FILTER); + TensorShape weight_shape = + get_weight_shapes(conv2d_attrs, casted_input_shape) + .at(TensorSlotName::FILTER); - parallel_tensor_guid_t t_replicated_weight = add_replicate_layer(pcg, degree, add_weight_layer(pcg, weight_shape)); + parallel_tensor_guid_t t_replicated_weight = + add_replicate_layer(pcg, degree, add_weight_layer(pcg, weight_shape)); - parallel_tensor_guid_t t_partitioned_conv2d = add_conv2d_layer(pcg, conv2d_attrs, t_partitioned_input, t_replicated_weight); + parallel_tensor_guid_t t_partitioned_conv2d = add_conv2d_layer( + pcg, conv2d_attrs, t_partitioned_input, t_replicated_weight); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, ff_dim_t{0_n}, degree, t_partitioned_conv2d); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, ff_dim_t{0_n}, degree, t_partitioned_conv2d); return sub_pcg_from_full_pcg(pcg); }(); @@ -750,7 +813,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*add_zero_attn=*/false, }; - TensorShape weights_shape = throw_if_unexpected(get_weights_shape(attention_attrs, query_shape, key_shape, value_shape)); + TensorShape weights_shape = throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)); SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -761,7 +825,13 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_weights = add_weight_layer(pcg, weights_shape); - parallel_tensor_guid_t t_attention = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weights, attention_match); + parallel_tensor_guid_t t_attention = add_attention_layer(pcg, + attention_attrs, + t_query, + t_key, + t_value, + t_weights, + attention_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -776,7 +846,8 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_value = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::VALUE); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -808,17 +879,21 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_query = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, query_shape)); - parallel_tensor_guid_t t_key = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, key_shape)); - parallel_tensor_guid_t t_value = add_partition_layer(pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, value_shape)); + parallel_tensor_guid_t t_query = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, query_shape)); + parallel_tensor_guid_t t_key = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, key_shape)); + parallel_tensor_guid_t t_value = add_partition_layer( + pcg, ff_dim_t{0_n}, degree, add_input_layer(pcg, value_shape)); - parallel_tensor_guid_t t_weight = add_replicate_layer(pcg, degree, add_weight_layer(pcg, weights_shape)); + parallel_tensor_guid_t t_weight = add_replicate_layer( + pcg, degree, add_weight_layer(pcg, weights_shape)); - + parallel_tensor_guid_t t_partitioned_attention = add_attention_layer( + pcg, attention_attrs, t_query, t_key, t_value, t_weight); - parallel_tensor_guid_t t_partitioned_attention = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weight); - - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, ff_dim_t{0_n}, degree, t_partitioned_attention); + parallel_tensor_guid_t t_combine = add_combine_layer( + pcg, ff_dim_t{0_n}, degree, t_partitioned_attention); return sub_pcg_from_full_pcg(pcg); }(); @@ -858,7 +933,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*add_zero_attn=*/false, }; - TensorShape weight_shape = throw_if_unexpected(get_weights_shape(attention_attrs, query_shape, key_shape, value_shape)); + TensorShape weight_shape = throw_if_unexpected(get_weights_shape( + attention_attrs, query_shape, key_shape, value_shape)); SubParallelComputationGraph original_pcg = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -869,7 +945,14 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_weight = add_weight_layer(pcg, weight_shape); - parallel_tensor_guid_t attention_added = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weight, attention_match); + parallel_tensor_guid_t attention_added = + add_attention_layer(pcg, + attention_attrs, + t_query, + t_key, + t_value, + t_weight, + attention_match); return sub_pcg_from_full_pcg(pcg); }(); @@ -884,7 +967,8 @@ TEST_SUITE(FF_TEST_SUITE) { open_parallel_tensor_guid_t match_layer_value = get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::VALUE); open_parallel_tensor_guid_t match_layer_input_weights = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::WEIGHT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::WEIGHT); return PCGPatternMatch{ bidict{ @@ -916,15 +1000,21 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_query = add_replicate_layer(pcg, degree, add_input_layer(pcg, query_shape)); - parallel_tensor_guid_t t_key = add_replicate_layer(pcg, degree, add_input_layer(pcg, key_shape)); - parallel_tensor_guid_t t_value = add_replicate_layer(pcg, degree, add_input_layer(pcg, value_shape)); + parallel_tensor_guid_t t_query = + add_replicate_layer(pcg, degree, add_input_layer(pcg, query_shape)); + parallel_tensor_guid_t t_key = + add_replicate_layer(pcg, degree, add_input_layer(pcg, key_shape)); + parallel_tensor_guid_t t_value = + add_replicate_layer(pcg, degree, add_input_layer(pcg, value_shape)); - parallel_tensor_guid_t t_weight = add_partition_layer(pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, weight_shape)); + parallel_tensor_guid_t t_weight = add_partition_layer( + pcg, ff_dim_t{1_n}, degree, add_weight_layer(pcg, weight_shape)); - parallel_tensor_guid_t t_replicated_attention = add_attention_layer(pcg, attention_attrs, t_query, t_key, t_value, t_weight); + parallel_tensor_guid_t t_replicated_attention = add_attention_layer( + pcg, attention_attrs, t_query, t_key, t_value, t_weight); - parallel_tensor_guid_t t_reduction = add_reduction_layer(pcg, degree, t_replicated_attention); + parallel_tensor_guid_t t_reduction = + add_reduction_layer(pcg, degree, t_replicated_attention); return sub_pcg_from_full_pcg(pcg); }(); @@ -961,7 +1051,10 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); parallel_tensor_guid_t t_softmax = add_single_output_layer( - pcg, make_layer_attrs(softmax_attrs, softmax_match), {{TensorSlotName::INPUT, t_input}}, {}); + pcg, + make_layer_attrs(softmax_attrs, softmax_match), + {{TensorSlotName::INPUT, t_input}}, + {}); return sub_pcg_from_full_pcg(pcg); }(); @@ -989,12 +1082,17 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_partitioned_input = add_partition_layer(pcg, partition_dim, degree, add_input_layer(pcg, input_shape)); + parallel_tensor_guid_t t_partitioned_input = add_partition_layer( + pcg, partition_dim, degree, add_input_layer(pcg, input_shape)); parallel_tensor_guid_t t_partitioned_softmax = add_single_output_layer( - pcg, make_layer_attrs(softmax_attrs), {{TensorSlotName::INPUT, t_partitioned_input}}, {}); + pcg, + make_layer_attrs(softmax_attrs), + {{TensorSlotName::INPUT, t_partitioned_input}}, + {}); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, partition_dim, degree, t_partitioned_softmax); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, partition_dim, degree, t_partitioned_softmax); return sub_pcg_from_full_pcg(pcg); }(); @@ -1034,8 +1132,14 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_lhs = add_input_layer(pcg, lhs_shape); parallel_tensor_guid_t t_rhs = add_input_layer(pcg, rhs_shape); - parallel_tensor_guid_t t_add = add_single_output_layer( - pcg, make_layer_attrs(add_attrs, add_match), {{TensorSlotName::LHS_INPUT, t_lhs}, {TensorSlotName::RHS_INPUT, t_rhs},}, {}); + parallel_tensor_guid_t t_add = + add_single_output_layer(pcg, + make_layer_attrs(add_attrs, add_match), + { + {TensorSlotName::LHS_INPUT, t_lhs}, + {TensorSlotName::RHS_INPUT, t_rhs}, + }, + {}); return sub_pcg_from_full_pcg(pcg); }(); @@ -1044,9 +1148,11 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t match_layer = get_parallel_layer_by_name(original_pcg, add_match); open_parallel_tensor_guid_t add_match_layer_lhs = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::LHS_INPUT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::LHS_INPUT); open_parallel_tensor_guid_t add_match_layer_rhs = - get_layer_inputs(original_pcg, match_layer).at(TensorSlotName::RHS_INPUT); + get_layer_inputs(original_pcg, match_layer) + .at(TensorSlotName::RHS_INPUT); return PCGPatternMatch{ bidict{ @@ -1070,19 +1176,22 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_lhs = add_partition_layer(pcg, parallel_dim, degree, add_input_layer(pcg, lhs_shape)); - parallel_tensor_guid_t t_rhs = add_partition_layer(pcg, parallel_dim, degree, add_input_layer(pcg, rhs_shape)); + parallel_tensor_guid_t t_lhs = add_partition_layer( + pcg, parallel_dim, degree, add_input_layer(pcg, lhs_shape)); + parallel_tensor_guid_t t_rhs = add_partition_layer( + pcg, parallel_dim, degree, add_input_layer(pcg, rhs_shape)); parallel_tensor_guid_t t_partitioned_add = add_single_output_layer(pcg, - make_layer_attrs(add_attrs, add_match), - { - {TensorSlotName::LHS_INPUT, t_lhs}, - {TensorSlotName::RHS_INPUT, t_rhs}, - }, - {}); + make_layer_attrs(add_attrs, add_match), + { + {TensorSlotName::LHS_INPUT, t_lhs}, + {TensorSlotName::RHS_INPUT, t_rhs}, + }, + {}); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, parallel_dim, degree, t_partitioned_add); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, parallel_dim, degree, t_partitioned_add); return sub_pcg_from_full_pcg(pcg); }(); @@ -1117,8 +1226,11 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_input = add_input_layer(pcg, input_shape); - parallel_tensor_guid_t t_relu = add_single_output_layer( - pcg, make_layer_attrs(relu_attrs, relu_match), {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_relu = + add_single_output_layer(pcg, + make_layer_attrs(relu_attrs, relu_match), + {{TensorSlotName::INPUT, t_input}}, + {}); return sub_pcg_from_full_pcg(pcg); }(); @@ -1146,12 +1258,17 @@ TEST_SUITE(FF_TEST_SUITE) { SubParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - parallel_tensor_guid_t t_input = add_partition_layer(pcg, parallel_dim, degree, add_input_layer(pcg, input_shape)); + parallel_tensor_guid_t t_input = add_partition_layer( + pcg, parallel_dim, degree, add_input_layer(pcg, input_shape)); - parallel_tensor_guid_t t_relu = add_single_output_layer( - pcg, make_layer_attrs(relu_attrs), {{TensorSlotName::INPUT, t_input}}, {}); + parallel_tensor_guid_t t_relu = + add_single_output_layer(pcg, + make_layer_attrs(relu_attrs), + {{TensorSlotName::INPUT, t_input}}, + {}); - parallel_tensor_guid_t t_combine = add_combine_layer(pcg, parallel_dim, degree, t_relu); + parallel_tensor_guid_t t_combine = + add_combine_layer(pcg, parallel_dim, degree, t_relu); return sub_pcg_from_full_pcg(pcg); }(); diff --git a/lib/utils/src/utils/positive_int/positive_range.cc b/lib/utils/src/utils/positive_int/positive_range.cc index 8a31ea0505..bb52f0b4d9 100644 --- a/lib/utils/src/utils/positive_int/positive_range.cc +++ b/lib/utils/src/utils/positive_int/positive_range.cc @@ -1,6 +1,6 @@ #include "utils/positive_int/positive_range.h" -#include "utils/containers/transform.h" #include "utils/containers/range.h" +#include "utils/containers/transform.h" namespace FlexFlow {