diff --git a/KLR/NKI/Simplify.lean b/KLR/NKI/Simplify.lean index c473fc4f..f32fc148 100644 --- a/KLR/NKI/Simplify.lean +++ b/KLR/NKI/Simplify.lean @@ -58,7 +58,7 @@ private def value : Python.Const -> Simplify Value | .float f => return .float f | .string s => return .string s | .ellipsis => throw "invalid use of ellipsis" - | .tensor s dty => return .tensor s dty none + | .tensor s dty name => return .tensor s dty name private def strValue (e : Python.Expr) : Simplify String := withPos e.pos do diff --git a/KLR/Python.lean b/KLR/Python.lean index 2776af51..c232e391 100644 --- a/KLR/Python.lean +++ b/KLR/Python.lean @@ -42,7 +42,8 @@ inductive Const where | string (value : String) | ellipsis -- TODO handle tensor data as well - | tensor (shape : List Nat) (dtype : String) + -- name is the path to the tensor in the argument structure (e.g., "x.0.attr.1") + | tensor (shape : List Nat) (dtype : String) (name : Option String := none) deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp /- diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index b5ccf01a..5020816d 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -565,8 +565,10 @@ private def processArgs (args : List Arg) : Trace (List Value × List Keyword) : let (inputs, kws) <- args.foldlM (init := ([], [])) fun (inputs, kws) ⟨name, e⟩ => do modify fun s => {s with tensorNames := s.tensorNames.insert name} match e with - | ⟨ .value (.tensor s d _), pos ⟩ => - let t := .tensor s d name + | ⟨ .value (.tensor s d tensorName), pos ⟩ => + -- Use the embedded tensor name if available, otherwise use the argument name + let finalName := tensorName.getD name + let t := .tensor s d finalName let e' := ⟨ .value t, pos ⟩ return (t :: inputs, .mk name e' :: kws) | _ => return (inputs, .mk name e :: kws) diff --git a/interop/klr/gather.c b/interop/klr/gather.c index c161449c..38106d06 100644 --- a/interop/klr/gather.c +++ b/interop/klr/gather.c @@ -378,7 +378,7 @@ static lean_object* tensor_const(struct state *st, PyObject *obj) { dty = py_strdup(st, dstr); Py_DECREF(dstr); - return Python_Const_tensor(sh, dty); + return Python_Const_tensor(sh, dty, mkNone()); error: if (sh) lean_dec(sh); @@ -387,6 +387,38 @@ static lean_object* tensor_const(struct state *st, PyObject *obj) { return Python_Const_none; } +// Handle tensor objects with a name (return Const) +static lean_object* tensor_const_named(struct state *st, PyObject *obj, lean_object *name) { + lean_object *sh = NULL; + lean_object *dty = NULL; + + if (!obj) goto error; + + PyObject *shape = PyObject_GetAttrString(obj, "shape"); + if (!shape) goto error; + sh = nat_list(st, shape); + Py_DECREF(shape); + + PyObject *dtype = PyObject_GetAttrString(obj, "dtype"); + if (!dtype) goto error; + + PyObject *dstr = PyObject_Str(dtype); + Py_DECREF(dtype); + if (!dstr) goto error; + + dty = py_strdup(st, dstr); + Py_DECREF(dstr); + + return Python_Const_tensor(sh, dty, mkSome(name)); + +error: + if (sh) lean_dec(sh); + if (dty) lean_dec(dty); + if (name) lean_dec(name); + error(st, "could not convert tensor value"); + return Python_Const_none; +} + // This function never raises exceptions // Returns a new reference static PyObject* get_numpy_generic_dtype() { @@ -627,6 +659,277 @@ static lean_object* const_exprs(struct state *st, PyObject *obj) { return const_list(st, obj, const_expr); } +// ============================================================================ +// Path-aware constant expression conversion for tensor naming +// These functions track the path through nested structures to name tensors +// ============================================================================ + +// Forward declarations for path-aware functions +static lean_object* const_expr_with_path(struct state *st, PyObject *obj, lean_object *path); +static lean_object* const_exprs_with_path(struct state *st, PyObject *obj, lean_object *path); +static lean_object* const_dict_with_path(struct state *st, PyObject *obj, lean_object *path); +static lean_object* const_dict_values_with_path(struct state *st, PyObject *obj, lean_object *path); + +// Helper to append index to path: "path.idx" +static lean_object* make_path_with_index(lean_object *base, Py_ssize_t idx) { + char buf[32]; + snprintf(buf, sizeof(buf), ".%zd", idx); + lean_object *suffix = lean_mk_string(buf); + lean_object *res = lean_string_append(base, suffix); + lean_dec(suffix); + return res; +} + +// Helper to append key to path: "path.key" +static lean_object* make_path_with_key(struct state *st, lean_object *base, PyObject *key) { + lean_object *dot = lean_mk_string("."); + lean_object *key_str = py_strdup(st, key); + lean_object *tmp = lean_string_append(base, dot); + lean_object *res = lean_string_append(tmp, key_str); + lean_dec(dot); + lean_dec(key_str); + lean_dec(tmp); + return res; +} + +// returns Expr - path-aware version that names tensors with their path +static lean_object* const_expr_with_path(struct state *st, PyObject *obj, lean_object *path) { + lean_object *pos = curPos(st); + lean_object *e = NULL; + PyObject *numpy_dt = NULL; + + if (!obj) { + error(st, "could not convert constant expression"); + lean_dec(path); + } + else if (PyTuple_Check(obj)) { + lean_object *l = const_exprs_with_path(st, obj, path); + e = Python_Expr_mk(Python_Expr_tuple(l, Python_Ctx_load), pos); + } + else if (PyList_Check(obj)) { + lean_object *l = const_exprs_with_path(st, obj, path); + e = Python_Expr_mk(Python_Expr_list(l, Python_Ctx_load), pos); + } + else if (PyDict_Check(obj)) { + PyObject *keys = PyDict_Keys(obj); + PyObject *vals = PyDict_Values(obj); + + // For dict values, we need to track keys in the path + lean_object *l_keys = const_exprs(st, keys); // keys don't need path tracking + lean_object *l_vals = const_dict_values_with_path(st, obj, path); + + Py_XDECREF(keys); + Py_XDECREF(vals); + + e = Python_Expr_mk(Python_Expr_dict(l_keys, l_vals), pos); + } + else if (PyModule_Check(obj)) { + PyObject *name = PyModule_GetNameObject(obj); + e = Python_Expr_mk(Python_Expr_name(py_strdup(st, name), Python_Ctx_load), pos); + lean_dec(path); + } + else if (is_tensor(obj)) { + // Use the path as the tensor name + lean_object *c = tensor_const_named(st, obj, path); + e = Python_Expr_mk(Python_Expr_const(c), pos); + // path ownership transferred to tensor_const_named + } + else if (is_numpy_dtype(obj)) { + const char* nki_dtype = suggest_nki_dtype(obj); + if (nki_dtype) { + error(st, "numpy dtypes are not supported as arguments. Use %s instead", nki_dtype); + } else { + error(st, "numpy dtypes are not supported as arguments"); + } + lean_dec(path); + } + else if ((numpy_dt = numpy_dtype_instance(obj)) && numpy_dt) { + const char* nki_dtype = suggest_nki_dtype(numpy_dt); + if (nki_dtype) { + error(st, "numpy dtypes are not supported as arguments. Use %s instead", nki_dtype); + Py_DECREF(numpy_dt); + } else { + error(st, "numpy dtypes are not supported as arguments"); + } + lean_dec(path); + } + else if (PyFunction_Check(obj)) { + lean_object *func_name = py_def_name(st, obj); + if (!st->ignore_refs) { + add_work(st, NULL, obj); + } + e = Python_Expr_mk(Python_Expr_name(func_name, Python_Ctx_load), pos); + lean_dec(path); + } + else if (PyObject_HasAttrString(obj, "__class__") && + PyObject_HasAttrString(obj, "__dict__")) + { + // general object types - recurse into __dict__ with path + PyObject *cls = PyObject_GetAttrString(obj, "__class__"); + PyObject *dict = PyObject_GetAttrString(obj, "__dict__"); + + lean_object *cls_name = py_def_name(st, cls); + lean_object *l_dict = const_dict_with_path(st, dict, path); + Py_XDECREF(dict); + + add_work(st, NULL, cls); + e = Python_Expr_mk(Python_Expr_object(cls_name, l_dict), pos); + // path ownership transferred to const_dict_with_path + } + else { + e = Python_Expr_mk(Python_Expr_const(value(st, obj)), pos); + lean_dec(path); + } + + if (!e) + e = Python_Expr_mk(Python_Expr_const(Python_Const_none), pos); + return e; +} + +// returns List Expr - path-aware version for lists/tuples +static lean_object* const_exprs_with_path(struct state *st, PyObject *obj, lean_object *path) { + if (!obj) { + lean_dec(path); + return mkNil(); + } + + Py_ssize_t sz = PyObject_Length(obj); + if (sz <= 0) { + lean_dec(path); + return mkNil(); + } + + lean_object *arr = lean_alloc_array(0, sz); + + for (Py_ssize_t i = 0; i < sz; i++) { + PyObject *key = PyLong_FromLong(i); + if (!key) { + error(st, "could not construct Long Object for key %ld", i); + break; + } + + PyObject *item = PyObject_GetItem(obj, key); + Py_DECREF(key); + if (!item) { + error(st, "could not get sequence item number %ld", i); + break; + } + + // Create path for this element: path.i + lean_inc(path); + lean_object *item_path = make_path_with_index(path, i); + lean_object *e = const_expr_with_path(st, item, item_path); + Py_DECREF(item); + if (e) + arr = lean_array_push(arr, e); + } + lean_dec(path); + return lean_array_to_list(arr); +} + +// returns List Expr - path-aware version for dict values +static lean_object* const_dict_values_with_path(struct state *st, PyObject *obj, lean_object *path) { + if (!obj) { + lean_dec(path); + return mkNil(); + } + + lean_object *arr = lean_mk_empty_array(); + + Py_ssize_t pos = 0; + PyObject *key, *val; + while (PyDict_Next(obj, &pos, &key, &val)) { + // Create path for this value: path.key + lean_inc(path); + lean_object *val_path = make_path_with_key(st, path, key); + lean_object *e = const_expr_with_path(st, val, val_path); + arr = lean_array_push(arr, e); + } + lean_dec(path); + return lean_array_to_list(arr); +} + +// Returns List Keyword - path-aware version for object __dict__ +static lean_object* const_dict_with_path(struct state *st, PyObject *obj, lean_object *path) { + if (!obj) { + error(st, "could not convert dictionary"); + lean_dec(path); + return mkNil(); + } + + lean_object *arr = lean_mk_empty_array(); + lean_object *l_pos = curPos(st); + + Py_ssize_t pos = 0; + PyObject *key, *val; + while (PyDict_Next(obj, &pos, &key, &val)) { + lean_object *s = py_strdup(st, key); + // Create path for this attribute: path.key + lean_inc(path); + lean_object *val_path = make_path_with_key(st, path, key); + lean_object *e = const_expr_with_path(st, val, val_path); + arr = lean_array_push(arr, Python_Keyword_mk(mkOption(s), e, l_pos)); + } + lean_dec(path); + return lean_array_to_list(arr); +} + +// Process positional args with path tracking +// Returns List Expr where tensors have names like "arg0", "arg1.0", "arg2.attr.0", etc. +static lean_object* const_args_with_paths(struct state *st, PyObject *args) { + if (!args || args == Py_None) return mkNil(); + + Py_ssize_t sz = PyObject_Length(args); + if (sz <= 0) return mkNil(); + + lean_object *arr = lean_alloc_array(0, sz); + + for (Py_ssize_t i = 0; i < sz; i++) { + PyObject *key = PyLong_FromLong(i); + if (!key) { + error(st, "could not construct Long Object for key %ld", i); + break; + } + + PyObject *item = PyObject_GetItem(args, key); + Py_DECREF(key); + if (!item) { + error(st, "could not get argument number %ld", i); + break; + } + + // Create base path for this argument: "arg" + char buf[32]; + snprintf(buf, sizeof(buf), "arg%zd", i); + lean_object *path = lean_mk_string(buf); + lean_object *e = const_expr_with_path(st, item, path); + Py_DECREF(item); + if (e) + arr = lean_array_push(arr, e); + } + return lean_array_to_list(arr); +} + +// Process keyword args with path tracking +// Returns List Keyword where tensors have names like "x", "x.0", "x.attr.0", etc. +static lean_object* const_kwargs_with_paths(struct state *st, PyObject *kws) { + if (!kws || kws == Py_None) return mkNil(); + + lean_object *arr = lean_mk_empty_array(); + lean_object *l_pos = curPos(st); + + Py_ssize_t pos = 0; + PyObject *key, *val; + while (PyDict_Next(kws, &pos, &key, &val)) { + lean_object *s = py_strdup(st, key); + // Use the keyword name as the base path + lean_object *path = py_strdup(st, key); + lean_object *e = const_expr_with_path(st, val, path); + arr = lean_array_push(arr, Python_Keyword_mk(mkOption(s), e, l_pos)); + } + return lean_array_to_list(arr); +} + // returns Nat static lean_object* const_nat(struct state *st, PyObject *obj) { unsigned res = 0; @@ -1718,9 +2021,10 @@ PyObject* specialize( // add main function to work list, and process arguments // potentially adding more dependencies to the work list + // Use path-aware functions to name tensors based on their location in the argument structure add_work(&st, NULL, k->f); - lean_object *l_args = args == Py_None ? mkNil() : const_exprs(&st, args); - lean_object *l_kwargs = kws == Py_None ? mkNil() : const_dict(&st, kws); + lean_object *l_args = args == Py_None ? mkNil() : const_args_with_paths(&st, args); + lean_object *l_kwargs = kws == Py_None ? mkNil() : const_kwargs_with_paths(&st, kws); while (true) { struct worklist *work = st.work; diff --git a/interop/klr/lean_ast.h b/interop/klr/lean_ast.h index 12686b10..c7f771c3 100644 --- a/interop/klr/lean_ast.h +++ b/interop/klr/lean_ast.h @@ -4,7 +4,7 @@ lean_object* Python_Const_bool(uint8_t); lean_object* Python_Const_int(lean_object*); lean_object* Python_Const_float(double); lean_object* Python_Const_string(lean_object*); -lean_object* Python_Const_tensor(lean_object*,lean_object*); +lean_object* Python_Const_tensor(lean_object*,lean_object*,lean_object*); extern lean_object* Python_Const_none; extern lean_object* Python_Const_ellipsis; lean_object* Python_Keyword_mk(lean_object*,lean_object*,lean_object*);