diff --git a/KLR/Core/Basic.lean b/KLR/Core/Basic.lean index 3c67f158..d9938f99 100644 --- a/KLR/Core/Basic.lean +++ b/KLR/Core/Basic.lean @@ -189,6 +189,7 @@ partial def operatorBasicTensors : Operator → List TensorRef | .devicePrint t => [t.src] | .exponential e => [e.dst, e.src] | .activate2 a => [a.dst, a.src] + | .dveReadAccumulator e => [e.dst] partial def operatorAdditionalTensors : Operator → List TensorName | .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes) diff --git a/KLR/Core/Operators.lean b/KLR/Core/Operators.lean index 3c9ebcdc..5b80e9fc 100644 --- a/KLR/Core/Operators.lean +++ b/KLR/Core/Operators.lean @@ -1262,6 +1262,17 @@ instance : MapTensorRefs Activate2 where } @[serde tag = 218] +structure DveReadAccumulator where + dst : TensorRef + negated : Bool + deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp + +instance : MapTensorRefs DveReadAccumulator where + mapM ft _ op := do pure { op with + dst := ← ft op.dst + } + +@[serde tag = 219] inductive Operator where | activate (op : Activate) | ncActivate (op : NcActivate) @@ -1339,9 +1350,10 @@ inductive Operator where | devicePrint (op: DevicePrint) | exponential(op: Exponential) | activate2 (op: Activate2) + | dveReadAccumulator (op: DveReadAccumulator) deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp -@[serde tag = 219] +@[serde tag = 220] inductive TGROperator where | activate (op : Activate) | affineSelect (op : AffineSelect) @@ -1451,3 +1463,4 @@ instance : MapTensorRefs Operator where | .devicePrint op => return .devicePrint (← MapTensorRefs.mapM ft fo op) | .exponential op => return .exponential (← MapTensorRefs.mapM ft fo op) | .activate2 op => return .activate2 (← MapTensorRefs.mapM ft fo op) + | .dveReadAccumulator op => return .dveReadAccumulator (← MapTensorRefs.mapM ft fo op) diff --git a/KLR/Extract/Extract/Basic.lean b/KLR/Extract/Extract/Basic.lean index b0c6bc74..9a9155c4 100644 --- a/KLR/Extract/Extract/Basic.lean +++ b/KLR/Extract/Extract/Basic.lean @@ -376,6 +376,7 @@ def klrAST: MetaM (List LeanType) := do `KLR.Core.DevicePrint, `KLR.Core.Exponential, `KLR.Core.Activate2, + `KLR.Core.DveReadAccumulator, `KLR.Core.Operator, -- Core.Basic `KLR.Core.Stmt, diff --git a/KLR/Trace/ISA.lean b/KLR/Trace/ISA.lean index 3401ebcd..42b803fd 100644 --- a/KLR/Trace/ISA.lean +++ b/KLR/Trace/ISA.lean @@ -1315,3 +1315,13 @@ nki builtin.isa.activate2 dtype := dst.tensor.dtype }) name return .none + +nki builtin.isa.dveReadAccumulator + (dst : Access) + (negated : Bool := false) + (name : Option String := none) := do + Trace.add_stmt $ .oper (.dveReadAccumulator { + dst := .abstract dst, + negated := negated, + }) name + return .none diff --git a/interop/klr/NKI.asdl b/interop/klr/NKI.asdl index ef9da411..4b853496 100644 --- a/interop/klr/NKI.asdl +++ b/interop/klr/NKI.asdl @@ -366,6 +366,8 @@ DevicePrint = (TensorRef src, String printPrefix, PrintOutputBuffer buffer) Exponential = (TensorRef dst, TensorRef src, Operand maxValue, TensorRef? reduceRes, AccumCmd reducecmd, Operand reduceInit) Activate2 = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand imm1, ActivationFunc activationFunc, Operand reluParam, AluOp reduceOp, TensorRef? reduceRes, AccumCmd reduceCmd, Bool reverse0, Bool reverse1, Dtype? dtype) + +DveReadAccumulator = (TensorRef dst, Bool negated) Operator = | activate(Activate op) | ncActivate(NcActivate op) @@ -443,6 +445,7 @@ Operator = | devicePrint(DevicePrint op) | exponential(Exponential op) | activate2(Activate2 op) + | dveReadAccumulator(DveReadAccumulator op) Stmt = | oper(Operator op, String? name, Pos pos) diff --git a/interop/klr/klir_ast.hpp b/interop/klr/klir_ast.hpp index baf4be72..2ccb8296 100644 --- a/interop/klr/klir_ast.hpp +++ b/interop/klr/klir_ast.hpp @@ -1087,6 +1087,11 @@ struct Activate2 final { Option dtype; }; +struct DveReadAccumulator final { + Ptr dst; + Bool negated; +}; + struct Operator { enum class Tag { activate = 1, @@ -1165,6 +1170,7 @@ struct Operator { devicePrint, exponential, activate2, + dveReadAccumulator, }; Tag tag; Operator(Tag tag) : tag(tag) {} @@ -1555,6 +1561,11 @@ struct OperatorActivate2Wrapper final : Operator { OperatorActivate2Wrapper() : Operator(Tag::activate2) {} }; +struct OperatorDveReadAccumulatorWrapper final : Operator { + Ptr op; + OperatorDveReadAccumulatorWrapper() : Operator(Tag::dveReadAccumulator) {} +}; + struct Stmt { enum class Tag { oper = 1, diff --git a/interop/klr/klir_pretty_print.cpp b/interop/klr/klir_pretty_print.cpp index 42ce6d40..08e40c59 100644 --- a/interop/klr/klir_pretty_print.cpp +++ b/interop/klr/klir_pretty_print.cpp @@ -3211,6 +3211,18 @@ std::string to_string(Activate2 &Activate2Instance) { return result; }; +std::string to_string(DveReadAccumulator &DveReadAccumulatorInstance) { + std::string result; + result += "DveReadAccumulator("; + result += "dst="; + result += to_string(*(DveReadAccumulatorInstance.dst.get())); + result += ", "; + result += "negated="; + result += std::to_string(DveReadAccumulatorInstance.negated); + result += ")"; + return result; +}; + std::string to_string(OperatorActivateWrapper &OperatorActivateWrapperInstance) { std::string result; @@ -3886,6 +3898,15 @@ to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance) { result += ")"; return result; }; +std::string to_string(OperatorDveReadAccumulatorWrapper + &OperatorDveReadAccumulatorWrapperInstance) { + std::string result; + result += "OperatorDveReadAccumulatorWrapper("; + result += "op="; + result += to_string(*(OperatorDveReadAccumulatorWrapperInstance.op.get())); + result += ")"; + return result; +}; std::string to_string(Operator &OperatorInstance) { switch (OperatorInstance.tag) { case (Operator::Tag::activate): { @@ -4270,6 +4291,11 @@ std::string to_string(Operator &OperatorInstance) { static_cast(OperatorInstance); return to_string(derivedRef); } + case (Operator::Tag::dveReadAccumulator): { + OperatorDveReadAccumulatorWrapper &derivedRef = + static_cast(OperatorInstance); + return to_string(derivedRef); + } default: return "UNABLE TO PRINT"; } diff --git a/interop/klr/klir_pretty_print.hpp b/interop/klr/klir_pretty_print.hpp index 856ffed0..47217ab3 100644 --- a/interop/klr/klir_pretty_print.hpp +++ b/interop/klr/klir_pretty_print.hpp @@ -268,6 +268,8 @@ std::string to_string(Exponential &ExponentialInstance); std::string to_string(Activate2 &Activate2Instance); +std::string to_string(DveReadAccumulator &DveReadAccumulatorInstance); + std::string to_string(OperatorActivateWrapper &OperatorActivateWrapperInstance); std::string to_string(OperatorNcActivateWrapper &OperatorNcActivateWrapperInstance); @@ -403,6 +405,8 @@ std::string to_string(OperatorExponentialWrapper &OperatorExponentialWrapperInstance); std::string to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance); +std::string to_string(OperatorDveReadAccumulatorWrapper + &OperatorDveReadAccumulatorWrapperInstance); std::string to_string(Operator &OperatorInstance); std::string to_string(StmtOperWrapper &StmtOperWrapperInstance); diff --git a/interop/klr/klir_serde.cpp b/interop/klr/klir_serde.cpp index 52eb610c..76149b0d 100644 --- a/interop/klr/klir_serde.cpp +++ b/interop/klr/klir_serde.cpp @@ -2931,6 +2931,16 @@ bool Activate2_ser(FILE *out, const Ptr &value) { return true; } +bool DveReadAccumulator_ser(FILE *out, const Ptr &value) { + if (!serialize_tag(out, 218, 0, 2)) + return false; + if (!TensorRef_ser(out, value->dst)) + return false; + if (!Bool_ser(out, value->negated)) + return false; + return true; +} + bool Operator_ser(FILE *out, const Ptr &value) { u8 tag_val = 0; u8 field_count = 1; // All variants have exactly 1 field @@ -3241,13 +3251,17 @@ bool Operator_ser(FILE *out, const Ptr &value) { tag_val = 75; field_count = 1; break; + case Operator::Tag::dveReadAccumulator: + tag_val = 76; + field_count = 1; + break; default: throw std::runtime_error("Unknown Operator type in serialization"); return false; } // Serialize the tag - if (!serialize_tag(out, 218, tag_val, field_count)) + if (!serialize_tag(out, 219, tag_val, field_count)) return false; // Serialize the fields based on the specific variant @@ -3624,6 +3638,11 @@ bool Operator_ser(FILE *out, const Ptr &value) { static_cast(value.get()); return Activate2_ser(out, typed_value->op); } + case Operator::Tag::dveReadAccumulator: { + auto *typed_value = + static_cast(value.get()); + return DveReadAccumulator_ser(out, typed_value->op); + } default: throw std::runtime_error("Unknown Operator type in serialization"); return false; @@ -7066,11 +7085,30 @@ Ptr Activate2_des(FILE *in) { return x; } +Ptr DveReadAccumulator_des(FILE *in) { + u8 t, c, l; + if (!deserialize_tag(in, &t, &c, &l)) { + std::ostringstream msg; + msg << "Could not find tag, expecting DveReadAccumulator:218,0"; + throw std::runtime_error(msg.str()); + } + if (t != 218 || c != 0 || l != 2) { + std::ostringstream msg; + msg << "Expecting DveReadAccumulator:(218,0,2)"; + msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; + throw std::runtime_error(msg.str()); + } + Ptr x = ptr(); + x->dst = TensorRef_des(in); + x->negated = Bool_des(in); + return x; +} + Ptr Operator_des(FILE *in) { u8 t, c, l; if (!deserialize_tag(in, &t, &c, &l)) throw std::runtime_error("Could not read tag"); - if (t != 218) + if (t != 219) throw std::runtime_error("Unexpected type tag"); switch (c) { case 0: { @@ -7696,6 +7734,15 @@ Ptr Operator_des(FILE *in) { return x; break; } + case 76: { + if (l != 1) + throw std::runtime_error("Wrong number of elements"); + Ptr x = + ptr(); + x->op = DveReadAccumulator_des(in); + return x; + break; + } default: throw std::runtime_error("Invalid value tag"); } diff --git a/interop/klr/klir_serde.hpp b/interop/klr/klir_serde.hpp index cbfc96f6..6bf94313 100644 --- a/interop/klr/klir_serde.hpp +++ b/interop/klr/klir_serde.hpp @@ -178,6 +178,7 @@ bool PrintOutputBuffer_ser(FILE *out, const PrintOutputBuffer &value); bool DevicePrint_ser(FILE *out, const Ptr &value); bool Exponential_ser(FILE *out, const Ptr &value); bool Activate2_ser(FILE *out, const Ptr &value); +bool DveReadAccumulator_ser(FILE *out, const Ptr &value); bool Operator_ser(FILE *out, const Ptr &value); bool Stmt_ser(FILE *out, const Ptr &value); bool Block_ser(FILE *out, const Ptr &value); @@ -310,6 +311,7 @@ PrintOutputBuffer PrintOutputBuffer_des(FILE *in); Ptr DevicePrint_des(FILE *in); Ptr Exponential_des(FILE *in); Ptr Activate2_des(FILE *in); +Ptr DveReadAccumulator_des(FILE *in); Ptr Operator_des(FILE *in); Ptr Stmt_des(FILE *in); Ptr Block_des(FILE *in);