Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions KLR/Core/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ partial def operatorBasicTensors : Operator → List TensorRef
| .devicePrint t => [t.src]
| .exponential e => [e.dst, e.src]
| .activate2 a => [a.dst, a.src]
| .dveReadAccumulator e => [e.dst]

partial def operatorAdditionalTensors : Operator → List TensorName
| .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes)
Expand Down
15 changes: 14 additions & 1 deletion KLR/Core/Operators.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,17 @@ instance : MapTensorRefs Activate2 where
}

@[serde tag = 218]
structure DveReadAccumulator where
dst : TensorRef
negated : Bool
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

instance : MapTensorRefs DveReadAccumulator where
mapM ft _ op := do pure { op with
dst := ← ft op.dst
}

@[serde tag = 219]
inductive Operator where
| activate (op : Activate)
| ncActivate (op : NcActivate)
Expand Down Expand Up @@ -1339,9 +1350,10 @@ inductive Operator where
| devicePrint (op: DevicePrint)
| exponential(op: Exponential)
| activate2 (op: Activate2)
| dveReadAccumulator (op: DveReadAccumulator)
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

@[serde tag = 219]
@[serde tag = 220]
inductive TGROperator where
| activate (op : Activate)
| affineSelect (op : AffineSelect)
Expand Down Expand Up @@ -1451,3 +1463,4 @@ instance : MapTensorRefs Operator where
| .devicePrint op => return .devicePrint (← MapTensorRefs.mapM ft fo op)
| .exponential op => return .exponential (← MapTensorRefs.mapM ft fo op)
| .activate2 op => return .activate2 (← MapTensorRefs.mapM ft fo op)
| .dveReadAccumulator op => return .dveReadAccumulator (← MapTensorRefs.mapM ft fo op)
1 change: 1 addition & 0 deletions KLR/Extract/Extract/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def klrAST: MetaM (List LeanType) := do
`KLR.Core.DevicePrint,
`KLR.Core.Exponential,
`KLR.Core.Activate2,
`KLR.Core.DveReadAccumulator,
`KLR.Core.Operator,
-- Core.Basic
`KLR.Core.Stmt,
Expand Down
10 changes: 10 additions & 0 deletions KLR/Trace/ISA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1315,3 +1315,13 @@ nki builtin.isa.activate2
dtype := dst.tensor.dtype
}) name
return .none

nki builtin.isa.dveReadAccumulator
(dst : Access)
(negated : Bool := false)
(name : Option String := none) := do
Trace.add_stmt $ .oper (.dveReadAccumulator {
dst := .abstract dst,
negated := negated,
}) name
return .none
3 changes: 3 additions & 0 deletions interop/klr/NKI.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ DevicePrint = (TensorRef src, String printPrefix, PrintOutputBuffer buffer)
Exponential = (TensorRef dst, TensorRef src, Operand maxValue, TensorRef? reduceRes, AccumCmd reducecmd, Operand reduceInit)

Activate2 = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand imm1, ActivationFunc activationFunc, Operand reluParam, AluOp reduceOp, TensorRef? reduceRes, AccumCmd reduceCmd, Bool reverse0, Bool reverse1, Dtype? dtype)

DveReadAccumulator = (TensorRef dst, Bool negated)
Operator =
| activate(Activate op)
| ncActivate(NcActivate op)
Expand Down Expand Up @@ -443,6 +445,7 @@ Operator =
| devicePrint(DevicePrint op)
| exponential(Exponential op)
| activate2(Activate2 op)
| dveReadAccumulator(DveReadAccumulator op)

Stmt =
| oper(Operator op, String? name, Pos pos)
Expand Down
11 changes: 11 additions & 0 deletions interop/klr/klir_ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,11 @@ struct Activate2 final {
Option<Dtype> dtype;
};

struct DveReadAccumulator final {
Ptr<TensorRef> dst;
Bool negated;
};

