Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion KLR/NKI/Simplify.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private def value : Python.Const -> Simplify Value
| .float f => return .float f
| .string s => return .string s
| .ellipsis => throw "invalid use of ellipsis"
| .tensor s dty => return .tensor s dty none
| .tensor s dty name => return .tensor s dty name

private def strValue (e : Python.Expr) : Simplify String :=
withPos e.pos do
Expand Down
3 changes: 2 additions & 1 deletion KLR/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ inductive Const where
| string (value : String)
| ellipsis
-- TODO handle tensor data as well
| tensor (shape : List Nat) (dtype : String)
-- name is the path to the tensor in the argument structure (e.g., "x.0.attr.1")
| tensor (shape : List Nat) (dtype : String) (name : Option String := none)
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

/-
Expand Down
6 changes: 4 additions & 2 deletions KLR/Trace/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,10 @@ private def processArgs (args : List Arg) : Trace (List Value × List Keyword) :
let (inputs, kws) <- args.foldlM (init := ([], [])) fun (inputs, kws) ⟨name, e⟩ => do
modify fun s => {s with tensorNames := s.tensorNames.insert name}
match e with
| ⟨ .value (.tensor s d _), pos ⟩ =>
let t := .tensor s d name
| ⟨ .value (.tensor s d tensorName), pos ⟩ =>
-- Use the embedded tensor name if available, otherwise use the argument name
let finalName := tensorName.getD name
let t := .tensor s d finalName
let e' := ⟨ .value t, pos ⟩
return (t :: inputs, .mk name e' :: kws)
| _ => return (inputs, .mk name e :: kws)
Expand Down
310 changes: 307 additions & 3 deletions interop/klr/gather.c
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ static lean_object* tensor_const(struct state *st, PyObject *obj) {
dty = py_strdup(st, dstr);
Py_DECREF(dstr);

return Python_Const_tensor(sh, dty);
return Python_Const_tensor(sh, dty, mkNone());

error:
if (sh) lean_dec(sh);
Expand All @@ -387,6 +387,38 @@ static lean_object* tensor_const(struct state *st, PyObject *obj) {
return Python_Const_none;
}

// Handle tensor objects with a name (return Const)
static lean_object* tensor_const_named(struct state *st, PyObject *obj, lean_object *name) {
lean_object *sh = NULL;
lean_object *dty = NULL;

if (!obj) goto error;

PyObject *shape = PyObject_GetAttrString(obj, "shape");
if (!shape) goto error;
sh = nat_list(st, shape);
Py_DECREF(shape);

PyObject *dtype = PyObject_GetAttrString(obj, "dtype");
if (!dtype) goto error;

PyObject *dstr = PyObject_Str(dtype);
Py_DECREF(dtype);
if (!dstr) goto error;

dty = py_strdup(st, dstr);
Py_DECREF(dstr);

return Python_Const_tensor(sh, dty, mkSome(name));

error:
if (sh) lean_dec(sh);
if (dty) lean_dec(dty);
if (name) lean_dec(name);
error(st, "could not convert tensor value");
return Python_Const_none;
}

// This function never raises exceptions
// Returns a new reference
static PyObject* get_numpy_generic_dtype() {
Expand Down Expand Up @@ -627,6 +659,277 @@ static lean_object* const_exprs(struct state *st, PyObject *obj) {
return const_list(st, obj, const_expr);
}

// ============================================================================
// Path-aware constant expression conversion for tensor naming
// These functions track the path through nested structures to name tensors
// ============================================================================

// Forward declarations for path-aware functions
static lean_object* const_expr_with_path(struct state *st, PyObject *obj, lean_object *path);
static lean_object* const_exprs_with_path(struct state *st, PyObject *obj, lean_object *path);
static lean_object* const_dict_with_path(struct state *st, PyObject *obj, lean_object *path);
static lean_object* const_dict_values_with_path(struct state *st, PyObject *obj, lean_object *path);

// Helper to append index to path: "path.idx"
static lean_object* make_path_with_index(lean_object *base, Py_ssize_t idx) {
char buf[32];
snprintf(buf, sizeof(buf), ".%zd", idx);
lean_object *suffix = lean_mk_string(buf);
lean_object *res = lean_string_append(base, suffix);
lean_dec(suffix);
return res;
}

// Helper to append key to path: "path.key"
static lean_object* make_path_with_key(struct state *st, lean_object *base, PyObject *key) {
lean_object *dot = lean_mk_string(".");
lean_object *key_str = py_strdup(st, key);
lean_object *tmp = lean_string_append(base, dot);
lean_object *res = lean_string_append(tmp, key_str);
lean_dec(dot);
lean_dec(key_str);
lean_dec(tmp);
return res;
}

// returns Expr - path-aware version that names tensors with their path
static lean_object* const_expr_with_path(struct state *st, PyObject *obj, lean_object *path) {
lean_object *pos = curPos(st);
lean_object *e = NULL;
PyObject *numpy_dt = NULL;

if (!obj) {
error(st, "could not convert constant expression");
lean_dec(path);
}
else if (PyTuple_Check(obj)) {
lean_object *l = const_exprs_with_path(st, obj, path);
e = Python_Expr_mk(Python_Expr_tuple(l, Python_Ctx_load), pos);
}
else if (PyList_Check(obj)) {
lean_object *l = const_exprs_with_path(st, obj, path);
e = Python_Expr_mk(Python_Expr_list(l, Python_Ctx_load), pos);
}
else if (PyDict_Check(obj)) {
PyObject *keys = PyDict_Keys(obj);
PyObject *vals = PyDict_Values(obj);

// For dict values, we need to track keys in the path
lean_object *l_keys = const_exprs(st, keys); // keys don't need path tracking
lean_object *l_vals = const_dict_values_with_path(st, obj, path);

Py_XDECREF(keys);
Py_XDECREF(vals);

e = Python_Expr_mk(Python_Expr_dict(l_keys, l_vals), pos);
}
else if (PyModule_Check(obj)) {
PyObject *name = PyModule_GetNameObject(obj);
e = Python_Expr_mk(Python_Expr_name(py_strdup(st, name), Python_Ctx_load), pos);
lean_dec(path);
}
else if (is_tensor(obj)) {
// Use the path as the tensor name
lean_object *c = tensor_const_named(st, obj, path);
e = Python_Expr_mk(Python_Expr_const(c), pos);
// path ownership transferred to tensor_const_named
}
else if (is_numpy_dtype(obj)) {
const char* nki_dtype = suggest_nki_dtype(obj);
if (nki_dtype) {
error(st, "numpy dtypes are not supported as arguments. Use %s instead", nki_dtype);
} else {
error(st, "numpy dtypes are not supported as arguments");
}
lean_dec(path);
}
else if ((numpy_dt = numpy_dtype_instance(obj)) && numpy_dt) {
const char* nki_dtype = suggest_nki_dtype(numpy_dt);
if (nki_dtype) {
error(st, "numpy dtypes are not supported as arguments. Use %s instead", nki_dtype);
Py_DECREF(numpy_dt);
} else {
error(st, "numpy dtypes are not supported as arguments");
}
lean_dec(path);
}
else if (PyFunction_Check(obj)) {
lean_object *func_name = py_def_name(st, obj);
if (!st->ignore_refs) {
add_work(st, NULL, obj);
}
e = Python_Expr_mk(Python_Expr_name(func_name, Python_Ctx_load), pos);
lean_dec(path);
}
else if (PyObject_HasAttrString(obj, "__class__") &&
PyObject_HasAttrString(obj, "__dict__"))
{
// general object types - recurse into __dict__ with path
PyObject *cls = PyObject_GetAttrString(obj, "__class__");
PyObject *dict = PyObject_GetAttrString(obj, "__dict__");

lean_object *cls_name = py_def_name(st, cls);
lean_object *l_dict = const_dict_with_path(st, dict, path);
Py_XDECREF(dict);

add_work(st, NULL, cls);
e = Python_Expr_mk(Python_Expr_object(cls_name, l_dict), pos);
// path ownership transferred to const_dict_with_path
}
else {
e = Python_Expr_mk(Python_Expr_const(value(st, obj)), pos);
lean_dec(path);
}

if (!e)
e = Python_Expr_mk(Python_Expr_const(Python_Const_none), pos);
return e;
}

// returns List Expr - path-aware version for lists/tuples
static lean_object* const_exprs_with_path(struct state *st, PyObject *obj, lean_object *path) {
if (!obj) {
lean_dec(path);
return mkNil();
}

Py_ssize_t sz = PyObject_Length(obj);
if (sz <= 0) {
lean_dec(path);
return mkNil();
}

lean_object *arr = lean_alloc_array(0, sz);

for (Py_ssize_t i = 0; i < sz; i++) {
PyObject *key = PyLong_FromLong(i);
if (!key) {
error(st, "could not construct Long Object for key %ld", i);
break;
}

PyObject *item = PyObject_GetItem(obj, key);
Py_DECREF(key);
if (!item) {
error(st, "could not get sequence item number %ld", i);
break;
}

// Create path for this element: path.i
lean_inc(path);
lean_object *item_path = make_path_with_index(path, i);
lean_object *e = const_expr_with_path(st, item, item_path);
Py_DECREF(item);
if (e)
arr = lean_array_push(arr, e);
}
lean_dec(path);
return lean_array_to_list(arr);
}

// returns List Expr - path-aware version for dict values
static lean_object* const_dict_values_with_path(struct state *st, PyObject *obj, lean_object *path) {
if (!obj) {
lean_dec(path);
return mkNil();
}

lean_object *arr = lean_mk_empty_array();

Py_ssize_t pos = 0;
PyObject *key, *val;
while (PyDict_Next(obj, &pos, &key, &val)) {
// Create path for this value: path.key
lean_inc(path);
lean_object *val_path = make_path_with_key(st, path, key);
lean_object *e = const_expr_with_path(st, val, val_path);
arr = lean_array_push(arr, e);
}
lean_dec(path);
return lean_array_to_list(arr);
}

// Returns List Keyword - path-aware version for object __dict__
static lean_object* const_dict_with_path(struct state *st, PyObject *obj, lean_object *path) {
if (!obj) {
error(st, "could not convert dictionary");
lean_dec(path);
return mkNil();
}

lean_object *arr = lean_mk_empty_array();
lean_object *l_pos = curPos(st);

Py_ssize_t pos = 0;
PyObject *key, *val;
while (PyDict_Next(obj, &pos, &key, &val)) {
lean_object *s = py_strdup(st, key);
// Create path for this attribute: path.key
lean_inc(path);
lean_object *val_path = make_path_with_key(st, path, key);
lean_object *e = const_expr_with_path(st, val, val_path);
arr = lean_array_push(arr, Python_Keyword_mk(mkOption(s), e, l_pos));
}
lean_dec(path);
return lean_array_to_list(arr);
}

// Process positional args with path tracking
// Returns List Expr where tensors have names like "arg0", "arg1.0", "arg2.attr.0", etc.
static lean_object* const_args_with_paths(struct state *st, PyObject *args) {
if (!args || args == Py_None) return mkNil();

Py_ssize_t sz = PyObject_Length(args);
if (sz <= 0) return mkNil();

lean_object *arr = lean_alloc_array(0, sz);

for (Py_ssize_t i = 0; i < sz; i++) {
PyObject *key = PyLong_FromLong(i);
if (!key) {
error(st, "could not construct Long Object for key %ld", i);
break;
}

PyObject *item = PyObject_GetItem(args, key);
Py_DECREF(key);
if (!item) {
error(st, "could not get argument number %ld", i);
break;
}

// Create base path for this argument: "arg<i>"
char buf[32];
snprintf(buf, sizeof(buf), "arg%zd", i);
lean_object *path = lean_mk_string(buf);
lean_object *e = const_expr_with_path(st, item, path);
Py_DECREF(item);
if (e)
arr = lean_array_push(arr, e);
}
return lean_array_to_list(arr);
}

// Process keyword args with path tracking
// Returns List Keyword where tensors have names like "x", "x.0", "x.attr.0", etc.
static lean_object* const_kwargs_with_paths(struct state *st, PyObject *kws) {
if (!kws || kws == Py_None) return mkNil();

lean_object *arr = lean_mk_empty_array();
lean_object *l_pos = curPos(st);

Py_ssize_t pos = 0;
PyObject *key, *val;
while (PyDict_Next(kws, &pos, &key, &val)) {
lean_object *s = py_strdup(st, key);
// Use the keyword name as the base path
lean_object *path = py_strdup(st, key);
lean_object *e = const_expr_with_path(st, val, path);
arr = lean_array_push(arr, Python_Keyword_mk(mkOption(s), e, l_pos));
}
return lean_array_to_list(arr);
}

// returns Nat
static lean_object* const_nat(struct state *st, PyObject *obj) {
unsigned res = 0;
Expand Down Expand Up @@ -1718,9 +2021,10 @@ PyObject* specialize(

// add main function to work list, and process arguments
// potentially adding more dependencies to the work list
// Use path-aware functions to name tensors based on their location in the argument structure
add_work(&st, NULL, k->f);
lean_object *l_args = args == Py_None ? mkNil() : const_exprs(&st, args);
lean_object *l_kwargs = kws == Py_None ? mkNil() : const_dict(&st, kws);
lean_object *l_args = args == Py_None ? mkNil() : const_args_with_paths(&st, args);
lean_object *l_kwargs = kws == Py_None ? mkNil() : const_kwargs_with_paths(&st, kws);

while (true) {
struct worklist *work = st.work;
Expand Down
2 changes: 1 addition & 1 deletion interop/klr/lean_ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lean_object* Python_Const_bool(uint8_t);
lean_object* Python_Const_int(lean_object*);
lean_object* Python_Const_float(double);
lean_object* Python_Const_string(lean_object*);
lean_object* Python_Const_tensor(lean_object*,lean_object*);
lean_object* Python_Const_tensor(lean_object*,lean_object*,lean_object*);
extern lean_object* Python_Const_none;
extern lean_object* Python_Const_ellipsis;
lean_object* Python_Keyword_mk(lean_object*,lean_object*,lean_object*);
Expand Down
Loading