Skip to content

Commit 7df58e5

Browse files
timsaucerclaude
andauthored
feat: pass calling SessionContext to Python UDTF callbacks (#1555)
* feat: pass calling SessionContext to Python UDTF callbacks DataFusion 53 added `TableFunctionImpl::call_with_args(TableFunctionArgs)` where `TableFunctionArgs` carries both the positional expression arguments and the calling `&dyn Session`. The pure-Python UDTF path previously discarded everything but the exprs. Thread the session through when the user callback's signature opts in by declaring a `session` keyword parameter (or `**kwargs`). At call time we downcast the `&dyn Session` to its canonical `SessionState` impl and build a fresh `SessionContext` over the same Arc-shared state, exposed to Python as a `datafusion.SessionContext` wrapper. Existing callbacks whose signatures do not declare `session` continue to be called with the positional expression arguments only — no behavior change for current users. Note: a UDTF body cannot drive a fresh `ctx.sql(...).collect()` on the passed-in session because the outer SQL execution already holds the tokio runtime. Use the session for metadata access (catalogs, UDF lookups, config) rather than nested DataFrame collection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs: clarify py_session_from_session downcast is defensive The doc comment implied a foreign FFI session was a real input. No current path reaches a pure-Python UDTF with a non-SessionState session: the SQL planner and __call__ both hand a SessionState, and a ForeignSession would only arrive via FFI-export of the UDTF, which datafusion-python does not do. Reword to state the guard is defensive and rewrap the error string. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: opt-in UDTF session injection via with_session flag Replaces signature sniffing with an explicit ``with_session=True`` kwarg on ``TableFunction`` / ``udtf``. Avoids name-based detection footguns (positional-only ``session`` params, accidental ``**kwargs`` opt-in, shadowing by unrelated params) and makes author intent visible at registration. Also documents the feature in the UDTF user guide. Rust field renamed ``accepts_session`` -> ``inject_session_on_call`` to match the Python-side opt-in semantics. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix: reject with_session=True for FFI UDTFs and qualify mutation docs Raise TypeError when with_session=True is combined with an FFI-exported table function (one exposing __datafusion_table_function__). The Rust FFI branch does not consult the flag, so it would silently be dropped; guard both TableFunction.__init__ and the udtf() convenience entry. Qualify the doc claim that mutations through the injected session propagate to the caller: registry mutations do (shared Arc registries), but config changes do not (SessionConfig is cloned). Mirror the caveat in TableFunction.__init__ per the user-guide caveats convention. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9872283 commit 7df58e5

4 files changed

Lines changed: 303 additions & 29 deletions

File tree

crates/core/src/udtf.rs

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,34 @@
1818
use std::ptr::NonNull;
1919
use std::sync::Arc;
2020

21-
use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
22-
use datafusion::error::Result as DataFusionResult;
21+
use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl, TableProvider};
22+
use datafusion::error::{DataFusionError, Result as DataFusionResult};
23+
use datafusion::execution::context::SessionContext;
24+
use datafusion::execution::session_state::SessionState;
2325
use datafusion::logical_expr::Expr;
2426
use datafusion_ffi::udtf::FFI_TableFunction;
2527
use pyo3::IntoPyObjectExt;
2628
use pyo3::exceptions::{PyImportError, PyTypeError};
2729
use pyo3::prelude::*;
28-
use pyo3::types::{PyCapsule, PyTuple, PyType};
30+
use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType};
2931

3032
use crate::context::PySessionContext;
3133
use crate::errors::{py_datafusion_err, to_datafusion_err};
3234
use crate::expr::PyExpr;
3335
use crate::table::PyTable;
3436

37+
/// A pure-Python UDTF callable plus the metadata we discovered about it
38+
/// at registration time.
39+
#[derive(Debug, Clone)]
40+
pub(crate) struct PythonTableFunctionCallable {
41+
pub(crate) callable: Arc<Py<PyAny>>,
42+
/// When true, the calling :class:`SessionContext` is passed to the
43+
/// callable as a ``session`` keyword argument on every invocation.
44+
/// Opt-in at registration time via ``with_session=True`` on the
45+
/// Python wrapper.
46+
pub(crate) inject_session_on_call: bool,
47+
}
48+
3549
/// Represents a user defined table function
3650
#[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")]
3751
#[derive(Debug, Clone)]
@@ -40,21 +54,21 @@ pub struct PyTableFunction {
4054
pub(crate) inner: PyTableFunctionInner,
4155
}
4256