struct Operator {
enum class Tag {
activate = 1,
Expand Down Expand Up @@ -1165,6 +1170,7 @@ struct Operator {
devicePrint,
exponential,
activate2,
dveReadAccumulator,
};
Tag tag;
Operator(Tag tag) : tag(tag) {}
Expand Down Expand Up @@ -1555,6 +1561,11 @@ struct OperatorActivate2Wrapper final : Operator {
OperatorActivate2Wrapper() : Operator(Tag::activate2) {}
};

struct OperatorDveReadAccumulatorWrapper final : Operator {
Ptr<DveReadAccumulator> op;
OperatorDveReadAccumulatorWrapper() : Operator(Tag::dveReadAccumulator) {}
};

struct Stmt {
enum class Tag {
oper = 1,
Expand Down
26 changes: 26 additions & 0 deletions interop/klr/klir_pretty_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3211,6 +3211,18 @@ std::string to_string(Activate2 &Activate2Instance) {
return result;
};

std::string to_string(DveReadAccumulator &DveReadAccumulatorInstance) {
std::string result;
result += "DveReadAccumulator(";
result += "dst=";
result += to_string(*(DveReadAccumulatorInstance.dst.get()));
result += ", ";
result += "negated=";
result += std::to_string(DveReadAccumulatorInstance.negated);
result += ")";
return result;
};

std::string
to_string(OperatorActivateWrapper &OperatorActivateWrapperInstance) {
std::string result;
Expand Down Expand Up @@ -3886,6 +3898,15 @@ to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance) {
result += ")";
return result;
};
std::string to_string(OperatorDveReadAccumulatorWrapper
&OperatorDveReadAccumulatorWrapperInstance) {
std::string result;
result += "OperatorDveReadAccumulatorWrapper(";
result += "op=";
result += to_string(*(OperatorDveReadAccumulatorWrapperInstance.op.get()));
result += ")";
return result;
};
std::string to_string(Operator &OperatorInstance) {
switch (OperatorInstance.tag) {
case (Operator::Tag::activate): {
Expand Down Expand Up @@ -4270,6 +4291,11 @@ std::string to_string(Operator &OperatorInstance) {
static_cast<OperatorActivate2Wrapper &>(OperatorInstance);
return to_string(derivedRef);
}
case (Operator::Tag::dveReadAccumulator): {
OperatorDveReadAccumulatorWrapper &derivedRef =
static_cast<OperatorDveReadAccumulatorWrapper &>(OperatorInstance);
return to_string(derivedRef);
}
default:
return "UNABLE TO PRINT";
}
Expand Down
4 changes: 4 additions & 0 deletions interop/klr/klir_pretty_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ std::string to_string(Exponential &ExponentialInstance);

std::string to_string(Activate2 &Activate2Instance);

std::string to_string(DveReadAccumulator &DveReadAccumulatorInstance);

std::string to_string(OperatorActivateWrapper &OperatorActivateWrapperInstance);
std::string
to_string(OperatorNcActivateWrapper &OperatorNcActivateWrapperInstance);
Expand Down Expand Up @@ -403,6 +405,8 @@ std::string
to_string(OperatorExponentialWrapper &OperatorExponentialWrapperInstance);
std::string
to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance);
std::string to_string(OperatorDveReadAccumulatorWrapper
&OperatorDveReadAccumulatorWrapperInstance);
std::string to_string(Operator &OperatorInstance);

std::string to_string(StmtOperWrapper &StmtOperWrapperInstance);
Expand Down
51 changes: 49 additions & 2 deletions interop/klr/klir_serde.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2931,6 +2931,16 @@ bool Activate2_ser(FILE *out, const Ptr<Activate2> &value) {
return true;
}

bool DveReadAccumulator_ser(FILE *out, const Ptr<DveReadAccumulator> &value) {
if (!serialize_tag(out, 218, 0, 2))
return false;
if (!TensorRef_ser(out, value->dst))
return false;
if (!Bool_ser(out, value->negated))
return false;
return true;
}

bool Operator_ser(FILE *out, const Ptr<Operator> &value) {
u8 tag_val = 0;
u8 field_count = 1; // All variants have exactly 1 field
Expand Down Expand Up @@ -3241,13 +3251,17 @@ bool Operator_ser(FILE *out, const Ptr<Operator> &value) {
tag_val = 75;
field_count = 1;
break;
case Operator::Tag::dveReadAccumulator:
tag_val = 76;
field_count = 1;
break;
default:
throw std::runtime_error("Unknown Operator type in serialization");
return false;
}

// Serialize the tag
if (!serialize_tag(out, 218, tag_val, field_count))
if (!serialize_tag(out, 219, tag_val, field_count))
return false;

// Serialize the fields based on the specific variant
Expand Down Expand Up @@ -3624,6 +3638,11 @@ bool Operator_ser(FILE *out, const Ptr<Operator> &value) {
static_cast<const OperatorActivate2Wrapper *>(value.get());
return Activate2_ser(out, typed_value->op);
}
case Operator::Tag::dveReadAccumulator: {
auto *typed_value =
static_cast<const OperatorDveReadAccumulatorWrapper *>(value.get());
return DveReadAccumulator_ser(out, typed_value->op);
}
default:
throw std::runtime_error("Unknown Operator type in serialization");
return false;
Expand Down Expand Up @@ -7066,11 +7085,30 @@ Ptr<Activate2> Activate2_des(FILE *in) {
return x;
}

Ptr<DveReadAccumulator> DveReadAccumulator_des(FILE *in) {
u8 t, c, l;
if (!deserialize_tag(in, &t, &c, &l)) {
std::ostringstream msg;
msg << "Could not find tag, expecting DveReadAccumulator:218,0";
throw std::runtime_error(msg.str());
}
if (t != 218 || c != 0 || l != 2) {
std::ostringstream msg;
msg << "Expecting DveReadAccumulator:(218,0,2)";
msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")";
throw std::runtime_error(msg.str());
}
Ptr<DveReadAccumulator> x = ptr<DveReadAccumulator>();
x->dst = TensorRef_des(in);
x->negated = Bool_des(in);
return x;
}

Ptr<Operator> Operator_des(FILE *in) {
u8 t, c, l;
if (!deserialize_tag(in, &t, &c, &l))
throw std::runtime_error("Could not read tag");
if (t != 218)
if (t != 219)
throw std::runtime_error("Unexpected type tag");
switch (c) {
case 0: {
Expand Down Expand Up @@ -7696,6 +7734,15 @@ Ptr<Operator> Operator_des(FILE *in) {
return x;
break;
}
case 76: {
if (l != 1)
throw std::runtime_error("Wrong number of elements");
Ptr<OperatorDveReadAccumulatorWrapper> x =
ptr<OperatorDveReadAccumulatorWrapper>();
x->op = DveReadAccumulator_des(in);
return x;
break;
}
default:
throw std::runtime_error("Invalid value tag");
}
Expand Down
2 changes: 2 additions & 0 deletions interop/klr/klir_serde.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ bool PrintOutputBuffer_ser(FILE *out, const PrintOutputBuffer &value);
bool DevicePrint_ser(FILE *out, const Ptr<DevicePrint> &value);
bool Exponential_ser(FILE *out, const Ptr<Exponential> &value);
bool Activate2_ser(FILE *out, const Ptr<Activate2> &value);
bool DveReadAccumulator_ser(FILE *out, const Ptr<DveReadAccumulator> &value);
bool Operator_ser(FILE *out, const Ptr<Operator> &value);
bool Stmt_ser(FILE *out, const Ptr<Stmt> &value);
bool Block_ser(FILE *out, const Ptr<Block> &value);
Expand Down Expand Up @@ -310,6 +311,7 @@ PrintOutputBuffer PrintOutputBuffer_des(FILE *in);
Ptr<DevicePrint> DevicePrint_des(FILE *in);
Ptr<Exponential> Exponential_des(FILE *in);
Ptr<Activate2> Activate2_des(FILE *in);
Ptr<DveReadAccumulator> DveReadAccumulator_des(FILE *in);
Ptr<Operator> Operator_des(FILE *in);
Ptr<Stmt> Stmt_des(FILE *in);
Ptr<Block> Block_des(FILE *in);
Expand Down
Loading