diff --git a/docs/docs.json b/docs/docs.json
index 0198ecd..d369630 100644
--- a/docs/docs.json
+++ b/docs/docs.json
@@ -161,6 +161,7 @@
"geneva/udfs/scalar-udtfs",
"geneva/udfs/batch-udtfs",
"geneva/udfs/error_handling",
+ "geneva/udfs/profiling-memory",
"geneva/udfs/blobs"
]
},
diff --git a/docs/geneva/udfs/profiling-memory.mdx b/docs/geneva/udfs/profiling-memory.mdx
new file mode 100644
index 0000000..e63d1f8
--- /dev/null
+++ b/docs/geneva/udfs/profiling-memory.mdx
@@ -0,0 +1,233 @@
+---
+title: Profiling Stateful UDF Memory
+sidebarTitle: Profiling Memory
+description: Find memory leaks and runaway peak usage in stateful UDFs with memray, before they cause worker OOMs in production.
+icon: chart-line
+---
+
+import {
+ PyStatefulUdfClass,
+ PyMemrayTrackerUdf,
+ PyRayClusterProfile,
+ PyLogMemory,
+ PyLeakyCache,
+ PyBoundedCache,
+ PyLeakyAggregator,
+ PyLeakyClosure,
+ PyTorchInferenceMode,
+ PyConfidenceCheck,
+} from '/snippets/geneva_profiling_memory.mdx';
+
+Stateful UDFs are the most common source of worker memory pressure in Geneva. Unlike scalar UDFs, a stateful UDF instance lives for the **entire lifetime of a Ray actor**, processing many batches in sequence. Anything your `setup()` allocates is held for the duration of the job, and anything `__call__` retains accumulates batch after batch — sometimes silently, until a worker OOMs partway through a large backfill.
+
+This page shows you how to profile a stateful UDF locally with [memray](https://github.com/bloomberg/memray), what to look for, and the common patterns that leak.
+
+
+If your worker is being OOM-killed and you don't know why, **profile a single actor locally first**. A 5-minute memray run on your laptop is faster than another 45-minute distributed run that fails the same way.
+
+
+## Why stateful UDFs leak
+
+A stateful UDF in Geneva is a class:
+
+
+
+ {PyStatefulUdfClass}
+
+
+
+A few facts make memory behavior easy to get wrong:
+
+- **One instance per worker, many batches.** Geneva instantiates the class once per Ray actor. The same `self` processes every batch routed to that worker — potentially thousands.
+- **`setup()` runs once.** Whatever you allocate there stays in memory until the actor dies. That's intentional for things like ML models, but it's a footgun for "lazy" caches that grow.
+- **`self.` survives across calls.** Anything you attach to `self` inside `__call__` is retained for the rest of the actor's life.
+- **Workers don't restart between batches.** Unlike a serverless function, you don't get a fresh process per invocation. Memory accumulates linearly with batch count until the worker hits its memory cap.
+
+The result: a leak that looks tiny in unit tests (1 batch, 4 MiB) can blow up an 8-hour backfill (10 000 batches, 40 GiB).
+
+## When to profile
+
+Profile your UDF if **any** of these are true:
+
+- The UDF loads a model, builds an index, or otherwise allocates more than ~100 MiB in `setup()`.
+- The UDF maintains a cache, deduplication table, or running statistic in `self`.
+- A worker is being OOM-killed during backfill (look for `FatalWorkerOOMError`, see [Job troubleshooting](/geneva/jobs/troubleshooting)).
+- Worker RSS grows steadily during a backfill rather than staying flat after `setup()`.
+
+You do **not** need to profile pure functional UDFs (no `self` state) or UDFs that only ever read from `self` — those can't leak by construction.
+
+## Profiling a UDF with memray
+
+memray ships as a `dev` dependency in Geneva, so it's already in your environment if you installed with `uv sync`.
+
+The trick to profiling under Ray is that **workers run in separate processes**. Wrapping `pytest` with `memray run` only sees the driver, not the actors that actually run your UDF. The cleanest pattern is to have **the UDF instrument itself**, controlled by an environment variable that's only set when you want a profile.
+
+### Step 1 — add an opt-in tracker to your UDF
+
+
+
+ {PyMemrayTrackerUdf}
+
+
+
+The tracker is a no-op when `MY_UDF_MEMRAY_OUT_DIR` isn't set, so leaving this code in your UDF is safe for production runs.
+
+### Step 2 — propagate the env var to Ray workers
+
+Ray workers don't inherit driver environment variables by default. When you start a local Ray cluster from Geneva, pass `extra_env` so the variable reaches each worker:
+
+
+
+ {PyRayClusterProfile}
+
+
+
+
+Set `concurrency=1` while profiling. One actor processing all batches sequentially produces a single clean trace; the default of 8 produces 8 noisier traces that you'd have to merge mentally.
+
+
+### Step 3 — read the trace
+
+When backfill finishes, you'll have one (or more) `memray--.bin` files under `/tmp/my-udf-profile/`. Render and inspect them:
+
+```bash
+# Quick summary — peak heap, total allocations, what's leaked
+uv run -m memray summary /tmp/my-udf-profile/memray-*.bin
+
+# Interactive flamegraph in your browser
+uv run -m memray flamegraph /tmp/my-udf-profile/memray-*.bin
+open /tmp/my-udf-profile/memray-*.html
+
+# Top allocators by retained bytes
+uv run -m memray tree /tmp/my-udf-profile/memray-*.bin
+```
+
+## What the numbers mean
+
+memray reports two values you'll care about most:
+
+- **Peak heap** (`metadata.peak_memory`) — the high-water mark. This is what triggers OOMs. A peak well above your `setup()` allocations means a batch transiently doubles memory before freeing.
+- **Leaked allocations** (`get_leaked_allocation_records()`) — what was still allocated when the tracker ended. **This is not necessarily a bug** — your `setup()` model is "leaked" in this sense because it lives the actor's lifetime. The signal is *how much above expected baseline* is leaked.
+
+A healthy stateful UDF profile, after processing many batches, looks roughly like:
+
+```
+peak heap ≈ setup() allocations + 1 batch of working memory
+leaked ≈ setup() allocations (i.e. nothing extra retained from __call__)
+```
+
+An **unhealthy** profile looks like:
+
+```
+peak heap ≈ setup() + N × per-call allocation ← grows with batch count
+leaked ≈ setup() + N × per-call allocation ← per-call state never freed
+```
+
+The flamegraph will show a thick stack frame anchored in `__call__` rising as you scroll through time — that's the leak.
+
+## Going deeper: RSS vs Arrow allocations
+
+memray gives you the *Python-side* allocation story. For real diagnosis of "where is the worker's memory actually going?", it pays to watch **process RSS** and **Arrow's own allocator** side-by-side — together they tell you which subsystem owns the bytes, often faster than reading a flamegraph.
+
+Drop this into your UDF (or anywhere on the worker) to log a snapshot:
+
+
+
+ {PyLogMemory}
+
+
+
+The three numbers answer different questions:
+
+- **`rss_mb`** — every byte the OS has handed this Python interpreter. Includes Python heap, Arrow, native libraries, and pages the C allocator (`glibc`/`jemalloc`) is holding even though Python freed them. This is what triggers cgroup OOM-kills.
+- **`arrow_live_mb`** — bytes currently held by *live PyArrow buffers* (`RecordBatch`, `Array`, `ChunkedArray`, etc.). Goes up when you create Arrow data, down when those references are dropped.
+- **`gap_mb` = rss − arrow_live** — "everything else." This is the Python heap (your own `self.cache`, model weights, dicts, lists), native libraries (PyTorch, ONNX), and allocator retention.
+
+### Diagnostic patterns
+
+Log the breakdown every few batches and the *shape of growth over time* tells you which subsystem to fix:
+
+| Pattern | Diagnosis | First thing to try |
+|---|---|---|
+| `rss` climbs slowly, `arrow_live` flat near zero, big growing `gap` | Allocator retention — Python freed it but `glibc` is keeping the pages | `ctypes.CDLL("libc.so.6").malloc_trim(0)` periodically (Linux only); or set `MALLOC_TRIM_THRESHOLD_=131072` |
+| `rss` climbs, `arrow_live` climbs in lockstep | Real Arrow leak — your code is holding `RecordBatch` / `Array` references | Find where you're appending batches to `self`, or where checkpoint / error payloads aren't being released |
+| `rss` spikes hugely on a few calls then settles, eventually one spike OOMs | Peak is too big, not a leak — a single call allocates more than the worker has | Shrink `batch_size`, `blob_read_buffer_size`, or split the work |
+| `rss` flat for hours then sudden cliff upward | One pathological row — usually one huge blob (a 4K-resolution image, a 50 MB PDF) | Find the offending row by ID; add a size check at the top of `__call__` |
+| `rss` rises during `setup()`, then flat for the whole run, `gap` constant | Healthy — that's your model loaded once per actor | Nothing to do |
+
+
+The reference UDFs in Geneva's own integration test (`src/stress_tests/_memray_probe.py`) print exactly this breakdown every 32 calls. The workflow's stdout logs are a working example of the "clean" and "leaky" patterns — the leaky one shows the **lockstep with Python heap** signature (the second row above, but with `gap` climbing instead of `arrow_live` — because the leak is `bytearray`, not Arrow).
+
+
+## Common leak patterns
+
+### 1. The growing cache
+
+
+
+ {PyLeakyCache}
+
+
+
+Looks harmless. Fine on a unit test with 10 inputs. **Catastrophic on a backfill of 10M rows**, where most inputs are unique and the cache grows to fill the worker.
+
+**Fix:** Use a bounded cache (`functools.lru_cache` with `maxsize`, or a manual size cap), or skip caching when you don't know the cardinality.
+
+
+
+ {PyBoundedCache}
+
+
+
+### 2. Accumulating per-call buffers
+
+
+
+ {PyLeakyAggregator}
+
+
+
+**Fix:** Don't hold references to inputs past the return of `__call__`. If you need rolling state, summarize into a small aggregate (counts, sums) instead of holding the raw batches.
+
+### 3. Closures capturing batch arrays
+
+
+
+ {PyLeakyClosure}
+
+
+
+**Fix:** Extract only the small values you actually need into the closure, or execute the work eagerly.
+
+### 4. ML model state that grows
+
+Some ML libraries retain per-call state internally (KV caches, gradient buffers, autograd graphs). If you're using PyTorch:
+
+
+
+ {PyTorchInferenceMode}
+
+
+
+For Hugging Face pipelines, ensure you're in `eval()` mode and not accumulating gradients. For long-running stateful UDFs on GPUs, also see `torch.cuda.empty_cache()` between large batches.
+
+## A confidence check
+
+A useful "does my profiling actually work?" sanity check: temporarily introduce a deliberate leak and confirm memray catches it.
+
+
+
+ {PyConfidenceCheck}
+
+
+
+If `memray summary` doesn't show leaked bytes growing roughly with batch count after this change, your tracker isn't actually attached (most often: the env var isn't reaching workers — re-check `extra_env`).
+
+Geneva's own test suite ships a reference implementation of this pattern in `src/stress_tests/_memray_probe.py` and `src/stress_tests/test_memray_stateful_udf.py`, plus a GitHub Actions workflow (`memray-stateful-udf-profile.yml`) that uploads the per-actor `.bin` and rendered flamegraph as a CI artifact. Feel free to copy that scaffolding for your own project's UDFs.
+
+## Related
+
+- [UDFs](/geneva/udfs/udfs) — defining stateful UDFs
+- [Job troubleshooting](/geneva/jobs/troubleshooting) — diagnosing OOMs and other worker errors
+- [Advanced configuration](/geneva/udfs/advanced-configuration) — admission control and resource limits
+- [memray documentation](https://bloomberg.github.io/memray/) — flamegraph, summary, and tree report formats
diff --git a/docs/snippets/geneva_profiling_memory.mdx b/docs/snippets/geneva_profiling_memory.mdx
new file mode 100644
index 0000000..f6addc4
--- /dev/null
+++ b/docs/snippets/geneva_profiling_memory.mdx
@@ -0,0 +1,22 @@
+{/* Auto-generated by scripts/mdx_snippets_gen.py. Do not edit manually. */}
+
+export const PyBoundedCache = "from functools import lru_cache\n\nclass GoodEmbedding:\n def __init__(self):\n self._embed = None\n\n def setup(self):\n model = load_model()\n self._embed = lru_cache(maxsize=1024)(model.embed)\n\n def __call__(self, text: str) -> list[float]:\n if self._embed is None:\n self.setup()\n return self._embed(text)\n";
+
+export const PyConfidenceCheck = "def __call__(self, x):\n scratch = bytearray(8 * 1024 * 1024) # 8 MiB\n self._scratches.append(scratch) # <-- deliberate leak\n return ...\n";
+
+export const PyLeakyAggregator = "class BadAggregator:\n def __init__(self):\n self.history = []\n\n def __call__(self, batch: pa.RecordBatch) -> pa.Array:\n self.history.append(batch) # holds every batch ever processed\n ...\n";
+
+export const PyLeakyCache = "class BadEmbedding:\n def __init__(self):\n self.cache: dict[str, list[float]] = {}\n\n def __call__(self, text: str) -> list[float]:\n if text not in self.cache:\n self.cache[text] = self.model.embed(text)\n return self.cache[text]\n";
+
+export const PyLeakyClosure = "class BadDeferred:\n def __init__(self):\n self.work_queue = []\n\n def __call__(self, x: pa.Array) -> pa.Array:\n # Lambda captures `x` by reference — the whole Array stays alive\n self.work_queue.append(lambda: expensive(x))\n ...\n";
+
+export const PyLogMemory = "import resource, pyarrow as pa\n\ndef log_memory(seq: int) -> None:\n rss_bytes = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss\n # ru_maxrss is bytes on macOS, KiB on Linux:\n import sys\n if sys.platform != \"darwin\":\n rss_bytes *= 1024\n arrow_live = pa.total_allocated_bytes()\n print(\n f\"seq={seq} \"\n f\"rss_mb={rss_bytes // 1024**2} \"\n f\"arrow_live_mb={arrow_live // 1024**2} \"\n f\"gap_mb={(rss_bytes - arrow_live) // 1024**2}\",\n flush=True,\n )\n";
+
+export const PyMemrayTrackerUdf = "import os, pathlib, uuid\nfrom typing import Any\nimport memray\nimport geneva\nimport pyarrow as pa\n\n_MEMRAY_OUT_DIR_ENV = \"MY_UDF_MEMRAY_OUT_DIR\"\n\n\n@geneva.udf(data_type=pa.list_(pa.float32(), 512))\nclass MyEmbedding:\n def __init__(self):\n self.model = None\n self._tracker: Any = None # memray.Tracker, when profiling is on\n\n def setup(self):\n # Open a memray tracker per worker process, if requested. Each\n # worker writes its own .bin file so traces don't collide.\n out_dir = os.environ.get(_MEMRAY_OUT_DIR_ENV)\n if out_dir:\n pathlib.Path(out_dir).mkdir(parents=True, exist_ok=True)\n bin_path = pathlib.Path(out_dir) / (\n f\"memray-{os.getpid()}-{uuid.uuid4().hex}.bin\"\n )\n self._tracker = memray.Tracker(\n str(bin_path), native_traces=False, follow_fork=False\n )\n self._tracker.__enter__()\n self.model = load_model()\n\n def __call__(self, text: str) -> list[float]:\n if self.model is None:\n self.setup()\n return self.model.embed(text)\n";
+
+export const PyRayClusterProfile = "from geneva.runners.ray._mgr import ray_cluster\n\nwith ray_cluster(\n local=True,\n extra_env={\"MY_UDF_MEMRAY_OUT_DIR\": \"/tmp/my-udf-profile\"},\n):\n table.backfill(\"embedding\", concurrency=1)\n";
+
+export const PyStatefulUdfClass = "@geneva.udf(data_type=pa.list_(pa.float32(), 512))\nclass MyEmbedding:\n def __init__(self):\n self.model = None\n\n def setup(self):\n self.model = load_model() # allocated once per actor\n\n def __call__(self, text: str) -> list[float]:\n if self.model is None:\n self.setup()\n return self.model.embed(text)\n";
+
+export const PyTorchInferenceMode = "def __call__(self, text: str) -> list[float]:\n with torch.inference_mode(): # <-- prevents autograd graph retention\n return self.model.encode(text)\n";
+
diff --git a/docs/snippets/search.mdx b/docs/snippets/search.mdx
index 591ef0b..2c61812 100644
--- a/docs/snippets/search.mdx
+++ b/docs/snippets/search.mdx
@@ -8,10 +8,10 @@ export const PyBasicHybridSearch = "data = [\n {\"text\": \"rebel spaceships
export const PyBasicHybridSearchAsync = "uri = \"data/sample-lancedb\"\nasync_db = await lancedb.connect_async(uri)\ndata = [\n {\"text\": \"rebel spaceships striking from a hidden base\"},\n {\"text\": \"have won their first victory against the evil Galactic Empire\"},\n {\"text\": \"during the battle rebel spies managed to steal secret plans\"},\n {\"text\": \"to the Empire's ultimate weapon the Death Star\"},\n]\nasync_tbl = await async_db.create_table(\"documents_async\", schema=Documents)\n# ingest docs with auto-vectorization\nawait async_tbl.add(data)\n# Create a fts index before the hybrid search\nawait async_tbl.create_index(\"text\", config=FTS())\ntext_query = \"flower moon\"\n# hybrid search with default re-ranker\nawait (await async_tbl.search(\"flower moon\", query_type=\"hybrid\")).to_pandas()\n";
-export const PyClassDefinition = "class Metadata(BaseModel):\n source: str\n timestamp: datetime\n\n\nclass Document(BaseModel):\n content: str\n meta: Metadata\n\n\nclass LanceSchema(LanceModel):\n id: str\n vector: Vector(1536)\n payload: Document\n";
-
export const PyClassDocuments = "class Documents(LanceModel):\n vector: Vector(embeddings.ndims()) = embeddings.VectorField()\n text: str = embeddings.SourceField()\n";
+export const PyClassDefinition = "class Metadata(BaseModel):\n source: str\n timestamp: datetime\n\n\nclass Document(BaseModel):\n content: str\n meta: Metadata\n\n\nclass LanceSchema(LanceModel):\n id: str\n vector: Vector(1536)\n payload: Document\n";
+
export const PyCreateTableAsyncWithNestedSchema = "# Let's add 100 sample rows to our dataset\ndata = [\n LanceSchema(\n id=f\"id{i}\",\n vector=np.random.randn(1536),\n payload=Document(\n content=f\"document{i}\",\n meta=Metadata(source=f\"source{i % 10}\", timestamp=datetime.now()),\n ),\n )\n for i in range(100)\n]\n\nasync_tbl = await async_db.create_table(\n \"documents_async\", data=data, mode=\"overwrite\"\n)\n";
export const PyCreateTableWithNestedSchema = "# Let's add 100 sample rows to our dataset\ndata = [\n LanceSchema(\n id=f\"id{i}\",\n vector=np.random.randn(1536),\n payload=Document(\n content=f\"document{i}\",\n meta=Metadata(source=f\"source{i % 10}\", timestamp=datetime.now()),\n ),\n )\n for i in range(100)\n]\n\n# Synchronous client\ntbl = db.create_table(\"documents\", data=data, mode=\"overwrite\")\n";
diff --git a/tests/py/test_geneva_profiling_memory.py b/tests/py/test_geneva_profiling_memory.py
new file mode 100644
index 0000000..7ca45db
--- /dev/null
+++ b/tests/py/test_geneva_profiling_memory.py
@@ -0,0 +1,248 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright The LanceDB Authors
+
+"""Snippets for docs/geneva/udfs/profiling-memory.mdx."""
+
+from unittest.mock import MagicMock
+
+
+def test_stateful_udf_class():
+ import geneva
+ import pyarrow as pa
+
+ load_model = MagicMock(return_value=MagicMock(embed=MagicMock(return_value=[0.1] * 512)))
+
+ # --8<-- [start:stateful_udf_class]
+ @geneva.udf(data_type=pa.list_(pa.float32(), 512))
+ class MyEmbedding:
+ def __init__(self):
+ self.model = None
+
+ def setup(self):
+ self.model = load_model() # allocated once per actor
+
+ def __call__(self, text: str) -> list[float]:
+ if self.model is None:
+ self.setup()
+ return self.model.embed(text)
+ # --8<-- [end:stateful_udf_class]
+
+ assert MyEmbedding is not None
+
+
+def test_memray_tracker_udf(monkeypatch):
+ import sys
+
+ load_model = MagicMock(return_value=MagicMock(embed=MagicMock(return_value=[0.1] * 512)))
+ monkeypatch.setitem(sys.modules, "memray", MagicMock())
+
+ # --8<-- [start:memray_tracker_udf]
+ import os, pathlib, uuid
+ from typing import Any
+ import memray
+ import geneva
+ import pyarrow as pa
+
+ _MEMRAY_OUT_DIR_ENV = "MY_UDF_MEMRAY_OUT_DIR"
+
+
+ @geneva.udf(data_type=pa.list_(pa.float32(), 512))
+ class MyEmbedding:
+ def __init__(self):
+ self.model = None
+ self._tracker: Any = None # memray.Tracker, when profiling is on
+
+ def setup(self):
+ # Open a memray tracker per worker process, if requested. Each
+ # worker writes its own .bin file so traces don't collide.
+ out_dir = os.environ.get(_MEMRAY_OUT_DIR_ENV)
+ if out_dir:
+ pathlib.Path(out_dir).mkdir(parents=True, exist_ok=True)
+ bin_path = pathlib.Path(out_dir) / (
+ f"memray-{os.getpid()}-{uuid.uuid4().hex}.bin"
+ )
+ self._tracker = memray.Tracker(
+ str(bin_path), native_traces=False, follow_fork=False
+ )
+ self._tracker.__enter__()
+ self.model = load_model()
+
+ def __call__(self, text: str) -> list[float]:
+ if self.model is None:
+ self.setup()
+ return self.model.embed(text)
+ # --8<-- [end:memray_tracker_udf]
+
+ assert MyEmbedding is not None
+
+
+def test_ray_cluster_profile(monkeypatch):
+ from contextlib import contextmanager
+
+ table = MagicMock()
+
+ @contextmanager
+ def _mock_cluster(*args, **kwargs):
+ yield
+
+ monkeypatch.setattr("geneva.runners.ray._mgr.ray_cluster", _mock_cluster)
+
+ # --8<-- [start:ray_cluster_profile]
+ from geneva.runners.ray._mgr import ray_cluster
+
+ with ray_cluster(
+ local=True,
+ extra_env={"MY_UDF_MEMRAY_OUT_DIR": "/tmp/my-udf-profile"},
+ ):
+ table.backfill("embedding", concurrency=1)
+ # --8<-- [end:ray_cluster_profile]
+
+ table.backfill.assert_called_once_with("embedding", concurrency=1)
+
+
+def test_log_memory(capsys):
+ # --8<-- [start:log_memory]
+ import resource, pyarrow as pa
+
+ def log_memory(seq: int) -> None:
+ rss_bytes = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+ # ru_maxrss is bytes on macOS, KiB on Linux:
+ import sys
+ if sys.platform != "darwin":
+ rss_bytes *= 1024
+ arrow_live = pa.total_allocated_bytes()
+ print(
+ f"seq={seq} "
+ f"rss_mb={rss_bytes // 1024**2} "
+ f"arrow_live_mb={arrow_live // 1024**2} "
+ f"gap_mb={(rss_bytes - arrow_live) // 1024**2}",
+ flush=True,
+ )
+ # --8<-- [end:log_memory]
+
+ log_memory(0)
+ captured = capsys.readouterr()
+ assert "seq=0" in captured.out
+ assert "rss_mb=" in captured.out
+
+
+def test_leaky_cache():
+ # --8<-- [start:leaky_cache]
+ class BadEmbedding:
+ def __init__(self):
+ self.cache: dict[str, list[float]] = {}
+
+ def __call__(self, text: str) -> list[float]:
+ if text not in self.cache:
+ self.cache[text] = self.model.embed(text)
+ return self.cache[text]
+ # --8<-- [end:leaky_cache]
+
+ obj = BadEmbedding()
+ obj.model = MagicMock(embed=MagicMock(return_value=[0.1, 0.2]))
+ obj("hello")
+ obj("hello")
+ assert obj.model.embed.call_count == 1 # second call served from cache
+ assert len(obj.cache) == 1
+
+
+def test_bounded_cache():
+ load_model = MagicMock(return_value=MagicMock(embed=MagicMock(return_value=[0.1, 0.2])))
+
+ # --8<-- [start:bounded_cache]
+ from functools import lru_cache
+
+ class GoodEmbedding:
+ def __init__(self):
+ self._embed = None
+
+ def setup(self):
+ model = load_model()
+ self._embed = lru_cache(maxsize=1024)(model.embed)
+
+ def __call__(self, text: str) -> list[float]:
+ if self._embed is None:
+ self.setup()
+ return self._embed(text)
+ # --8<-- [end:bounded_cache]
+
+ obj = GoodEmbedding()
+ assert obj("hello") == [0.1, 0.2]
+ load_model.assert_called_once()
+
+
+def test_leaky_aggregator():
+ import pyarrow as pa
+
+ # --8<-- [start:leaky_aggregator]
+ class BadAggregator:
+ def __init__(self):
+ self.history = []
+
+ def __call__(self, batch: pa.RecordBatch) -> pa.Array:
+ self.history.append(batch) # holds every batch ever processed
+ ...
+ # --8<-- [end:leaky_aggregator]
+
+ obj = BadAggregator()
+ assert isinstance(obj.history, list)
+
+
+def test_leaky_closure():
+ import pyarrow as pa
+
+ def expensive(x):
+ return x
+
+ # --8<-- [start:leaky_closure]
+ class BadDeferred:
+ def __init__(self):
+ self.work_queue = []
+
+ def __call__(self, x: pa.Array) -> pa.Array:
+ # Lambda captures `x` by reference — the whole Array stays alive
+ self.work_queue.append(lambda: expensive(x))
+ ...
+ # --8<-- [end:leaky_closure]
+
+ obj = BadDeferred()
+ assert isinstance(obj.work_queue, list)
+
+
+def test_torch_inference_mode(monkeypatch):
+ import sys
+
+ mock_torch = MagicMock()
+ monkeypatch.setitem(sys.modules, "torch", mock_torch)
+
+ class TorchUDF:
+ def __init__(self):
+ self.model = MagicMock(encode=MagicMock(return_value=[0.1]))
+
+ # --8<-- [start:torch_inference_mode]
+ def __call__(self, text: str) -> list[float]:
+ with torch.inference_mode(): # <-- prevents autograd graph retention
+ return self.model.encode(text)
+ # --8<-- [end:torch_inference_mode]
+
+ import torch # resolves to monkeypatched mock above
+ obj = TorchUDF()
+ assert obj("hello") == [0.1]
+
+
+def test_confidence_check():
+ class DebuggingUDF:
+ def __init__(self):
+ self._scratches = []
+
+ # --8<-- [start:confidence_check]
+ def __call__(self, x):
+ scratch = bytearray(8 * 1024 * 1024) # 8 MiB
+ self._scratches.append(scratch) # <-- deliberate leak
+ return ...
+ # --8<-- [end:confidence_check]
+
+ obj = DebuggingUDF()
+ obj(None)
+ assert len(obj._scratches) == 1
+ assert len(obj._scratches[0]) == 8 * 1024 * 1024