43-
// TODO: Implement pure python based user defined table functions
4457
#[derive(Debug, Clone)]
4558
pub(crate) enum PyTableFunctionInner {
46-
PythonFunction(Arc<Py<PyAny>>),
59+
PythonFunction(PythonTableFunctionCallable),
4760
FFIFunction(Arc<dyn TableFunctionImpl>),
4861
}
4962

5063
#[pymethods]
5164
impl PyTableFunction {
5265
#[new]
53-
#[pyo3(signature=(name, func, session))]
66+
#[pyo3(signature=(name, func, session, inject_session_on_call=false))]
5467
pub fn new(
5568
name: &str,
5669
func: Bound<'_, PyAny>,
5770
session: Option<Bound<PyAny>>,
71+
inject_session_on_call: bool,
5872
) -> PyResult<Self> {
5973
let inner = if func.hasattr("__datafusion_table_function__")? {
6074
let py = func.py();
@@ -80,8 +94,10 @@ impl PyTableFunction {
8094

8195
PyTableFunctionInner::FFIFunction(foreign_func)
8296
} else {
83-
let py_obj = Arc::new(func.unbind());
84-
PyTableFunctionInner::PythonFunction(py_obj)
97+
PyTableFunctionInner::PythonFunction(PythonTableFunctionCallable {
98+
callable: Arc::new(func.unbind()),
99+
inject_session_on_call,
100+
})
85101
};
86102

87103
Ok(Self {
@@ -107,20 +123,66 @@ impl PyTableFunction {
107123
}
108124
}
109125

126+
/// Materialize a fresh :class:`PySessionContext` from the borrowed
127+
/// ``&dyn Session`` handed in at call time.
128+
///
129+
/// Upstream invokes ``call_with_args`` with a trait-object reference
130+
/// rather than an owned context; we downcast it to the canonical
131+
/// :class:`SessionState` impl and rebuild a :class:`SessionContext`
132+
/// (sharing the same registries via the Arc-heavy interior of
133+
/// :class:`SessionState`).
134+
///
135+
/// The downcast is defensive. Every path that reaches a pure-Python
136+
/// UDTF today hands us a `SessionState`: the SQL planner builds the
137+
/// args from its own `SessionState`, and `PyTableFunction::__call__`
138+
/// uses the global context's state. A non-`SessionState` session
139+
/// (e.g. a `ForeignSession`) would only arrive if this UDTF were
140+
/// exported across the FFI boundary to a foreign-library consumer,
141+
/// which datafusion-python does not do. Should that change, this
142+
/// returns an error rather than silently misbehaving.
143+
fn py_session_from_session(session: &dyn Session) -> DataFusionResult<PySessionContext> {
144+
let state = session
145+
.as_any()
146+
.downcast_ref::<SessionState>()
147+
.ok_or_else(|| {
148+
DataFusionError::Execution(
149+
"Cannot expose this UDTF's calling session to Python: the \
150+
session is not a SessionState. Drop the `session` keyword \
151+
from the callback signature to fall back to the \
152+
expression-only call form."
153+
.to_string(),
154+
)
155+
})?;
156+
Ok(PySessionContext::from(SessionContext::new_with_state(
157+
state.clone(),
158+
)))
159+
}
160+
110161
#[allow(clippy::result_large_err)]
111162
fn call_python_table_function(
112-
func: &Arc<Py<PyAny>>,
113-
args: &[Expr],
163+
func: &PythonTableFunctionCallable,
164+
args: TableFunctionArgs,
114165
) -> DataFusionResult<Arc<dyn TableProvider>> {
115-
let args = args
166+
let py_session = if func.inject_session_on_call {
167+
Some(py_session_from_session(args.session())?)
168+
} else {
169+
None
170+
};
171+
let py_exprs = args
172+
.exprs()
116173
.iter()
117174
.map(|arg| PyExpr::from(arg.clone()))
118175
.collect::<Vec<_>>();
119176

120-
// move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
121177
Python::attach(|py| {
122-
let py_args = PyTuple::new(py, args)?;
123-
let provider_obj = func.call1(py, py_args)?;
178+
let py_args = PyTuple::new(py, py_exprs)?;
179+
let provider_obj = if let Some(session) = py_session {
180+
let kwargs = PyDict::new(py);
181+
kwargs.set_item("session", session.into_pyobject(py)?)?;
182+
func.callable.call(py, py_args, Some(&kwargs))?
183+
} else {
184+
func.callable.call1(py, py_args)?
185+
};
124186
let provider = provider_obj.bind(py).clone();
125187

126188
Ok::<Arc<dyn TableProvider>, PyErr>(PyTable::new(provider, None)?.table)
@@ -132,8 +194,8 @@ impl TableFunctionImpl for PyTableFunction {
132194
fn call_with_args(&self, args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
133195
match &self.inner {
134196
PyTableFunctionInner::FFIFunction(func) => func.call_with_args(args),
135-
PyTableFunctionInner::PythonFunction(obj) => {
136-
call_python_table_function(obj, args.exprs())
197+
PyTableFunctionInner::PythonFunction(callable) => {
198+
call_python_table_function(callable, args)
137199
}
138200
}
139201
}

docs/source/user-guide/common-operations/udf-and-udfa.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,42 @@ that you wish to expose via PyO3, you need to expose it as a ``PyCapsule``.
431431
PyCapsule::new(py, provider, Some(name))
432432
}
433433
}
434+
435+
Accessing the Calling Session
436+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
437+
438+
Pure-Python UDTFs can opt into receiving the calling
439+
:py:class:`~datafusion.SessionContext` by registering with
440+
``with_session=True``. The context is passed as a ``session`` keyword
441+
argument on every invocation. Use it to look up registered tables,
442+
UDFs, or session configuration from inside the callback.
443+
444+
.. code-block:: python
445+
446+
from datafusion import SessionContext, Table, udtf
447+
from datafusion.context import TableProviderExportable
448+
import pyarrow as pa
449+
import pyarrow.dataset as ds
450+
451+
@udtf("list_tables", with_session=True)
452+
def list_tables(*, session: SessionContext) -> TableProviderExportable:
453+
names = sorted(session.catalog().schema().names())
454+
batch = pa.RecordBatch.from_pydict({"name": names})
455+
return Table(ds.dataset([batch]))
456+
457+
ctx = SessionContext()
458+
ctx.register_batch("t1", pa.RecordBatch.from_pydict({"x": [1]}))
459+
ctx.register_udtf(list_tables)
460+
ctx.sql("SELECT * FROM list_tables()").show()
461+
462+
Without ``with_session=True``, the callback receives only the positional
463+
expression arguments. The flag is opt-in so existing UDTFs keep working
464+
unchanged.
465+
466+
The injected ``session`` is a fresh :py:class:`~datafusion.SessionContext`
467+
wrapper backed by the same underlying state as the caller, so registries
468+
(tables, UDFs, catalogs) are visible. Registry mutations (e.g. registering
469+
a new table or UDF) propagate to the live session because the registries
470+
are reference-counted and shared. Configuration changes made through the
471+
wrapper (e.g. setting session options) do **not** propagate — the wrapper
472+
holds its own clone of the session config.

python/datafusion/user_defined.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,24 @@ def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
11021102
)
11031103

11041104

1105+
def _wrap_session_kwarg_for_udtf(func: Callable[..., Any]) -> Callable[..., Any]:
1106+
"""Adapt the raw internal session pyo3 object back to a Python wrapper.
1107+
1108+
The Rust call site forwards a ``datafusion._internal.SessionContext``,
1109+
but UDTF authors expect to interact with the public
1110+
:class:`datafusion.SessionContext` wrapper. This closure wraps the
1111+
internal object once per call before delegating to ``func``.
1112+
"""
1113+
1114+
@functools.wraps(func, updated=())
1115+
def adapter(*args: Any, session: Any, **kwargs: Any) -> Any:
1116+
wrapped = SessionContext.__new__(SessionContext)
1117+
wrapped.ctx = session
1118+
return func(*args, session=wrapped, **kwargs)
1119+
1120+
return adapter
1121+
1122+
11051123
class TableFunction:
11061124
"""Class for performing user-defined table functions (UDTF).
11071125
@@ -1110,14 +1128,44 @@ class TableFunction:
11101128
"""
11111129

11121130
def __init__(
1113-
self, name: str, func: Callable[[], any], ctx: SessionContext | None = None
1131+
self,
1132+
name: str,
1133+
func: Callable[..., Any],
1134+
ctx: SessionContext | None = None,
1135+
*,
1136+
with_session: bool = False,
11141137
) -> None:
11151138
"""Instantiate a user-defined table function (UDTF).
11161139
1140+
Set ``with_session=True`` to have the calling
1141+
:class:`SessionContext` passed as a ``session`` keyword argument
1142+
on each invocation. Use it inside the callback to look up
1143+
registered tables, UDFs, or session configuration. When
1144+
``with_session`` is ``False`` (the default), ``func`` is invoked
1145+
with the positional expression arguments only.
1146+
1147+
``with_session=True`` is only supported for pure-Python callables.
1148+
Passing it together with an FFI-exported table function (one
1149+
exposing ``__datafusion_table_function__``) raises
1150+
:class:`TypeError`.
1151+
1152+
Registry mutations performed through the injected session (such
1153+
as registering tables or UDFs) propagate to the caller's
1154+
:class:`SessionContext` because the registries are shared.
1155+
Configuration changes do **not** propagate; the wrapper holds
1156+
its own clone of the session config.
1157+
11171158
See :py:func:`udtf` for a convenience function and argument
11181159
descriptions.
11191160
"""
1120-
self._udtf = df_internal.TableFunction(name, func, ctx)
1161+
if with_session and hasattr(func, "__datafusion_table_function__"):
1162+
msg = (
1163+
"`with_session=True` is not supported for FFI-exported table "
1164+
"functions; session injection requires a pure-Python callable."
1165+
)
1166+
raise TypeError(msg)
1167+
registered = _wrap_session_kwarg_for_udtf(func) if with_session else func
1168+
self._udtf = df_internal.TableFunction(name, registered, ctx, with_session)
11211169

11221170
def __call__(self, *args: Expr) -> Any:
11231171
"""Execute the UDTF and return a table provider."""
@@ -1128,47 +1176,73 @@ def __call__(self, *args: Expr) -> Any:
11281176
@staticmethod
11291177
def udtf(
11301178
name: str,
1179+
*,
1180+
with_session: bool = False,
11311181
) -> Callable[..., Any]: ...
11321182

11331183
@overload
11341184
@staticmethod
11351185
def udtf(
1136-
func: Callable[[], Any],
1186+
func: Callable[..., Any],
11371187
name: str,
1188+
*,
1189+
with_session: bool = False,
11381190
) -> TableFunction: ...
11391191

11401192
@staticmethod
1141-
def udtf(*args: Any, **kwargs: Any):
1142-
"""Create a new User-Defined Table Function (UDTF)."""
1193+
def udtf(*args: Any, with_session: bool = False, **kwargs: Any):
1194+
"""Create a new User-Defined Table Function (UDTF).
1195+
1196+
Pass ``with_session=True`` to have the calling
1197+
:class:`SessionContext` injected as a ``session`` keyword
1198+
argument on each invocation.
1199+
"""
11431200
if args and callable(args[0]):
11441201
# Case 1: Used as a function, require the first parameter to be callable
1145-
return TableFunction._create_table_udf(*args, **kwargs)
1202+
return TableFunction._create_table_udf(
1203+
*args, with_session=with_session, **kwargs
1204+
)
11461205
if args and hasattr(args[0], "__datafusion_table_function__"):
11471206
# Case 2: We have a datafusion FFI provided function
1207+
if with_session:
1208+
msg = (
1209+
"`with_session=True` is not supported for FFI-exported "
1210+
"table functions; session injection requires a "
1211+
"pure-Python callable."
1212+
)
1213+
raise TypeError(msg)
11481214
return TableFunction(args[1], args[0])
11491215
# Case 3: Used as a decorator with parameters
1150-
return TableFunction._create_table_udf_decorator(*args, **kwargs)
1216+
return TableFunction._create_table_udf_decorator(
1217+
*args, with_session=with_session, **kwargs
1218+
)
11511219

11521220
@staticmethod
11531221
def _create_table_udf(
11541222
func: Callable[..., Any],
11551223
name: str,
1224+
*,
1225+
with_session: bool = False,
11561226
) -> TableFunction:
11571227
"""Create a TableFunction instance from function arguments."""
11581228
if not callable(func):
11591229
msg = "`func` must be callable."
11601230
raise TypeError(msg)
11611231

1162-
return TableFunction(name, func)
1232+
return TableFunction(name, func, with_session=with_session)
11631233

11641234
@staticmethod
11651235
def _create_table_udf_decorator(
11661236
name: str | None = None,
1167-
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
1168-
"""Create a decorator for a WindowUDF."""
1169-
1170-
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
1171-
return TableFunction._create_table_udf(func, name)
1237+
*,
1238+
with_session: bool = False,
1239+
) -> Callable[[Callable[..., Any]], TableFunction]:
1240+
"""Create a decorator for a TableFunction."""
1241+
1242+
def decorator(func: Callable[..., Any]) -> TableFunction:
1243+
return TableFunction._create_table_udf(
1244+
func, name, with_session=with_session
1245+
)
11721246

11731247
return decorator
11741248

0 commit comments

Comments
 (0)