diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml new file mode 100644 index 00000000..a02c3da8 --- /dev/null +++ b/.github/workflows/build-wheels.yml @@ -0,0 +1,188 @@ +name: Build AppKit Python Wheels + +on: + push: + branches: + - main + paths: + - 'packages/appkit-rs/**' + pull_request: + branches: + - main + paths: + - 'packages/appkit-rs/**' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + id-token: write + +jobs: + build-wheels: + name: Build wheels (${{ matrix.target }}) + runs-on: + group: databricks-protected-runner-group + labels: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + include: + - target: x86_64-unknown-linux-gnu + runner: linux-ubuntu-latest + manylinux: "2_28" + - target: aarch64-unknown-linux-gnu + runner: linux-ubuntu-latest + manylinux: "2_28" + - target: x86_64-apple-darwin + runner: macos-latest + - target: aarch64-apple-darwin + runner: macos-latest + + steps: + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + + - name: Setup Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: "3.11" + + - name: Build wheels + uses: PyO3/maturin-action@aef21716882fbb364b26db30fbbcbbbb533f60cc # v1.48.1 + with: + target: ${{ matrix.target }} + manylinux: ${{ matrix.manylinux || 'auto' }} + args: --release --out dist --manifest-path packages/appkit-rs/Cargo.toml + sccache: true + + - name: Upload wheels + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: wheels-${{ matrix.target }} + retention-days: 7 + path: dist/*.whl + + smoke-test: + name: Smoke test + needs: build-wheels + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + + steps: + - name: Download x86_64 Linux wheel + uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # v4.2.1 + with: + name: wheels-x86_64-unknown-linux-gnu + path: dist + + - name: Setup Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: "3.11" + + - name: Install wheel and smoke-test import + run: | + pip install dist/*.whl + python -c "import appkit; print('appkit version:', appkit.__version__ if hasattr(appkit, '__version__') else 'ok')" + + rust-tests: + name: Rust tests + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cargo test + working-directory: packages/appkit-rs + run: cargo test + + python-tests: + name: Python tests + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + + - name: Setup Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: "3.11" + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install maturin and build + working-directory: packages/appkit-rs + run: | + pip install maturin + maturin develop + + - name: Install test dependencies and run pytest + working-directory: packages/appkit-rs + run: | + pip install pytest pytest-asyncio + pytest + + checksums: + name: Generate SHA256 + needs: build-wheels + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + + steps: + - name: Download all wheels + uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # v4.2.1 + with: + pattern: wheels-* + merge-multiple: true + path: dist + + - name: Generate SHA256 digests + run: | + cd dist + sha256sum *.whl > SHA256SUMS + cat SHA256SUMS + + - name: Upload checksums + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: appkit-py-checksums-${{ github.run_number }} + retention-days: 7 + path: dist/SHA256SUMS + + publish: + name: Publish to PyPI + needs: [build-wheels, smoke-test, rust-tests, python-tests, checksums] + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/appkit-py-v') + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + environment: pypi + + steps: + - name: Download all wheels + uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # v4.2.1 + with: + pattern: wheels-* + merge-multiple: true + path: dist + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@76f52bc884231f62b54f72e44af3222aff3ef9f6 # v1.12.4 + with: + packages-dir: dist/ diff --git a/.gitignore b/.gitignore index 5d417368..67f0b2bc 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,16 @@ coverage .turbo .databricks + +# Rust / maturin +target/ +.cargo/ +*.so +*.pyd +*.dylib + +# Python / uv +__pycache__/ +*.egg-info/ +.venv/ +*.pyc diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..753c31bd --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,3 @@ +[workspace] +resolver = "2" +members = ["packages/appkit-rs"] diff --git a/docs/docs/api/appkit-python/Class.AnalyticsPlugin.md b/docs/docs/api/appkit-python/Class.AnalyticsPlugin.md new file mode 100644 index 00000000..515f04f0 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.AnalyticsPlugin.md @@ -0,0 +1,49 @@ +--- +sidebar_label: AnalyticsPlugin +sidebar_position: 1 +--- + + + +# Class: AnalyticsPlugin + +SQL query execution plugin. + +Queries live on disk under ``queries_dir`` (default ``config/queries``) +and are referenced by key in the route path. + +_Defined in `appkit/plugins/analytics.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: AnalyticsPluginConfig) -> None +``` + +### `queries_dir` (property) + +```python +def queries_dir(self) -> Path +``` + +### `warehouse_id` (property) + +```python +def warehouse_id(self) -> str +``` + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, router: Any) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.AnalyticsPluginConfig.md b/docs/docs/api/appkit-python/Class.AnalyticsPluginConfig.md new file mode 100644 index 00000000..37b79997 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.AnalyticsPluginConfig.md @@ -0,0 +1,24 @@ +--- +sidebar_label: AnalyticsPluginConfig +sidebar_position: 2 +--- + + + +# Class: AnalyticsPluginConfig + +Configuration for :class:`AnalyticsPlugin`. + +``warehouse_id`` routes queries to a Databricks SQL warehouse. +``queries_dir`` overrides the default ``config/queries`` path. +``host`` defaults to the ``DATABRICKS_HOST`` environment variable. + +_Defined in `appkit/plugins/analytics.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, warehouse_id: str | None = None, queries_dir: str | os.PathLike[str] | None = None, host: str | None = None, timeout_ms: int | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.AppConfig.md b/docs/docs/api/appkit-python/Class.AppConfig.md new file mode 100644 index 00000000..6e43704b --- /dev/null +++ b/docs/docs/api/appkit-python/Class.AppConfig.md @@ -0,0 +1,38 @@ +--- +sidebar_label: AppConfig +sidebar_position: 3 +--- + + + +# Class: AppConfig + +Application configuration parsed from environment variables. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `databricks_host` | `str` | +| `client_id` | `Optional[str]` | +| `client_secret` | `Optional[str]` | +| `warehouse_id` | `Optional[str]` | +| `app_port` | `int` | +| `host` | `str` | +| `otel_endpoint` | `Optional[str]` | + +## Methods + +### `__init__` + +```python +def __init__( self, databricks_host: str, *, client_id: Optional[str] = None, client_secret: Optional[str] = None, warehouse_id: Optional[str] = None, app_port: int = 8000, host: str = "0.0.0.0", otel_endpoint: Optional[str] = None, ) -> None +``` + +### `from_env` (staticmethod) + +```python +def from_env() -> AppConfig +``` diff --git a/docs/docs/api/appkit-python/Class.AppKit.md b/docs/docs/api/appkit-python/Class.AppKit.md new file mode 100644 index 00000000..93fa65a6 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.AppKit.md @@ -0,0 +1,78 @@ +--- +sidebar_label: AppKit +sidebar_position: 4 +--- + + + +# Class: AppKit + +AppKit orchestrator — registers plugins and manages initialization. + +Most applications should use :func:`create_app` instead of driving +this class directly; ``create_app`` wires registration, initialization +and optional server startup in one call. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self) -> None +``` + +### `register` + +```python +def register(self, plugin: Plugin) -> None +``` + +Register a plugin. Must be called before :meth:`initialize`. + +### `initialize` + +```python +def initialize( self, config: AppConfig, *, cache_config: Optional[CacheConfig] = None, ) -> Awaitable[None] +``` + +Initialize telemetry, cache, and run phase-ordered ``Plugin.setup``. + +After this returns, plugins are ready to serve requests. Calling +``initialize`` twice is an error. + +### `get_plugin` + +```python +def get_plugin(self, name: str) -> Optional[Plugin] +``` + +Look up a registered plugin by its manifest name. + +### `plugin_names` + +```python +def plugin_names(self) -> list[str] +``` + +Return the manifest names of all registered plugins. + +### `start_server` + +```python +def start_server(self, server_config: ServerConfig) -> Awaitable[None] +``` + +Start the HTTP server and block until it exits. + +Routes previously injected via ``Plugin.inject_routes`` are mounted +under ``/api//...``. + +### `shutdown` + +```python +def shutdown(self) -> None +``` + +Stop the HTTP server and release resources. diff --git a/docs/docs/api/appkit-python/Class.CacheConfig.md b/docs/docs/api/appkit-python/Class.CacheConfig.md new file mode 100644 index 00000000..806682ee --- /dev/null +++ b/docs/docs/api/appkit-python/Class.CacheConfig.md @@ -0,0 +1,29 @@ +--- +sidebar_label: CacheConfig +sidebar_position: 5 +--- + + + +# Class: CacheConfig + +Cache configuration with defaults matching TypeScript cacheDefaults. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `enabled` | `bool` | +| `ttl` | `int` | +| `max_size` | `int` | +| `cleanup_probability` | `float` | + +## Methods + +### `__init__` + +```python +def __init__( self, *, enabled: bool = True, ttl: int = 3600, max_size: int = 1000, cleanup_probability: float = 0.01, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.CacheManager.md b/docs/docs/api/appkit-python/Class.CacheManager.md new file mode 100644 index 00000000..4750fb01 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.CacheManager.md @@ -0,0 +1,68 @@ +--- +sidebar_label: CacheManager +sidebar_position: 6 +--- + + + +# Class: CacheManager + +Cache manager with TTL, LRU eviction, and in-flight deduplication. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self, config: Optional[CacheConfig] = None) -> None +``` + +### `generate_key` (staticmethod) + +```python +def generate_key(parts: list[str], user_key: str) -> str +``` + +### `get` + +```python +def get(self, key: str) -> Awaitable[Optional[str]] +``` + +### `set` + +```python +def set(self, key: str, value: str, *, ttl: Optional[int] = None) -> Awaitable[None] +``` + +### `delete` + +```python +def delete(self, key: str) -> Awaitable[None] +``` + +### `has` + +```python +def has(self, key: str) -> Awaitable[bool] +``` + +### `clear` + +```python +def clear(self) -> Awaitable[None] +``` + +### `size` + +```python +def size(self) -> Awaitable[int] +``` + +### `get_or_execute` + +```python +def get_or_execute( self, key: str, func: Callable[[], Awaitable[str]], *, ttl: Optional[int] = None, ) -> Awaitable[str] +``` diff --git a/docs/docs/api/appkit-python/Class.DatabaseCredential.md b/docs/docs/api/appkit-python/Class.DatabaseCredential.md new file mode 100644 index 00000000..a4dc469d --- /dev/null +++ b/docs/docs/api/appkit-python/Class.DatabaseCredential.md @@ -0,0 +1,19 @@ +--- +sidebar_label: DatabaseCredential +sidebar_position: 7 +--- + + + +# Class: DatabaseCredential + +Generated database credential for Lakebase access. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `token` | `str` | +| `expiration_time` | `str` | diff --git a/docs/docs/api/appkit-python/Class.ExecutionResult.md b/docs/docs/api/appkit-python/Class.ExecutionResult.md new file mode 100644 index 00000000..57f1ccbd --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ExecutionResult.md @@ -0,0 +1,21 @@ +--- +sidebar_label: ExecutionResult +sidebar_position: 8 +--- + + + +# Class: ExecutionResult + +Python-facing execution result (frozen, immutable). + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `ok` | `bool` | +| `data` | `Optional[str]` | +| `status` | `Optional[int]` | +| `message` | `Optional[str]` | diff --git a/docs/docs/api/appkit-python/Class.FileDirectoryEntry.md b/docs/docs/api/appkit-python/Class.FileDirectoryEntry.md new file mode 100644 index 00000000..8032431f --- /dev/null +++ b/docs/docs/api/appkit-python/Class.FileDirectoryEntry.md @@ -0,0 +1,22 @@ +--- +sidebar_label: FileDirectoryEntry +sidebar_position: 9 +--- + + + +# Class: FileDirectoryEntry + +A single entry in a directory listing. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `path` | `str` | +| `name` | `str` | +| `is_directory` | `bool` | +| `file_size` | `Optional[int]` | +| `last_modified` | `Optional[int]` | diff --git a/docs/docs/api/appkit-python/Class.FileMetadata.md b/docs/docs/api/appkit-python/Class.FileMetadata.md new file mode 100644 index 00000000..b1e9312f --- /dev/null +++ b/docs/docs/api/appkit-python/Class.FileMetadata.md @@ -0,0 +1,20 @@ +--- +sidebar_label: FileMetadata +sidebar_position: 10 +--- + + + +# Class: FileMetadata + +File metadata from a HEAD request. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `content_length` | `Optional[int]` | +| `content_type` | `Optional[str]` | +| `last_modified` | `Optional[str]` | diff --git a/docs/docs/api/appkit-python/Class.FilePreview.md b/docs/docs/api/appkit-python/Class.FilePreview.md new file mode 100644 index 00000000..b8979e27 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.FilePreview.md @@ -0,0 +1,23 @@ +--- +sidebar_label: FilePreview +sidebar_position: 11 +--- + + + +# Class: FilePreview + +File preview with optional text content. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `content_length` | `Optional[int]` | +| `content_type` | `Optional[str]` | +| `last_modified` | `Optional[str]` | +| `text_preview` | `Optional[str]` | +| `is_text` | `bool` | +| `is_image` | `bool` | diff --git a/docs/docs/api/appkit-python/Class.FilesConnector.md b/docs/docs/api/appkit-python/Class.FilesConnector.md new file mode 100644 index 00000000..378f18f8 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.FilesConnector.md @@ -0,0 +1,106 @@ +--- +sidebar_label: FilesConnector +sidebar_position: 12 +--- + + + +# Class: FilesConnector + +Databricks Files API connector. + +Operates against Unity Catalog Volume paths. The ``default_volume`` +constructor argument is used as a prefix when a ``file_path`` is not +already a ``/Volumes/...`` absolute path. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self, host: str, *, default_volume: Optional[str] = None) -> None +``` + +### `resolve_path` + +```python +def resolve_path(self, file_path: str) -> str +``` + +Join ``file_path`` with the connector's default volume if the +path is not already a fully-qualified ``/Volumes/...`` path. + +### `list` + +```python +def list( self, token: str, *, directory_path: Optional[str] = None, ) -> Awaitable[list[FileDirectoryEntry]] +``` + +List entries under ``directory_path`` (or the default volume). + +### `read` + +```python +def read( self, token: str, file_path: str, *, max_size: Optional[int] = None, ) -> Awaitable[str] +``` + +Read a text file, optionally truncated to ``max_size`` bytes. + +### `download` + +```python +def download(self, token: str, file_path: str) -> Awaitable[bytes] +``` + +Download a file as raw bytes. + +### `exists` + +```python +def exists(self, token: str, file_path: str) -> Awaitable[bool] +``` + +Return whether the given file or directory exists. + +### `metadata` + +```python +def metadata(self, token: str, file_path: str) -> Awaitable[FileMetadata] +``` + +Fetch metadata for a file via a HEAD request. + +### `upload` + +```python +def upload( self, token: str, file_path: str, contents: bytes, *, overwrite: bool = True, ) -> Awaitable[None] +``` + +Upload ``contents`` to ``file_path``. Overwrites by default. + +### `create_directory` + +```python +def create_directory(self, token: str, directory_path: str) -> Awaitable[None] +``` + +Create a directory, creating parents as needed. + +### `delete` + +```python +def delete(self, token: str, file_path: str) -> Awaitable[None] +``` + +Delete a file or (empty) directory. + +### `preview` + +```python +def preview( self, token: str, file_path: str, *, max_chars: int = 1024, ) -> Awaitable[FilePreview] +``` + +Fetch metadata and a capped-length text preview if the file is +recognized as text. diff --git a/docs/docs/api/appkit-python/Class.FilesPlugin.md b/docs/docs/api/appkit-python/Class.FilesPlugin.md new file mode 100644 index 00000000..dec67e17 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.FilesPlugin.md @@ -0,0 +1,44 @@ +--- +sidebar_label: FilesPlugin +sidebar_position: 13 +--- + + + +# Class: FilesPlugin + +Unity Catalog Volumes file operations plugin. + +_Defined in `appkit/plugins/files.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: FilesPluginConfig) -> None +``` + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, router: Any) -> None +``` + +### `connector` + +```python +def connector(self, volume_key: str) -> FilesConnector +``` + +Return the :class:`FilesConnector` registered for ``volume_key``. + +Raises :class:`ValueError` when the alias is not configured. diff --git a/docs/docs/api/appkit-python/Class.FilesPluginConfig.md b/docs/docs/api/appkit-python/Class.FilesPluginConfig.md new file mode 100644 index 00000000..452dfab9 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.FilesPluginConfig.md @@ -0,0 +1,25 @@ +--- +sidebar_label: FilesPluginConfig +sidebar_position: 14 +--- + + + +# Class: FilesPluginConfig + +Configuration for :class:`FilesPlugin`. + +``volumes`` maps alias → :class:`VolumeConfig`. Aliases appear in route +URLs as the ``volume`` query parameter (for example +``/api/files/list?volume=uploads``). ``host`` defaults to the +``DATABRICKS_HOST`` environment variable. + +_Defined in `appkit/plugins/files.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, volumes: Mapping[str, VolumeConfig], host: str | None = None, timeout_ms: int | None = None, max_upload_size: int | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.GenieAttachment.md b/docs/docs/api/appkit-python/Class.GenieAttachment.md new file mode 100644 index 00000000..79e42a75 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GenieAttachment.md @@ -0,0 +1,24 @@ +--- +sidebar_label: GenieAttachment +sidebar_position: 15 +--- + + + +# Class: GenieAttachment + +Genie query attachment metadata. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `attachment_id` | `Optional[str]` | +| `query_title` | `Optional[str]` | +| `query_description` | `Optional[str]` | +| `query_sql` | `Optional[str]` | +| `query_statement_id` | `Optional[str]` | +| `text_content` | `Optional[str]` | +| `suggested_questions` | `Optional[list[str]]` | diff --git a/docs/docs/api/appkit-python/Class.GenieConnector.md b/docs/docs/api/appkit-python/Class.GenieConnector.md new file mode 100644 index 00000000..38cc8e7e --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GenieConnector.md @@ -0,0 +1,56 @@ +--- +sidebar_label: GenieConnector +sidebar_position: 16 +--- + + + +# Class: GenieConnector + +Databricks Genie connector. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__( self, host: str, *, timeout_ms: Optional[int] = None, max_messages: Optional[int] = None, ) -> None +``` + +### `start_message` + +```python +def start_message( self, token: str, space_id: str, content: str, *, conversation_id: Optional[str] = None, ) -> Awaitable[tuple[str, str]] +``` + +### `send_message` + +```python +def send_message( self, token: str, space_id: str, content: str, *, conversation_id: Optional[str] = None, timeout_ms: Optional[int] = None, ) -> Awaitable[GenieMessage] +``` + +### `get_message` + +```python +def get_message( self, token: str, space_id: str, conversation_id: str, message_id: str, *, timeout_ms: Optional[int] = None, ) -> Awaitable[GenieMessage] +``` + +### `list_messages` + +```python +def list_messages( self, token: str, space_id: str, conversation_id: str, *, page_size: Optional[int] = None, page_token: Optional[str] = None, ) -> Awaitable[tuple[list[GenieMessage], Optional[str]]] +``` + +### `get_query_result` + +```python +def get_query_result( self, token: str, space_id: str, conversation_id: str, message_id: str, attachment_id: str, ) -> Awaitable[GenieQueryResult] +``` + +### `get_conversation` + +```python +def get_conversation( self, token: str, space_id: str, conversation_id: str, ) -> Awaitable[GenieConversationHistory] +``` diff --git a/docs/docs/api/appkit-python/Class.GenieConversationHistory.md b/docs/docs/api/appkit-python/Class.GenieConversationHistory.md new file mode 100644 index 00000000..01a12186 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GenieConversationHistory.md @@ -0,0 +1,20 @@ +--- +sidebar_label: GenieConversationHistory +sidebar_position: 17 +--- + + + +# Class: GenieConversationHistory + +Full conversation history. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `conversation_id` | `str` | +| `space_id` | `str` | +| `messages` | `list[GenieMessage]` | diff --git a/docs/docs/api/appkit-python/Class.GenieMessage.md b/docs/docs/api/appkit-python/Class.GenieMessage.md new file mode 100644 index 00000000..45558249 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GenieMessage.md @@ -0,0 +1,24 @@ +--- +sidebar_label: GenieMessage +sidebar_position: 18 +--- + + + +# Class: GenieMessage + +Genie message response. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `message_id` | `str` | +| `conversation_id` | `str` | +| `space_id` | `str` | +| `status` | `str` | +| `content` | `str` | +| `attachments` | `list[GenieAttachment]` | +| `error` | `Optional[str]` | diff --git a/docs/docs/api/appkit-python/Class.GeniePlugin.md b/docs/docs/api/appkit-python/Class.GeniePlugin.md new file mode 100644 index 00000000..ef3b4f82 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GeniePlugin.md @@ -0,0 +1,34 @@ +--- +sidebar_label: GeniePlugin +sidebar_position: 19 +--- + + + +# Class: GeniePlugin + +Genie conversational analytics plugin. + +_Defined in `appkit/plugins/genie.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: GeniePluginConfig) -> None +``` + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, router: Any) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.GeniePluginConfig.md b/docs/docs/api/appkit-python/Class.GeniePluginConfig.md new file mode 100644 index 00000000..5869d582 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GeniePluginConfig.md @@ -0,0 +1,24 @@ +--- +sidebar_label: GeniePluginConfig +sidebar_position: 20 +--- + + + +# Class: GeniePluginConfig + +Configuration for :class:`GeniePlugin`. + +``spaces`` maps alias → Genie ``space_id``. Route handlers accept +an alias (not the raw space id) so the client never sees the UC +resource identifier. + +_Defined in `appkit/plugins/genie.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, spaces: Mapping[str, str], host: str | None = None, timeout_ms: int | None = None, max_messages: int | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.GenieQueryResult.md b/docs/docs/api/appkit-python/Class.GenieQueryResult.md new file mode 100644 index 00000000..6c3bfb28 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.GenieQueryResult.md @@ -0,0 +1,18 @@ +--- +sidebar_label: GenieQueryResult +sidebar_position: 21 +--- + + + +# Class: GenieQueryResult + +Query result from a Genie attachment. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `data` | `str` | diff --git a/docs/docs/api/appkit-python/Class.LakebaseConnector.md b/docs/docs/api/appkit-python/Class.LakebaseConnector.md new file mode 100644 index 00000000..f50a868c --- /dev/null +++ b/docs/docs/api/appkit-python/Class.LakebaseConnector.md @@ -0,0 +1,34 @@ +--- +sidebar_label: LakebaseConnector +sidebar_position: 22 +--- + + + +# Class: LakebaseConnector + +Databricks Lakebase connector. + +Generates short-lived PostgreSQL credentials for a set of Lakebase +instances via the database credentials API. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self, host: str) -> None +``` + +### `generate_credential` + +```python +def generate_credential( self, token: str, instance_names: list[str], *, request_id: Optional[str] = None, ) -> Awaitable[DatabaseCredential] +``` + +Generate a credential good for one or more Lakebase instances. + +Use the returned token as the PostgreSQL password for the life of +``expiration_time`` — typically tens of minutes. diff --git a/docs/docs/api/appkit-python/Class.LakebasePgConfig.md b/docs/docs/api/appkit-python/Class.LakebasePgConfig.md new file mode 100644 index 00000000..419da7ec --- /dev/null +++ b/docs/docs/api/appkit-python/Class.LakebasePgConfig.md @@ -0,0 +1,36 @@ +--- +sidebar_label: LakebasePgConfig +sidebar_position: 23 +--- + + + +# Class: LakebasePgConfig + +PostgreSQL connection configuration for Lakebase. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `host` | `str` | +| `database` | `str` | +| `port` | `int` | +| `ssl_mode` | `str` | +| `app_name` | `Optional[str]` | + +## Methods + +### `__init__` + +```python +def __init__( self, *, host: Optional[str] = None, database: Optional[str] = None, port: Optional[int] = None, ssl_mode: Optional[str] = None, app_name: Optional[str] = None, ) -> None +``` + +### `from_env` (staticmethod) + +```python +def from_env() -> LakebasePgConfig +``` diff --git a/docs/docs/api/appkit-python/Class.LakebasePlugin.md b/docs/docs/api/appkit-python/Class.LakebasePlugin.md new file mode 100644 index 00000000..a8376223 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.LakebasePlugin.md @@ -0,0 +1,69 @@ +--- +sidebar_label: LakebasePlugin +sidebar_position: 24 +--- + + + +# Class: LakebasePlugin + +Lakebase PostgreSQL integration plugin. + +Exposes: + +- :attr:`pg_config` — resolved :class:`LakebasePgConfig` for pool setup. +- :attr:`connector` — the underlying :class:`LakebaseConnector`. +- :meth:`generate_credential` — wrapper around the REST call. + +_Defined in `appkit/plugins/lakebase.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: LakebasePluginConfig | None = None) -> None +``` + +### `pg_config` (property) + +```python +def pg_config(self) -> LakebasePgConfig +``` + +### `connector` (property) + +```python +def connector(self) -> LakebaseConnector +``` + +### `generate_credential` + +```python +async def generate_credential( self, token: str, instance_names: list[str] | None = None, *, request_id: str | None = None, ) -> DatabaseCredential +``` + +Generate a short-lived credential for Lakebase connection(s). + +When ``instance_names`` is omitted, the plugin's configured PG host +is used as the single instance name. + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `exports` + +```python +def exports(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, _router: Any) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.LakebasePluginConfig.md b/docs/docs/api/appkit-python/Class.LakebasePluginConfig.md new file mode 100644 index 00000000..88b22f74 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.LakebasePluginConfig.md @@ -0,0 +1,25 @@ +--- +sidebar_label: LakebasePluginConfig +sidebar_position: 25 +--- + + + +# Class: LakebasePluginConfig + +Configuration for :class:`LakebasePlugin`. + +``pg_config`` overrides the default :class:`LakebasePgConfig` built +from ``PGHOST``/``PGDATABASE``/``LAKEBASE_ENDPOINT`` etc. ``host`` +defaults to ``DATABRICKS_HOST`` and is used to reach the Lakebase +credential-generation REST API (distinct from the PG host). + +_Defined in `appkit/plugins/lakebase.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, pg_config: LakebasePgConfig | None = None, host: str | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.Plugin.md b/docs/docs/api/appkit-python/Class.Plugin.md new file mode 100644 index 00000000..a466ea8c --- /dev/null +++ b/docs/docs/api/appkit-python/Class.Plugin.md @@ -0,0 +1,99 @@ +--- +sidebar_label: Plugin +sidebar_position: 26 +--- + + + +# Class: Plugin + +Base class for Python plugins. Subclass to create custom plugins. + +Example:: + + class MyPlugin(Plugin): + def __init__(self): + super().__init__("my-plugin", manifest=PluginManifest("my-plugin")) + + async def setup(self): + pass + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `name` | `str` | +| `phase` | `str` | +| `manifest` | `PluginManifest` | +| `is_ready` | `bool` | + +## Methods + +### `__init__` + +```python +def __init__( self, name: str, *, phase: str = "normal", manifest: PluginManifest, ) -> None +``` + +### `setup` + +```python +def setup(self) -> Awaitable[None] +``` + +One-time initialization hook called once during +:meth:`AppKit.initialize` after the plugin's runtime is injected. + +### `exports` + +```python +def exports(self) -> dict[str, str] +``` + +Return string values exported to other plugins and the server. + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +Return per-plugin config surfaced to clients via ``/api/config``. + +### `inject_routes` + +```python +def inject_routes(self, router: Router) -> None +``` + +Register HTTP routes with the server router. Called once per +plugin, mounted under ``/api//``. + +### `execute` + +```python +def execute( self, func: Callable[[], Awaitable[str]], *, user_key: str = "", timeout_ms: Optional[int] = None, retry_attempts: Optional[int] = None, cache_key: Optional[list[str]] = None, cache_ttl: Optional[int] = None, ) -> Awaitable[ExecutionResult] +``` + +Execute a coroutine through the plugin's interceptor chain +(telemetry, timeout, retry, cache). + +``user_key`` scopes caches to a user for OBO flows. ``cache_key`` +parts are hashed together with ``user_key`` to form a stable key. +Pass ``timeout_ms=None`` / ``retry_attempts=None`` to fall back to +the plugin defaults. + +### `execute_stream` + +```python +def execute_stream( self, func: Callable[[], Any], *, user_key: str = "", timeout_ms: Optional[int] = None, ) -> Awaitable[StreamIterator] +``` + +Execute a streaming function through the interceptor chain. + +The callable should return a Python async generator that yields +JSON strings. Returns a StreamIterator for async iteration. + +Retry and cache are not supported for streams. diff --git a/docs/docs/api/appkit-python/Class.PluginManifest.md b/docs/docs/api/appkit-python/Class.PluginManifest.md new file mode 100644 index 00000000..d58d7c24 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.PluginManifest.md @@ -0,0 +1,28 @@ +--- +sidebar_label: PluginManifest +sidebar_position: 27 +--- + + + +# Class: PluginManifest + +Plugin manifest — metadata. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `name` | `str` | +| `display_name` | `Optional[str]` | +| `description` | `Optional[str]` | + +## Methods + +### `__init__` + +```python +def __init__( self, name: str, *, display_name: Optional[str] = None, description: Optional[str] = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.PluginPhase.md b/docs/docs/api/appkit-python/Class.PluginPhase.md new file mode 100644 index 00000000..3dd5a965 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.PluginPhase.md @@ -0,0 +1,20 @@ +--- +sidebar_label: PluginPhase +sidebar_position: 28 +--- + + + +# Class: PluginPhase + +Phase ordering constants for Python plugins. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `CORE` | `str` | +| `NORMAL` | `str` | +| `DEFERRED` | `str` | diff --git a/docs/docs/api/appkit-python/Class.Request.md b/docs/docs/api/appkit-python/Class.Request.md new file mode 100644 index 00000000..97684a10 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.Request.md @@ -0,0 +1,34 @@ +--- +sidebar_label: Request +sidebar_position: 29 +--- + + + +# Class: Request + +HTTP request data forwarded to Python route handlers. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `method` | `str` | +| `path` | `str` | +| `headers` | `dict[str, str]` | +| `query` | `str` | +| `body` | `str` | + +## Methods + +### `json` + +```python +def json(self) -> Any +``` + +Parse the request body as JSON and return Python-native data. + +Raises ValueError if the body is not valid JSON. diff --git a/docs/docs/api/appkit-python/Class.Router.md b/docs/docs/api/appkit-python/Class.Router.md new file mode 100644 index 00000000..7df6c424 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.Router.md @@ -0,0 +1,50 @@ +--- +sidebar_label: Router +sidebar_position: 30 +--- + + + +# Class: Router + +Router passed to Plugin.inject_routes() for route registration. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `plugin_name` | `str` | + +## Methods + +### `get` + +```python +def get( self, path: str, handler: Callable[[Request], Awaitable[str]], *, stream: bool = False, ) -> None +``` + +### `post` + +```python +def post( self, path: str, handler: Callable[[Request], Awaitable[str]], *, stream: bool = False, ) -> None +``` + +### `put` + +```python +def put( self, path: str, handler: Callable[[Request], Awaitable[str]], *, stream: bool = False, ) -> None +``` + +### `delete` + +```python +def delete( self, path: str, handler: Callable[[Request], Awaitable[str]], *, stream: bool = False, ) -> None +``` + +### `patch` + +```python +def patch( self, path: str, handler: Callable[[Request], Awaitable[str]], *, stream: bool = False, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.ServerConfig.md b/docs/docs/api/appkit-python/Class.ServerConfig.md new file mode 100644 index 00000000..1dfc5e0b --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServerConfig.md @@ -0,0 +1,29 @@ +--- +sidebar_label: ServerConfig +sidebar_position: 31 +--- + + + +# Class: ServerConfig + +Server configuration. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `host` | `str` | +| `port` | `int` | +| `auto_start` | `bool` | +| `static_path` | `Optional[str]` | + +## Methods + +### `__init__` + +```python +def __init__( self, *, host: str = "0.0.0.0", port: int = 8000, auto_start: bool = True, static_path: Optional[str] = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.ServerPlugin.md b/docs/docs/api/appkit-python/Class.ServerPlugin.md new file mode 100644 index 00000000..731276be --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServerPlugin.md @@ -0,0 +1,52 @@ +--- +sidebar_label: ServerPlugin +sidebar_position: 32 +--- + + + +# Class: ServerPlugin + +Core HTTP server plugin. + +The server plugin runs in the Core phase so that the server is ready +before any Normal-phase plugin calls ``inject_routes``. It exposes no +routes of its own — route hosting is handled by ``AppKit.start_server``. + +_Defined in `appkit/plugins/server.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: ServerPluginConfig | None = None) -> None +``` + +### `config` (property) + +```python +def config(self) -> ServerPluginConfig +``` + +### `to_server_config` + +```python +def to_server_config(self) -> ServerConfig +``` + +Convert plugin config into an :class:`appkit.ServerConfig`. + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, _router: Any) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.ServerPluginConfig.md b/docs/docs/api/appkit-python/Class.ServerPluginConfig.md new file mode 100644 index 00000000..f8f82b84 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServerPluginConfig.md @@ -0,0 +1,23 @@ +--- +sidebar_label: ServerPluginConfig +sidebar_position: 33 +--- + + + +# Class: ServerPluginConfig + +Configuration for :class:`ServerPlugin`. + +Mirrors :class:`appkit.ServerConfig`. Defaults align with the Rust +``ServerPluginConfig`` (``0.0.0.0:8000``, auto-start enabled). + +_Defined in `appkit/plugins/server.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, host: str = "0.0.0.0", port: int = 8000, auto_start: bool = True, static_path: str | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.ServiceContext.md b/docs/docs/api/appkit-python/Class.ServiceContext.md new file mode 100644 index 00000000..2048f49c --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServiceContext.md @@ -0,0 +1,32 @@ +--- +sidebar_label: ServiceContext +sidebar_position: 34 +--- + + + +# Class: ServiceContext + +Service-level authentication context (service principal). + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `config` | `AppConfig` | + +## Methods + +### `__init__` + +```python +def __init__(self, config: AppConfig) -> None +``` + +### `get_token` + +```python +def get_token(self) -> Awaitable[str] +``` diff --git a/docs/docs/api/appkit-python/Class.ServingConnector.md b/docs/docs/api/appkit-python/Class.ServingConnector.md new file mode 100644 index 00000000..dcb0d7bb --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServingConnector.md @@ -0,0 +1,43 @@ +--- +sidebar_label: ServingConnector +sidebar_position: 35 +--- + + + +# Class: ServingConnector + +Databricks Serving Endpoints connector. + +Wraps ``/serving-endpoints//invocations`` for synchronous calls +and the SSE streaming variant for LLM-style endpoints. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self, host: str) -> None +``` + +### `invoke` + +```python +def invoke( self, token: str, endpoint_name: str, body: str, ) -> Awaitable[ServingResponse] +``` + +Invoke a serving endpoint with a JSON request body and return +the raw response. + +### `stream` + +```python +def stream( self, token: str, endpoint_name: str, body: str, ) -> Awaitable[StreamIterator] +``` + +Stream from a serving endpoint (SSE). + +Returns a StreamIterator that yields parsed SSE data payloads +as they arrive. The stream ends on ``data: [DONE]`` or connection close. diff --git a/docs/docs/api/appkit-python/Class.ServingEndpointConfig.md b/docs/docs/api/appkit-python/Class.ServingEndpointConfig.md new file mode 100644 index 00000000..9e4b31d8 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServingEndpointConfig.md @@ -0,0 +1,24 @@ +--- +sidebar_label: ServingEndpointConfig +sidebar_position: 36 +--- + + + +# Class: ServingEndpointConfig + +Per-alias serving endpoint configuration. + +``env`` is the environment variable that holds the actual endpoint +name (for example ``CHAT_ENDPOINT``). ``served_model`` optionally +pins the request to a specific served model inside the endpoint. + +_Defined in `appkit/plugins/serving.py`._ + +## Methods + +### `__init__` + +```python +def __init__(self, *, env: str, served_model: str | None = None) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.ServingPlugin.md b/docs/docs/api/appkit-python/Class.ServingPlugin.md new file mode 100644 index 00000000..35bbcd0b --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServingPlugin.md @@ -0,0 +1,45 @@ +--- +sidebar_label: ServingPlugin +sidebar_position: 37 +--- + + + +# Class: ServingPlugin + +Model Serving plugin — invoke and stream endpoints via alias. + +_Defined in `appkit/plugins/serving.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: ServingPluginConfig) -> None +``` + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, router: Any) -> None +``` + +### `resolve_endpoint` + +```python +def resolve_endpoint(self, alias: str) -> str +``` + +Return the endpoint name for ``alias`` from the configured env var. + +Raises :class:`ValidationError` if the alias is unknown or the +environment variable is unset or empty. diff --git a/docs/docs/api/appkit-python/Class.ServingPluginConfig.md b/docs/docs/api/appkit-python/Class.ServingPluginConfig.md new file mode 100644 index 00000000..1041d37a --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServingPluginConfig.md @@ -0,0 +1,23 @@ +--- +sidebar_label: ServingPluginConfig +sidebar_position: 38 +--- + + + +# Class: ServingPluginConfig + +Configuration for :class:`ServingPlugin`. + +``endpoints`` maps alias → :class:`ServingEndpointConfig`. At least one +endpoint is required. + +_Defined in `appkit/plugins/serving.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, endpoints: Mapping[str, ServingEndpointConfig], host: str | None = None, timeout_ms: int | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.ServingResponse.md b/docs/docs/api/appkit-python/Class.ServingResponse.md new file mode 100644 index 00000000..769473ae --- /dev/null +++ b/docs/docs/api/appkit-python/Class.ServingResponse.md @@ -0,0 +1,19 @@ +--- +sidebar_label: ServingResponse +sidebar_position: 39 +--- + + + +# Class: ServingResponse + +Response from a serving endpoint invocation. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `data` | `str` | +| `status_code` | `int` | diff --git a/docs/docs/api/appkit-python/Class.SqlColumn.md b/docs/docs/api/appkit-python/Class.SqlColumn.md new file mode 100644 index 00000000..54601783 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.SqlColumn.md @@ -0,0 +1,19 @@ +--- +sidebar_label: SqlColumn +sidebar_position: 40 +--- + + + +# Class: SqlColumn + +Column schema information. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `name` | `str` | +| `type_name` | `str` | diff --git a/docs/docs/api/appkit-python/Class.SqlStatementResult.md b/docs/docs/api/appkit-python/Class.SqlStatementResult.md new file mode 100644 index 00000000..c3f879c7 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.SqlStatementResult.md @@ -0,0 +1,22 @@ +--- +sidebar_label: SqlStatementResult +sidebar_position: 41 +--- + + + +# Class: SqlStatementResult + +Result of a SQL statement execution. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `statement_id` | `str` | +| `status` | `str` | +| `columns` | `list[SqlColumn]` | +| `data` | `str` | +| `row_count` | `int` | diff --git a/docs/docs/api/appkit-python/Class.SqlWarehouseConnector.md b/docs/docs/api/appkit-python/Class.SqlWarehouseConnector.md new file mode 100644 index 00000000..6f21bef5 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.SqlWarehouseConnector.md @@ -0,0 +1,36 @@ +--- +sidebar_label: SqlWarehouseConnector +sidebar_position: 42 +--- + + + +# Class: SqlWarehouseConnector + +Databricks SQL Warehouse connector. + +Runs parameterised SQL statements against a Serverless or Pro warehouse +via the ``/api/2.0/sql/statements`` endpoint and polls until the +statement reaches a terminal status. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self, host: str, *, timeout_ms: Optional[int] = None) -> None +``` + +### `execute_statement` + +```python +def execute_statement( self, token: str, statement: str, warehouse_id: str, *, parameters: Optional[list[tuple[str, str]]] = None, catalog: Optional[str] = None, schema: Optional[str] = None, wait_timeout: Optional[str] = None, disposition: Optional[str] = None, format: Optional[str] = None, on_wait_timeout: Optional[str] = None, byte_limit: Optional[int] = None, row_limit: Optional[int] = None, timeout_ms: Optional[int] = None, ) -> Awaitable[SqlStatementResult] +``` + +Execute ``statement`` on ``warehouse_id`` and return the result. + +``parameters`` is a list of ``(name, value)`` pairs corresponding +to ``:name`` placeholders in the SQL; values are always passed as +strings and typed by the server via column metadata. diff --git a/docs/docs/api/appkit-python/Class.StreamIterator.md b/docs/docs/api/appkit-python/Class.StreamIterator.md new file mode 100644 index 00000000..ebf9c881 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.StreamIterator.md @@ -0,0 +1,20 @@ +--- +sidebar_label: StreamIterator +sidebar_position: 43 +--- + + + +# Class: StreamIterator + +Async iterator yielding JSON string items from a streaming execution. + +Used by ``Plugin.execute_stream()`` and ``ServingConnector.stream()``. + +Example:: + + stream = await plugin.execute_stream(my_gen_fn) + async for item in stream: + data = json.loads(item) + +_Defined in `appkit.pyi`._ diff --git a/docs/docs/api/appkit-python/Class.UserContext.md b/docs/docs/api/appkit-python/Class.UserContext.md new file mode 100644 index 00000000..b3bb74c8 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.UserContext.md @@ -0,0 +1,36 @@ +--- +sidebar_label: UserContext +sidebar_position: 44 +--- + + + +# Class: UserContext + +Per-request user context for OBO (On-Behalf-Of) flows. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `token` | `str` | +| `user_id` | `str` | +| `user_name` | `Optional[str]` | +| `workspace_id` | `str` | +| `warehouse_id` | `Optional[str]` | + +## Methods + +### `is_user_context` (property) + +```python +def is_user_context(self) -> bool +``` + +### `__init__` + +```python +def __init__( self, token: str, user_id: str, *, user_name: Optional[str] = None, workspace_id: str, warehouse_id: Optional[str] = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.VectorSearchConnector.md b/docs/docs/api/appkit-python/Class.VectorSearchConnector.md new file mode 100644 index 00000000..4c220da2 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.VectorSearchConnector.md @@ -0,0 +1,45 @@ +--- +sidebar_label: VectorSearchConnector +sidebar_position: 45 +--- + + + +# Class: VectorSearchConnector + +Databricks Vector Search REST connector. + +Wraps the ``/api/2.0/vector-search`` endpoints; returns raw JSON response +bodies as strings so Python callers can reuse their existing shaping +logic without a second serde pass across the PyO3 boundary. + +_Defined in `appkit.pyi`._ + +## Methods + +### `__init__` + +```python +def __init__(self, host: str, *, timeout_ms: Optional[int] = None) -> None +``` + +### `query` + +```python +def query( self, token: str, index_name: str, *, columns: list[str], num_results: int = 20, query_type: str = "hybrid", query_text: Optional[str] = None, query_vector: Optional[list[float]] = None, filters_json: Optional[str] = None, reranker_columns: Optional[list[str]] = None, ) -> Awaitable[str] +``` + +Run a query against ``index_name`` and return the raw JSON +response body. + +Pass ``query_text`` for text/hybrid queries, ``query_vector`` for +ANN queries. ``filters_json`` is the JSON-serialised filter object. + +### `query_next_page` + +```python +def query_next_page( self, token: str, index_name: str, endpoint_name: str, page_token: str, ) -> Awaitable[str] +``` + +Fetch the next page of a paginated query using ``page_token`` +from a prior :meth:`query` response. diff --git a/docs/docs/api/appkit-python/Class.VectorSearchIndexConfig.md b/docs/docs/api/appkit-python/Class.VectorSearchIndexConfig.md new file mode 100644 index 00000000..21b77308 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.VectorSearchIndexConfig.md @@ -0,0 +1,26 @@ +--- +sidebar_label: VectorSearchIndexConfig +sidebar_position: 46 +--- + + + +# Class: VectorSearchIndexConfig + +Per-index alias configuration. + +``index_name`` is the fully-qualified ``catalog.schema.index`` name. +``endpoint_name`` is required when paginating. ``columns`` lists +the columns returned from the index; ``query_type`` picks the +default search mode. ``reranker_columns`` enables the Databricks +reranker when non-empty. + +_Defined in `appkit/plugins/vector_search.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, index_name: str, endpoint_name: str | None = None, columns: list[str] | None = None, query_type: str = "hybrid", num_results: int = 20, reranker_columns: list[str] | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.VectorSearchPlugin.md b/docs/docs/api/appkit-python/Class.VectorSearchPlugin.md new file mode 100644 index 00000000..f38b3352 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.VectorSearchPlugin.md @@ -0,0 +1,34 @@ +--- +sidebar_label: VectorSearchPlugin +sidebar_position: 47 +--- + + + +# Class: VectorSearchPlugin + +Vector Search plugin — hybrid, ANN, and full-text queries. + +_Defined in `appkit/plugins/vector_search.py`._ + +**Extends:** `Plugin` + +## Methods + +### `__init__` + +```python +def __init__(self, config: VectorSearchPluginConfig) -> None +``` + +### `client_config` + +```python +def client_config(self) -> dict[str, str] +``` + +### `inject_routes` + +```python +def inject_routes(self, router: Any) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.VectorSearchPluginConfig.md b/docs/docs/api/appkit-python/Class.VectorSearchPluginConfig.md new file mode 100644 index 00000000..5b14c76c --- /dev/null +++ b/docs/docs/api/appkit-python/Class.VectorSearchPluginConfig.md @@ -0,0 +1,22 @@ +--- +sidebar_label: VectorSearchPluginConfig +sidebar_position: 48 +--- + + + +# Class: VectorSearchPluginConfig + +Configuration for :class:`VectorSearchPlugin`. + +``indexes`` maps alias → :class:`VectorSearchIndexConfig`. + +_Defined in `appkit/plugins/vector_search.py`._ + +## Methods + +### `__init__` + +```python +def __init__( self, *, indexes: Mapping[str, VectorSearchIndexConfig], host: str | None = None, timeout_ms: int | None = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.VolumeConfig.md b/docs/docs/api/appkit-python/Class.VolumeConfig.md new file mode 100644 index 00000000..97007851 --- /dev/null +++ b/docs/docs/api/appkit-python/Class.VolumeConfig.md @@ -0,0 +1,23 @@ +--- +sidebar_label: VolumeConfig +sidebar_position: 49 +--- + + + +# Class: VolumeConfig + +Per-volume alias configuration. + +``path`` is the fully-qualified Unity Catalog volume path, for example +``/Volumes/catalog/schema/volume``. + +_Defined in `appkit/plugins/files.py`._ + +## Methods + +### `__init__` + +```python +def __init__(self, *, path: str, max_upload_size: int | None = None) -> None +``` diff --git a/docs/docs/api/appkit-python/Class.VsSearchRequest.md b/docs/docs/api/appkit-python/Class.VsSearchRequest.md new file mode 100644 index 00000000..8f2ab56d --- /dev/null +++ b/docs/docs/api/appkit-python/Class.VsSearchRequest.md @@ -0,0 +1,35 @@ +--- +sidebar_label: VsSearchRequest +sidebar_position: 50 +--- + + + +# Class: VsSearchRequest + +Parsed Vector Search request matching the TS ``SearchRequest`` shape. + +Parameters are passed by keyword; ``filters_json`` is a JSON object +string of scalar-or-array filter values. + +_Defined in `appkit.pyi`._ + +## Attributes + +| Name | Type | +| --- | --- | +| `query_text` | `Optional[str]` | +| `query_vector` | `Optional[list[float]]` | +| `columns` | `Optional[list[str]]` | +| `num_results` | `Optional[int]` | +| `query_type` | `Optional[str]` | +| `filters_json` | `Optional[str]` | +| `reranker_columns` | `Optional[list[str]]` | + +## Methods + +### `__init__` + +```python +def __init__( self, *, query_text: Optional[str] = None, query_vector: Optional[list[float]] = None, columns: Optional[list[str]] = None, num_results: Optional[int] = None, query_type: Optional[str] = None, filters_json: Optional[str] = None, reranker_columns: Optional[list[str]] = None, ) -> None +``` diff --git a/docs/docs/api/appkit-python/Function.analytics.md b/docs/docs/api/appkit-python/Function.analytics.md new file mode 100644 index 00000000..7a9fa551 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.analytics.md @@ -0,0 +1,16 @@ +--- +sidebar_label: analytics +sidebar_position: 1000 +--- + + + +# Function: analytics + +Construct an :class:`AnalyticsPlugin` — the plugin entry point. + +```python +def analytics(config: AnalyticsPluginConfig) -> AnalyticsPlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/Function.as_user.md b/docs/docs/api/appkit-python/Function.as_user.md new file mode 100644 index 00000000..5426b024 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.as_user.md @@ -0,0 +1,18 @@ +--- +sidebar_label: as_user +sidebar_position: 1001 +--- + + + +# Function: as_user + +Run an async callable with the given UserContext set for the duration. + +Returns an awaitable coroutine. + +```python +def as_user( user_context: UserContext, func: Callable[[], Awaitable[Any]], ) -> Awaitable[Any] +``` + +_Defined in `appkit.pyi`._ diff --git a/docs/docs/api/appkit-python/Function.create_app.md b/docs/docs/api/appkit-python/Function.create_app.md new file mode 100644 index 00000000..42df3a7a --- /dev/null +++ b/docs/docs/api/appkit-python/Function.create_app.md @@ -0,0 +1,26 @@ +--- +sidebar_label: create_app +sidebar_position: 1002 +--- + + + +# Function: create_app + +Create and initialize an AppKit instance in one call. + +This is the primary public API — mirrors TypeScript's ``createApp(...)``. + +Steps: + 1. Creates an AppKit instance + 2. Registers all provided plugins + 3. Initializes (telemetry, cache, phase-ordered plugin setup) + 4. Optionally starts the HTTP server (when ``auto_start=True``) + +Returns the initialized AppKit instance. + +```python +def create_app( *, config: AppConfig, plugins: list[Plugin] = ..., cache_config: Optional[CacheConfig] = None, server_config: Optional[ServerConfig] = None, auto_start: bool = True, ) -> Awaitable[AppKit] +``` + +_Defined in `appkit.pyi`._ diff --git a/docs/docs/api/appkit-python/Function.files.md b/docs/docs/api/appkit-python/Function.files.md new file mode 100644 index 00000000..dbc48ff7 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.files.md @@ -0,0 +1,16 @@ +--- +sidebar_label: files +sidebar_position: 1003 +--- + + + +# Function: files + +Construct a :class:`FilesPlugin` — the plugin entry point. + +```python +def files(config: FilesPluginConfig) -> FilesPlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/Function.genie.md b/docs/docs/api/appkit-python/Function.genie.md new file mode 100644 index 00000000..f5198ff6 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.genie.md @@ -0,0 +1,16 @@ +--- +sidebar_label: genie +sidebar_position: 1004 +--- + + + +# Function: genie + +Construct a :class:`GeniePlugin` — the plugin entry point. + +```python +def genie(config: GeniePluginConfig) -> GeniePlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/Function.get_current_user.md b/docs/docs/api/appkit-python/Function.get_current_user.md new file mode 100644 index 00000000..fce0fcd8 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.get_current_user.md @@ -0,0 +1,17 @@ +--- +sidebar_label: get_current_user +sidebar_position: 1005 +--- + + + +# Function: get_current_user + +Get the current UserContext from the execution context, or None +if running as service principal. + +```python +def get_current_user() -> Optional[UserContext] +``` + +_Defined in `appkit.pyi`._ diff --git a/docs/docs/api/appkit-python/Function.is_in_user_context.md b/docs/docs/api/appkit-python/Function.is_in_user_context.md new file mode 100644 index 00000000..af9f3c12 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.is_in_user_context.md @@ -0,0 +1,16 @@ +--- +sidebar_label: is_in_user_context +sidebar_position: 1006 +--- + + + +# Function: is_in_user_context + +Check whether the current execution is running in a user context. + +```python +def is_in_user_context() -> bool +``` + +_Defined in `appkit.pyi`._ diff --git a/docs/docs/api/appkit-python/Function.lakebase.md b/docs/docs/api/appkit-python/Function.lakebase.md new file mode 100644 index 00000000..da7ebae0 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.lakebase.md @@ -0,0 +1,16 @@ +--- +sidebar_label: lakebase +sidebar_position: 1007 +--- + + + +# Function: lakebase + +Construct a :class:`LakebasePlugin` — the plugin entry point. + +```python +def lakebase(config: LakebasePluginConfig | None = None) -> LakebasePlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/Function.run_in_user_context.md b/docs/docs/api/appkit-python/Function.run_in_user_context.md new file mode 100644 index 00000000..9999e1be --- /dev/null +++ b/docs/docs/api/appkit-python/Function.run_in_user_context.md @@ -0,0 +1,19 @@ +--- +sidebar_label: run_in_user_context +sidebar_position: 1008 +--- + + + +# Function: run_in_user_context + +Run a synchronous callable with the given UserContext set as the +current execution context for the duration of the call. + +Mirrors TypeScript's ``runInUserContext(userContext, fn)``. + +```python +def run_in_user_context(user_context: UserContext, func: Callable[[], Any]) -> Any +``` + +_Defined in `appkit.pyi`._ diff --git a/docs/docs/api/appkit-python/Function.server.md b/docs/docs/api/appkit-python/Function.server.md new file mode 100644 index 00000000..1f3e13d5 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.server.md @@ -0,0 +1,16 @@ +--- +sidebar_label: server +sidebar_position: 1009 +--- + + + +# Function: server + +Construct a :class:`ServerPlugin` — the plugin entry point. + +```python +def server(config: ServerPluginConfig | None = None) -> ServerPlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/Function.serving.md b/docs/docs/api/appkit-python/Function.serving.md new file mode 100644 index 00000000..c60a4114 --- /dev/null +++ b/docs/docs/api/appkit-python/Function.serving.md @@ -0,0 +1,16 @@ +--- +sidebar_label: serving +sidebar_position: 1010 +--- + + + +# Function: serving + +Construct a :class:`ServingPlugin` — the plugin entry point. + +```python +def serving(config: ServingPluginConfig) -> ServingPlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/Function.vector_search.md b/docs/docs/api/appkit-python/Function.vector_search.md new file mode 100644 index 00000000..2e7aa49d --- /dev/null +++ b/docs/docs/api/appkit-python/Function.vector_search.md @@ -0,0 +1,16 @@ +--- +sidebar_label: vector_search +sidebar_position: 1011 +--- + + + +# Function: vector_search + +Construct a :class:`VectorSearchPlugin` — the plugin entry point. + +```python +def vector_search(config: VectorSearchPluginConfig) -> VectorSearchPlugin +``` + +_Defined in `appkit/plugins/__init__.py`._ diff --git a/docs/docs/api/appkit-python/_category_.json b/docs/docs/api/appkit-python/_category_.json new file mode 100644 index 00000000..f3b78526 --- /dev/null +++ b/docs/docs/api/appkit-python/_category_.json @@ -0,0 +1,8 @@ +{ + "label": "appkit (Python)", + "position": 3, + "link": { + "type": "doc", + "id": "index" + } +} diff --git a/docs/docs/api/appkit-python/index.md b/docs/docs/api/appkit-python/index.md new file mode 100644 index 00000000..3b4310f1 --- /dev/null +++ b/docs/docs/api/appkit-python/index.md @@ -0,0 +1,80 @@ +--- +sidebar_label: "appkit (Python)" +sidebar_position: 0 +--- + + + +# appkit (Python) + +Python API reference for the `appkit` package shipped as part of `appkit-rs`. + +Install with `uv add appkit-rs` or the equivalent in your Python tooling. + +## Classes + +- [`AnalyticsPlugin`](./Class.AnalyticsPlugin.md) — SQL query execution plugin. _(from `appkit/plugins/analytics.py`)_ +- [`AnalyticsPluginConfig`](./Class.AnalyticsPluginConfig.md) — Configuration for :class:`AnalyticsPlugin`. _(from `appkit/plugins/analytics.py`)_ +- [`AppConfig`](./Class.AppConfig.md) — Application configuration parsed from environment variables. _(from `appkit.pyi`)_ +- [`AppKit`](./Class.AppKit.md) — AppKit orchestrator — registers plugins and manages initialization. _(from `appkit.pyi`)_ +- [`CacheConfig`](./Class.CacheConfig.md) — Cache configuration with defaults matching TypeScript cacheDefaults. _(from `appkit.pyi`)_ +- [`CacheManager`](./Class.CacheManager.md) — Cache manager with TTL, LRU eviction, and in-flight deduplication. _(from `appkit.pyi`)_ +- [`DatabaseCredential`](./Class.DatabaseCredential.md) — Generated database credential for Lakebase access. _(from `appkit.pyi`)_ +- [`ExecutionResult`](./Class.ExecutionResult.md) — Python-facing execution result (frozen, immutable). _(from `appkit.pyi`)_ +- [`FileDirectoryEntry`](./Class.FileDirectoryEntry.md) — A single entry in a directory listing. _(from `appkit.pyi`)_ +- [`FileMetadata`](./Class.FileMetadata.md) — File metadata from a HEAD request. _(from `appkit.pyi`)_ +- [`FilePreview`](./Class.FilePreview.md) — File preview with optional text content. _(from `appkit.pyi`)_ +- [`FilesConnector`](./Class.FilesConnector.md) — Databricks Files API connector. _(from `appkit.pyi`)_ +- [`FilesPlugin`](./Class.FilesPlugin.md) — Unity Catalog Volumes file operations plugin. _(from `appkit/plugins/files.py`)_ +- [`FilesPluginConfig`](./Class.FilesPluginConfig.md) — Configuration for :class:`FilesPlugin`. _(from `appkit/plugins/files.py`)_ +- [`GenieAttachment`](./Class.GenieAttachment.md) — Genie query attachment metadata. _(from `appkit.pyi`)_ +- [`GenieConnector`](./Class.GenieConnector.md) — Databricks Genie connector. _(from `appkit.pyi`)_ +- [`GenieConversationHistory`](./Class.GenieConversationHistory.md) — Full conversation history. _(from `appkit.pyi`)_ +- [`GenieMessage`](./Class.GenieMessage.md) — Genie message response. _(from `appkit.pyi`)_ +- [`GeniePlugin`](./Class.GeniePlugin.md) — Genie conversational analytics plugin. _(from `appkit/plugins/genie.py`)_ +- [`GeniePluginConfig`](./Class.GeniePluginConfig.md) — Configuration for :class:`GeniePlugin`. _(from `appkit/plugins/genie.py`)_ +- [`GenieQueryResult`](./Class.GenieQueryResult.md) — Query result from a Genie attachment. _(from `appkit.pyi`)_ +- [`LakebaseConnector`](./Class.LakebaseConnector.md) — Databricks Lakebase connector. _(from `appkit.pyi`)_ +- [`LakebasePgConfig`](./Class.LakebasePgConfig.md) — PostgreSQL connection configuration for Lakebase. _(from `appkit.pyi`)_ +- [`LakebasePlugin`](./Class.LakebasePlugin.md) — Lakebase PostgreSQL integration plugin. _(from `appkit/plugins/lakebase.py`)_ +- [`LakebasePluginConfig`](./Class.LakebasePluginConfig.md) — Configuration for :class:`LakebasePlugin`. _(from `appkit/plugins/lakebase.py`)_ +- [`Plugin`](./Class.Plugin.md) — Base class for Python plugins. Subclass to create custom plugins. _(from `appkit.pyi`)_ +- [`PluginManifest`](./Class.PluginManifest.md) — Plugin manifest — metadata. _(from `appkit.pyi`)_ +- [`PluginPhase`](./Class.PluginPhase.md) — Phase ordering constants for Python plugins. _(from `appkit.pyi`)_ +- [`Request`](./Class.Request.md) — HTTP request data forwarded to Python route handlers. _(from `appkit.pyi`)_ +- [`Router`](./Class.Router.md) — Router passed to Plugin.inject_routes() for route registration. _(from `appkit.pyi`)_ +- [`ServerConfig`](./Class.ServerConfig.md) — Server configuration. _(from `appkit.pyi`)_ +- [`ServerPlugin`](./Class.ServerPlugin.md) — Core HTTP server plugin. _(from `appkit/plugins/server.py`)_ +- [`ServerPluginConfig`](./Class.ServerPluginConfig.md) — Configuration for :class:`ServerPlugin`. _(from `appkit/plugins/server.py`)_ +- [`ServiceContext`](./Class.ServiceContext.md) — Service-level authentication context (service principal). _(from `appkit.pyi`)_ +- [`ServingConnector`](./Class.ServingConnector.md) — Databricks Serving Endpoints connector. _(from `appkit.pyi`)_ +- [`ServingEndpointConfig`](./Class.ServingEndpointConfig.md) — Per-alias serving endpoint configuration. _(from `appkit/plugins/serving.py`)_ +- [`ServingPlugin`](./Class.ServingPlugin.md) — Model Serving plugin — invoke and stream endpoints via alias. _(from `appkit/plugins/serving.py`)_ +- [`ServingPluginConfig`](./Class.ServingPluginConfig.md) — Configuration for :class:`ServingPlugin`. _(from `appkit/plugins/serving.py`)_ +- [`ServingResponse`](./Class.ServingResponse.md) — Response from a serving endpoint invocation. _(from `appkit.pyi`)_ +- [`SqlColumn`](./Class.SqlColumn.md) — Column schema information. _(from `appkit.pyi`)_ +- [`SqlStatementResult`](./Class.SqlStatementResult.md) — Result of a SQL statement execution. _(from `appkit.pyi`)_ +- [`SqlWarehouseConnector`](./Class.SqlWarehouseConnector.md) — Databricks SQL Warehouse connector. _(from `appkit.pyi`)_ +- [`StreamIterator`](./Class.StreamIterator.md) — Async iterator yielding JSON string items from a streaming execution. _(from `appkit.pyi`)_ +- [`UserContext`](./Class.UserContext.md) — Per-request user context for OBO (On-Behalf-Of) flows. _(from `appkit.pyi`)_ +- [`VectorSearchConnector`](./Class.VectorSearchConnector.md) — Databricks Vector Search REST connector. _(from `appkit.pyi`)_ +- [`VectorSearchIndexConfig`](./Class.VectorSearchIndexConfig.md) — Per-index alias configuration. _(from `appkit/plugins/vector_search.py`)_ +- [`VectorSearchPlugin`](./Class.VectorSearchPlugin.md) — Vector Search plugin — hybrid, ANN, and full-text queries. _(from `appkit/plugins/vector_search.py`)_ +- [`VectorSearchPluginConfig`](./Class.VectorSearchPluginConfig.md) — Configuration for :class:`VectorSearchPlugin`. _(from `appkit/plugins/vector_search.py`)_ +- [`VolumeConfig`](./Class.VolumeConfig.md) — Per-volume alias configuration. _(from `appkit/plugins/files.py`)_ +- [`VsSearchRequest`](./Class.VsSearchRequest.md) — Parsed Vector Search request matching the TS ``SearchRequest`` shape. _(from `appkit.pyi`)_ + +## Functions + +- [`analytics`](./Function.analytics.md) — Construct an :class:`AnalyticsPlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ +- [`as_user`](./Function.as_user.md) — Run an async callable with the given UserContext set for the duration. _(from `appkit.pyi`)_ +- [`create_app`](./Function.create_app.md) — Create and initialize an AppKit instance in one call. _(from `appkit.pyi`)_ +- [`files`](./Function.files.md) — Construct a :class:`FilesPlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ +- [`genie`](./Function.genie.md) — Construct a :class:`GeniePlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ +- [`get_current_user`](./Function.get_current_user.md) — Get the current UserContext from the execution context, or None _(from `appkit.pyi`)_ +- [`is_in_user_context`](./Function.is_in_user_context.md) — Check whether the current execution is running in a user context. _(from `appkit.pyi`)_ +- [`lakebase`](./Function.lakebase.md) — Construct a :class:`LakebasePlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ +- [`run_in_user_context`](./Function.run_in_user_context.md) — Run a synchronous callable with the given UserContext set as the _(from `appkit.pyi`)_ +- [`server`](./Function.server.md) — Construct a :class:`ServerPlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ +- [`serving`](./Function.serving.md) — Construct a :class:`ServingPlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ +- [`vector_search`](./Function.vector_search.md) — Construct a :class:`VectorSearchPlugin` — the plugin entry point. _(from `appkit/plugins/__init__.py`)_ diff --git a/docs/docs/api/index.md b/docs/docs/api/index.md index d16354c7..98e8d55c 100644 --- a/docs/docs/api/index.md +++ b/docs/docs/api/index.md @@ -6,6 +6,7 @@ This section contains the API reference for the AppKit packages. - [`appkit`](appkit/index.md) - Core library. Provides the core functionality for building Databricks applications. - [`appkit-ui`](appkit-ui/index.md) - UI components library. Provides a set of UI primitives for building Databricks apps in [React](https://react.dev/). +- [`appkit` (Python)](appkit-python/index.md) - Python SDK for building Databricks apps that use the same plugin model as the Node.js library. Learn more about the architecture of AppKit in the [architecture](../architecture.md) document. @@ -21,3 +22,9 @@ To install the AppKit packages into your existing JavaScript/TypeScript project, npm install @databricks/appkit npm install @databricks/appkit-ui ``` + +For the Python SDK, install `appkit-rs` with [uv](https://docs.astral.sh/uv/): + +```bash +uv add appkit-rs +``` diff --git a/packages/appkit-rs/.release-it.json b/packages/appkit-rs/.release-it.json new file mode 100644 index 00000000..d91b656a --- /dev/null +++ b/packages/appkit-rs/.release-it.json @@ -0,0 +1,32 @@ +{ + "$schema": "https://unpkg.com/release-it@19/schema/release-it.json", + "git": { + "commit": false, + "tag": false, + "push": false, + "requireBranch": false, + "requireCleanWorkingDir": false, + "requireCommits": true, + "requireCommitsFail": false, + "tagMatch": "appkit-py-v*", + "tagName": "appkit-py-v${version}", + "getLatestTagFromAllRefs": true, + "commitsPath": "." + }, + "github": { + "release": false + }, + "npm": false, + "hooks": {}, + "plugins": { + "@release-it/conventional-changelog": { + "preset": { + "name": "conventionalcommits", + "bumpStrict": true + }, + "infile": "changelog-diff.md", + "gitRawCommitsOpts": { "path": "." }, + "commitsOpts": { "path": "." } + } + } +} diff --git a/packages/appkit-rs/Cargo.toml b/packages/appkit-rs/Cargo.toml new file mode 100644 index 00000000..76179c89 --- /dev/null +++ b/packages/appkit-rs/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "appkit-rs" +version = "0.1.0" +edition = "2021" + +[lib] +name = "appkit" +crate-type = ["cdylib", "rlib"] + +[features] +extension-module = ["pyo3/extension-module"] + +[dependencies] +pyo3 = "0.23" +pyo3-async-runtimes = { version = "0.23", features = ["tokio-runtime"] } +tokio = { version = "1", features = ["full"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +sha2 = "0.10" +opentelemetry = "0.27" +opentelemetry_sdk = { version = "0.27", features = ["rt-tokio"] } +opentelemetry-otlp = { version = "0.27" } +tracing = "0.1" +rand = "0.8" +urlencoding = "2" +axum = "0.7" +tower-http = { version = "0.6", features = ["cors", "fs"] } +uuid = { version = "1", features = ["v4"] } +tokio-stream = "0.1" +futures = "0.3" + +[dev-dependencies] +serial_test = "3" diff --git a/packages/appkit-rs/README.md b/packages/appkit-rs/README.md new file mode 100644 index 00000000..4fc5714d --- /dev/null +++ b/packages/appkit-rs/README.md @@ -0,0 +1,123 @@ +# appkit (Python SDK) + +Python SDK for Databricks AppKit. Provides the same plugin-based application framework as the TypeScript `@databricks/appkit` SDK, targeting Python backend applications. + +## Status + +**Prototype** — the API is functional but may change. Not yet published to PyPI. + +## Prerequisites + +- **Rust toolchain** — install via [rustup](https://rustup.rs/) +- **Python 3.11+** +- **[maturin](https://www.maturin.rs/)** — `pip install maturin` + +## Local Build + +```bash +# Development (debug build, installed into current venv) +cd packages/appkit-rs +maturin develop + +# Production wheel +maturin build --release +``` + +The built wheel is written to `target/wheels/`. + +## Running Tests + +```bash +# Rust unit tests +cargo test + +# Python integration tests (requires maturin develop first) +pytest +``` + +## Wheel Bundling + +For prototyping or sharing pre-built wheels without PyPI: + +```bash +maturin build --release --manylinux 2_28 +``` + +Then reference the wheel from a `requirements.txt`: + +``` +databricks_appkit @ file:///path/to/appkit-0.1.0-cp311-cp311-manylinux_2_28_x86_64.whl +``` + +## CI + +The [`build-wheels.yml`](../../.github/workflows/build-wheels.yml) workflow builds wheels for: + +- **Linux**: x86_64 and aarch64 (manylinux 2_28) +- **macOS**: x86_64 and aarch64 + +It runs on pushes to `main` and on pull requests that touch `packages/appkit-rs/**`. A publish step is gated on `appkit-py-v*` tags. + +## Deployment Gotchas + +### Cross-compilation for Databricks Apps + +Databricks Apps runs on x86_64 Linux with Python 3.11. If you're building on macOS (or a different arch), you need to cross-compile the native extension. Two approaches: + +1. **Build on a Linux x86_64 machine** (simplest): + ```bash + PYO3_PYTHON=python3.11 cargo build --release --lib + strip target/release/libappkit.so + cp target/release/libappkit.so appkit/appkit.cpython-311-x86_64-linux-gnu.so + ``` + +2. **Use maturin with zig** (from macOS): + ```bash + pip install 'maturin[zig]' + maturin build --release --target x86_64-unknown-linux-gnu --zig + ``` + +### Deploying without PyPI + +Since the SDK isn't published yet, deploy the native extension directly alongside your app code: + +``` +your-app/ +├── app.yaml +├── server/ +│ ├── __init__.py +│ └── app.py +└── appkit/ # SDK files copied into your app + ├── __init__.py + ├── _context.py + ├── appkit.cpython-311-x86_64-linux-gnu.so + └── plugins/ + ├── __init__.py + └── ... +``` + +Upload to workspace and deploy: +```bash +databricks workspace import-dir ./your-app /Workspace/Users//my-app --overwrite +databricks apps deploy my-app --source-code-path /Workspace/Users//my-app +``` + +### SIGTERM handling + +Databricks Apps sends SIGTERM with a 15-second grace period on redeployment. The default `asyncio.Event().wait()` pattern in app entry points doesn't handle SIGTERM within that window. Consider adding a signal handler: + +```python +import signal + +stop = asyncio.Event() +loop = asyncio.get_running_loop() +for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, stop.set) + +await stop.wait() +app.shutdown() +``` + +## Getting Started + +See the [Python template](../../template-python/) for a runnable scaffold that uses `appkit` to build a backend application with plugins, routes, caching, and streaming. diff --git a/packages/appkit-rs/appkit.pyi b/packages/appkit-rs/appkit.pyi new file mode 100644 index 00000000..8288ece2 --- /dev/null +++ b/packages/appkit-rs/appkit.pyi @@ -0,0 +1,899 @@ +"""Type stubs for appkit — Databricks AppKit Python SDK.""" + +from __future__ import annotations + +from typing import Any, AsyncIterator, Awaitable, Callable, Optional, Sequence + +# --------------------------------------------------------------------------- +# Module-level context variable (created at import time) +# --------------------------------------------------------------------------- + +_USER_CONTEXT_VAR: Any # contextvars.ContextVar[UserContext | None] + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class AppConfig: + """Application configuration parsed from environment variables.""" + + databricks_host: str + client_id: Optional[str] + client_secret: Optional[str] + warehouse_id: Optional[str] + app_port: int + host: str + otel_endpoint: Optional[str] + + def __init__( + self, + databricks_host: str, + *, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + warehouse_id: Optional[str] = None, + app_port: int = 8000, + host: str = "0.0.0.0", + otel_endpoint: Optional[str] = None, + ) -> None: ... + @staticmethod + def from_env() -> AppConfig: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +# --------------------------------------------------------------------------- +# Auth +# --------------------------------------------------------------------------- + +class ServiceContext: + """Service-level authentication context (service principal).""" + + config: AppConfig + + def __init__(self, config: AppConfig) -> None: ... + def get_token(self) -> Awaitable[str]: ... + def __repr__(self) -> str: ... + +class UserContext: + """Per-request user context for OBO (On-Behalf-Of) flows.""" + + token: str + user_id: str + user_name: Optional[str] + workspace_id: str + warehouse_id: Optional[str] + + @property + def is_user_context(self) -> bool: ... + def __init__( + self, + token: str, + user_id: str, + *, + user_name: Optional[str] = None, + workspace_id: str, + warehouse_id: Optional[str] = None, + ) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +# --------------------------------------------------------------------------- +# Cache +# --------------------------------------------------------------------------- + +class CacheConfig: + """Cache configuration with defaults matching TypeScript cacheDefaults.""" + + enabled: bool + ttl: int + max_size: int + cleanup_probability: float + + def __init__( + self, + *, + enabled: bool = True, + ttl: int = 3600, + max_size: int = 1000, + cleanup_probability: float = 0.01, + ) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class CacheManager: + """Cache manager with TTL, LRU eviction, and in-flight deduplication.""" + + def __init__(self, config: Optional[CacheConfig] = None) -> None: ... + @staticmethod + def generate_key(parts: list[str], user_key: str) -> str: ... + def get(self, key: str) -> Awaitable[Optional[str]]: ... + def set(self, key: str, value: str, *, ttl: Optional[int] = None) -> Awaitable[None]: ... + def delete(self, key: str) -> Awaitable[None]: ... + def has(self, key: str) -> Awaitable[bool]: ... + def clear(self) -> Awaitable[None]: ... + def size(self) -> Awaitable[int]: ... + def get_or_execute( + self, + key: str, + func: Callable[[], Awaitable[str]], + *, + ttl: Optional[int] = None, + ) -> Awaitable[str]: ... + def __repr__(self) -> str: ... + def __bool__(self) -> bool: ... + +# --------------------------------------------------------------------------- +# Plugin system +# --------------------------------------------------------------------------- + +class PluginPhase: + """Phase ordering constants for Python plugins.""" + + CORE: str + NORMAL: str + DEFERRED: str + +class PluginManifest: + """Plugin manifest — metadata.""" + + name: str + display_name: Optional[str] + description: Optional[str] + + def __init__( + self, + name: str, + *, + display_name: Optional[str] = None, + description: Optional[str] = None, + ) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class StreamIterator: + """Async iterator yielding JSON string items from a streaming execution. + + Used by ``Plugin.execute_stream()`` and ``ServingConnector.stream()``. + + Example:: + + stream = await plugin.execute_stream(my_gen_fn) + async for item in stream: + data = json.loads(item) + """ + + def __aiter__(self) -> AsyncIterator[str]: ... + def __anext__(self) -> Awaitable[str]: ... + def __repr__(self) -> str: ... + +class ExecutionResult: + """Python-facing execution result (frozen, immutable).""" + + ok: bool + data: Optional[str] + status: Optional[int] + message: Optional[str] + + def __repr__(self) -> str: ... + def __bool__(self) -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class Plugin: + """Base class for Python plugins. Subclass to create custom plugins. + + Example:: + + class MyPlugin(Plugin): + def __init__(self): + super().__init__("my-plugin", manifest=PluginManifest("my-plugin")) + + async def setup(self): + pass + """ + + name: str + phase: str + manifest: PluginManifest + is_ready: bool + + def __init__( + self, + name: str, + *, + phase: str = "normal", + manifest: PluginManifest, + ) -> None: ... + def setup(self) -> Awaitable[None]: + """One-time initialization hook called once during + :meth:`AppKit.initialize` after the plugin's runtime is injected. + """ + ... + def exports(self) -> dict[str, str]: + """Return string values exported to other plugins and the server.""" + ... + def client_config(self) -> dict[str, str]: + """Return per-plugin config surfaced to clients via ``/api/config``.""" + ... + def inject_routes(self, router: Router) -> None: + """Register HTTP routes with the server router. Called once per + plugin, mounted under ``/api//``. + """ + ... + def execute( + self, + func: Callable[[], Awaitable[str]], + *, + user_key: str = "", + timeout_ms: Optional[int] = None, + retry_attempts: Optional[int] = None, + cache_key: Optional[list[str]] = None, + cache_ttl: Optional[int] = None, + ) -> Awaitable[ExecutionResult]: + """Execute a coroutine through the plugin's interceptor chain + (telemetry, timeout, retry, cache). + + ``user_key`` scopes caches to a user for OBO flows. ``cache_key`` + parts are hashed together with ``user_key`` to form a stable key. + Pass ``timeout_ms=None`` / ``retry_attempts=None`` to fall back to + the plugin defaults. + """ + ... + def execute_stream( + self, + func: Callable[[], Any], + *, + user_key: str = "", + timeout_ms: Optional[int] = None, + ) -> Awaitable[StreamIterator]: + """Execute a streaming function through the interceptor chain. + + The callable should return a Python async generator that yields + JSON strings. Returns a StreamIterator for async iteration. + + Retry and cache are not supported for streams. + """ + ... + def __repr__(self) -> str: ... + +class AppKit: + """AppKit orchestrator — registers plugins and manages initialization. + + Most applications should use :func:`create_app` instead of driving + this class directly; ``create_app`` wires registration, initialization + and optional server startup in one call. + """ + + def __init__(self) -> None: ... + def register(self, plugin: Plugin) -> None: + """Register a plugin. Must be called before :meth:`initialize`.""" + ... + def initialize( + self, + config: AppConfig, + *, + cache_config: Optional[CacheConfig] = None, + ) -> Awaitable[None]: + """Initialize telemetry, cache, and run phase-ordered ``Plugin.setup``. + + After this returns, plugins are ready to serve requests. Calling + ``initialize`` twice is an error. + """ + ... + def get_plugin(self, name: str) -> Optional[Plugin]: + """Look up a registered plugin by its manifest name.""" + ... + def plugin_names(self) -> list[str]: + """Return the manifest names of all registered plugins.""" + ... + def start_server(self, server_config: ServerConfig) -> Awaitable[None]: + """Start the HTTP server and block until it exits. + + Routes previously injected via ``Plugin.inject_routes`` are mounted + under ``/api//...``. + """ + ... + def shutdown(self) -> None: + """Stop the HTTP server and release resources.""" + ... + def __repr__(self) -> str: ... + def __len__(self) -> int: ... + def __bool__(self) -> bool: ... + def __contains__(self, name: str) -> bool: ... + +# --------------------------------------------------------------------------- +# Server / routing +# --------------------------------------------------------------------------- + +class Router: + """Router passed to Plugin.inject_routes() for route registration.""" + + plugin_name: str + + def get( + self, + path: str, + handler: Callable[[Request], Awaitable[str]], + *, + stream: bool = False, + ) -> None: ... + def post( + self, + path: str, + handler: Callable[[Request], Awaitable[str]], + *, + stream: bool = False, + ) -> None: ... + def put( + self, + path: str, + handler: Callable[[Request], Awaitable[str]], + *, + stream: bool = False, + ) -> None: ... + def delete( + self, + path: str, + handler: Callable[[Request], Awaitable[str]], + *, + stream: bool = False, + ) -> None: ... + def patch( + self, + path: str, + handler: Callable[[Request], Awaitable[str]], + *, + stream: bool = False, + ) -> None: ... + def __repr__(self) -> str: ... + +class Request: + """HTTP request data forwarded to Python route handlers.""" + + method: str + path: str + headers: dict[str, str] + query: str + body: str + + def json(self) -> Any: + """Parse the request body as JSON and return Python-native data. + + Raises ValueError if the body is not valid JSON. + """ + ... + def __repr__(self) -> str: ... + +class ServerConfig: + """Server configuration.""" + + host: str + port: int + auto_start: bool + static_path: Optional[str] + + def __init__( + self, + *, + host: str = "0.0.0.0", + port: int = 8000, + auto_start: bool = True, + static_path: Optional[str] = None, + ) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +# --------------------------------------------------------------------------- +# Connectors — Files +# --------------------------------------------------------------------------- + +class FileDirectoryEntry: + """A single entry in a directory listing.""" + + path: str + name: str + is_directory: bool + file_size: Optional[int] + last_modified: Optional[int] + + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class FileMetadata: + """File metadata from a HEAD request.""" + + content_length: Optional[int] + content_type: Optional[str] + last_modified: Optional[str] + + def __repr__(self) -> str: ... + +class FilePreview: + """File preview with optional text content.""" + + content_length: Optional[int] + content_type: Optional[str] + last_modified: Optional[str] + text_preview: Optional[str] + is_text: bool + is_image: bool + + def __repr__(self) -> str: ... + +class FilesConnector: + """Databricks Files API connector. + + Operates against Unity Catalog Volume paths. The ``default_volume`` + constructor argument is used as a prefix when a ``file_path`` is not + already a ``/Volumes/...`` absolute path. + """ + + def __init__(self, host: str, *, default_volume: Optional[str] = None) -> None: ... + def resolve_path(self, file_path: str) -> str: + """Join ``file_path`` with the connector's default volume if the + path is not already a fully-qualified ``/Volumes/...`` path. + """ + ... + def list( + self, + token: str, + *, + directory_path: Optional[str] = None, + ) -> Awaitable[list[FileDirectoryEntry]]: + """List entries under ``directory_path`` (or the default volume).""" + ... + def read( + self, + token: str, + file_path: str, + *, + max_size: Optional[int] = None, + ) -> Awaitable[str]: + """Read a text file, optionally truncated to ``max_size`` bytes.""" + ... + def download(self, token: str, file_path: str) -> Awaitable[bytes]: + """Download a file as raw bytes.""" + ... + def exists(self, token: str, file_path: str) -> Awaitable[bool]: + """Return whether the given file or directory exists.""" + ... + def metadata(self, token: str, file_path: str) -> Awaitable[FileMetadata]: + """Fetch metadata for a file via a HEAD request.""" + ... + def upload( + self, + token: str, + file_path: str, + contents: bytes, + *, + overwrite: bool = True, + ) -> Awaitable[None]: + """Upload ``contents`` to ``file_path``. Overwrites by default.""" + ... + def create_directory(self, token: str, directory_path: str) -> Awaitable[None]: + """Create a directory, creating parents as needed.""" + ... + def delete(self, token: str, file_path: str) -> Awaitable[None]: + """Delete a file or (empty) directory.""" + ... + def preview( + self, + token: str, + file_path: str, + *, + max_chars: int = 1024, + ) -> Awaitable[FilePreview]: + """Fetch metadata and a capped-length text preview if the file is + recognized as text.""" + ... + def __repr__(self) -> str: ... + +# --------------------------------------------------------------------------- +# Connectors — SQL Warehouse +# --------------------------------------------------------------------------- + +class SqlColumn: + """Column schema information.""" + + name: str + type_name: str + + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class SqlStatementResult: + """Result of a SQL statement execution.""" + + statement_id: str + status: str + columns: list[SqlColumn] + data: str + row_count: int + + def __repr__(self) -> str: ... + def __len__(self) -> int: ... + def __bool__(self) -> bool: ... + +class SqlWarehouseConnector: + """Databricks SQL Warehouse connector. + + Runs parameterised SQL statements against a Serverless or Pro warehouse + via the ``/api/2.0/sql/statements`` endpoint and polls until the + statement reaches a terminal status. + """ + + def __init__(self, host: str, *, timeout_ms: Optional[int] = None) -> None: ... + def execute_statement( + self, + token: str, + statement: str, + warehouse_id: str, + *, + parameters: Optional[list[tuple[str, str]]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + wait_timeout: Optional[str] = None, + disposition: Optional[str] = None, + format: Optional[str] = None, + on_wait_timeout: Optional[str] = None, + byte_limit: Optional[int] = None, + row_limit: Optional[int] = None, + timeout_ms: Optional[int] = None, + ) -> Awaitable[SqlStatementResult]: + """Execute ``statement`` on ``warehouse_id`` and return the result. + + ``parameters`` is a list of ``(name, value)`` pairs corresponding + to ``:name`` placeholders in the SQL; values are always passed as + strings and typed by the server via column metadata. + """ + ... + def __repr__(self) -> str: ... + +# --------------------------------------------------------------------------- +# Connectors — Genie +# --------------------------------------------------------------------------- + +class GenieAttachment: + """Genie query attachment metadata.""" + + attachment_id: Optional[str] + query_title: Optional[str] + query_description: Optional[str] + query_sql: Optional[str] + query_statement_id: Optional[str] + text_content: Optional[str] + suggested_questions: Optional[list[str]] + + def __repr__(self) -> str: ... + +class GenieMessage: + """Genie message response.""" + + message_id: str + conversation_id: str + space_id: str + status: str + content: str + attachments: list[GenieAttachment] + error: Optional[str] + + def __repr__(self) -> str: ... + +class GenieConversationHistory: + """Full conversation history.""" + + conversation_id: str + space_id: str + messages: list[GenieMessage] + + def __repr__(self) -> str: ... + def __len__(self) -> int: ... + +class GenieQueryResult: + """Query result from a Genie attachment.""" + + data: str + + def __repr__(self) -> str: ... + +class GenieConnector: + """Databricks Genie connector.""" + + def __init__( + self, + host: str, + *, + timeout_ms: Optional[int] = None, + max_messages: Optional[int] = None, + ) -> None: ... + def start_message( + self, + token: str, + space_id: str, + content: str, + *, + conversation_id: Optional[str] = None, + ) -> Awaitable[tuple[str, str]]: ... + def send_message( + self, + token: str, + space_id: str, + content: str, + *, + conversation_id: Optional[str] = None, + timeout_ms: Optional[int] = None, + ) -> Awaitable[GenieMessage]: ... + def get_message( + self, + token: str, + space_id: str, + conversation_id: str, + message_id: str, + *, + timeout_ms: Optional[int] = None, + ) -> Awaitable[GenieMessage]: ... + def list_messages( + self, + token: str, + space_id: str, + conversation_id: str, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> Awaitable[tuple[list[GenieMessage], Optional[str]]]: ... + def get_query_result( + self, + token: str, + space_id: str, + conversation_id: str, + message_id: str, + attachment_id: str, + ) -> Awaitable[GenieQueryResult]: ... + def get_conversation( + self, + token: str, + space_id: str, + conversation_id: str, + ) -> Awaitable[GenieConversationHistory]: ... + def __repr__(self) -> str: ... + +# --------------------------------------------------------------------------- +# Connectors — Serving +# --------------------------------------------------------------------------- + +class ServingResponse: + """Response from a serving endpoint invocation.""" + + data: str + status_code: int + + def __repr__(self) -> str: ... + def __bool__(self) -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class ServingConnector: + """Databricks Serving Endpoints connector. + + Wraps ``/serving-endpoints//invocations`` for synchronous calls + and the SSE streaming variant for LLM-style endpoints. + """ + + def __init__(self, host: str) -> None: ... + def invoke( + self, + token: str, + endpoint_name: str, + body: str, + ) -> Awaitable[ServingResponse]: + """Invoke a serving endpoint with a JSON request body and return + the raw response.""" + ... + def stream( + self, + token: str, + endpoint_name: str, + body: str, + ) -> Awaitable[StreamIterator]: + """Stream from a serving endpoint (SSE). + + Returns a StreamIterator that yields parsed SSE data payloads + as they arrive. The stream ends on ``data: [DONE]`` or connection close. + """ + ... + def __repr__(self) -> str: ... + +# --------------------------------------------------------------------------- +# Connectors — Lakebase +# --------------------------------------------------------------------------- + +class DatabaseCredential: + """Generated database credential for Lakebase access.""" + + token: str + expiration_time: str + + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class LakebasePgConfig: + """PostgreSQL connection configuration for Lakebase.""" + + host: str + database: str + port: int + ssl_mode: str + app_name: Optional[str] + + def __init__( + self, + *, + host: Optional[str] = None, + database: Optional[str] = None, + port: Optional[int] = None, + ssl_mode: Optional[str] = None, + app_name: Optional[str] = None, + ) -> None: ... + @staticmethod + def from_env() -> LakebasePgConfig: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class LakebaseConnector: + """Databricks Lakebase connector. + + Generates short-lived PostgreSQL credentials for a set of Lakebase + instances via the database credentials API. + """ + + def __init__(self, host: str) -> None: ... + def generate_credential( + self, + token: str, + instance_names: list[str], + *, + request_id: Optional[str] = None, + ) -> Awaitable[DatabaseCredential]: + """Generate a credential good for one or more Lakebase instances. + + Use the returned token as the PostgreSQL password for the life of + ``expiration_time`` — typically tens of minutes. + """ + ... + def __repr__(self) -> str: ... + +# --------------------------------------------------------------------------- +# Connectors — Vector Search +# --------------------------------------------------------------------------- + +class VsSearchRequest: + """Parsed Vector Search request matching the TS ``SearchRequest`` shape. + + Parameters are passed by keyword; ``filters_json`` is a JSON object + string of scalar-or-array filter values. + """ + + query_text: Optional[str] + query_vector: Optional[list[float]] + columns: Optional[list[str]] + num_results: Optional[int] + query_type: Optional[str] + filters_json: Optional[str] + reranker_columns: Optional[list[str]] + + def __init__( + self, + *, + query_text: Optional[str] = None, + query_vector: Optional[list[float]] = None, + columns: Optional[list[str]] = None, + num_results: Optional[int] = None, + query_type: Optional[str] = None, + filters_json: Optional[str] = None, + reranker_columns: Optional[list[str]] = None, + ) -> None: ... + def __repr__(self) -> str: ... + +class VectorSearchConnector: + """Databricks Vector Search REST connector. + + Wraps the ``/api/2.0/vector-search`` endpoints; returns raw JSON response + bodies as strings so Python callers can reuse their existing shaping + logic without a second serde pass across the PyO3 boundary. + """ + + def __init__(self, host: str, *, timeout_ms: Optional[int] = None) -> None: ... + def query( + self, + token: str, + index_name: str, + *, + columns: list[str], + num_results: int = 20, + query_type: str = "hybrid", + query_text: Optional[str] = None, + query_vector: Optional[list[float]] = None, + filters_json: Optional[str] = None, + reranker_columns: Optional[list[str]] = None, + ) -> Awaitable[str]: + """Run a query against ``index_name`` and return the raw JSON + response body. + + Pass ``query_text`` for text/hybrid queries, ``query_vector`` for + ANN queries. ``filters_json`` is the JSON-serialised filter object. + """ + ... + def query_next_page( + self, + token: str, + index_name: str, + endpoint_name: str, + page_token: str, + ) -> Awaitable[str]: + """Fetch the next page of a paginated query using ``page_token`` + from a prior :meth:`query` response.""" + ... + def __repr__(self) -> str: ... + +# --------------------------------------------------------------------------- +# Top-level functions +# --------------------------------------------------------------------------- + +def create_app( + *, + config: AppConfig, + plugins: list[Plugin] = ..., + cache_config: Optional[CacheConfig] = None, + server_config: Optional[ServerConfig] = None, + auto_start: bool = True, +) -> Awaitable[AppKit]: + """Create and initialize an AppKit instance in one call. + + This is the primary public API — mirrors TypeScript's ``createApp(...)``. + + Steps: + 1. Creates an AppKit instance + 2. Registers all provided plugins + 3. Initializes (telemetry, cache, phase-ordered plugin setup) + 4. Optionally starts the HTTP server (when ``auto_start=True``) + + Returns the initialized AppKit instance. + """ + ... + +def run_in_user_context(user_context: UserContext, func: Callable[[], Any]) -> Any: + """Run a synchronous callable with the given UserContext set as the + current execution context for the duration of the call. + + Mirrors TypeScript's ``runInUserContext(userContext, fn)``. + """ + ... + +def as_user( + user_context: UserContext, + func: Callable[[], Awaitable[Any]], +) -> Awaitable[Any]: + """Run an async callable with the given UserContext set for the duration. + + Returns an awaitable coroutine. + """ + ... + +def get_current_user() -> Optional[UserContext]: + """Get the current UserContext from the execution context, or None + if running as service principal. + """ + ... + +def is_in_user_context() -> bool: + """Check whether the current execution is running in a user context.""" + ... diff --git a/packages/appkit-rs/appkit/__init__.py b/packages/appkit-rs/appkit/__init__.py new file mode 100644 index 00000000..031e2204 --- /dev/null +++ b/packages/appkit-rs/appkit/__init__.py @@ -0,0 +1,5 @@ +from .appkit import * # noqa: F401,F403 +from .appkit import __doc__ + +if hasattr(__import__("appkit.appkit", fromlist=["__all__"]), "__all__"): + from .appkit import __all__ diff --git a/packages/appkit-rs/appkit/_context.py b/packages/appkit-rs/appkit/_context.py new file mode 100644 index 00000000..b0b764a2 --- /dev/null +++ b/packages/appkit-rs/appkit/_context.py @@ -0,0 +1,13 @@ +"""Async context-var wrapper for as_user(). + +Keeps the context-var set/reset inside a native Python coroutine so the +value propagates correctly across the PyO3-tokio bridge. +""" + + +async def _as_user_wrapper(_cv, _ctx, _fn): + _tok = _cv.set(_ctx) + try: + return await _fn() + finally: + _cv.reset(_tok) diff --git a/packages/appkit-rs/appkit/plugins/__init__.py b/packages/appkit-rs/appkit/plugins/__init__.py new file mode 100644 index 00000000..e0939235 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/__init__.py @@ -0,0 +1,85 @@ +"""Shipped AppKit plugins — Python subclasses of ``appkit.Plugin``. + +These wrap the Rust connectors with route injection, per-plugin config, +and OBO (on-behalf-of) token handling so apps can register real plugins +instead of raw connectors. +""" + +from .analytics import AnalyticsPlugin, AnalyticsPluginConfig +from .files import FilesPlugin, FilesPluginConfig, VolumeConfig +from .genie import GeniePlugin, GeniePluginConfig +from .lakebase import LakebasePlugin, LakebasePluginConfig +from .server import ServerPlugin, ServerPluginConfig +from .serving import ( + ServingEndpointConfig, + ServingPlugin, + ServingPluginConfig, +) +from .vector_search import ( + VectorSearchIndexConfig, + VectorSearchPlugin, + VectorSearchPluginConfig, +) + + +def analytics(config: AnalyticsPluginConfig) -> AnalyticsPlugin: + """Construct an :class:`AnalyticsPlugin` — the plugin entry point.""" + return AnalyticsPlugin(config) + + +def vector_search(config: VectorSearchPluginConfig) -> VectorSearchPlugin: + """Construct a :class:`VectorSearchPlugin` — the plugin entry point.""" + return VectorSearchPlugin(config) + + +def server(config: ServerPluginConfig | None = None) -> ServerPlugin: + """Construct a :class:`ServerPlugin` — the plugin entry point.""" + return ServerPlugin(config) + + +def files(config: FilesPluginConfig) -> FilesPlugin: + """Construct a :class:`FilesPlugin` — the plugin entry point.""" + return FilesPlugin(config) + + +def genie(config: GeniePluginConfig) -> GeniePlugin: + """Construct a :class:`GeniePlugin` — the plugin entry point.""" + return GeniePlugin(config) + + +def serving(config: ServingPluginConfig) -> ServingPlugin: + """Construct a :class:`ServingPlugin` — the plugin entry point.""" + return ServingPlugin(config) + + +def lakebase(config: LakebasePluginConfig | None = None) -> LakebasePlugin: + """Construct a :class:`LakebasePlugin` — the plugin entry point.""" + return LakebasePlugin(config) + + +__all__ = [ + "AnalyticsPlugin", + "AnalyticsPluginConfig", + "FilesPlugin", + "FilesPluginConfig", + "VolumeConfig", + "GeniePlugin", + "GeniePluginConfig", + "ServingPlugin", + "ServingPluginConfig", + "ServingEndpointConfig", + "LakebasePlugin", + "LakebasePluginConfig", + "ServerPlugin", + "ServerPluginConfig", + "VectorSearchPlugin", + "VectorSearchPluginConfig", + "VectorSearchIndexConfig", + "analytics", + "vector_search", + "server", + "files", + "genie", + "serving", + "lakebase", +] diff --git a/packages/appkit-rs/appkit/plugins/_obo.py b/packages/appkit-rs/appkit/plugins/_obo.py new file mode 100644 index 00000000..7ee33f8c --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/_obo.py @@ -0,0 +1,53 @@ +"""Shared OBO (On-Behalf-Of) token extraction helpers. + +Plugin route handlers call :func:`obo_token` to pull the per-user access +token that Databricks Apps inject as ``X-Forwarded-Access-Token`` on every +proxied request. All comparison is lowercase because the Rust server +lowercases HTTP header names when it forwards them to Python handlers. +""" + +from __future__ import annotations + +from collections.abc import Mapping + +from appkit import AuthenticationError + +OBO_HEADER = "x-forwarded-access-token" +USER_HEADER = "x-forwarded-user" +EMAIL_HEADER = "x-forwarded-email" + + +def _get_header(headers: Mapping[str, str], name: str) -> str | None: + target = name.lower() + for key, value in headers.items(): + if key.lower() == target: + return value + return None + + +def obo_token(headers: Mapping[str, str]) -> str: + """Extract the OBO bearer token, raising if absent. + + Raises :class:`appkit.AuthenticationError` when the header is missing so + the default interceptor chain maps it to a 401 response. + """ + token = _get_header(headers, OBO_HEADER) + if not token: + raise AuthenticationError( + f"Missing {OBO_HEADER} header — plugin route requires OBO access." + ) + return token + + +def obo_user_key(headers: Mapping[str, str]) -> str: + """Resolve a stable per-user cache key from forwarded identity headers. + + Falls back to an empty string if neither user nor email is present; the + cache interceptor treats that as a shared key, which is the intended + behavior for unauthenticated routes. + """ + user = _get_header(headers, USER_HEADER) + if user: + return user + email = _get_header(headers, EMAIL_HEADER) + return email or "" diff --git a/packages/appkit-rs/appkit/plugins/analytics.py b/packages/appkit-rs/appkit/plugins/analytics.py new file mode 100644 index 00000000..515890f9 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/analytics.py @@ -0,0 +1,369 @@ +"""AnalyticsPlugin — SQL query execution against a Databricks SQL Warehouse. + +Loads parameterised SQL files from ``config/queries/`` and exposes a +``POST /api/analytics/query/:query_key`` route. Query files ending in +``.obo.sql`` execute as the calling user (OBO); plain ``.sql`` executes +as the configured service principal. + +The Rust side owns query-file discovery, ``:param`` extraction (literal- +and comment-aware), and cache-key composition — see +``packages/appkit-rs/src/plugins/analytics.rs``. This module provides +the Python plugin surface: route injection, request parsing, OBO +handling, and parameter validation against the query placeholders. +""" + +from __future__ import annotations + +import json +import os +import re +from pathlib import Path +from typing import Any + +from appkit import ( + Plugin, + PluginManifest, + SqlWarehouseConnector, + ValidationError, +) + +from ._obo import obo_token, obo_user_key + +# Mirror the Rust `is_valid_query_key` rule — no path traversal. +_QUERY_KEY_RE = re.compile(r"^[A-Za-z0-9_-]{1,128}$") + + +class AnalyticsPluginConfig: + """Configuration for :class:`AnalyticsPlugin`. + + ``warehouse_id`` routes queries to a Databricks SQL warehouse. + ``queries_dir`` overrides the default ``config/queries`` path. + ``host`` defaults to the ``DATABRICKS_HOST`` environment variable. + """ + + __slots__ = ("warehouse_id", "queries_dir", "host", "timeout_ms") + + def __init__( + self, + *, + warehouse_id: str | None = None, + queries_dir: str | os.PathLike[str] | None = None, + host: str | None = None, + timeout_ms: int | None = None, + ) -> None: + self.warehouse_id = warehouse_id + self.queries_dir = Path(queries_dir) if queries_dir else None + self.host = host + self.timeout_ms = timeout_ms + + def __repr__(self) -> str: + return ( + f"AnalyticsPluginConfig(warehouse_id={self.warehouse_id!r}, " + f"queries_dir={self.queries_dir!r})" + ) + + +def _extract_param_names(query: str) -> list[str]: + """Extract `:param_name` placeholders, skipping SQL string/comment contexts. + + Mirrors ``QueryProcessor::extract_param_names`` in + ``packages/appkit-rs/src/plugins/analytics.rs`` — see that module for + the canonical specification. This Python port exists so the plugin + can validate extra-key errors locally without a Rust round-trip. + """ + out: list[str] = [] + seen: set[str] = set() + i = 0 + n = len(query) + while i < n: + c = query[i] + + # Line comment: -- ... to end of line. + if c == "-" and i + 1 < n and query[i + 1] == "-": + i += 2 + while i < n and query[i] != "\n": + i += 1 + continue + + # Block comment: /* ... */, nestable. + if c == "/" and i + 1 < n and query[i + 1] == "*": + i += 2 + depth = 1 + while i < n and depth > 0: + if i + 1 < n and query[i] == "/" and query[i + 1] == "*": + depth += 1 + i += 2 + elif i + 1 < n and query[i] == "*" and query[i + 1] == "/": + depth -= 1 + i += 2 + else: + i += 1 + continue + + # Single-quoted string literal: '...'. Doubled '' is an escape. + if c == "'": + i += 1 + while i < n: + if query[i] == "'": + if i + 1 < n and query[i + 1] == "'": + i += 2 + else: + i += 1 + break + else: + i += 1 + continue + + # Double-quoted identifier: "...". Doubled "" is an escape. + if c == '"': + i += 1 + while i < n: + if query[i] == '"': + if i + 1 < n and query[i + 1] == '"': + i += 2 + else: + i += 1 + break + else: + i += 1 + continue + + # Dollar-quoted string: $tag$...$tag$ (tag may be empty). + if c == "$": + tag_end = i + 1 + while tag_end < n and _is_ident_continue(query[tag_end]): + tag_end += 1 + if tag_end < n and query[tag_end] == "$": + delim = query[i : tag_end + 1] + j = tag_end + 1 + hit = query.find(delim, j) + i = hit + len(delim) if hit != -1 else n + continue + + if c == ":": + # `::TYPE` cast — consume both colons plus the type identifier. + if i + 1 < n and query[i + 1] == ":": + i += 2 + while i < n and _is_ident_continue(query[i]): + i += 1 + continue + if i + 1 < n and _is_ident_start(query[i + 1]): + start = i + 1 + end = start + while end < n and _is_ident_continue(query[end]): + end += 1 + name = query[start:end] + if name and name not in seen: + seen.add(name) + out.append(name) + i = end + continue + i += 1 + return out + + +def _is_ident_start(c: str) -> bool: + return c.isalpha() or c == "_" + + +def _is_ident_continue(c: str) -> bool: + return c.isalnum() or c == "_" + + +class _LoadedQuery: + __slots__ = ("query_key", "query", "is_as_user") + + def __init__(self, query_key: str, query: str, *, is_as_user: bool) -> None: + self.query_key = query_key + self.query = query + self.is_as_user = is_as_user + + +def _load_query(queries_dir: Path, query_key: str) -> _LoadedQuery | None: + """Load `.obo.sql` (preferred) or `.sql`.""" + if not _QUERY_KEY_RE.match(query_key): + return None + obo = queries_dir / f"{query_key}.obo.sql" + sp = queries_dir / f"{query_key}.sql" + if obo.is_file(): + return _LoadedQuery(query_key, obo.read_text(), is_as_user=True) + if sp.is_file(): + return _LoadedQuery(query_key, sp.read_text(), is_as_user=False) + return None + + +class AnalyticsPlugin(Plugin): + """SQL query execution plugin. + + Queries live on disk under ``queries_dir`` (default ``config/queries``) + and are referenced by key in the route path. + """ + + NAME = "analytics" + + def __init__(self, config: AnalyticsPluginConfig) -> None: + super().__init__( + self.NAME, + manifest=PluginManifest( + self.NAME, + display_name="Analytics Plugin", + description="SQL query execution against Databricks SQL Warehouses", + ), + ) + host = config.host or os.environ.get("DATABRICKS_HOST") + if not host: + raise ValueError( + "AnalyticsPlugin requires a Databricks host. Set DATABRICKS_HOST " + "or pass host= in AnalyticsPluginConfig." + ) + warehouse_id = config.warehouse_id or os.environ.get( + "DATABRICKS_WAREHOUSE_ID" + ) + if not warehouse_id: + raise ValueError( + "AnalyticsPlugin requires a warehouse_id. Set " + "DATABRICKS_WAREHOUSE_ID or pass warehouse_id= in " + "AnalyticsPluginConfig." + ) + self._config = config + self._host = host + self._warehouse_id = warehouse_id + self._queries_dir = config.queries_dir or Path("config") / "queries" + self._connector = SqlWarehouseConnector(host, timeout_ms=config.timeout_ms) + + @property + def queries_dir(self) -> Path: + return self._queries_dir + + @property + def warehouse_id(self) -> str: + return self._warehouse_id + + def client_config(self) -> dict[str, str]: + return {"warehouse_id": self._warehouse_id} + + def inject_routes(self, router: Any) -> None: + router.post("/query/:query_key", self._handle_query) + router.get("/queries", self._handle_list_queries) + + async def _handle_list_queries(self, _request: Any) -> str: + if not self._queries_dir.is_dir(): + return json.dumps({"queries": []}) + keys: set[str] = set() + for path in self._queries_dir.iterdir(): + if path.is_file() and path.suffix == ".sql": + name = path.name + if name.endswith(".obo.sql"): + keys.add(name[: -len(".obo.sql")]) + else: + keys.add(name[: -len(".sql")]) + return json.dumps({"queries": sorted(keys)}) + + async def _handle_query(self, request: Any) -> str: + query_key = self._extract_query_key(request.path) + loaded = _load_query(self._queries_dir, query_key) + if loaded is None: + raise ValidationError(f"Unknown query: {query_key!r}") + + body = request.json() if request.body else {} + if not isinstance(body, dict): + raise ValidationError("Request body must be a JSON object") + raw_params = body.get("parameters", {}) or {} + if not isinstance(raw_params, dict): + raise ValidationError("'parameters' must be a JSON object") + + param_names = _extract_param_names(loaded.query) + param_set = set(param_names) + for key in raw_params: + if key not in param_set: + valid = ", ".join(sorted(param_set)) if param_set else "none" + raise ValidationError( + f"Invalid value for {key!r}: expected a parameter defined " + f"in the query (valid: {valid})" + ) + + sql_parameters: list[tuple[str, str]] = [] + for name, value in raw_params.items(): + if value is None: + continue + sql_parameters.append((name, _coerce_sql_value(value))) + + token, user_key = self._resolve_auth(request, loaded.is_as_user) + + async def run() -> str: + result = await self._connector.execute_statement( + token, + loaded.query, + self._warehouse_id, + parameters=sql_parameters or None, + timeout_ms=self._config.timeout_ms, + ) + return json.dumps( + { + "statement_id": result.statement_id, + "status": result.status, + "columns": [ + {"name": c.name, "type": c.type_name} for c in result.columns + ], + "row_count": result.row_count, + "data": json.loads(result.data) if result.data else [], + } + ) + + execution = await self.execute( + run, + user_key=user_key, + cache_key=["analytics:query", query_key, json.dumps(raw_params, sort_keys=True)], + ) + if not execution.ok: + raise _status_to_error(execution.status or 500, execution.message or "") + return execution.data or "{}" + + def _extract_query_key(self, path: str) -> str: + tail = path.rsplit("/", 1)[-1] + if not _QUERY_KEY_RE.match(tail): + raise ValidationError(f"Invalid query key: {tail!r}") + return tail + + def _resolve_auth(self, request: Any, is_as_user: bool) -> tuple[str, str]: + if is_as_user: + token = obo_token(request.headers) + return token, obo_user_key(request.headers) + env_token = os.environ.get("DATABRICKS_TOKEN", "") + return env_token, "" + + +def _coerce_sql_value(value: Any) -> str: + if isinstance(value, bool): + return "true" if value else "false" + if value is None: + return "" + return str(value) + + +def _status_to_error(status: int, message: str) -> Exception: + from appkit import ( + AppKitError, + AuthenticationError, + InternalError, + NotFoundError, + TimeoutError as AppkitTimeoutError, + UpstreamError, + ) + + if status == 400: + return ValidationError(message) + if status == 401: + return AuthenticationError(message) + if status == 404: + return NotFoundError(message) + if status == 408: + return AppkitTimeoutError(message) + if 500 <= status < 600 and status != 500: + return UpstreamError(message) + if status == 500: + return InternalError(message) + return AppKitError(message) + + +__all__ = ["AnalyticsPlugin", "AnalyticsPluginConfig"] diff --git a/packages/appkit-rs/appkit/plugins/files.py b/packages/appkit-rs/appkit/plugins/files.py new file mode 100644 index 00000000..6f8eb895 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/files.py @@ -0,0 +1,257 @@ +"""FilesPlugin — Python wrapper around the Rust ``FilesConnector``. + +Registers one :class:`FilesConnector` per configured volume alias and exposes +Unity Catalog Volumes operations over HTTP routes mounted at +``/api/files/...``. Every route requires an OBO token forwarded by Databricks +Apps (``X-Forwarded-Access-Token``). +""" + +from __future__ import annotations + +import base64 +import json +import os +from collections.abc import Mapping +from typing import Any +from urllib.parse import parse_qs + +from appkit import FilesConnector, Plugin, PluginManifest + +from ._obo import obo_token + + +class VolumeConfig: + """Per-volume alias configuration. + + ``path`` is the fully-qualified Unity Catalog volume path, for example + ``/Volumes/catalog/schema/volume``. + """ + + __slots__ = ("path", "max_upload_size") + + def __init__(self, *, path: str, max_upload_size: int | None = None) -> None: + if not path: + raise ValueError("VolumeConfig.path is required") + self.path = path + self.max_upload_size = max_upload_size + + def __repr__(self) -> str: + return ( + f"VolumeConfig(path={self.path!r}, " + f"max_upload_size={self.max_upload_size!r})" + ) + + +class FilesPluginConfig: + """Configuration for :class:`FilesPlugin`. + + ``volumes`` maps alias → :class:`VolumeConfig`. Aliases appear in route + URLs as the ``volume`` query parameter (for example + ``/api/files/list?volume=uploads``). ``host`` defaults to the + ``DATABRICKS_HOST`` environment variable. + """ + + __slots__ = ("volumes", "host", "timeout_ms", "max_upload_size") + + def __init__( + self, + *, + volumes: Mapping[str, VolumeConfig], + host: str | None = None, + timeout_ms: int | None = None, + max_upload_size: int | None = None, + ) -> None: + if not volumes: + raise ValueError("FilesPluginConfig requires at least one volume") + self.volumes: dict[str, VolumeConfig] = dict(volumes) + self.host = host + self.timeout_ms = timeout_ms + self.max_upload_size = max_upload_size + + def __repr__(self) -> str: + return ( + f"FilesPluginConfig(volumes={sorted(self.volumes)!r}, " + f"host={self.host!r})" + ) + + +def _query(request: Any) -> dict[str, str]: + parsed = parse_qs(request.query, keep_blank_values=True) + return {k: v[0] for k, v in parsed.items() if v} + + +class FilesPlugin(Plugin): + """Unity Catalog Volumes file operations plugin.""" + + NAME = "files" + + def __init__(self, config: FilesPluginConfig) -> None: + super().__init__( + self.NAME, + manifest=PluginManifest( + self.NAME, + display_name="Files Plugin", + description="Unity Catalog Volumes file operations", + ), + ) + host = config.host or os.environ.get("DATABRICKS_HOST") + if not host: + raise ValueError( + "FilesPlugin requires a Databricks host. Set DATABRICKS_HOST " + "or pass host= in FilesPluginConfig." + ) + self._config = config + self._host = host + self._connectors: dict[str, FilesConnector] = { + alias: FilesConnector(host, default_volume=vcfg.path) + for alias, vcfg in config.volumes.items() + } + + def client_config(self) -> dict[str, str]: + return {"volumes": ",".join(sorted(self._config.volumes))} + + def inject_routes(self, router: Any) -> None: + router.get("/volumes", self._handle_volumes) + router.get("/list", self._handle_list) + router.get("/read", self._handle_read) + router.get("/metadata", self._handle_metadata) + router.get("/exists", self._handle_exists) + router.get("/preview", self._handle_preview) + router.post("/mkdir", self._handle_mkdir) + router.delete("/delete", self._handle_delete) + router.post("/upload", self._handle_upload) + + def connector(self, volume_key: str) -> FilesConnector: + """Return the :class:`FilesConnector` registered for ``volume_key``. + + Raises :class:`ValueError` when the alias is not configured. + """ + try: + return self._connectors[volume_key] + except KeyError as exc: + raise ValueError( + f"Unknown volume {volume_key!r}. Configured: " + f"{sorted(self._connectors)!r}" + ) from exc + + def _resolve( + self, request: Any, *, require_path: bool = False + ) -> tuple[FilesConnector, str, str | None]: + token = obo_token(request.headers) + params = _query(request) + volume_key = params.get("volume") + if not volume_key: + raise ValueError("Missing required query parameter 'volume'") + connector = self.connector(volume_key) + path = params.get("path") + if require_path and not path: + raise ValueError("Missing required query parameter 'path'") + return connector, token, path + + async def _handle_volumes(self, _request: Any) -> str: + return json.dumps({"volumes": sorted(self._config.volumes)}) + + async def _handle_list(self, request: Any) -> str: + connector, token, path = self._resolve(request) + entries = await connector.list(token, directory_path=path) + payload = [ + { + "path": e.path, + "name": e.name, + "is_directory": e.is_directory, + "file_size": e.file_size, + "last_modified": e.last_modified, + } + for e in entries + ] + return json.dumps({"entries": payload}) + + async def _handle_read(self, request: Any) -> str: + connector, token, path = self._resolve(request, require_path=True) + content = await connector.read(token, path) + return json.dumps({"content": content}) + + async def _handle_metadata(self, request: Any) -> str: + connector, token, path = self._resolve(request, require_path=True) + meta = await connector.metadata(token, path) + return json.dumps( + { + "content_length": meta.content_length, + "content_type": meta.content_type, + "last_modified": meta.last_modified, + } + ) + + async def _handle_exists(self, request: Any) -> str: + connector, token, path = self._resolve(request, require_path=True) + exists = await connector.exists(token, path) + return json.dumps({"exists": exists}) + + async def _handle_preview(self, request: Any) -> str: + connector, token, path = self._resolve(request, require_path=True) + params = _query(request) + max_chars_raw = params.get("max_chars", "1024") + try: + max_chars = int(max_chars_raw) + except ValueError as exc: + raise ValueError(f"Invalid max_chars: {max_chars_raw!r}") from exc + preview = await connector.preview(token, path, max_chars=max_chars) + return json.dumps( + { + "content_length": preview.content_length, + "content_type": preview.content_type, + "last_modified": preview.last_modified, + "text_preview": preview.text_preview, + "is_text": preview.is_text, + "is_image": preview.is_image, + } + ) + + async def _handle_mkdir(self, request: Any) -> str: + token = obo_token(request.headers) + body = request.json() + if not isinstance(body, dict): + raise ValueError("mkdir body must be a JSON object") + volume_key = body.get("volume") + path = body.get("path") + if not volume_key or not path: + raise ValueError("mkdir requires 'volume' and 'path' fields") + await self.connector(volume_key).create_directory(token, path) + return json.dumps({"created": path}) + + async def _handle_delete(self, request: Any) -> str: + connector, token, path = self._resolve(request, require_path=True) + await connector.delete(token, path) + return json.dumps({"deleted": path}) + + async def _handle_upload(self, request: Any) -> str: + token = obo_token(request.headers) + body = request.json() + if not isinstance(body, dict): + raise ValueError("upload body must be a JSON object") + volume_key = body.get("volume") + path = body.get("path") + contents_b64 = body.get("contents_base64") + overwrite = bool(body.get("overwrite", True)) + if not volume_key or not path or contents_b64 is None: + raise ValueError( + "upload requires 'volume', 'path', and 'contents_base64' fields" + ) + try: + contents = base64.b64decode(contents_b64, validate=True) + except (ValueError, base64.binascii.Error) as exc: + raise ValueError(f"contents_base64 is not valid base64: {exc}") from exc + max_size = self._config.volumes[volume_key].max_upload_size or ( + self._config.max_upload_size + ) + if max_size is not None and len(contents) > max_size: + raise ValueError( + f"Upload size {len(contents)} exceeds max {max_size} bytes" + ) + await self.connector(volume_key).upload( + token, path, contents, overwrite=overwrite + ) + return json.dumps({"uploaded": path, "size": len(contents)}) + + +__all__ = ["FilesPlugin", "FilesPluginConfig", "VolumeConfig"] diff --git a/packages/appkit-rs/appkit/plugins/genie.py b/packages/appkit-rs/appkit/plugins/genie.py new file mode 100644 index 00000000..cc502194 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/genie.py @@ -0,0 +1,199 @@ +"""GeniePlugin — conversational analytics over Databricks Genie spaces. + +Exposes routes under ``/api/genie``: + +- ``POST /message`` — send a message (start a new conversation or reply + to an existing one) and wait for the completed response. +- ``GET /conversation?space=&conversation_id=`` — read the + full conversation history. +- ``GET /query-result?...`` — fetch the tabular result for an attachment. + +Auth is always OBO — Genie spaces are user-scoped. The Rust +``GenieConnector`` handles polling, retries, and response shaping. +""" + +from __future__ import annotations + +import json +import os +from collections.abc import Mapping +from typing import Any +from urllib.parse import parse_qs + +from appkit import GenieConnector, Plugin, PluginManifest, ValidationError + +from ._obo import obo_token, obo_user_key + + +class GeniePluginConfig: + """Configuration for :class:`GeniePlugin`. + + ``spaces`` maps alias → Genie ``space_id``. Route handlers accept + an alias (not the raw space id) so the client never sees the UC + resource identifier. + """ + + __slots__ = ("spaces", "host", "timeout_ms", "max_messages") + + def __init__( + self, + *, + spaces: Mapping[str, str], + host: str | None = None, + timeout_ms: int | None = None, + max_messages: int | None = None, + ) -> None: + if not spaces: + raise ValueError("GeniePluginConfig requires at least one space") + self.spaces: dict[str, str] = dict(spaces) + self.host = host + self.timeout_ms = timeout_ms + self.max_messages = max_messages + + def __repr__(self) -> str: + return f"GeniePluginConfig(spaces={sorted(self.spaces)!r})" + + +def _query(request: Any) -> dict[str, str]: + parsed = parse_qs(request.query, keep_blank_values=True) + return {k: v[0] for k, v in parsed.items() if v} + + +class GeniePlugin(Plugin): + """Genie conversational analytics plugin.""" + + NAME = "genie" + + def __init__(self, config: GeniePluginConfig) -> None: + super().__init__( + self.NAME, + manifest=PluginManifest( + self.NAME, + display_name="Genie Plugin", + description="Databricks Genie conversational analytics", + ), + ) + host = config.host or os.environ.get("DATABRICKS_HOST") + if not host: + raise ValueError( + "GeniePlugin requires a Databricks host. Set DATABRICKS_HOST " + "or pass host= in GeniePluginConfig." + ) + self._config = config + self._host = host + self._connector = GenieConnector( + host, + timeout_ms=config.timeout_ms, + max_messages=config.max_messages, + ) + + def client_config(self) -> dict[str, str]: + return {"spaces": ",".join(sorted(self._config.spaces))} + + def inject_routes(self, router: Any) -> None: + router.post("/message", self._handle_message) + router.get("/conversation", self._handle_conversation) + router.get("/query-result", self._handle_query_result) + + def _resolve_space(self, alias: str) -> str: + try: + return self._config.spaces[alias] + except KeyError as exc: + raise ValidationError( + f"Unknown space alias {alias!r}. Configured: " + f"{sorted(self._config.spaces)!r}" + ) from exc + + async def _handle_message(self, request: Any) -> str: + token = obo_token(request.headers) + body = request.json() if request.body else {} + if not isinstance(body, dict): + raise ValidationError("Request body must be a JSON object") + alias = body.get("space") + content = body.get("content") + if not alias or not content: + raise ValidationError("'space' and 'content' are required") + conversation_id = body.get("conversation_id") + space_id = self._resolve_space(alias) + + msg = await self._connector.send_message( + token, + space_id, + content, + conversation_id=conversation_id, + ) + return _message_to_json(msg) + + async def _handle_conversation(self, request: Any) -> str: + token = obo_token(request.headers) + params = _query(request) + alias = params.get("space") + conv_id = params.get("conversation_id") + if not alias or not conv_id: + raise ValidationError( + "'space' and 'conversation_id' are required query parameters" + ) + space_id = self._resolve_space(alias) + history = await self._connector.get_conversation(token, space_id, conv_id) + return json.dumps( + { + "conversation_id": history.conversation_id, + "space_id": history.space_id, + "messages": [ + json.loads(_message_to_json(m)) for m in history.messages + ], + } + ) + + async def _handle_query_result(self, request: Any) -> str: + token = obo_token(request.headers) + params = _query(request) + alias = params.get("space") + conv_id = params.get("conversation_id") + msg_id = params.get("message_id") + att_id = params.get("attachment_id") + if not (alias and conv_id and msg_id and att_id): + raise ValidationError( + "'space', 'conversation_id', 'message_id', and 'attachment_id' " + "are required query parameters" + ) + space_id = self._resolve_space(alias) + result = await self._connector.get_query_result( + token, space_id, conv_id, msg_id, att_id + ) + # GenieQueryResult.data is already a JSON string; pass it through + # inside a stable envelope so clients always parse an object. + return json.dumps({"data": json.loads(result.data) if result.data else None}) + + def _user_key(self, request: Any) -> str: + return obo_user_key(request.headers) + + +def _message_to_json(msg: Any) -> str: + attachments = [] + for att in msg.attachments: + attachments.append( + { + "attachment_id": att.attachment_id, + "query_title": att.query_title, + "query_description": att.query_description, + "query_sql": att.query_sql, + "query_statement_id": att.query_statement_id, + "text_content": att.text_content, + "suggested_questions": att.suggested_questions, + } + ) + return json.dumps( + { + "message_id": msg.message_id, + "conversation_id": msg.conversation_id, + "space_id": msg.space_id, + "status": msg.status, + "content": msg.content, + "attachments": attachments, + "error": msg.error, + } + ) + + +__all__ = ["GeniePlugin", "GeniePluginConfig"] diff --git a/packages/appkit-rs/appkit/plugins/lakebase.py b/packages/appkit-rs/appkit/plugins/lakebase.py new file mode 100644 index 00000000..0c0fd832 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/lakebase.py @@ -0,0 +1,136 @@ +"""LakebasePlugin — Databricks Lakebase (PostgreSQL) integration. + +Lakebase exposes a programmatic pool API rather than HTTP routes, so this +plugin publishes connection helpers through plugin attributes and +``exports()`` rather than via ``inject_routes``. The corresponding Rust +:class:`crate::plugins::lakebase::LakebasePluginCore` declares the +``postgres`` resource requirement for manifest parity. + +Typical usage:: + + lakebase = LakebasePlugin(LakebasePluginConfig()) + app = await create_app(plugins=[lakebase, ...], config=AppConfig.from_env()) + + # Inside a route handler running under OBO: + credential = await lakebase.generate_credential( + obo_token(request.headers), + instance_names=[lakebase.pg_config.host], + ) + # Pass `credential.token` as the PostgreSQL password for this request. +""" + +from __future__ import annotations + +import os +from typing import Any + +from appkit import ( + DatabaseCredential, + LakebaseConnector, + LakebasePgConfig, + Plugin, + PluginManifest, +) + + +class LakebasePluginConfig: + """Configuration for :class:`LakebasePlugin`. + + ``pg_config`` overrides the default :class:`LakebasePgConfig` built + from ``PGHOST``/``PGDATABASE``/``LAKEBASE_ENDPOINT`` etc. ``host`` + defaults to ``DATABRICKS_HOST`` and is used to reach the Lakebase + credential-generation REST API (distinct from the PG host). + """ + + __slots__ = ("pg_config", "host") + + def __init__( + self, + *, + pg_config: LakebasePgConfig | None = None, + host: str | None = None, + ) -> None: + self.pg_config = pg_config + self.host = host + + def __repr__(self) -> str: + return f"LakebasePluginConfig(pg_config={self.pg_config!r})" + + +class LakebasePlugin(Plugin): + """Lakebase PostgreSQL integration plugin. + + Exposes: + + - :attr:`pg_config` — resolved :class:`LakebasePgConfig` for pool setup. + - :attr:`connector` — the underlying :class:`LakebaseConnector`. + - :meth:`generate_credential` — wrapper around the REST call. + """ + + NAME = "lakebase" + + def __init__(self, config: LakebasePluginConfig | None = None) -> None: + super().__init__( + self.NAME, + manifest=PluginManifest( + self.NAME, + display_name="Lakebase", + description="Databricks Lakebase PostgreSQL integration", + ), + ) + config = config or LakebasePluginConfig() + host = config.host or os.environ.get("DATABRICKS_HOST") + if not host: + raise ValueError( + "LakebasePlugin requires a Databricks host. Set DATABRICKS_HOST " + "or pass host= in LakebasePluginConfig." + ) + self._config = config + self._host = host + self._pg_config = config.pg_config or LakebasePgConfig() + self._connector = LakebaseConnector(host) + + @property + def pg_config(self) -> LakebasePgConfig: + return self._pg_config + + @property + def connector(self) -> LakebaseConnector: + return self._connector + + async def generate_credential( + self, + token: str, + instance_names: list[str] | None = None, + *, + request_id: str | None = None, + ) -> DatabaseCredential: + """Generate a short-lived credential for Lakebase connection(s). + + When ``instance_names`` is omitted, the plugin's configured PG host + is used as the single instance name. + """ + names = list(instance_names) if instance_names else [self._pg_config.host] + return await self._connector.generate_credential( + token, names, request_id=request_id + ) + + def client_config(self) -> dict[str, str]: + return { + "database": self._pg_config.database, + "ssl_mode": self._pg_config.ssl_mode, + } + + def exports(self) -> dict[str, str]: + return { + "pg_host": self._pg_config.host, + "pg_database": self._pg_config.database, + "pg_port": str(self._pg_config.port), + "pg_ssl_mode": self._pg_config.ssl_mode, + } + + def inject_routes(self, _router: Any) -> None: + return None + + +__all__ = ["LakebasePlugin", "LakebasePluginConfig"] diff --git a/packages/appkit-rs/appkit/plugins/server.py b/packages/appkit-rs/appkit/plugins/server.py new file mode 100644 index 00000000..b8557ebc --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/server.py @@ -0,0 +1,92 @@ +"""ServerPlugin — Python face of the core HTTP server. + +The actual axum-backed server lives in the Rust crate (``crate::server``) +and is started by :func:`appkit.create_app` / ``AppKit.start_server``. +This plugin exists so apps can declare the server in their plugin list, +tune host/port/static-file settings, and participate in phase ordering +(the server plugin runs in the Core phase — before any Normal plugin +injects routes). + +``ServerPlugin.to_server_config()`` returns the :class:`appkit.ServerConfig` +that ``create_app`` should hand to ``start_server``. +""" + +from __future__ import annotations + +from typing import Any + +from appkit import Plugin, PluginManifest, PluginPhase, ServerConfig + + +class ServerPluginConfig: + """Configuration for :class:`ServerPlugin`. + + Mirrors :class:`appkit.ServerConfig`. Defaults align with the Rust + ``ServerPluginConfig`` (``0.0.0.0:8000``, auto-start enabled). + """ + + __slots__ = ("host", "port", "auto_start", "static_path") + + def __init__( + self, + *, + host: str = "0.0.0.0", + port: int = 8000, + auto_start: bool = True, + static_path: str | None = None, + ) -> None: + self.host = host + self.port = port + self.auto_start = auto_start + self.static_path = static_path + + def __repr__(self) -> str: + return f"ServerPluginConfig(host={self.host!r}, port={self.port})" + + +class ServerPlugin(Plugin): + """Core HTTP server plugin. + + The server plugin runs in the Core phase so that the server is ready + before any Normal-phase plugin calls ``inject_routes``. It exposes no + routes of its own — route hosting is handled by ``AppKit.start_server``. + """ + + NAME = "server" + + def __init__(self, config: ServerPluginConfig | None = None) -> None: + super().__init__( + self.NAME, + phase=PluginPhase.CORE, + manifest=PluginManifest( + self.NAME, + display_name="Server Plugin", + description=( + "HTTP server with axum route hosting, SSE streaming, " + "and graceful shutdown" + ), + ), + ) + self._config = config or ServerPluginConfig() + + @property + def config(self) -> ServerPluginConfig: + return self._config + + def to_server_config(self) -> ServerConfig: + """Convert plugin config into an :class:`appkit.ServerConfig`.""" + return ServerConfig( + host=self._config.host, + port=self._config.port, + auto_start=self._config.auto_start, + static_path=self._config.static_path, + ) + + def client_config(self) -> dict[str, str]: + return {} + + def inject_routes(self, _router: Any) -> None: + return None + + +__all__ = ["ServerPlugin", "ServerPluginConfig"] diff --git a/packages/appkit-rs/appkit/plugins/serving.py b/packages/appkit-rs/appkit/plugins/serving.py new file mode 100644 index 00000000..f8654a83 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/serving.py @@ -0,0 +1,182 @@ +"""ServingPlugin — invoke and stream from Databricks model serving endpoints. + +Routes mounted under ``/api/serving``: + +- ``POST /invoke/:endpoint`` — invoke a configured endpoint (non-streaming). +- ``POST /stream/:endpoint`` — stream from a configured endpoint (SSE). + +Endpoint aliases are resolved against :class:`ServingEndpointConfig` at +configuration time; the real endpoint name is read from the environment +variable named in ``ServingEndpointConfig.env`` so deployment and the plugin +config never hard-code served-model identifiers. +""" + +from __future__ import annotations + +import json +import os +from collections.abc import Mapping +from typing import Any + +from appkit import ( + Plugin, + PluginManifest, + ServingConnector, + StreamIterator, + ValidationError, +) + +from ._obo import obo_token, obo_user_key + + +class ServingEndpointConfig: + """Per-alias serving endpoint configuration. + + ``env`` is the environment variable that holds the actual endpoint + name (for example ``CHAT_ENDPOINT``). ``served_model`` optionally + pins the request to a specific served model inside the endpoint. + """ + + __slots__ = ("env", "served_model") + + def __init__(self, *, env: str, served_model: str | None = None) -> None: + if not env: + raise ValueError("ServingEndpointConfig.env is required") + self.env = env + self.served_model = served_model + + def __repr__(self) -> str: + return ( + f"ServingEndpointConfig(env={self.env!r}, " + f"served_model={self.served_model!r})" + ) + + +class ServingPluginConfig: + """Configuration for :class:`ServingPlugin`. + + ``endpoints`` maps alias → :class:`ServingEndpointConfig`. At least one + endpoint is required. + """ + + __slots__ = ("endpoints", "host", "timeout_ms") + + def __init__( + self, + *, + endpoints: Mapping[str, ServingEndpointConfig], + host: str | None = None, + timeout_ms: int | None = None, + ) -> None: + if not endpoints: + raise ValueError("ServingPluginConfig requires at least one endpoint") + self.endpoints: dict[str, ServingEndpointConfig] = dict(endpoints) + self.host = host + self.timeout_ms = timeout_ms + + def __repr__(self) -> str: + return f"ServingPluginConfig(endpoints={sorted(self.endpoints)!r})" + + +class ServingPlugin(Plugin): + """Model Serving plugin — invoke and stream endpoints via alias.""" + + NAME = "serving" + + def __init__(self, config: ServingPluginConfig) -> None: + super().__init__( + self.NAME, + manifest=PluginManifest( + self.NAME, + display_name="Model Serving Plugin", + description=( + "Invoke and stream from Databricks serving endpoints" + ), + ), + ) + host = config.host or os.environ.get("DATABRICKS_HOST") + if not host: + raise ValueError( + "ServingPlugin requires a Databricks host. Set DATABRICKS_HOST " + "or pass host= in ServingPluginConfig." + ) + self._config = config + self._host = host + self._connector = ServingConnector(host) + + def client_config(self) -> dict[str, str]: + return {"endpoints": ",".join(sorted(self._config.endpoints))} + + def inject_routes(self, router: Any) -> None: + router.post("/invoke/:endpoint", self._handle_invoke) + router.post("/stream/:endpoint", self._handle_stream, stream=True) + + def resolve_endpoint(self, alias: str) -> str: + """Return the endpoint name for ``alias`` from the configured env var. + + Raises :class:`ValidationError` if the alias is unknown or the + environment variable is unset or empty. + """ + try: + cfg = self._config.endpoints[alias] + except KeyError as exc: + raise ValidationError( + f"Unknown endpoint alias {alias!r}. Configured: " + f"{sorted(self._config.endpoints)!r}" + ) from exc + value = os.environ.get(cfg.env, "") + if not value: + raise ValidationError( + f"Serving endpoint alias {alias!r} is configured to read from " + f"environment variable {cfg.env!r}, but that variable is not set." + ) + return value + + def _extract_alias(self, path: str) -> str: + alias = path.rsplit("/", 1)[-1] + if not alias: + raise ValidationError("Missing endpoint alias in path") + return alias + + def _merge_served_model(self, alias: str, body: Any) -> dict[str, Any]: + if not isinstance(body, dict): + raise ValidationError("Request body must be a JSON object") + cfg = self._config.endpoints[alias] + if cfg.served_model and "served_model_name" not in body: + body = dict(body) + body["served_model_name"] = cfg.served_model + return body + + async def _handle_invoke(self, request: Any) -> str: + token = obo_token(request.headers) + alias = self._extract_alias(request.path) + endpoint = self.resolve_endpoint(alias) + body = request.json() if request.body else {} + body = self._merge_served_model(alias, body) + + async def run() -> str: + response = await self._connector.invoke( + token, endpoint, json.dumps(body) + ) + return response.data + + result = await self.execute( + run, + user_key=obo_user_key(request.headers), + timeout_ms=self._config.timeout_ms, + ) + if not result.ok: + raise RuntimeError(result.message or "Serving invocation failed") + return result.data or "{}" + + async def _handle_stream(self, request: Any) -> StreamIterator: + token = obo_token(request.headers) + alias = self._extract_alias(request.path) + endpoint = self.resolve_endpoint(alias) + body = request.json() if request.body else {} + body = self._merge_served_model(alias, body) + + return await self._connector.stream(token, endpoint, json.dumps(body)) + + +__all__ = ["ServingPlugin", "ServingPluginConfig", "ServingEndpointConfig"] diff --git a/packages/appkit-rs/appkit/plugins/vector_search.py b/packages/appkit-rs/appkit/plugins/vector_search.py new file mode 100644 index 00000000..c6cb0733 --- /dev/null +++ b/packages/appkit-rs/appkit/plugins/vector_search.py @@ -0,0 +1,233 @@ +"""VectorSearchPlugin — query Databricks Vector Search indexes. + +Exposes two routes mounted under ``/api/vector-search``: + +- ``POST /query`` — run a query against a configured index alias. +- ``POST /query-next-page`` — fetch the next page of a paginated query. + +The Rust ``VectorSearchConnector`` owns request-body construction (see +``packages/appkit-rs/src/connectors/vector_search.rs``). This module +handles request parsing, per-index defaults, and OBO token extraction. +""" + +from __future__ import annotations + +import json +import os +from collections.abc import Mapping +from typing import Any + +from appkit import ( + Plugin, + PluginManifest, + ValidationError, + VectorSearchConnector, +) + +from ._obo import obo_token, obo_user_key + +_VALID_QUERY_TYPES = ("ann", "hybrid", "full_text") + + +class VectorSearchIndexConfig: + """Per-index alias configuration. + + ``index_name`` is the fully-qualified ``catalog.schema.index`` name. + ``endpoint_name`` is required when paginating. ``columns`` lists + the columns returned from the index; ``query_type`` picks the + default search mode. ``reranker_columns`` enables the Databricks + reranker when non-empty. + """ + + __slots__ = ( + "index_name", + "endpoint_name", + "columns", + "query_type", + "num_results", + "reranker_columns", + ) + + def __init__( + self, + *, + index_name: str, + endpoint_name: str | None = None, + columns: list[str] | None = None, + query_type: str = "hybrid", + num_results: int = 20, + reranker_columns: list[str] | None = None, + ) -> None: + if not index_name: + raise ValueError("VectorSearchIndexConfig.index_name is required") + if query_type not in _VALID_QUERY_TYPES: + raise ValueError( + f"Invalid query_type {query_type!r}; expected one of " + f"{_VALID_QUERY_TYPES}" + ) + self.index_name = index_name + self.endpoint_name = endpoint_name + self.columns = list(columns or []) + self.query_type = query_type + self.num_results = num_results + self.reranker_columns = list(reranker_columns) if reranker_columns else None + + def __repr__(self) -> str: + return ( + f"VectorSearchIndexConfig(index_name={self.index_name!r}, " + f"query_type={self.query_type!r})" + ) + + +class VectorSearchPluginConfig: + """Configuration for :class:`VectorSearchPlugin`. + + ``indexes`` maps alias → :class:`VectorSearchIndexConfig`. + """ + + __slots__ = ("indexes", "host", "timeout_ms") + + def __init__( + self, + *, + indexes: Mapping[str, VectorSearchIndexConfig], + host: str | None = None, + timeout_ms: int | None = None, + ) -> None: + if not indexes: + raise ValueError( + "VectorSearchPluginConfig requires at least one index" + ) + self.indexes: dict[str, VectorSearchIndexConfig] = dict(indexes) + self.host = host + self.timeout_ms = timeout_ms + + def __repr__(self) -> str: + return ( + f"VectorSearchPluginConfig(indexes={sorted(self.indexes)!r})" + ) + + +class VectorSearchPlugin(Plugin): + """Vector Search plugin — hybrid, ANN, and full-text queries.""" + + NAME = "vector-search" + + def __init__(self, config: VectorSearchPluginConfig) -> None: + super().__init__( + self.NAME, + manifest=PluginManifest( + self.NAME, + display_name="Vector Search Plugin", + description=( + "Query Databricks Vector Search indexes with hybrid search, " + "reranking, and pagination" + ), + ), + ) + host = config.host or os.environ.get("DATABRICKS_HOST") + if not host: + raise ValueError( + "VectorSearchPlugin requires a Databricks host. Set " + "DATABRICKS_HOST or pass host= in VectorSearchPluginConfig." + ) + self._config = config + self._host = host + self._connector = VectorSearchConnector(host, timeout_ms=config.timeout_ms) + + def client_config(self) -> dict[str, str]: + return {"indexes": ",".join(sorted(self._config.indexes))} + + def inject_routes(self, router: Any) -> None: + router.post("/query", self._handle_query) + router.post("/query-next-page", self._handle_next_page) + + def _resolve_index(self, alias: str) -> VectorSearchIndexConfig: + try: + return self._config.indexes[alias] + except KeyError as exc: + raise ValidationError( + f"Unknown index alias {alias!r}. Configured: " + f"{sorted(self._config.indexes)!r}" + ) from exc + + async def _handle_query(self, request: Any) -> str: + token = obo_token(request.headers) + body = request.json() if request.body else {} + if not isinstance(body, dict): + raise ValidationError("Request body must be a JSON object") + alias = body.get("index") + if not alias: + raise ValidationError("Missing required field 'index'") + index_cfg = self._resolve_index(alias) + + query_text = body.get("query_text") + query_vector = body.get("query_vector") + if query_text is None and query_vector is None: + raise ValidationError( + "Request must include 'query_text' or 'query_vector'" + ) + columns = body.get("columns") or index_cfg.columns + if not columns: + raise ValidationError( + "'columns' must be set either on the request or in the index " + "configuration" + ) + query_type = body.get("query_type") or index_cfg.query_type + if query_type not in _VALID_QUERY_TYPES: + raise ValidationError( + f"Invalid query_type {query_type!r}; expected one of " + f"{_VALID_QUERY_TYPES}" + ) + num_results = int(body.get("num_results") or index_cfg.num_results) + filters = body.get("filters") + filters_json = json.dumps(filters) if filters else None + reranker_columns = body.get("reranker_columns") or index_cfg.reranker_columns + + async def run() -> str: + return await self._connector.query( + token, + index_cfg.index_name, + columns=list(columns), + num_results=num_results, + query_type=query_type, + query_text=query_text, + query_vector=query_vector, + filters_json=filters_json, + reranker_columns=reranker_columns, + ) + + result = await self.execute(run, user_key=obo_user_key(request.headers)) + if not result.ok: + raise RuntimeError(result.message or "Vector search failed") + return result.data or "{}" + + async def _handle_next_page(self, request: Any) -> str: + token = obo_token(request.headers) + body = request.json() if request.body else {} + if not isinstance(body, dict): + raise ValidationError("Request body must be a JSON object") + alias = body.get("index") + page_token = body.get("page_token") + if not alias or not page_token: + raise ValidationError( + "'index' and 'page_token' are required for query-next-page" + ) + index_cfg = self._resolve_index(alias) + endpoint_name = body.get("endpoint_name") or index_cfg.endpoint_name + if not endpoint_name: + raise ValidationError( + "'endpoint_name' is required (set it on the request or in " + "the index configuration)" + ) + raw = await self._connector.query_next_page( + token, index_cfg.index_name, endpoint_name, page_token + ) + return raw + + +__all__ = [ + "VectorSearchPlugin", + "VectorSearchPluginConfig", + "VectorSearchIndexConfig", +] diff --git a/packages/appkit-rs/pyproject.toml b/packages/appkit-rs/pyproject.toml new file mode 100644 index 00000000..ff302860 --- /dev/null +++ b/packages/appkit-rs/pyproject.toml @@ -0,0 +1,21 @@ +[build-system] +requires = ["maturin>=1.5,<2.0"] +build-backend = "maturin" + +[project] +name = "databricks-appkit" +version = "0.1.0" +requires-python = ">=3.11" +description = "Databricks AppKit Python SDK" + +[project.optional-dependencies] +test = ["pytest>=8.0", "pytest-asyncio>=0.23"] + +[tool.maturin] +features = ["extension-module"] +python-source = "." +module-name = "appkit.appkit" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/packages/appkit-rs/src/auth.rs b/packages/appkit-rs/src/auth.rs new file mode 100644 index 00000000..c1e81f58 --- /dev/null +++ b/packages/appkit-rs/src/auth.rs @@ -0,0 +1,293 @@ +use pyo3::prelude::*; +use reqwest::Client; +use serde::Deserialize; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +use crate::config::AppConfig; + +#[derive(Deserialize)] +struct TokenResponse { + access_token: String, + #[allow(dead_code)] + token_type: String, + expires_in: u64, +} + +struct CachedToken { + token: String, + expires_at: Instant, +} + +/// OAuth M2M token provider using client credentials flow. +/// Acquires tokens via POST to `{DATABRICKS_HOST}/oidc/v1/token` and caches +/// them with a 30-second safety buffer before expiry. +pub struct M2MTokenProvider { + host: String, + client_id: String, + client_secret: String, + http_client: Client, + cached: Mutex>, +} + +impl M2MTokenProvider { + pub fn new(host: String, client_id: String, client_secret: String) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + client_id, + client_secret, + http_client: Client::new(), + cached: Mutex::new(None), + } + } + + /// Get a valid token, refreshing if expired or not yet acquired. + pub async fn get_token(&self) -> Result { + let mut guard = self.cached.lock().await; + + if let Some(ref cached) = *guard { + if Instant::now() < cached.expires_at { + return Ok(cached.token.clone()); + } + } + + let token_url = format!("{}/oidc/v1/token", self.host); + let resp = self + .http_client + .post(&token_url) + .form(&[ + ("grant_type", "client_credentials"), + ("client_id", self.client_id.as_str()), + ("client_secret", self.client_secret.as_str()), + ("scope", "all-apis"), + ]) + .send() + .await + .map_err(|e| format!("Token request failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(format!("Token request returned {status}: {body}")); + } + + let token_resp: TokenResponse = resp + .json() + .await + .map_err(|e| format!("Failed to parse token response: {e}"))?; + + let expires_at = + Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(30)); + let token = token_resp.access_token.clone(); + + *guard = Some(CachedToken { + token: token_resp.access_token, + expires_at, + }); + + Ok(token) + } +} + +/// Service-level authentication context (service principal). +/// Analogous to TypeScript's ServiceContextState — holds M2M credentials +/// and provides token acquisition for service-principal API calls. +#[pyclass(frozen, module = "appkit")] +pub struct ServiceContext { + pub token_provider: Arc, + #[pyo3(get)] + pub config: AppConfig, +} + +#[pymethods] +impl ServiceContext { + #[new] + #[pyo3(signature = (config))] + pub fn new(config: AppConfig) -> PyResult { + let client_id = config.client_id.clone().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "DATABRICKS_CLIENT_ID is required for ServiceContext", + ) + })?; + let client_secret = config.client_secret.clone().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "DATABRICKS_CLIENT_SECRET is required for ServiceContext", + ) + })?; + + let provider = + M2MTokenProvider::new(config.databricks_host.clone(), client_id, client_secret); + + Ok(Self { + token_provider: Arc::new(provider), + config, + }) + } + + /// Acquire a valid service-principal access token (async). + fn get_token<'py>(&self, py: Python<'py>) -> PyResult> { + let provider = self.token_provider.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + provider + .get_token() + .await + .map_err(pyo3::exceptions::PyRuntimeError::new_err) + }) + } + + fn __repr__(&self) -> String { + format!("ServiceContext(host={:?})", self.config.databricks_host) + } +} + +/// Per-request user context for OBO (On-Behalf-Of) flows. +/// Analogous to TypeScript's UserContext — carries a forwarded user token, +/// identity headers, and inherited execution-scoped IDs (workspace, warehouse) +/// so connectors and route handlers can select service-principal vs per-user auth. +#[derive(Clone)] +#[pyclass(frozen, module = "appkit")] +pub struct UserContext { + #[pyo3(get)] + pub token: String, + #[pyo3(get)] + pub user_id: String, + #[pyo3(get)] + pub user_name: Option, + /// Inherited from ServiceContext — the workspace ID for this execution. + #[pyo3(get)] + pub workspace_id: String, + /// Inherited from ServiceContext — optional warehouse ID (only present when + /// a plugin requires the SQL_WAREHOUSE resource). + #[pyo3(get)] + pub warehouse_id: Option, +} + +#[pymethods] +impl UserContext { + #[new] + #[pyo3(signature = (token, user_id, *, user_name = None, workspace_id, warehouse_id = None))] + pub fn new( + token: String, + user_id: String, + user_name: Option, + workspace_id: String, + warehouse_id: Option, + ) -> Self { + Self { + token, + user_id, + user_name, + workspace_id, + warehouse_id, + } + } + + /// Discriminator property — always `True` for UserContext. + /// Mirrors TypeScript's `isUserContext: true` field. + #[getter] + fn is_user_context(&self) -> bool { + true + } + + fn __repr__(&self) -> String { + format!( + "UserContext(user_id={:?}, user_name={:?}, workspace_id={:?})", + self.user_id, self.user_name, self.workspace_id + ) + } + + fn __eq__(&self, other: &Self) -> bool { + self.token == other.token + && self.user_id == other.user_id + && self.user_name == other.user_name + && self.workspace_id == other.workspace_id + && self.warehouse_id == other.warehouse_id + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.user_id.hash(&mut hasher); + self.workspace_id.hash(&mut hasher); + hasher.finish() + } +} + +/// Discriminated execution context — either a service principal or a per-request user. +/// Mirrors TypeScript's `ExecutionContext = ServiceContextState | UserContext`. +pub enum ExecutionContext { + Service(Arc), + User(UserContext), +} + +impl ExecutionContext { + /// Get the bearer token for the current context. + pub async fn get_token(&self) -> Result { + match self { + Self::Service(provider) => provider.get_token().await, + Self::User(ctx) => Ok(ctx.token.clone()), + } + } + + pub fn user_id(&self) -> Option<&str> { + match self { + Self::Service(_) => None, + Self::User(ctx) => Some(&ctx.user_id), + } + } + + pub fn is_user_context(&self) -> bool { + matches!(self, Self::User(_)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_user_context() { + let ctx = UserContext::new( + "tok123".into(), + "user-1".into(), + Some("Alice".into()), + "ws-123".into(), + Some("wh-456".into()), + ); + assert_eq!(ctx.token, "tok123"); + assert_eq!(ctx.user_id, "user-1"); + assert_eq!(ctx.user_name.as_deref(), Some("Alice")); + assert_eq!(ctx.workspace_id, "ws-123"); + assert_eq!(ctx.warehouse_id.as_deref(), Some("wh-456")); + } + + #[test] + fn test_user_context_without_warehouse() { + let ctx = UserContext::new( + "tok".into(), + "u1".into(), + None, + "ws-1".into(), + None, + ); + assert!(ctx.warehouse_id.is_none()); + assert_eq!(ctx.workspace_id, "ws-1"); + } + + #[test] + fn test_execution_context_user() { + let user = UserContext::new("tok".into(), "u1".into(), None, "ws-1".into(), None); + let exec = ExecutionContext::User(user); + assert!(exec.is_user_context()); + assert_eq!(exec.user_id(), Some("u1")); + } + + #[tokio::test] + async fn test_execution_context_user_token() { + let user = UserContext::new("my-token".into(), "u1".into(), None, "ws-1".into(), None); + let exec = ExecutionContext::User(user); + let token = exec.get_token().await.unwrap(); + assert_eq!(token, "my-token"); + } +} diff --git a/packages/appkit-rs/src/cache.rs b/packages/appkit-rs/src/cache.rs new file mode 100644 index 00000000..19969b8a --- /dev/null +++ b/packages/appkit-rs/src/cache.rs @@ -0,0 +1,846 @@ +use pyo3::prelude::*; +use serde_json::Value as JsonValue; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::{watch, Mutex}; + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +/// Cache configuration with defaults matching TypeScript `cacheDefaults`. +#[derive(Clone, Debug)] +#[pyclass(frozen, module = "appkit")] +pub struct CacheConfig { + #[pyo3(get)] + pub enabled: bool, + #[pyo3(get)] + pub ttl: u64, + #[pyo3(get)] + pub max_size: usize, + #[pyo3(get)] + pub cleanup_probability: f64, +} + +#[pymethods] +impl CacheConfig { + #[new] + #[pyo3(signature = (*, enabled = true, ttl = 3600, max_size = 1000, cleanup_probability = 0.01))] + pub fn new(enabled: bool, ttl: u64, max_size: usize, cleanup_probability: f64) -> Self { + Self { + enabled, + ttl, + max_size, + cleanup_probability, + } + } + + fn __repr__(&self) -> String { + format!( + "CacheConfig(enabled={}, ttl={}, max_size={}, cleanup_probability={})", + self.enabled, self.ttl, self.max_size, self.cleanup_probability + ) + } + + fn __eq__(&self, other: &Self) -> bool { + self.enabled == other.enabled + && self.ttl == other.ttl + && self.max_size == other.max_size + && self.cleanup_probability == other.cleanup_probability + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.enabled.hash(&mut hasher); + self.ttl.hash(&mut hasher); + self.max_size.hash(&mut hasher); + hasher.finish() + } +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + enabled: true, + ttl: 3600, + max_size: 1000, + cleanup_probability: 0.01, + } + } +} + +// --------------------------------------------------------------------------- +// Internal storage +// --------------------------------------------------------------------------- + +struct CacheEntry { + value: JsonValue, + expiry: u64, // milliseconds since epoch +} + +fn now_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_millis() as u64 +} + +/// Probabilistic check using a randomly-seeded hasher. +fn rand_check(probability: f64) -> bool { + use std::collections::hash_map::RandomState; + use std::hash::{BuildHasher, Hasher}; + let mut hasher = RandomState::new().build_hasher(); + hasher.write_u8(0); + let hash = hasher.finish(); + (hash as f64 / u64::MAX as f64) < probability +} + +/// In-memory LRU cache storage with bounded capacity. +/// Matches the semantics of the TypeScript `InMemoryStorage`. +struct InMemoryStorage { + cache: HashMap, + access_order: HashMap, + access_counter: u64, + max_size: usize, +} + +impl InMemoryStorage { + fn new(max_size: usize) -> Self { + Self { + cache: HashMap::new(), + access_order: HashMap::new(), + access_counter: 0, + max_size, + } + } + + fn get(&mut self, key: &str) -> Option<&CacheEntry> { + let expired = self + .cache + .get(key) + .map(|e| now_millis() > e.expiry) + .unwrap_or(true); + + if expired { + self.cache.remove(key); + self.access_order.remove(key); + return None; + } + + self.access_counter += 1; + self.access_order + .insert(key.to_string(), self.access_counter); + self.cache.get(key) + } + + fn set(&mut self, key: String, entry: CacheEntry) { + if self.cache.len() >= self.max_size && !self.cache.contains_key(&key) { + self.evict_lru(); + } + self.access_counter += 1; + self.access_order.insert(key.clone(), self.access_counter); + self.cache.insert(key, entry); + } + + fn delete(&mut self, key: &str) { + self.cache.remove(key); + self.access_order.remove(key); + } + + fn has(&mut self, key: &str) -> bool { + if let Some(entry) = self.cache.get(key) { + if now_millis() > entry.expiry { + let key = key.to_string(); + self.cache.remove(&key); + self.access_order.remove(&key); + return false; + } + true + } else { + false + } + } + + fn clear(&mut self) { + self.cache.clear(); + self.access_order.clear(); + self.access_counter = 0; + } + + fn size(&self) -> usize { + self.cache.len() + } + + fn cleanup_expired(&mut self) { + let now = now_millis(); + let expired_keys: Vec = self + .cache + .iter() + .filter(|(_, entry)| now > entry.expiry) + .map(|(key, _)| key.clone()) + .collect(); + for key in expired_keys { + self.cache.remove(&key); + self.access_order.remove(&key); + } + } + + fn evict_lru(&mut self) { + if let Some((key, _)) = self + .access_order + .iter() + .min_by_key(|(_, &counter)| counter) + .map(|(k, v)| (k.clone(), *v)) + { + self.cache.remove(&key); + self.access_order.remove(&key); + } + } +} + +// --------------------------------------------------------------------------- +// Cache manager (Rust-internal + PyO3) +// --------------------------------------------------------------------------- + +type InFlightValue = Option>; + +/// Cache manager with TTL, LRU eviction, concurrent in-flight deduplication, +/// and probabilistic cleanup. +/// +/// Mirrors the TypeScript `CacheManager` with `InMemoryStorage`. +#[pyclass(module = "appkit")] +pub struct CacheManager { + storage: Arc>, + config: CacheConfig, + in_flight: Arc>>>, +} + +impl CacheManager { + /// Create a CacheManager from Rust code (not via Python). + pub fn new_internal(config: CacheConfig) -> Self { + let max_size = config.max_size; + Self { + storage: Arc::new(Mutex::new(InMemoryStorage::new(max_size))), + config, + in_flight: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Generate a SHA256 cache key from composite parts. + /// Matches the TypeScript `generateKey` which hashes `[userKey, ...parts]`. + pub fn generate_key_from_parts(parts: &[&str], user_key: &str) -> String { + let mut components: Vec<&str> = vec![user_key]; + components.extend_from_slice(parts); + let serialized = serde_json::to_string(&components).unwrap_or_default(); + let hash = Sha256::digest(serialized.as_bytes()); + hash.iter().map(|b| format!("{b:02x}")).collect() + } + + async fn maybe_cleanup(&self) { + if rand_check(self.config.cleanup_probability) { + let mut storage = self.storage.lock().await; + storage.cleanup_expired(); + } + } + + /// Core get_or_execute with in-flight deduplication (Rust-internal API). + /// + /// If a value for `key` is cached, returns it immediately. + /// If another task is already computing the same key, waits for its result. + /// Otherwise executes `func` and caches the result. + pub async fn get_or_execute_internal( + &self, + key: String, + func: F, + ttl: Option, + ) -> Result + where + F: FnOnce() -> Fut + Send, + Fut: std::future::Future> + Send, + { + if !self.config.enabled { + return func().await; + } + + // Check cache + { + let mut storage = self.storage.lock().await; + if let Some(entry) = storage.get(&key) { + return Ok(entry.value.clone()); + } + } + + // Try to join an existing in-flight request, or become the executor. + enum Action { + Wait(watch::Receiver), + Execute(watch::Sender), + } + + let action = { + let mut in_flight = self.in_flight.lock().await; + if let Some(existing_tx) = in_flight.get(&key) { + Action::Wait(existing_tx.subscribe()) + } else { + let (tx, _rx) = watch::channel(None); + in_flight.insert(key.clone(), tx.clone()); + Action::Execute(tx) + } + }; + + match action { + Action::Wait(mut rx) => { + // Wait for the executor to broadcast its result. + loop { + { + let val = rx.borrow().clone(); + if let Some(result) = val { + return result; + } + } + if rx.changed().await.is_err() { + // Executor dropped without sending — execute ourselves as fallback. + return func().await; + } + } + } + Action::Execute(tx) => { + let result = func().await; + + // Cache successful results + if let Ok(ref value) = result { + let ttl_secs = ttl.unwrap_or(self.config.ttl); + let expiry = now_millis() + ttl_secs * 1000; + let mut storage = self.storage.lock().await; + storage.set( + key.clone(), + CacheEntry { + value: value.clone(), + expiry, + }, + ); + } + + // Broadcast result to waiting tasks + let _ = tx.send(Some(result.clone())); + + // Remove from in-flight map + { + let mut in_flight = self.in_flight.lock().await; + in_flight.remove(&key); + } + + self.maybe_cleanup().await; + result + } + } + } + + // Rust-only convenience wrappers used by Rust callers (connectors, etc.) + + pub async fn get_internal(&self, key: &str) -> Option { + let mut storage = self.storage.lock().await; + storage.get(key).map(|e| e.value.clone()) + } + + pub async fn set_internal(&self, key: String, value: JsonValue, ttl: Option) { + let ttl_secs = ttl.unwrap_or(self.config.ttl); + let expiry = now_millis() + ttl_secs * 1000; + let mut storage = self.storage.lock().await; + storage.set(key, CacheEntry { value, expiry }); + } + + pub async fn delete_internal(&self, key: &str) { + let mut storage = self.storage.lock().await; + storage.delete(key); + } + + pub async fn has_internal(&self, key: &str) -> bool { + let mut storage = self.storage.lock().await; + storage.has(key) + } + + pub async fn clear_internal(&self) { + let mut storage = self.storage.lock().await; + storage.clear(); + } + + pub async fn size_internal(&self) -> usize { + let storage = self.storage.lock().await; + storage.size() + } +} + +// --------------------------------------------------------------------------- +// Python bindings +// --------------------------------------------------------------------------- + +#[pymethods] +impl CacheManager { + #[new] + #[pyo3(signature = (config = None))] + fn new(config: Option) -> Self { + let config = config.unwrap_or_default(); + let max_size = config.max_size; + Self { + storage: Arc::new(Mutex::new(InMemoryStorage::new(max_size))), + config, + in_flight: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Generate a SHA256 cache key from parts and user key. + #[staticmethod] + #[pyo3(signature = (parts, user_key))] + fn generate_key(parts: Vec, user_key: String) -> String { + let refs: Vec<&str> = parts.iter().map(|s| s.as_str()).collect(); + Self::generate_key_from_parts(&refs, &user_key) + } + + /// Get a cached value by key. Returns a JSON string or None. + fn get<'py>(&self, py: Python<'py>, key: String) -> PyResult> { + let storage = self.storage.clone(); + let cleanup_prob = self.config.cleanup_probability; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = storage.lock().await; + let result = guard.get(&key).map(|e| e.value.to_string()); + if rand_check(cleanup_prob) { + guard.cleanup_expired(); + } + Ok(result) + }) + } + + /// Store a value (JSON string) with optional TTL in seconds. + #[pyo3(signature = (key, value, *, ttl = None))] + fn set<'py>( + &self, + py: Python<'py>, + key: String, + value: String, + ttl: Option, + ) -> PyResult> { + let storage = self.storage.clone(); + let default_ttl = self.config.ttl; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let ttl_secs = ttl.unwrap_or(default_ttl); + let expiry = now_millis() + ttl_secs * 1000; + let json_value: JsonValue = + serde_json::from_str(&value).unwrap_or(JsonValue::String(value)); + let mut guard = storage.lock().await; + guard.set(key, CacheEntry { value: json_value, expiry }); + Ok(()) + }) + } + + /// Delete a cached entry by key. + fn delete<'py>(&self, py: Python<'py>, key: String) -> PyResult> { + let storage = self.storage.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = storage.lock().await; + guard.delete(&key); + Ok(()) + }) + } + + /// Check if a key exists and is not expired. + fn has<'py>(&self, py: Python<'py>, key: String) -> PyResult> { + let storage = self.storage.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = storage.lock().await; + Ok(guard.has(&key)) + }) + } + + /// Clear all cached entries. + fn clear<'py>(&self, py: Python<'py>) -> PyResult> { + let storage = self.storage.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = storage.lock().await; + guard.clear(); + Ok(()) + }) + } + + /// Return the number of cached entries. + fn size<'py>(&self, py: Python<'py>) -> PyResult> { + let storage = self.storage.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let guard = storage.lock().await; + Ok(guard.size()) + }) + } + + /// Execute a Python async callable with caching. + /// + /// The callable must be an async function (coroutine function) that returns + /// a JSON string. On cache hit the callable is not invoked. + #[pyo3(signature = (key, func, *, ttl = None))] + fn get_or_execute<'py>( + &self, + py: Python<'py>, + key: String, + func: PyObject, + ttl: Option, + ) -> PyResult> { + let storage = self.storage.clone(); + let in_flight = self.in_flight.clone(); + let config = self.config.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + if !config.enabled { + let future = Python::with_gil(|py| { + let coroutine = func.call0(py)?; + pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py)) + })?; + let result = future.await?; + return Python::with_gil(|py| result.extract::(py)); + } + + // Check cache + { + let mut guard = storage.lock().await; + if let Some(entry) = guard.get(&key) { + return Ok(entry.value.to_string()); + } + } + + // Check for existing in-flight request or register as executor. + enum Action { + Wait(watch::Receiver), + Execute(watch::Sender), + } + + let action = { + let mut in_flight_guard = in_flight.lock().await; + if let Some(existing_tx) = in_flight_guard.get(&key) { + Action::Wait(existing_tx.subscribe()) + } else { + let (tx, _rx) = watch::channel(None); + in_flight_guard.insert(key.clone(), tx.clone()); + Action::Execute(tx) + } + }; + + match action { + Action::Wait(mut rx) => { + loop { + { + let val = rx.borrow().clone(); + if let Some(result) = val { + return result + .map(|v| v.to_string()) + .map_err(pyo3::exceptions::PyRuntimeError::new_err); + } + } + if rx.changed().await.is_err() { + // Executor dropped — fall back to calling the function ourselves. + let future = Python::with_gil(|py| { + let coroutine = func.call0(py)?; + pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py)) + })?; + let result = future.await?; + return Python::with_gil(|py| result.extract::(py)); + } + } + } + Action::Execute(tx) => { + // Call the Python async function. + let py_result: PyResult = async { + let future = Python::with_gil(|py| { + let coroutine = func.call0(py)?; + pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py)) + })?; + let result = future.await?; + Python::with_gil(|py| result.extract::(py)) + } + .await; + + // Convert to cache-compatible result. + let cache_result: Result = match &py_result { + Ok(s) => Ok(serde_json::from_str(s) + .unwrap_or(JsonValue::String(s.clone()))), + Err(e) => Err(e.to_string()), + }; + + // Cache successful results. + if let Ok(ref value) = cache_result { + let ttl_secs = ttl.unwrap_or(config.ttl); + let expiry = now_millis() + ttl_secs * 1000; + let mut guard = storage.lock().await; + guard.set( + key.clone(), + CacheEntry { + value: value.clone(), + expiry, + }, + ); + } + + // Broadcast result and clean up. + let _ = tx.send(Some(cache_result)); + { + let mut in_flight_guard = in_flight.lock().await; + in_flight_guard.remove(&key); + } + if rand_check(config.cleanup_probability) { + let mut guard = storage.lock().await; + guard.cleanup_expired(); + } + + py_result + } + } + }) + } + + fn __repr__(&self) -> String { + format!( + "CacheManager(enabled={}, ttl={}, max_size={})", + self.config.enabled, self.config.ttl, self.config.max_size + ) + } + + fn __bool__(&self) -> bool { + self.config.enabled + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_cache(max_size: usize) -> CacheManager { + CacheManager { + storage: Arc::new(Mutex::new(InMemoryStorage::new(max_size))), + config: CacheConfig { + enabled: true, + ttl: 60, + max_size, + cleanup_probability: 0.0, // deterministic tests + }, + in_flight: Arc::new(Mutex::new(HashMap::new())), + } + } + + #[test] + fn test_generate_key_deterministic() { + let a = CacheManager::generate_key_from_parts(&["q1", "p1"], "user-1"); + let b = CacheManager::generate_key_from_parts(&["q1", "p1"], "user-1"); + assert_eq!(a, b); + assert_eq!(a.len(), 64); // SHA256 hex + } + + #[test] + fn test_generate_key_varies_by_user() { + let a = CacheManager::generate_key_from_parts(&["q"], "alice"); + let b = CacheManager::generate_key_from_parts(&["q"], "bob"); + assert_ne!(a, b); + } + + #[tokio::test] + async fn test_set_and_get() { + let cache = make_cache(10); + cache + .set_internal("k1".into(), JsonValue::String("hello".into()), None) + .await; + let val = cache.get_internal("k1").await; + assert_eq!(val, Some(JsonValue::String("hello".into()))); + } + + #[tokio::test] + async fn test_get_miss() { + let cache = make_cache(10); + assert!(cache.get_internal("nope").await.is_none()); + } + + #[tokio::test] + async fn test_delete() { + let cache = make_cache(10); + cache + .set_internal("k".into(), JsonValue::Bool(true), None) + .await; + assert!(cache.has_internal("k").await); + cache.delete_internal("k").await; + assert!(!cache.has_internal("k").await); + } + + #[tokio::test] + async fn test_clear() { + let cache = make_cache(10); + for i in 0..5 { + cache + .set_internal(format!("k{i}"), JsonValue::Null, None) + .await; + } + assert_eq!(cache.size_internal().await, 5); + cache.clear_internal().await; + assert_eq!(cache.size_internal().await, 0); + } + + #[tokio::test] + async fn test_ttl_expiry() { + let cache = make_cache(10); + // Set with 0-second TTL → immediately expired + cache + .set_internal("k".into(), JsonValue::String("v".into()), Some(0)) + .await; + // The entry's expiry is now_millis() + 0, so it should be expired on next get + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + assert!(cache.get_internal("k").await.is_none()); + } + + #[tokio::test] + async fn test_lru_eviction() { + let cache = make_cache(3); + cache + .set_internal("a".into(), JsonValue::String("1".into()), None) + .await; + cache + .set_internal("b".into(), JsonValue::String("2".into()), None) + .await; + cache + .set_internal("c".into(), JsonValue::String("3".into()), None) + .await; + + // Access "a" to make it recently used + cache.get_internal("a").await; + + // Insert "d" — should evict "b" (least recently used) + cache + .set_internal("d".into(), JsonValue::String("4".into()), None) + .await; + + assert!(cache.has_internal("a").await); + assert!(!cache.has_internal("b").await); + assert!(cache.has_internal("c").await); + assert!(cache.has_internal("d").await); + } + + #[tokio::test] + async fn test_get_or_execute_caches() { + let cache = make_cache(10); + let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0)); + + let cc = call_count.clone(); + let v1 = cache + .get_or_execute_internal( + "key1".into(), + move || { + cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + async { Ok(JsonValue::Number(42.into())) } + }, + None, + ) + .await + .unwrap(); + assert_eq!(v1, JsonValue::Number(42.into())); + assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1); + + // Second call with same key should use cache — func not called + let cc = call_count.clone(); + let v2 = cache + .get_or_execute_internal( + "key1".into(), + move || { + cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + async { Ok(JsonValue::Null) } + }, + None, + ) + .await + .unwrap(); + assert_eq!(v2, JsonValue::Number(42.into())); + assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_get_or_execute_dedup() { + let cache = Arc::new(make_cache(10)); + let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0)); + + let mut handles = Vec::new(); + for _ in 0..5 { + let c = cache.clone(); + let cc = call_count.clone(); + handles.push(tokio::spawn(async move { + c.get_or_execute_internal( + "shared".into(), + move || { + cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + async { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + Ok(JsonValue::String("done".into())) + } + }, + None, + ) + .await + })); + } + + for handle in handles { + let r = handle.await.unwrap().unwrap(); + assert_eq!(r, JsonValue::String("done".into())); + } + + // The function should have been called at most twice (first executor + + // possible fallback race), but never 5 times. + let count = call_count.load(std::sync::atomic::Ordering::SeqCst); + assert!(count <= 2, "expected dedup, got {count} calls"); + } + + #[tokio::test] + async fn test_disabled_cache_always_executes() { + let cache = CacheManager { + storage: Arc::new(Mutex::new(InMemoryStorage::new(10))), + config: CacheConfig { + enabled: false, + ..CacheConfig::default() + }, + in_flight: Arc::new(Mutex::new(HashMap::new())), + }; + + let v = cache + .get_or_execute_internal( + "k".into(), + || async { Ok(JsonValue::String("computed".into())) }, + None, + ) + .await + .unwrap(); + assert_eq!(v, JsonValue::String("computed".into())); + // Nothing cached when disabled + assert!(cache.get_internal("k").await.is_none()); + } + + #[test] + fn test_cleanup_expired() { + let mut storage = InMemoryStorage::new(100); + let past = now_millis().saturating_sub(1000); + storage.set( + "expired".into(), + CacheEntry { + value: JsonValue::Null, + expiry: past, + }, + ); + storage.set( + "valid".into(), + CacheEntry { + value: JsonValue::Null, + expiry: now_millis() + 60_000, + }, + ); + assert_eq!(storage.size(), 2); + storage.cleanup_expired(); + assert_eq!(storage.size(), 1); + assert!(storage.cache.contains_key("valid")); + } +} diff --git a/packages/appkit-rs/src/config.rs b/packages/appkit-rs/src/config.rs new file mode 100644 index 00000000..5f2432d4 --- /dev/null +++ b/packages/appkit-rs/src/config.rs @@ -0,0 +1,177 @@ +use pyo3::prelude::*; +use std::env; + +fn non_empty_env(key: &str) -> Option { + env::var(key).ok().filter(|v| !v.is_empty()) +} + +/// Application configuration parsed from environment variables. +/// Mirrors the TypeScript ServiceContext / execution-context environment expectations. +#[derive(Clone, Debug)] +#[pyclass(frozen, module = "appkit")] +pub struct AppConfig { + #[pyo3(get)] + pub databricks_host: String, + #[pyo3(get)] + pub client_id: Option, + #[pyo3(get)] + pub client_secret: Option, + #[pyo3(get)] + pub warehouse_id: Option, + #[pyo3(get)] + pub app_port: u16, + #[pyo3(get)] + pub host: String, + #[pyo3(get)] + pub otel_endpoint: Option, +} + +#[pymethods] +impl AppConfig { + #[new] + #[pyo3(signature = ( + databricks_host, + *, + client_id = None, + client_secret = None, + warehouse_id = None, + app_port = 8000, + host = "0.0.0.0".to_string(), + otel_endpoint = None, + ))] + pub fn new( + databricks_host: String, + client_id: Option, + client_secret: Option, + warehouse_id: Option, + app_port: u16, + host: String, + otel_endpoint: Option, + ) -> Self { + Self { + databricks_host, + client_id, + client_secret, + warehouse_id, + app_port, + host, + otel_endpoint, + } + } + + /// Parse configuration from environment variables. + #[staticmethod] + pub fn from_env() -> PyResult { + let mut databricks_host = non_empty_env("DATABRICKS_HOST").ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "DATABRICKS_HOST environment variable is required", + ) + })?; + // Databricks Apps sets DATABRICKS_HOST without a scheme; normalise. + if !databricks_host.starts_with("https://") && !databricks_host.starts_with("http://") { + databricks_host = format!("https://{databricks_host}"); + } + + let app_port = non_empty_env("DATABRICKS_APP_PORT") + .or_else(|| non_empty_env("PORT")) + .and_then(|v| v.parse().ok()) + .unwrap_or(8000); + + let host = non_empty_env("FLASK_RUN_HOST").unwrap_or_else(|| "0.0.0.0".to_string()); + + Ok(Self { + databricks_host, + client_id: non_empty_env("DATABRICKS_CLIENT_ID"), + client_secret: non_empty_env("DATABRICKS_CLIENT_SECRET"), + warehouse_id: non_empty_env("DATABRICKS_WAREHOUSE_ID"), + app_port, + host, + otel_endpoint: non_empty_env("OTEL_EXPORTER_OTLP_ENDPOINT"), + }) + } + + fn __repr__(&self) -> String { + format!( + "AppConfig(databricks_host={:?}, warehouse_id={:?}, app_port={})", + self.databricks_host, self.warehouse_id, self.app_port + ) + } + + fn __eq__(&self, other: &Self) -> bool { + self.databricks_host == other.databricks_host + && self.client_id == other.client_id + && self.client_secret == other.client_secret + && self.warehouse_id == other.warehouse_id + && self.app_port == other.app_port + && self.host == other.host + && self.otel_endpoint == other.otel_endpoint + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.databricks_host.hash(&mut hasher); + self.app_port.hash(&mut hasher); + self.host.hash(&mut hasher); + hasher.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + + #[test] + fn test_app_config_new() { + let config = AppConfig::new( + "https://example.databricks.com".into(), + Some("client-id".into()), + Some("client-secret".into()), + Some("warehouse-123".into()), + 8080, + "0.0.0.0".into(), + None, + ); + assert_eq!(config.databricks_host, "https://example.databricks.com"); + assert_eq!(config.client_id.as_deref(), Some("client-id")); + assert_eq!(config.warehouse_id.as_deref(), Some("warehouse-123")); + assert_eq!(config.app_port, 8080); + assert!(config.otel_endpoint.is_none()); + } + + #[test] + #[serial] + fn test_app_config_from_env() { + // Snapshot original values so we can restore them afterward. + let orig_host = env::var("DATABRICKS_HOST").ok(); + let orig_client_id = env::var("DATABRICKS_CLIENT_ID").ok(); + let orig_app_port = env::var("DATABRICKS_APP_PORT").ok(); + + // Helper to restore or remove an env var. + fn restore_env(key: &str, original: Option) { + match original { + Some(val) => env::set_var(key, val), + None => env::remove_var(key), + } + } + + // Clear to test missing-var error path. + env::remove_var("DATABRICKS_HOST"); + let result = AppConfig::from_env(); + assert!(result.is_err()); + + env::set_var("DATABRICKS_HOST", "https://test.databricks.com"); + env::set_var("DATABRICKS_CLIENT_ID", "cid"); + env::set_var("DATABRICKS_APP_PORT", "9090"); + let config = AppConfig::from_env().unwrap(); + assert_eq!(config.databricks_host, "https://test.databricks.com"); + assert_eq!(config.client_id.as_deref(), Some("cid")); + assert_eq!(config.app_port, 9090); + + // Restore original env state. + restore_env("DATABRICKS_HOST", orig_host); + restore_env("DATABRICKS_CLIENT_ID", orig_client_id); + restore_env("DATABRICKS_APP_PORT", orig_app_port); + } +} diff --git a/packages/appkit-rs/src/connectors/files.rs b/packages/appkit-rs/src/connectors/files.rs new file mode 100644 index 00000000..5deeef95 --- /dev/null +++ b/packages/appkit-rs/src/connectors/files.rs @@ -0,0 +1,737 @@ +use pyo3::prelude::*; +use reqwest::Client; +use serde::Deserialize; + +/// Maximum file read size in bytes (10 MB), matching TS FILES_MAX_READ_SIZE. +const FILES_MAX_READ_SIZE: usize = 10 * 1024 * 1024; + +// --------------------------------------------------------------------------- +// Internal serde types for Databricks Files API responses +// --------------------------------------------------------------------------- + +#[derive(Deserialize)] +struct DirectoryListResponse { + #[serde(default)] + contents: Vec, + next_page_token: Option, +} + +#[derive(Deserialize)] +struct DirectoryEntryRaw { + path: Option, + name: Option, + is_directory: Option, + file_size: Option, + last_modified: Option, +} + +// --------------------------------------------------------------------------- +// Python-facing response types (frozen / immutable) +// --------------------------------------------------------------------------- + +/// A single entry in a directory listing. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct FileDirectoryEntry { + #[pyo3(get)] + pub path: String, + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub is_directory: bool, + #[pyo3(get)] + pub file_size: Option, + #[pyo3(get)] + pub last_modified: Option, +} + +#[pymethods] +impl FileDirectoryEntry { + fn __repr__(&self) -> String { + format!( + "FileDirectoryEntry(name={:?}, is_directory={})", + self.name, self.is_directory + ) + } + + fn __eq__(&self, other: &Self) -> bool { + self.path == other.path && self.name == other.name && self.is_directory == other.is_directory + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.path.hash(&mut hasher); + self.name.hash(&mut hasher); + self.is_directory.hash(&mut hasher); + hasher.finish() + } +} + +/// File metadata from a HEAD request. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct FileMetadata { + #[pyo3(get)] + pub content_length: Option, + #[pyo3(get)] + pub content_type: Option, + #[pyo3(get)] + pub last_modified: Option, +} + +#[pymethods] +impl FileMetadata { + fn __repr__(&self) -> String { + format!( + "FileMetadata(content_type={:?}, content_length={:?})", + self.content_type, self.content_length + ) + } +} + +/// File preview with optional text content. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct FilePreview { + #[pyo3(get)] + pub content_length: Option, + #[pyo3(get)] + pub content_type: Option, + #[pyo3(get)] + pub last_modified: Option, + #[pyo3(get)] + pub text_preview: Option, + #[pyo3(get)] + pub is_text: bool, + #[pyo3(get)] + pub is_image: bool, +} + +#[pymethods] +impl FilePreview { + fn __repr__(&self) -> String { + format!( + "FilePreview(content_type={:?}, is_text={}, is_image={})", + self.content_type, self.is_text, self.is_image + ) + } +} + +// --------------------------------------------------------------------------- +// Path validation helpers +// --------------------------------------------------------------------------- + +fn validate_and_resolve_path( + file_path: &str, + default_volume: Option<&str>, +) -> Result { + if file_path.len() > 4096 { + return Err(format!( + "Path exceeds maximum length of 4096 characters (got {}).", + file_path.len() + )); + } + if file_path.contains('\0') { + return Err("Path must not contain null bytes.".into()); + } + if file_path.split('/').any(|s| s == "..") { + return Err("Path traversal (\"../\") is not allowed.".into()); + } + + if file_path.starts_with('/') { + if !file_path.starts_with("/Volumes/") { + return Err( + "Absolute paths must start with \"/Volumes/\". \ + Unity Catalog volume paths follow the format: /Volumes////" + .into(), + ); + } + return Ok(file_path.to_string()); + } + + match default_volume { + Some(vol) => Ok(format!("{}/{}", vol, file_path)), + None => Err( + "Cannot resolve relative path: no default volume set. \ + Use an absolute path or set a default volume." + .into(), + ), + } +} + +/// Returns true if the content type is text-like. +fn is_text_content_type(content_type: &str) -> bool { + if content_type.starts_with("text/") { + return true; + } + const TEXT_KEYWORDS: &[&str] = &["json", "xml", "yaml", "sql", "javascript"]; + TEXT_KEYWORDS.iter().any(|kw| content_type.contains(kw)) +} + +// --------------------------------------------------------------------------- +// FilesConnector +// --------------------------------------------------------------------------- + +/// Databricks Files API connector. +/// +/// Provides file/directory operations against Unity Catalog volumes using +/// the REST API at `/api/2.0/fs/`. Auth tokens are passed per-call so both +/// service-principal and OBO flows are supported. +#[pyclass(module = "appkit")] +pub struct FilesConnector { + host: String, + default_volume: Option, + http: Client, +} + +#[pymethods] +impl FilesConnector { + #[new] + #[pyo3(signature = (host, *, default_volume = None))] + fn new(host: String, default_volume: Option) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + default_volume, + http: Client::new(), + } + } + + /// Validate and resolve a file path, applying the default volume if needed. + #[pyo3(signature = (file_path))] + fn resolve_path(&self, file_path: &str) -> PyResult { + validate_and_resolve_path(file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err) + } + + /// List contents of a directory. + #[pyo3(signature = (token, *, directory_path = None))] + fn list<'py>( + &self, + py: Python<'py>, + token: String, + directory_path: Option, + ) -> PyResult> { + let resolved = match directory_path { + Some(ref p) => validate_and_resolve_path(p, self.default_volume.as_deref()), + None => self + .default_volume + .clone() + .ok_or_else(|| "No directory path provided and no default volume set.".to_string()), + } + .map_err(pyo3::exceptions::PyValueError::new_err)?; + + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut all_entries = Vec::new(); + let mut page_token: Option = None; + + loop { + let mut url = format!("{}/api/2.0/fs/directories{}", host, resolved); + if let Some(ref tok) = page_token { + url = format!("{}?page_token={}", url, tok); + } + + let resp = http + .get(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "List directory failed ({status}): {body}" + ))); + } + + let data: DirectoryListResponse = resp + .json() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + for raw in data.contents { + all_entries.push(FileDirectoryEntry { + path: raw.path.unwrap_or_default(), + name: raw.name.unwrap_or_default(), + is_directory: raw.is_directory.unwrap_or(false), + file_size: raw.file_size, + last_modified: raw.last_modified, + }); + } + + match data.next_page_token { + Some(tok) if !tok.is_empty() => page_token = Some(tok), + _ => break, + } + } + + Ok(all_entries) + }) + } + + /// Read a file as a UTF-8 string. + #[pyo3(signature = (token, file_path, *, max_size = None))] + fn read<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + max_size: Option, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let max = max_size.unwrap_or(FILES_MAX_READ_SIZE); + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!("{}/api/2.0/fs/files{}", host, resolved); + let resp = http + .get(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Read file failed ({status}): {body}" + ))); + } + + let bytes = resp + .bytes() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if bytes.len() > max { + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "File exceeds maximum read size ({max} bytes). Use download() for large files." + ))); + } + + String::from_utf8(bytes.to_vec()).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "File is not valid UTF-8: {e}" + )) + }) + }) + } + + /// Download a file as raw bytes. + #[pyo3(signature = (token, file_path))] + fn download<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!("{}/api/2.0/fs/files{}", host, resolved); + let resp = http + .get(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Download failed ({status}): {body}" + ))); + } + + let bytes = resp + .bytes() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + Ok(bytes.to_vec()) + }) + } + + /// Check if a file exists. + #[pyo3(signature = (token, file_path))] + fn exists<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!("{}/api/2.0/fs/files{}", host, resolved); + let resp = http + .head(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + Ok(resp.status().is_success()) + }) + } + + /// Get file metadata via HEAD request. + #[pyo3(signature = (token, file_path))] + fn metadata<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!("{}/api/2.0/fs/files{}", host, resolved); + let resp = http + .head(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Metadata request failed ({status})" + ))); + } + + let headers = resp.headers(); + let content_length = headers + .get("content-length") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()); + let content_type = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let last_modified = headers + .get("last-modified") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + Ok(FileMetadata { + content_length, + content_type, + last_modified, + }) + }) + } + + /// Upload file contents. Defaults to overwrite=True. + #[pyo3(signature = (token, file_path, contents, *, overwrite = true))] + fn upload<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + contents: Vec, + overwrite: bool, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!( + "{}/api/2.0/fs/files{}?overwrite={}", + host, resolved, overwrite + ); + let resp = http + .put(&url) + .bearer_auth(&token) + .header("Content-Type", "application/octet-stream") + .body(contents) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Upload failed ({status}): {body}" + ))); + } + + Ok(()) + }) + } + + /// Create a directory. + #[pyo3(signature = (token, directory_path))] + fn create_directory<'py>( + &self, + py: Python<'py>, + token: String, + directory_path: String, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&directory_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!("{}/api/2.0/fs/directories{}", host, resolved); + let resp = http + .put(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Create directory failed ({status}): {body}" + ))); + } + + Ok(()) + }) + } + + /// Delete a file. + #[pyo3(signature = (token, file_path))] + fn delete<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!("{}/api/2.0/fs/files{}", host, resolved); + let resp = http + .delete(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Delete failed ({status}): {body}" + ))); + } + + Ok(()) + }) + } + + /// Get file preview with optional text content. + #[pyo3(signature = (token, file_path, *, max_chars = 1024))] + fn preview<'py>( + &self, + py: Python<'py>, + token: String, + file_path: String, + max_chars: usize, + ) -> PyResult> { + let resolved = validate_and_resolve_path(&file_path, self.default_volume.as_deref()) + .map_err(pyo3::exceptions::PyValueError::new_err)?; + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // 1. HEAD for metadata + let meta_url = format!("{}/api/2.0/fs/files{}", host, resolved); + let head_resp = http + .head(&meta_url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !head_resp.status().is_success() { + let status = head_resp.status(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Preview metadata failed ({status})" + ))); + } + + let headers = head_resp.headers(); + let content_length = headers + .get("content-length") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()); + let content_type = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let last_modified = headers + .get("last-modified") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let ct = content_type.as_deref().unwrap_or(""); + let is_text = is_text_content_type(ct); + let is_image = ct.starts_with("image/"); + + if !is_text { + return Ok(FilePreview { + content_length, + content_type, + last_modified, + text_preview: None, + is_text: false, + is_image, + }); + } + + // 2. GET the file for text preview + let get_resp = http + .get(&meta_url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !get_resp.status().is_success() { + return Ok(FilePreview { + content_length, + content_type, + last_modified, + text_preview: Some(String::new()), + is_text: true, + is_image: false, + }); + } + + let bytes = get_resp + .bytes() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + let full_text = String::from_utf8_lossy(&bytes); + let preview = if full_text.len() > max_chars { + full_text[..max_chars].to_string() + } else { + full_text.into_owned() + }; + + Ok(FilePreview { + content_length, + content_type, + last_modified, + text_preview: Some(preview), + is_text: true, + is_image: false, + }) + }) + } + + fn __repr__(&self) -> String { + format!( + "FilesConnector(host={:?}, default_volume={:?})", + self.host, self.default_volume + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_resolve_absolute_volumes_path() { + let result = + validate_and_resolve_path("/Volumes/catalog/schema/vol/file.txt", None); + assert_eq!(result.unwrap(), "/Volumes/catalog/schema/vol/file.txt"); + } + + #[test] + fn test_resolve_absolute_non_volumes_rejected() { + let result = validate_and_resolve_path("/etc/passwd", None); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("/Volumes/")); + } + + #[test] + fn test_resolve_relative_with_default_volume() { + let result = validate_and_resolve_path( + "subdir/file.txt", + Some("/Volumes/catalog/schema/vol"), + ); + assert_eq!( + result.unwrap(), + "/Volumes/catalog/schema/vol/subdir/file.txt" + ); + } + + #[test] + fn test_resolve_relative_without_default_volume() { + let result = validate_and_resolve_path("file.txt", None); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("default volume")); + } + + #[test] + fn test_path_traversal_rejected() { + let result = validate_and_resolve_path( + "/Volumes/catalog/schema/vol/../../../etc/passwd", + None, + ); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("traversal")); + } + + #[test] + fn test_null_byte_rejected() { + let result = validate_and_resolve_path("/Volumes/cat/sch/vol/f\0ile.txt", None); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("null")); + } + + #[test] + fn test_path_too_long() { + let long_path = format!("/Volumes/{}", "a".repeat(4097)); + let result = validate_and_resolve_path(&long_path, None); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("4096")); + } + + #[test] + fn test_is_text_content_type() { + assert!(is_text_content_type("text/plain")); + assert!(is_text_content_type("text/html")); + assert!(is_text_content_type("application/json")); + assert!(is_text_content_type("application/xml")); + assert!(is_text_content_type("application/x-yaml")); + assert!(is_text_content_type("application/sql")); + assert!(is_text_content_type("application/javascript")); + assert!(!is_text_content_type("image/png")); + assert!(!is_text_content_type("application/octet-stream")); + } +} diff --git a/packages/appkit-rs/src/connectors/genie.rs b/packages/appkit-rs/src/connectors/genie.rs new file mode 100644 index 00000000..2377e6fb --- /dev/null +++ b/packages/appkit-rs/src/connectors/genie.rs @@ -0,0 +1,885 @@ +use pyo3::prelude::*; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +// --------------------------------------------------------------------------- +// Defaults matching TS genieConnectorDefaults +// --------------------------------------------------------------------------- + +const DEFAULT_TIMEOUT_MS: u64 = 120_000; +const DEFAULT_MAX_MESSAGES: usize = 200; +const DEFAULT_INITIAL_PAGE_SIZE: u32 = 20; +const DEFAULT_PAGE_SIZE: u32 = 100; +const DEFAULT_POLL_INTERVAL_MS: u64 = 3_000; + +// --------------------------------------------------------------------------- +// Internal serde types for Databricks Genie API +// --------------------------------------------------------------------------- + +#[derive(Serialize)] +struct StartConversationBody { + content: String, +} + +#[derive(Serialize)] +struct CreateMessageBody { + content: String, +} + +#[derive(Deserialize, Debug)] +struct StartConversationResponse { + conversation_id: Option, + message_id: Option, +} + +#[derive(Deserialize, Debug)] +struct CreateMessageResponse { + message_id: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct GenieMessageRaw { + message_id: Option, + conversation_id: Option, + space_id: Option, + status: Option, + content: Option, + #[serde(default)] + attachments: Option>, + error: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct GenieAttachmentRaw { + attachment_id: Option, + query: Option, + text: Option, + suggested_questions: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct GenieQueryRaw { + title: Option, + description: Option, + query: Option, + statement_id: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct GenieTextRaw { + content: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct GenieQuestionsRaw { + questions: Option>, +} + +#[derive(Deserialize, Debug, Clone)] +struct GenieErrorRaw { + error: Option, +} + +#[derive(Deserialize, Debug)] +struct ListMessagesResponse { + #[serde(default)] + messages: Vec, + next_page_token: Option, +} + +#[derive(Deserialize, Debug)] +struct QueryResultWrapper { + statement_response: Option, +} + +// --------------------------------------------------------------------------- +// Python-facing response types (frozen / immutable) +// --------------------------------------------------------------------------- + +/// Genie query attachment metadata. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct GenieAttachment { + #[pyo3(get)] + pub attachment_id: Option, + #[pyo3(get)] + pub query_title: Option, + #[pyo3(get)] + pub query_description: Option, + #[pyo3(get)] + pub query_sql: Option, + #[pyo3(get)] + pub query_statement_id: Option, + #[pyo3(get)] + pub text_content: Option, + #[pyo3(get)] + pub suggested_questions: Option>, +} + +#[pymethods] +impl GenieAttachment { + fn __repr__(&self) -> String { + format!( + "GenieAttachment(attachment_id={:?}, statement_id={:?})", + self.attachment_id, self.query_statement_id + ) + } +} + +impl GenieAttachment { + fn from_raw(raw: &GenieAttachmentRaw) -> Self { + Self { + attachment_id: raw.attachment_id.clone(), + query_title: raw.query.as_ref().and_then(|q| q.title.clone()), + query_description: raw.query.as_ref().and_then(|q| q.description.clone()), + query_sql: raw.query.as_ref().and_then(|q| q.query.clone()), + query_statement_id: raw.query.as_ref().and_then(|q| q.statement_id.clone()), + text_content: raw.text.as_ref().and_then(|t| t.content.clone()), + suggested_questions: raw + .suggested_questions + .as_ref() + .and_then(|sq| sq.questions.clone()), + } + } +} + +/// Genie message response. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct GenieMessage { + #[pyo3(get)] + pub message_id: String, + #[pyo3(get)] + pub conversation_id: String, + #[pyo3(get)] + pub space_id: String, + #[pyo3(get)] + pub status: String, + #[pyo3(get)] + pub content: String, + #[pyo3(get)] + pub attachments: Vec, + #[pyo3(get)] + pub error: Option, +} + +#[pymethods] +impl GenieMessage { + fn __repr__(&self) -> String { + format!( + "GenieMessage(message_id={:?}, status={:?})", + self.message_id, self.status + ) + } +} + +impl GenieMessage { + fn from_raw(raw: &GenieMessageRaw) -> Self { + Self { + message_id: raw.message_id.clone().unwrap_or_default(), + conversation_id: raw.conversation_id.clone().unwrap_or_default(), + space_id: raw.space_id.clone().unwrap_or_default(), + status: raw.status.clone().unwrap_or_else(|| "COMPLETED".into()), + content: raw.content.clone().unwrap_or_default(), + attachments: raw + .attachments + .as_ref() + .map(|atts| atts.iter().map(GenieAttachment::from_raw).collect()) + .unwrap_or_default(), + error: raw.error.as_ref().and_then(|e| e.error.clone()), + } + } +} + +/// Full conversation history. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct GenieConversationHistory { + #[pyo3(get)] + pub conversation_id: String, + #[pyo3(get)] + pub space_id: String, + #[pyo3(get)] + pub messages: Vec, +} + +#[pymethods] +impl GenieConversationHistory { + fn __repr__(&self) -> String { + format!( + "GenieConversationHistory(conversation_id={:?}, message_count={})", + self.conversation_id, + self.messages.len() + ) + } + + fn __len__(&self) -> usize { + self.messages.len() + } +} + +/// Query result from a Genie attachment (statement_response). +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct GenieQueryResult { + /// Raw JSON of the statement_response. + #[pyo3(get)] + pub data: String, +} + +#[pymethods] +impl GenieQueryResult { + fn __repr__(&self) -> String { + let len = self.data.len().min(80); + format!("GenieQueryResult(data={:?}...)", &self.data[..len]) + } +} + +// --------------------------------------------------------------------------- +// Error classification — mirrors TS classifyGenieError +// --------------------------------------------------------------------------- + +fn classify_error(msg: &str) -> String { + if msg.contains("RESOURCE_DOES_NOT_EXIST") { + return "You don't have access to this Genie Space.".into(); + } + if msg.contains("failed to reach COMPLETED state") && msg.contains("FAILED") { + return "You may not have access to the data tables. Please verify your table permissions." + .into(); + } + if msg.is_empty() { + return "Genie request failed".into(); + } + msg.to_string() +} + +// --------------------------------------------------------------------------- +// GenieConnector +// --------------------------------------------------------------------------- + +/// Databricks Genie connector. +/// +/// Provides conversation/message operations, attachment fetching, and +/// polling for Genie AI query results via the REST API. +#[pyclass(module = "appkit")] +pub struct GenieConnector { + host: String, + timeout_ms: u64, + max_messages: usize, + http: Client, +} + +impl GenieConnector { + fn base_url(&self, space_id: &str) -> String { + format!("{}/api/2.0/genie/spaces/{}", self.host, space_id) + } + + /// Poll GET message until terminal state (COMPLETED/FAILED). + async fn poll_message( + &self, + token: &str, + space_id: &str, + conversation_id: &str, + message_id: &str, + timeout_ms: u64, + ) -> Result { + let start = std::time::Instant::now(); + + loop { + let elapsed = start.elapsed().as_millis() as u64; + if elapsed > timeout_ms { + return Err(format!( + "Genie message polling timed out after {timeout_ms}ms" + )); + } + + tokio::time::sleep(std::time::Duration::from_millis(DEFAULT_POLL_INTERVAL_MS)).await; + + let url = format!( + "{}/conversations/{}/messages/{}", + self.base_url(space_id), + conversation_id, + message_id, + ); + + let resp = self + .http + .get(&url) + .bearer_auth(token) + .send() + .await + .map_err(|e| format!("Poll message failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(format!("Poll message ({status}): {body}")); + } + + let raw: GenieMessageRaw = resp + .json() + .await + .map_err(|e| format!("Parse message failed: {e}"))?; + + let state = raw.status.as_deref().unwrap_or(""); + if state == "COMPLETED" || state == "FAILED" { + return Ok(GenieMessage::from_raw(&raw)); + } + } + } + + /// Fetch a page of messages (internal). + async fn list_messages_internal( + &self, + token: &str, + space_id: &str, + conversation_id: &str, + page_size: u32, + page_token: Option<&str>, + ) -> Result<(Vec, Option), String> { + let mut url = format!( + "{}/conversations/{}/messages?page_size={}", + self.base_url(space_id), + conversation_id, + page_size, + ); + if let Some(pt) = page_token { + url = format!("{}&page_token={}", url, pt); + } + + let resp = self + .http + .get(&url) + .bearer_auth(token) + .send() + .await + .map_err(|e| format!("List messages failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(format!("List messages ({status}): {body}")); + } + + let data: ListMessagesResponse = resp + .json() + .await + .map_err(|e| format!("Parse list response: {e}"))?; + + // Reverse to chronological order (API returns newest first) + let messages: Vec = data + .messages + .iter() + .rev() + .map(GenieMessage::from_raw) + .collect(); + + let next = data + .next_page_token + .filter(|t| !t.is_empty()); + + Ok((messages, next)) + } +} + +#[pymethods] +impl GenieConnector { + #[new] + #[pyo3(signature = (host, *, timeout_ms = None, max_messages = None))] + fn new(host: String, timeout_ms: Option, max_messages: Option) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS), + max_messages: max_messages.unwrap_or(DEFAULT_MAX_MESSAGES), + http: Client::new(), + } + } + + /// Start a new conversation or add a message to an existing one. + /// Returns (conversation_id, message_id). + #[pyo3(signature = (token, space_id, content, *, conversation_id = None))] + fn start_message<'py>( + &self, + py: Python<'py>, + token: String, + space_id: String, + content: String, + conversation_id: Option, + ) -> PyResult> { + let http = self.http.clone(); + let base = self.base_url(&space_id); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + if let Some(ref conv_id) = conversation_id { + // Add message to existing conversation + let url = format!("{}/conversations/{}/messages", base, conv_id); + let resp = http + .post(&url) + .bearer_auth(&token) + .json(&CreateMessageBody { + content: content.clone(), + }) + .send() + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Create message failed: {e}" + )) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err( + classify_error(&format!("Create message ({status}): {body}")), + )); + } + + let data: CreateMessageResponse = resp.json().await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Parse response: {e}")) + })?; + + Ok(( + conv_id.clone(), + data.message_id.unwrap_or_default(), + )) + } else { + // Start new conversation + let url = format!("{}/start-conversation", base); + let resp = http + .post(&url) + .bearer_auth(&token) + .json(&StartConversationBody { + content: content.clone(), + }) + .send() + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Start conversation failed: {e}" + )) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err( + classify_error(&format!("Start conversation ({status}): {body}")), + )); + } + + let data: StartConversationResponse = resp.json().await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Parse response: {e}")) + })?; + + Ok(( + data.conversation_id.unwrap_or_default(), + data.message_id.unwrap_or_default(), + )) + } + }) + } + + /// Send a message and wait for the completed response. + #[pyo3(signature = (token, space_id, content, *, conversation_id = None, timeout_ms = None))] + fn send_message<'py>( + &self, + py: Python<'py>, + token: String, + space_id: String, + content: String, + conversation_id: Option, + timeout_ms: Option, + ) -> PyResult> { + let http = self.http.clone(); + let base = self.base_url(&space_id); + let timeout = timeout_ms.unwrap_or(self.timeout_ms); + let host = self.host.clone(); + let connector_timeout = self.timeout_ms; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // Start message + let (conv_id, msg_id) = if let Some(ref existing_conv) = conversation_id { + let url = format!("{}/conversations/{}/messages", base, existing_conv); + let resp = http + .post(&url) + .bearer_auth(&token) + .json(&CreateMessageBody { + content: content.clone(), + }) + .send() + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(classify_error(&e.to_string())) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(classify_error( + &format!("{status}: {body}"), + ))); + } + + let data: CreateMessageResponse = resp.json().await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(e.to_string()) + })?; + (existing_conv.clone(), data.message_id.unwrap_or_default()) + } else { + let url = format!("{}/start-conversation", base); + let resp = http + .post(&url) + .bearer_auth(&token) + .json(&StartConversationBody { + content: content.clone(), + }) + .send() + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(classify_error(&e.to_string())) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(classify_error( + &format!("{status}: {body}"), + ))); + } + + let data: StartConversationResponse = resp.json().await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(e.to_string()) + })?; + ( + data.conversation_id.unwrap_or_default(), + data.message_id.unwrap_or_default(), + ) + }; + + // Create a temporary connector for polling + let tmp = GenieConnector { + host: host.clone(), + timeout_ms: connector_timeout, + max_messages: 0, + http: http.clone(), + }; + + // Poll until completed + tmp.poll_message(&token, &space_id, &conv_id, &msg_id, timeout) + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(classify_error(&e))) + }) + } + + /// Get a single message by ID, polling until terminal state. + #[pyo3(signature = (token, space_id, conversation_id, message_id, *, timeout_ms = None))] + fn get_message<'py>( + &self, + py: Python<'py>, + token: String, + space_id: String, + conversation_id: String, + message_id: String, + timeout_ms: Option, + ) -> PyResult> { + let http = self.http.clone(); + let base = self.base_url(&space_id); + let timeout = timeout_ms.unwrap_or(self.timeout_ms); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // First check current state + let url = format!( + "{}/conversations/{}/messages/{}", + base, conversation_id, message_id + ); + let resp = http + .get(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(classify_error( + &format!("{status}: {body}"), + ))); + } + + let raw: GenieMessageRaw = resp + .json() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + let state = raw.status.as_deref().unwrap_or(""); + if state == "COMPLETED" || state == "FAILED" { + return Ok(GenieMessage::from_raw(&raw)); + } + + // Poll until terminal + let start = std::time::Instant::now(); + loop { + let elapsed = start.elapsed().as_millis() as u64; + if elapsed > timeout { + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Message polling timed out after {timeout}ms" + ))); + } + + tokio::time::sleep(std::time::Duration::from_millis(DEFAULT_POLL_INTERVAL_MS)) + .await; + + let resp = http + .get(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(classify_error( + &format!("{status}: {body}"), + ))); + } + + let raw: GenieMessageRaw = resp + .json() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + let state = raw.status.as_deref().unwrap_or(""); + if state == "COMPLETED" || state == "FAILED" { + return Ok(GenieMessage::from_raw(&raw)); + } + } + }) + } + + /// List messages in a conversation (paginated). + #[pyo3(signature = (token, space_id, conversation_id, *, page_size = None, page_token = None))] + fn list_messages<'py>( + &self, + py: Python<'py>, + token: String, + space_id: String, + conversation_id: String, + page_size: Option, + page_token: Option, + ) -> PyResult> { + let http = self.http.clone(); + let host = self.host.clone(); + let timeout = self.timeout_ms; + let max_msgs = self.max_messages; + let ps = page_size.unwrap_or(DEFAULT_INITIAL_PAGE_SIZE); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let tmp = GenieConnector { + host, + timeout_ms: timeout, + max_messages: max_msgs, + http, + }; + + let (messages, next_token) = tmp + .list_messages_internal(&token, &space_id, &conversation_id, ps, page_token.as_deref()) + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(classify_error(&e)))?; + + Ok((messages, next_token)) + }) + } + + /// Get the query result for a message attachment. + #[pyo3(signature = (token, space_id, conversation_id, message_id, attachment_id))] + fn get_query_result<'py>( + &self, + py: Python<'py>, + token: String, + space_id: String, + conversation_id: String, + message_id: String, + attachment_id: String, + ) -> PyResult> { + let http = self.http.clone(); + let base = self.base_url(&space_id); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let url = format!( + "{}/conversations/{}/messages/{}/attachments/{}/query-result", + base, conversation_id, message_id, attachment_id, + ); + + let resp = http + .get(&url) + .bearer_auth(&token) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Failed to fetch query result for attachment {attachment_id} ({status}): {body}" + ))); + } + + let wrapper: QueryResultWrapper = resp + .json() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + let data = wrapper + .statement_response + .map(|v| serde_json::to_string(&v).unwrap_or_else(|_| "null".into())) + .unwrap_or_else(|| "null".into()); + + Ok(GenieQueryResult { data }) + }) + } + + /// Fetch full conversation history (all pages up to max_messages). + #[pyo3(signature = (token, space_id, conversation_id))] + fn get_conversation<'py>( + &self, + py: Python<'py>, + token: String, + space_id: String, + conversation_id: String, + ) -> PyResult> { + let http = self.http.clone(); + let host = self.host.clone(); + let timeout = self.timeout_ms; + let max_msgs = self.max_messages; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let tmp = GenieConnector { + host, + timeout_ms: timeout, + max_messages: max_msgs, + http, + }; + + let mut all_messages = Vec::new(); + let mut page_token: Option = None; + + loop { + let (messages, next) = tmp + .list_messages_internal( + &token, + &space_id, + &conversation_id, + DEFAULT_PAGE_SIZE, + page_token.as_deref(), + ) + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(classify_error(&e)))?; + + all_messages.extend(messages); + + match next { + Some(tok) if all_messages.len() < max_msgs => page_token = Some(tok), + _ => break, + } + } + + all_messages.truncate(max_msgs); + + Ok(GenieConversationHistory { + conversation_id, + space_id, + messages: all_messages, + }) + }) + } + + fn __repr__(&self) -> String { + format!( + "GenieConnector(host={:?}, timeout_ms={}, max_messages={})", + self.host, self.timeout_ms, self.max_messages + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_classify_error_resource_not_found() { + let msg = classify_error("Something RESOURCE_DOES_NOT_EXIST happened"); + assert!(msg.contains("access to this Genie Space")); + } + + #[test] + fn test_classify_error_table_permissions() { + let msg = classify_error("failed to reach COMPLETED state: FAILED"); + assert!(msg.contains("table permissions")); + } + + #[test] + fn test_classify_error_empty() { + let msg = classify_error(""); + assert_eq!(msg, "Genie request failed"); + } + + #[test] + fn test_classify_error_passthrough() { + let msg = classify_error("Some other error"); + assert_eq!(msg, "Some other error"); + } + + #[test] + fn test_message_from_raw() { + let raw = GenieMessageRaw { + message_id: Some("msg-1".into()), + conversation_id: Some("conv-1".into()), + space_id: Some("space-1".into()), + status: Some("COMPLETED".into()), + content: Some("Hello".into()), + attachments: Some(vec![GenieAttachmentRaw { + attachment_id: Some("att-1".into()), + query: Some(GenieQueryRaw { + title: Some("My Query".into()), + description: None, + query: Some("SELECT 1".into()), + statement_id: Some("stmt-1".into()), + }), + text: None, + suggested_questions: None, + }]), + error: None, + }; + + let msg = GenieMessage::from_raw(&raw); + assert_eq!(msg.message_id, "msg-1"); + assert_eq!(msg.status, "COMPLETED"); + assert_eq!(msg.attachments.len(), 1); + assert_eq!(msg.attachments[0].query_sql.as_deref(), Some("SELECT 1")); + assert_eq!( + msg.attachments[0].query_statement_id.as_deref(), + Some("stmt-1") + ); + } + + #[test] + fn test_message_from_raw_defaults() { + let raw = GenieMessageRaw { + message_id: None, + conversation_id: None, + space_id: None, + status: None, + content: None, + attachments: None, + error: None, + }; + + let msg = GenieMessage::from_raw(&raw); + assert_eq!(msg.message_id, ""); + assert_eq!(msg.status, "COMPLETED"); + assert!(msg.attachments.is_empty()); + } +} diff --git a/packages/appkit-rs/src/connectors/lakebase.rs b/packages/appkit-rs/src/connectors/lakebase.rs new file mode 100644 index 00000000..fbe5255e --- /dev/null +++ b/packages/appkit-rs/src/connectors/lakebase.rs @@ -0,0 +1,354 @@ +use pyo3::prelude::*; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::env; + +// --------------------------------------------------------------------------- +// Internal serde types for Databricks Lakebase API +// --------------------------------------------------------------------------- + +#[derive(Serialize)] +struct GenerateCredentialRequest { + instance_names: Vec, + request_id: String, +} + +#[derive(Deserialize, Debug)] +struct GenerateCredentialResponse { + token: Option, + expiration_time: Option, +} + +// --------------------------------------------------------------------------- +// Python-facing response types (frozen / immutable) +// --------------------------------------------------------------------------- + +/// Generated database credential for Lakebase access. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct DatabaseCredential { + /// OAuth token for database authentication. + #[pyo3(get)] + pub token: String, + /// ISO 8601 expiration time. + #[pyo3(get)] + pub expiration_time: String, +} + +#[pymethods] +impl DatabaseCredential { + fn __repr__(&self) -> String { + format!( + "DatabaseCredential(expiration_time={:?})", + self.expiration_time + ) + } + + fn __eq__(&self, other: &Self) -> bool { + self.token == other.token && self.expiration_time == other.expiration_time + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.token.hash(&mut hasher); + self.expiration_time.hash(&mut hasher); + hasher.finish() + } +} + +/// PostgreSQL connection configuration for Lakebase. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct LakebasePgConfig { + #[pyo3(get)] + pub host: String, + #[pyo3(get)] + pub database: String, + #[pyo3(get)] + pub port: u16, + #[pyo3(get)] + pub ssl_mode: String, + #[pyo3(get)] + pub app_name: Option, +} + +#[pymethods] +impl LakebasePgConfig { + /// Build a Lakebase PG config from environment variables. + /// + /// Reads: `PGHOST`, `PGDATABASE`, `PGPORT`, `PGSSLMODE`, `PGAPPNAME`, + /// and `LAKEBASE_ENDPOINT` (alternative to PGHOST). + #[new] + #[pyo3(signature = (*, host = None, database = None, port = None, ssl_mode = None, app_name = None))] + fn new( + host: Option, + database: Option, + port: Option, + ssl_mode: Option, + app_name: Option, + ) -> PyResult { + let resolved_host = host + .or_else(|| non_empty_env("PGHOST")) + .or_else(|| non_empty_env("LAKEBASE_ENDPOINT")) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "Lakebase host is required. Set PGHOST or LAKEBASE_ENDPOINT, or pass host=.", + ) + })?; + + let resolved_db = database + .or_else(|| non_empty_env("PGDATABASE")) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "Lakebase database is required. Set PGDATABASE or pass database=.", + ) + })?; + + let resolved_port = port.unwrap_or_else(|| { + non_empty_env("PGPORT") + .and_then(|v| v.parse().ok()) + .unwrap_or(5432) + }); + + let resolved_ssl = ssl_mode + .or_else(|| non_empty_env("PGSSLMODE")) + .unwrap_or_else(|| "require".into()); + + let resolved_app = app_name.or_else(|| non_empty_env("PGAPPNAME")); + + Ok(Self { + host: resolved_host, + database: resolved_db, + port: resolved_port, + ssl_mode: resolved_ssl, + app_name: resolved_app, + }) + } + + /// Build from environment variables only. + #[staticmethod] + fn from_env() -> PyResult { + Self::new(None, None, None, None, None) + } + + fn __repr__(&self) -> String { + format!( + "LakebasePgConfig(host={:?}, database={:?}, port={})", + self.host, self.database, self.port + ) + } + + fn __eq__(&self, other: &Self) -> bool { + self.host == other.host + && self.database == other.database + && self.port == other.port + && self.ssl_mode == other.ssl_mode + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.host.hash(&mut hasher); + self.database.hash(&mut hasher); + self.port.hash(&mut hasher); + hasher.finish() + } +} + +fn non_empty_env(key: &str) -> Option { + env::var(key).ok().filter(|v| !v.is_empty()) +} + +// --------------------------------------------------------------------------- +// LakebaseConnector +// --------------------------------------------------------------------------- + +/// Databricks Lakebase connector. +/// +/// Provides credential generation for database access via the REST API +/// at `/api/2.0/database/credentials`, and pool-config retrieval from +/// environment variables. +#[pyclass(module = "appkit")] +pub struct LakebaseConnector { + host: String, + http: Client, +} + +#[pymethods] +impl LakebaseConnector { + #[new] + #[pyo3(signature = (host))] + fn new(host: String) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + http: Client::new(), + } + } + + /// Generate a database credential for the given Lakebase instance(s). + /// + /// Calls POST `/api/2.0/database/credentials` with the service-principal + /// token. Returns a `DatabaseCredential` containing the temporary + /// password token and its expiration time. + #[pyo3(signature = (token, instance_names, *, request_id = None))] + fn generate_credential<'py>( + &self, + py: Python<'py>, + token: String, + instance_names: Vec, + request_id: Option, + ) -> PyResult> { + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let rid = request_id.unwrap_or_else(|| { + // Simple UUID v4-like random ID + use rand::Rng; + let mut rng = rand::thread_rng(); + let bytes: [u8; 16] = rng.gen(); + format!( + "{:08x}-{:04x}-4{:03x}-{:04x}-{:012x}", + u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), + u16::from_be_bytes([bytes[4], bytes[5]]), + u16::from_be_bytes([bytes[6], bytes[7]]) & 0x0FFF, + (u16::from_be_bytes([bytes[8], bytes[9]]) & 0x3FFF) | 0x8000, + u64::from_be_bytes([ + 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15] + ]), + ) + }); + + let url = format!("{}/api/2.0/database/credentials", host); + let body = GenerateCredentialRequest { + instance_names, + request_id: rid, + }; + + let resp = http + .post(&url) + .bearer_auth(&token) + .json(&body) + .send() + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Credential generation request failed: {e}" + )) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Credential generation failed ({status}): {text}" + ))); + } + + let data: GenerateCredentialResponse = resp.json().await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Failed to parse credential response: {e}" + )) + })?; + + let token_val = data.token.filter(|t| !t.is_empty()).ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err( + "Credential response missing token", + ) + })?; + + let expiration = data + .expiration_time + .filter(|t| !t.is_empty()) + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err( + "Credential response missing expiration_time", + ) + })?; + + Ok(DatabaseCredential { + token: token_val, + expiration_time: expiration, + }) + }) + } + + fn __repr__(&self) -> String { + format!("LakebaseConnector(host={:?})", self.host) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + + #[test] + fn test_pg_config_explicit() { + let config = LakebasePgConfig::new( + Some("host.example.com".into()), + Some("mydb".into()), + Some(5433), + Some("prefer".into()), + Some("myapp".into()), + ) + .unwrap(); + + assert_eq!(config.host, "host.example.com"); + assert_eq!(config.database, "mydb"); + assert_eq!(config.port, 5433); + assert_eq!(config.ssl_mode, "prefer"); + assert_eq!(config.app_name.as_deref(), Some("myapp")); + } + + #[test] + #[serial] + fn test_pg_config_defaults() { + // Set required env vars + env::set_var("PGHOST", "env-host.example.com"); + env::set_var("PGDATABASE", "envdb"); + env::remove_var("PGPORT"); + env::remove_var("PGSSLMODE"); + env::remove_var("PGAPPNAME"); + env::remove_var("LAKEBASE_ENDPOINT"); + + let config = LakebasePgConfig::from_env().unwrap(); + assert_eq!(config.host, "env-host.example.com"); + assert_eq!(config.database, "envdb"); + assert_eq!(config.port, 5432); + assert_eq!(config.ssl_mode, "require"); + assert!(config.app_name.is_none()); + + // Cleanup + env::remove_var("PGHOST"); + env::remove_var("PGDATABASE"); + } + + #[test] + #[serial] + fn test_pg_config_missing_host() { + env::remove_var("PGHOST"); + env::remove_var("LAKEBASE_ENDPOINT"); + env::set_var("PGDATABASE", "db"); + + let result = LakebasePgConfig::new(None, None, None, None, None); + assert!(result.is_err()); + + env::remove_var("PGDATABASE"); + } + + #[test] + #[serial] + fn test_pg_config_lakebase_endpoint_fallback() { + env::remove_var("PGHOST"); + env::set_var("LAKEBASE_ENDPOINT", "lakebase.example.com"); + env::set_var("PGDATABASE", "db"); + + let config = LakebasePgConfig::from_env().unwrap(); + assert_eq!(config.host, "lakebase.example.com"); + + env::remove_var("LAKEBASE_ENDPOINT"); + env::remove_var("PGDATABASE"); + } +} diff --git a/packages/appkit-rs/src/connectors/mod.rs b/packages/appkit-rs/src/connectors/mod.rs new file mode 100644 index 00000000..83af5ec3 --- /dev/null +++ b/packages/appkit-rs/src/connectors/mod.rs @@ -0,0 +1,6 @@ +pub mod files; +pub mod genie; +pub mod lakebase; +pub mod serving; +pub mod sql_warehouse; +pub mod vector_search; diff --git a/packages/appkit-rs/src/connectors/serving.rs b/packages/appkit-rs/src/connectors/serving.rs new file mode 100644 index 00000000..e88e1482 --- /dev/null +++ b/packages/appkit-rs/src/connectors/serving.rs @@ -0,0 +1,347 @@ +use futures::StreamExt; +use pyo3::prelude::*; +use reqwest::Client; + +use crate::interceptor::{ExecutionError, StreamItem}; +use crate::plugin::PyStreamIterator; + +// --------------------------------------------------------------------------- +// Python-facing response types (frozen / immutable) +// --------------------------------------------------------------------------- + +/// Response from a serving endpoint invocation. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct ServingResponse { + /// Response body as JSON string. + #[pyo3(get)] + pub data: String, + /// HTTP status code from the endpoint. + #[pyo3(get)] + pub status_code: u16, +} + +#[pymethods] +impl ServingResponse { + fn __repr__(&self) -> String { + format!( + "ServingResponse(status_code={}, data_len={})", + self.status_code, + self.data.len() + ) + } + + fn __bool__(&self) -> bool { + self.status_code >= 200 && self.status_code < 300 + } + + fn __eq__(&self, other: &Self) -> bool { + self.data == other.data && self.status_code == other.status_code + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.data.hash(&mut hasher); + self.status_code.hash(&mut hasher); + hasher.finish() + } +} + +// --------------------------------------------------------------------------- +// SSE parser +// --------------------------------------------------------------------------- + +/// Incremental SSE event parser. Buffers partial chunks and emits complete +/// `data:` payloads when a blank-line event boundary is encountered. +struct SseParser { + buffer: String, + data_lines: Vec, +} + +impl SseParser { + fn new() -> Self { + Self { + buffer: String::new(), + data_lines: Vec::new(), + } + } + + /// Feed a chunk of bytes and return any complete event data payloads. + fn feed(&mut self, chunk: &[u8]) -> Vec { + self.buffer.push_str(&String::from_utf8_lossy(chunk)); + let mut events = Vec::new(); + + while let Some(pos) = self.buffer.find('\n') { + let line = self.buffer[..pos].trim_end_matches('\r').to_string(); + self.buffer = self.buffer[pos + 1..].to_string(); + + if line.is_empty() { + // Empty line = event boundary. + if !self.data_lines.is_empty() { + events.push(self.data_lines.join("\n")); + self.data_lines.clear(); + } + } else if let Some(data) = line.strip_prefix("data: ") { + self.data_lines.push(data.to_string()); + } else if let Some(data) = line.strip_prefix("data:") { + self.data_lines.push(data.to_string()); + } + // Ignore other SSE fields (event:, id:, retry:, comments). + } + + events + } +} + +// --------------------------------------------------------------------------- +// ServingConnector +// --------------------------------------------------------------------------- + +/// Databricks Serving Endpoints connector. +/// +/// Provides invocation and SSE streaming against model serving endpoints +/// via the REST API at `/serving-endpoints/{name}/invocations`. +#[pyclass(module = "appkit")] +pub struct ServingConnector { + host: String, + http: Client, +} + +#[pymethods] +impl ServingConnector { + #[new] + #[pyo3(signature = (host))] + fn new(host: String) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + http: Client::new(), + } + } + + /// Invoke a serving endpoint (non-streaming). + /// + /// Strips any `stream` key from the body to prevent conflict with the + /// connector's control of streaming mode (mirrors TS behavior). + /// `body` is a JSON string of the request payload. + #[pyo3(signature = (token, endpoint_name, body))] + fn invoke<'py>( + &self, + py: Python<'py>, + token: String, + endpoint_name: String, + body: String, + ) -> PyResult> { + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // Parse, strip `stream` key, re-serialize + let mut payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Invalid JSON body: {e}")) + })?; + if let Some(obj) = payload.as_object_mut() { + obj.remove("stream"); + } + + let url = format!( + "{}/serving-endpoints/{}/invocations", + host, + urlencoding::encode(&endpoint_name), + ); + + let resp = http + .post(&url) + .bearer_auth(&token) + .header("Content-Type", "application/json") + .body(serde_json::to_vec(&payload).unwrap_or_default()) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + let status_code = resp.status().as_u16(); + let text = resp + .text() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if status_code >= 400 { + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Serving endpoint invocation failed ({status_code}): {text}" + ))); + } + + Ok(ServingResponse { + data: text, + status_code, + }) + }) + } + + /// Stream from a serving endpoint (SSE). + /// + /// Returns a `StreamIterator` that yields parsed SSE data payloads as + /// they arrive. Each item is the `data:` field content (typically JSON). + /// + /// Sets `stream: true` in the request body and `Accept: text/event-stream`. + /// The stream ends when the server sends `data: [DONE]` or closes the + /// connection. + #[pyo3(signature = (token, endpoint_name, body))] + fn stream<'py>( + &self, + py: Python<'py>, + token: String, + endpoint_name: String, + body: String, + ) -> PyResult> { + let http = self.http.clone(); + let host = self.host.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // Parse, strip existing `stream`, set `stream: true` + let mut payload: serde_json::Value = serde_json::from_str(&body).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Invalid JSON body: {e}")) + })?; + if let Some(obj) = payload.as_object_mut() { + obj.remove("stream"); + obj.insert("stream".to_string(), serde_json::Value::Bool(true)); + } + + let url = format!( + "{}/serving-endpoints/{}/invocations", + host, + urlencoding::encode(&endpoint_name), + ); + + let resp = http + .post(&url) + .bearer_auth(&token) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .body(serde_json::to_vec(&payload).unwrap_or_default()) + .send() + .await + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "SSE stream request failed ({status}): {body}" + ))); + } + + let (tx, rx) = tokio::sync::mpsc::channel::(32); + + // Spawn task to incrementally parse SSE events from the byte stream. + let byte_stream = resp.bytes_stream(); + tokio::spawn(async move { + let mut parser = SseParser::new(); + tokio::pin!(byte_stream); + + while let Some(chunk_result) = byte_stream.next().await { + match chunk_result { + Ok(bytes) => { + for event_data in parser.feed(&bytes) { + if event_data == "[DONE]" { + return; + } + if tx.send(Ok(event_data)).await.is_err() { + return; // Receiver dropped. + } + } + } + Err(e) => { + let _ = tx + .send(Err(ExecutionError { + status: 500, + message: e.to_string(), + })) + .await; + return; + } + } + } + }); + + Ok(PyStreamIterator::new(rx)) + }) + } + + fn __repr__(&self) -> String { + format!("ServingConnector(host={:?})", self.host) + } +} + +#[cfg(test)] +mod tests { + use super::SseParser; + + #[test] + fn test_strip_stream_from_body() { + let body = r#"{"inputs": "hello", "stream": false}"#; + let mut payload: serde_json::Value = serde_json::from_str(body).unwrap(); + if let Some(obj) = payload.as_object_mut() { + obj.remove("stream"); + } + assert!(!payload.as_object().unwrap().contains_key("stream")); + assert_eq!(payload["inputs"], "hello"); + } + + #[test] + fn test_set_stream_true() { + let body = r#"{"inputs": "hello"}"#; + let mut payload: serde_json::Value = serde_json::from_str(body).unwrap(); + if let Some(obj) = payload.as_object_mut() { + obj.remove("stream"); + obj.insert("stream".to_string(), serde_json::Value::Bool(true)); + } + assert_eq!(payload["stream"], true); + } + + // -- SSE parser -- + + #[test] + fn test_sse_parser_single_event() { + let mut parser = SseParser::new(); + let events = parser.feed(b"data: {\"key\":\"val\"}\n\n"); + assert_eq!(events, vec!["{\"key\":\"val\"}"]); + } + + #[test] + fn test_sse_parser_multi_chunk() { + let mut parser = SseParser::new(); + let events1 = parser.feed(b"data: hel"); + assert!(events1.is_empty()); + let events2 = parser.feed(b"lo\n\n"); + assert_eq!(events2, vec!["hello"]); + } + + #[test] + fn test_sse_parser_multiple_events() { + let mut parser = SseParser::new(); + let events = parser.feed(b"data: first\n\ndata: second\n\n"); + assert_eq!(events, vec!["first", "second"]); + } + + #[test] + fn test_sse_parser_done_sentinel() { + let mut parser = SseParser::new(); + let events = parser.feed(b"data: [DONE]\n\n"); + assert_eq!(events, vec!["[DONE]"]); + } + + #[test] + fn test_sse_parser_ignores_non_data_lines() { + let mut parser = SseParser::new(); + let events = parser.feed(b"event: message\nid: 1\ndata: payload\n\n"); + assert_eq!(events, vec!["payload"]); + } + + #[test] + fn test_sse_parser_crlf() { + let mut parser = SseParser::new(); + let events = parser.feed(b"data: hello\r\n\r\n"); + assert_eq!(events, vec!["hello"]); + } +} diff --git a/packages/appkit-rs/src/connectors/sql_warehouse.rs b/packages/appkit-rs/src/connectors/sql_warehouse.rs new file mode 100644 index 00000000..abbbe6de --- /dev/null +++ b/packages/appkit-rs/src/connectors/sql_warehouse.rs @@ -0,0 +1,578 @@ +use pyo3::prelude::*; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +// --------------------------------------------------------------------------- +// Defaults matching TS executeStatementDefaults +// --------------------------------------------------------------------------- + +const DEFAULT_WAIT_TIMEOUT: &str = "30s"; +const DEFAULT_DISPOSITION: &str = "INLINE"; +const DEFAULT_FORMAT: &str = "JSON_ARRAY"; +const DEFAULT_ON_WAIT_TIMEOUT: &str = "CONTINUE"; +const DEFAULT_POLL_TIMEOUT_MS: u64 = 60_000; +const INITIAL_POLL_DELAY_MS: u64 = 1_000; +const MAX_POLL_DELAY_MS: u64 = 5_000; + +// --------------------------------------------------------------------------- +// Internal serde types for Databricks SQL Statement Execution API +// --------------------------------------------------------------------------- + +#[derive(Serialize)] +struct ExecuteStatementBody { + statement: String, + warehouse_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + parameters: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + catalog: Option, + #[serde(skip_serializing_if = "Option::is_none")] + schema: Option, + wait_timeout: String, + disposition: String, + format: String, + on_wait_timeout: String, + #[serde(skip_serializing_if = "Option::is_none")] + byte_limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + row_limit: Option, +} + +#[derive(Serialize, Deserialize, Clone)] +struct StatementParam { + name: String, + value: String, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + type_name: Option, +} + +#[derive(Deserialize, Debug)] +struct StatementApiResponse { + statement_id: Option, + status: Option, + manifest: Option, + result: Option, +} + +#[derive(Deserialize, Debug)] +struct StatementStatus { + state: Option, + error: Option, +} + +#[derive(Deserialize, Debug)] +struct StatementError { + message: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct StatementManifest { + #[allow(dead_code)] + format: Option, + schema: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct ManifestSchema { + columns: Option>, +} + +#[derive(Deserialize, Debug, Clone)] +struct ColumnInfo { + name: Option, + type_name: Option, +} + +#[derive(Deserialize, Debug)] +struct StatementResult { + data_array: Option>>>, +} + +// --------------------------------------------------------------------------- +// Python-facing response types (frozen / immutable) +// --------------------------------------------------------------------------- + +/// Column schema information. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct SqlColumn { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub type_name: String, +} + +#[pymethods] +impl SqlColumn { + fn __repr__(&self) -> String { + format!("SqlColumn(name={:?}, type_name={:?})", self.name, self.type_name) + } + + fn __eq__(&self, other: &Self) -> bool { + self.name == other.name && self.type_name == other.type_name + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.name.hash(&mut hasher); + self.type_name.hash(&mut hasher); + hasher.finish() + } +} + +/// Result of a SQL statement execution. +#[pyclass(frozen, module = "appkit")] +#[derive(Clone)] +pub struct SqlStatementResult { + #[pyo3(get)] + pub statement_id: String, + #[pyo3(get)] + pub status: String, + /// Column schema. + #[pyo3(get)] + pub columns: Vec, + /// Result rows as JSON string (array of objects). Empty string if no data. + #[pyo3(get)] + pub data: String, + /// Number of result rows. + #[pyo3(get)] + pub row_count: usize, +} + +#[pymethods] +impl SqlStatementResult { + fn __repr__(&self) -> String { + format!( + "SqlStatementResult(statement_id={:?}, status={:?}, row_count={})", + self.statement_id, self.status, self.row_count + ) + } + + fn __len__(&self) -> usize { + self.row_count + } + + fn __bool__(&self) -> bool { + self.row_count > 0 + } +} + +// --------------------------------------------------------------------------- +// Data transformation — mirrors TS _transformDataArray +// --------------------------------------------------------------------------- + +fn transform_data_array( + columns: &[ColumnInfo], + data_array: &[Vec>], +) -> (Vec>, usize) { + let mut rows = Vec::with_capacity(data_array.len()); + + for row in data_array { + let mut obj = serde_json::Map::new(); + for (i, cell) in row.iter().enumerate() { + let col = columns.get(i); + let col_name = col + .and_then(|c| c.name.as_deref()) + .unwrap_or(&format!("column_{i}")) + .to_string(); + let col_type = col.and_then(|c| c.type_name.as_deref()).unwrap_or(""); + + let value = match cell { + None => JsonValue::Null, + Some(v) => { + // Attempt to parse JSON for STRING columns (matches TS behavior) + if col_type == "STRING" + && !v.is_empty() + && (v.starts_with('{') || v.starts_with('[')) + { + serde_json::from_str(v).unwrap_or_else(|_| JsonValue::String(v.clone())) + } else { + JsonValue::String(v.clone()) + } + } + }; + obj.insert(col_name, value); + } + rows.push(obj); + } + + let count = rows.len(); + (rows, count) +} + +// --------------------------------------------------------------------------- +// SqlWarehouseConnector +// --------------------------------------------------------------------------- + +/// Databricks SQL Warehouse connector. +/// +/// Executes SQL statements against a Databricks SQL Warehouse via the +/// REST API at `/api/2.0/sql/statements`. Supports parameterized queries, +/// automatic polling for async results, and JSON_ARRAY result transformation. +#[pyclass(module = "appkit")] +pub struct SqlWarehouseConnector { + host: String, + timeout_ms: u64, + http: Client, +} + +impl SqlWarehouseConnector { + /// Internal: execute and poll for statement result. + async fn execute_internal( + host: &str, + http: &Client, + token: &str, + body: &ExecuteStatementBody, + timeout_ms: u64, + ) -> Result { + // 1. Submit statement + let url = format!("{}/api/2.0/sql/statements", host); + let resp = http + .post(&url) + .bearer_auth(token) + .json(body) + .send() + .await + .map_err(|e| format!("Statement execution request failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(format!("Execute statement failed ({status}): {text}")); + } + + let api_resp: StatementApiResponse = resp + .json() + .await + .map_err(|e| format!("Failed to parse statement response: {e}"))?; + + let statement_id = api_resp + .statement_id + .clone() + .unwrap_or_default(); + let state = api_resp + .status + .as_ref() + .and_then(|s| s.state.as_deref()) + .unwrap_or("UNKNOWN"); + + match state { + "SUCCEEDED" => Self::build_result(&api_resp), + "PENDING" | "RUNNING" => { + Self::poll_for_result(host, http, token, &statement_id, timeout_ms).await + } + "FAILED" => { + let msg = api_resp + .status + .as_ref() + .and_then(|s| s.error.as_ref()) + .and_then(|e| e.message.as_deref()) + .unwrap_or("Statement failed"); + Err(format!("Statement failed: {msg}")) + } + "CANCELED" => Err("Statement was canceled.".into()), + "CLOSED" => Err("Statement results have been closed.".into()), + other => Err(format!("Unknown statement state: {other}")), + } + } + + /// Internal: poll GET /api/2.0/sql/statements/{id} until terminal state. + async fn poll_for_result( + host: &str, + http: &Client, + token: &str, + statement_id: &str, + timeout_ms: u64, + ) -> Result { + let start = std::time::Instant::now(); + let mut delay = INITIAL_POLL_DELAY_MS; + + loop { + let elapsed = start.elapsed().as_millis() as u64; + if elapsed > timeout_ms { + return Err(format!( + "Polling timeout exceeded after {timeout_ms}ms (elapsed: {elapsed}ms)" + )); + } + + tokio::time::sleep(std::time::Duration::from_millis(delay)).await; + delay = (delay * 2).min(MAX_POLL_DELAY_MS); + + let url = format!("{}/api/2.0/sql/statements/{}", host, statement_id); + let resp = http + .get(&url) + .bearer_auth(token) + .send() + .await + .map_err(|e| format!("Poll request failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(format!("Poll failed ({status}): {body}")); + } + + let api_resp: StatementApiResponse = resp + .json() + .await + .map_err(|e| format!("Failed to parse poll response: {e}"))?; + + let state = api_resp + .status + .as_ref() + .and_then(|s| s.state.as_deref()) + .unwrap_or("UNKNOWN"); + + match state { + "SUCCEEDED" => return Self::build_result(&api_resp), + "PENDING" | "RUNNING" => continue, + "FAILED" => { + let msg = api_resp + .status + .as_ref() + .and_then(|s| s.error.as_ref()) + .and_then(|e| e.message.as_deref()) + .unwrap_or("Statement failed"); + return Err(format!("Statement failed: {msg}")); + } + "CANCELED" => return Err("Statement was canceled.".into()), + "CLOSED" => return Err("Statement results have been closed.".into()), + other => return Err(format!("Unknown statement state: {other}")), + } + } + } + + /// Build a SqlStatementResult from a SUCCEEDED API response. + fn build_result(api_resp: &StatementApiResponse) -> Result { + let statement_id = api_resp.statement_id.clone().unwrap_or_default(); + let columns_raw = api_resp + .manifest + .as_ref() + .and_then(|m| m.schema.as_ref()) + .and_then(|s| s.columns.as_ref()); + + let columns: Vec = columns_raw + .map(|cols| { + cols.iter() + .map(|c| SqlColumn { + name: c.name.clone().unwrap_or_default(), + type_name: c.type_name.clone().unwrap_or_default(), + }) + .collect() + }) + .unwrap_or_default(); + + let data_array = api_resp + .result + .as_ref() + .and_then(|r| r.data_array.as_ref()); + + let (data_json, row_count) = match (columns_raw, data_array) { + (Some(cols), Some(arr)) => { + let (rows, count) = transform_data_array(cols, arr); + let json = + serde_json::to_string(&rows).unwrap_or_else(|_| "[]".to_string()); + (json, count) + } + _ => ("[]".to_string(), 0), + }; + + Ok(SqlStatementResult { + statement_id, + status: "SUCCEEDED".to_string(), + columns, + data: data_json, + row_count, + }) + } +} + +#[pymethods] +impl SqlWarehouseConnector { + #[new] + #[pyo3(signature = (host, *, timeout_ms = None))] + fn new(host: String, timeout_ms: Option) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + timeout_ms: timeout_ms.unwrap_or(DEFAULT_POLL_TIMEOUT_MS), + http: Client::new(), + } + } + + /// Execute a SQL statement against a warehouse. + /// + /// Polls automatically if the statement is PENDING/RUNNING. Returns + /// a `SqlStatementResult` with transformed JSON_ARRAY data (array of + /// column-named objects) matching the TypeScript SDK behavior. + #[pyo3(signature = ( + token, + statement, + warehouse_id, + *, + parameters = None, + catalog = None, + schema = None, + wait_timeout = None, + disposition = None, + format = None, + on_wait_timeout = None, + byte_limit = None, + row_limit = None, + timeout_ms = None, + ))] + #[allow(clippy::too_many_arguments)] + fn execute_statement<'py>( + &self, + py: Python<'py>, + token: String, + statement: String, + warehouse_id: String, + parameters: Option>, + catalog: Option, + schema: Option, + wait_timeout: Option, + disposition: Option, + format: Option, + on_wait_timeout: Option, + byte_limit: Option, + row_limit: Option, + timeout_ms: Option, + ) -> PyResult> { + let host = self.host.clone(); + let http = self.http.clone(); + let poll_timeout = timeout_ms.unwrap_or(self.timeout_ms); + + let params = parameters.map(|ps| { + ps.into_iter() + .map(|(name, value)| StatementParam { + name, + value, + type_name: None, + }) + .collect() + }); + + let body = ExecuteStatementBody { + statement, + warehouse_id, + parameters: params, + catalog, + schema, + wait_timeout: wait_timeout.unwrap_or_else(|| DEFAULT_WAIT_TIMEOUT.into()), + disposition: disposition.unwrap_or_else(|| DEFAULT_DISPOSITION.into()), + format: format.unwrap_or_else(|| DEFAULT_FORMAT.into()), + on_wait_timeout: on_wait_timeout.unwrap_or_else(|| DEFAULT_ON_WAIT_TIMEOUT.into()), + byte_limit, + row_limit, + }; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + Self::execute_internal(&host, &http, &token, &body, poll_timeout) + .await + .map_err(pyo3::exceptions::PyRuntimeError::new_err) + }) + } + + fn __repr__(&self) -> String { + format!( + "SqlWarehouseConnector(host={:?}, timeout_ms={})", + self.host, self.timeout_ms + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transform_data_array_basic() { + let columns = vec![ + ColumnInfo { + name: Some("id".into()), + type_name: Some("INT".into()), + }, + ColumnInfo { + name: Some("name".into()), + type_name: Some("STRING".into()), + }, + ]; + let data = vec![ + vec![Some("1".into()), Some("Alice".into())], + vec![Some("2".into()), Some("Bob".into())], + ]; + + let (rows, count) = transform_data_array(&columns, &data); + assert_eq!(count, 2); + assert_eq!(rows[0]["id"], JsonValue::String("1".into())); + assert_eq!(rows[0]["name"], JsonValue::String("Alice".into())); + assert_eq!(rows[1]["id"], JsonValue::String("2".into())); + } + + #[test] + fn test_transform_null_values() { + let columns = vec![ColumnInfo { + name: Some("val".into()), + type_name: Some("STRING".into()), + }]; + let data = vec![vec![None]]; + + let (rows, count) = transform_data_array(&columns, &data); + assert_eq!(count, 1); + assert_eq!(rows[0]["val"], JsonValue::Null); + } + + #[test] + fn test_transform_json_string_parsing() { + let columns = vec![ColumnInfo { + name: Some("meta".into()), + type_name: Some("STRING".into()), + }]; + let data = vec![ + vec![Some(r#"{"key":"value"}"#.into())], + vec![Some("plain text".into())], + vec![Some(r#"[1,2,3]"#.into())], + ]; + + let (rows, _) = transform_data_array(&columns, &data); + // JSON object should be parsed + assert!(rows[0]["meta"].is_object()); + // Plain string stays as string + assert_eq!(rows[1]["meta"], JsonValue::String("plain text".into())); + // JSON array should be parsed + assert!(rows[2]["meta"].is_array()); + } + + #[test] + fn test_transform_non_string_type_no_json_parse() { + let columns = vec![ColumnInfo { + name: Some("data".into()), + type_name: Some("INT".into()), + }]; + let data = vec![vec![Some("{123}".into())]]; + + let (rows, _) = transform_data_array(&columns, &data); + // INT column should NOT attempt JSON parse even if value looks like JSON + assert_eq!(rows[0]["data"], JsonValue::String("{123}".into())); + } + + #[test] + fn test_build_result_empty() { + let resp = StatementApiResponse { + statement_id: Some("stmt-1".into()), + status: Some(StatementStatus { + state: Some("SUCCEEDED".into()), + error: None, + }), + manifest: None, + result: None, + }; + let result = SqlWarehouseConnector::build_result(&resp).unwrap(); + assert_eq!(result.statement_id, "stmt-1"); + assert_eq!(result.status, "SUCCEEDED"); + assert_eq!(result.row_count, 0); + assert_eq!(result.data, "[]"); + } +} diff --git a/packages/appkit-rs/src/connectors/vector_search.rs b/packages/appkit-rs/src/connectors/vector_search.rs new file mode 100644 index 00000000..7b453cfc --- /dev/null +++ b/packages/appkit-rs/src/connectors/vector_search.rs @@ -0,0 +1,713 @@ +//! Vector Search connector — thin HTTP wrapper around the Databricks +//! Vector Search REST API. Mirrors the request/response shape expected by +//! the TypeScript `VectorSearchConnector`. +//! +//! Exposed operations: +//! - `query` → POST `/api/2.0/vector-search/indexes/:index/query` +//! - `query_next_page` → POST `/api/2.0/vector-search/indexes/:index/query-next-page` +//! +//! Auth: bearer token is passed per call (service-principal or OBO), matching +//! the pattern used by `SqlWarehouseConnector` and `FilesConnector`. +//! +//! The Python-facing class wraps a handful of strongly-typed accessors but +//! returns raw JSON strings for response data so the Python layer can reuse +//! its existing response-shaping logic without re-serialising across the PyO3 +//! boundary on every row. + +use pyo3::prelude::*; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map as JsonMap, Value as JsonValue}; + +// --------------------------------------------------------------------------- +// Request / response types (Rust-internal) +// --------------------------------------------------------------------------- + +/// Single filter value — a scalar or a list of scalars. +#[derive(Clone, Debug, PartialEq)] +pub enum FilterValue { + String(String), + Number(f64), + Boolean(bool), + Array(Vec), +} + +impl FilterValue { + fn to_json(&self) -> JsonValue { + match self { + Self::String(s) => JsonValue::String(s.clone()), + Self::Number(n) => serde_json::Number::from_f64(*n) + .map(JsonValue::Number) + .unwrap_or(JsonValue::Null), + Self::Boolean(b) => JsonValue::Bool(*b), + Self::Array(v) => JsonValue::Array(v.iter().map(FilterValue::to_json).collect()), + } + } +} + +/// Parameters for a query request. +#[derive(Clone, Debug)] +pub struct VsQueryParams { + pub index_name: String, + pub query_text: Option, + pub query_vector: Option>, + pub columns: Vec, + pub num_results: u32, + pub query_type: VsQueryType, + pub filters: Option>, + pub reranker_columns: Option>, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum VsQueryType { + Ann, + Hybrid, + FullText, +} + +impl VsQueryType { + pub fn as_wire(&self) -> &'static str { + match self { + Self::Ann => "ANN", + Self::Hybrid => "HYBRID", + Self::FullText => "FULL_TEXT", + } + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::Ann => "ann", + Self::Hybrid => "hybrid", + Self::FullText => "full_text", + } + } + + pub fn parse(s: &str) -> Option { + match s.to_ascii_lowercase().as_str() { + "ann" => Some(Self::Ann), + "hybrid" => Some(Self::Hybrid), + "full_text" => Some(Self::FullText), + _ => None, + } + } +} + +#[derive(Clone, Debug)] +pub struct VsNextPageParams { + pub index_name: String, + pub endpoint_name: String, + pub page_token: String, +} + +// --------------------------------------------------------------------------- +// Body building — separated from HTTP for testability +// --------------------------------------------------------------------------- + +/// Build the JSON body for a `query` request. This is the core +/// request-builder behavior called out in the checklist — keep it pure so +/// tests can verify the wire shape without a live HTTP server. +pub fn build_query_body(p: &VsQueryParams) -> JsonValue { + let mut body = JsonMap::new(); + body.insert( + "columns".into(), + JsonValue::Array( + p.columns + .iter() + .map(|c| JsonValue::String(c.clone())) + .collect(), + ), + ); + body.insert("num_results".into(), JsonValue::Number(p.num_results.into())); + body.insert( + "query_type".into(), + JsonValue::String(p.query_type.as_wire().to_string()), + ); + body.insert("debug_level".into(), JsonValue::Number(1.into())); + + if let Some(ref q) = p.query_text { + body.insert("query_text".into(), JsonValue::String(q.clone())); + } + if let Some(ref v) = p.query_vector { + let arr: Vec = v + .iter() + .map(|n| { + serde_json::Number::from_f64(*n) + .map(JsonValue::Number) + .unwrap_or(JsonValue::Null) + }) + .collect(); + body.insert("query_vector".into(), JsonValue::Array(arr)); + } + if let Some(ref filters) = p.filters { + if !filters.is_empty() { + let mut map = JsonMap::new(); + for (k, v) in filters { + map.insert(k.clone(), v.to_json()); + } + body.insert("filters".into(), JsonValue::Object(map)); + } + } + if let Some(ref cols) = p.reranker_columns { + body.insert( + "reranker".into(), + json!({ + "model": "databricks_reranker", + "parameters": { "columns_to_rerank": cols } + }), + ); + } + + JsonValue::Object(body) +} + +/// Build the JSON body for a `query-next-page` request. +pub fn build_next_page_body(p: &VsNextPageParams) -> JsonValue { + json!({ + "endpoint_name": p.endpoint_name, + "page_token": p.page_token, + }) +} + +// --------------------------------------------------------------------------- +// HTTP connector +// --------------------------------------------------------------------------- + +const DEFAULT_TIMEOUT_MS: u64 = 30_000; + +/// Databricks Vector Search REST connector. HTTP-level only — the plugin +/// layer handles embedding-fn resolution, reranker resolution, and response +/// shaping into `SearchResponse`. +#[pyclass(module = "appkit")] +pub struct VectorSearchConnector { + host: String, + timeout_ms: u64, + http: Client, +} + +impl VectorSearchConnector { + async fn post_json( + &self, + token: &str, + path: &str, + body: &JsonValue, + ) -> Result { + let url = format!("{}{}", self.host, path); + let resp = self + .http + .post(&url) + .bearer_auth(token) + .timeout(std::time::Duration::from_millis(self.timeout_ms)) + .json(body) + .send() + .await + .map_err(|e| format!("Vector search request failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + return Err(format!("Vector search error ({status}): {text}")); + } + + resp.text() + .await + .map_err(|e| format!("Failed to read response body: {e}")) + } + + /// Low-level query — returns the raw JSON body string from the VS API. + pub async fn query_internal( + &self, + token: &str, + params: &VsQueryParams, + ) -> Result { + let path = format!( + "/api/2.0/vector-search/indexes/{}/query", + params.index_name + ); + let body = build_query_body(params); + self.post_json(token, &path, &body).await + } + + /// Low-level next-page query. + pub async fn query_next_page_internal( + &self, + token: &str, + params: &VsNextPageParams, + ) -> Result { + let path = format!( + "/api/2.0/vector-search/indexes/{}/query-next-page", + params.index_name + ); + let body = build_next_page_body(params); + self.post_json(token, &path, &body).await + } +} + +// --------------------------------------------------------------------------- +// Python-facing request structs +// --------------------------------------------------------------------------- + +/// Mirrors the TS `SearchRequest` input — parsed from a plain dict on the +/// Python side and passed here to keep request validation + body construction +/// in Rust. +#[pyclass(frozen, name = "VsSearchRequest", module = "appkit")] +#[derive(Clone)] +pub struct PyVsSearchRequest { + #[pyo3(get)] + pub query_text: Option, + #[pyo3(get)] + pub query_vector: Option>, + #[pyo3(get)] + pub columns: Option>, + #[pyo3(get)] + pub num_results: Option, + #[pyo3(get)] + pub query_type: Option, + /// Filters as JSON string (parsed/serialized by Python side). + #[pyo3(get)] + pub filters_json: Option, + #[pyo3(get)] + pub reranker_columns: Option>, +} + +#[pymethods] +impl PyVsSearchRequest { + #[new] + #[pyo3(signature = (*, query_text = None, query_vector = None, columns = None, num_results = None, query_type = None, filters_json = None, reranker_columns = None))] + #[allow(clippy::too_many_arguments)] + fn new( + query_text: Option, + query_vector: Option>, + columns: Option>, + num_results: Option, + query_type: Option, + filters_json: Option, + reranker_columns: Option>, + ) -> Self { + Self { + query_text, + query_vector, + columns, + num_results, + query_type, + filters_json, + reranker_columns, + } + } + + fn __repr__(&self) -> String { + format!( + "VsSearchRequest(query_type={:?}, num_results={:?})", + self.query_type, self.num_results + ) + } +} + +#[pymethods] +impl VectorSearchConnector { + #[new] + #[pyo3(signature = (host, *, timeout_ms = None))] + fn new(host: String, timeout_ms: Option) -> Self { + Self { + host: host.trim_end_matches('/').to_string(), + timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS), + http: Client::new(), + } + } + + /// Execute a query. Returns the raw JSON body of the API response as a + /// string (matches `SqlWarehouseConnector.execute_statement` convention of + /// returning deserialized-once results). + #[pyo3(signature = ( + token, + index_name, + *, + columns, + num_results = 20, + query_type = "hybrid".to_string(), + query_text = None, + query_vector = None, + filters_json = None, + reranker_columns = None, + ))] + #[allow(clippy::too_many_arguments)] + fn query<'py>( + &self, + py: Python<'py>, + token: String, + index_name: String, + columns: Vec, + num_results: u32, + query_type: String, + query_text: Option, + query_vector: Option>, + filters_json: Option, + reranker_columns: Option>, + ) -> PyResult> { + let qt = VsQueryType::parse(&query_type).ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err(format!( + "Invalid query_type: {query_type}. Expected ann | hybrid | full_text" + )) + })?; + + let filters = match filters_json { + Some(s) if !s.is_empty() => Some(parse_filters_json(&s).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Invalid filters_json: {e}" + )) + })?), + _ => None, + }; + + let params = VsQueryParams { + index_name, + query_text, + query_vector, + columns, + num_results, + query_type: qt, + filters, + reranker_columns, + }; + + let host = self.host.clone(); + let http = self.http.clone(); + let timeout_ms = self.timeout_ms; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let connector = Self { host, timeout_ms, http }; + connector + .query_internal(&token, ¶ms) + .await + .map_err(pyo3::exceptions::PyRuntimeError::new_err) + }) + } + + /// Fetch the next page of results for a paginated query. + #[pyo3(signature = (token, index_name, endpoint_name, page_token))] + fn query_next_page<'py>( + &self, + py: Python<'py>, + token: String, + index_name: String, + endpoint_name: String, + page_token: String, + ) -> PyResult> { + let host = self.host.clone(); + let http = self.http.clone(); + let timeout_ms = self.timeout_ms; + + let params = VsNextPageParams { index_name, endpoint_name, page_token }; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let connector = Self { host, timeout_ms, http }; + connector + .query_next_page_internal(&token, ¶ms) + .await + .map_err(pyo3::exceptions::PyRuntimeError::new_err) + }) + } + + fn __repr__(&self) -> String { + format!( + "VectorSearchConnector(host={:?}, timeout_ms={})", + self.host, self.timeout_ms + ) + } +} + +// Parse `filters_json` — a JSON object of scalar or array values. +fn parse_filters_json(s: &str) -> Result, String> { + let raw: JsonValue = + serde_json::from_str(s).map_err(|e| format!("not valid JSON: {e}"))?; + let JsonValue::Object(map) = raw else { + return Err("expected a JSON object".into()); + }; + let mut out = Vec::with_capacity(map.len()); + for (k, v) in map { + out.push((k, json_to_filter_value(v)?)); + } + Ok(out) +} + +fn json_to_filter_value(v: JsonValue) -> Result { + match v { + JsonValue::String(s) => Ok(FilterValue::String(s)), + JsonValue::Number(n) => n + .as_f64() + .ok_or_else(|| "non-finite number".to_string()) + .map(FilterValue::Number), + JsonValue::Bool(b) => Ok(FilterValue::Boolean(b)), + JsonValue::Array(arr) => { + let mut items = Vec::with_capacity(arr.len()); + for item in arr { + items.push(json_to_filter_value(item)?); + } + Ok(FilterValue::Array(items)) + } + JsonValue::Null => Err("null filter values are not supported".into()), + JsonValue::Object(_) => Err("nested object filters are not supported".into()), + } +} + +// --------------------------------------------------------------------------- +// Response shaping — ported from the TS `_parseResponse` +// --------------------------------------------------------------------------- + +#[derive(Deserialize)] +struct RawResponse { + manifest: Option, + result: Option, + next_page_token: Option, + debug_info: Option, +} + +#[derive(Deserialize)] +struct RawManifest { + #[serde(default)] + columns: Vec, +} + +#[derive(Deserialize)] +struct RawColumn { + name: String, +} + +#[derive(Deserialize)] +struct RawResult { + #[serde(default)] + row_count: u64, + #[serde(default)] + data_array: Vec>, +} + +#[derive(Deserialize)] +struct RawDebugInfo { + #[serde(default)] + response_time: Option, + #[serde(default)] + latency_ms: Option, +} + +/// Ported from `VectorSearchPlugin._parseResponse` — shape raw VS output +/// into `{ results: [{ score, data }], totalCount, queryTimeMs, queryType, +/// nextPageToken }`. Keeping this on the Rust side lets Python consumers +/// render hits directly without re-implementing the row-to-object transform. +#[derive(Serialize)] +pub struct SearchResponse { + pub results: Vec, + #[serde(rename = "totalCount")] + pub total_count: u64, + #[serde(rename = "queryTimeMs")] + pub query_time_ms: u64, + #[serde(rename = "queryType")] + pub query_type: String, + #[serde(rename = "nextPageToken")] + pub next_page_token: Option, +} + +#[derive(Serialize)] +pub struct SearchHit { + pub score: f64, + pub data: JsonMap, +} + +pub fn parse_vs_response(raw_body: &str, query_type: VsQueryType) -> Result { + let raw: RawResponse = + serde_json::from_str(raw_body).map_err(|e| format!("Invalid VS response: {e}"))?; + let manifest = raw.manifest.unwrap_or(RawManifest { columns: vec![] }); + let result = raw.result.unwrap_or(RawResult { + row_count: 0, + data_array: vec![], + }); + + let col_names: Vec = manifest.columns.iter().map(|c| c.name.clone()).collect(); + let score_idx = col_names.iter().position(|n| n == "score"); + + let mut hits = Vec::with_capacity(result.data_array.len()); + for row in &result.data_array { + let mut data = JsonMap::new(); + for (i, name) in col_names.iter().enumerate() { + if name == "score" { + continue; + } + let value = row.get(i).cloned().unwrap_or(JsonValue::Null); + data.insert(name.clone(), value); + } + let score = score_idx + .and_then(|idx| row.get(idx)) + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + hits.push(SearchHit { score, data }); + } + + let query_time_ms = raw + .debug_info + .as_ref() + .and_then(|d| d.response_time.or(d.latency_ms)) + .unwrap_or(0); + + Ok(SearchResponse { + results: hits, + total_count: result.row_count, + query_time_ms, + query_type: query_type.as_str().to_string(), + next_page_token: raw.next_page_token, + }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_query_body_minimal() { + let body = build_query_body(&VsQueryParams { + index_name: "cat.sch.idx".into(), + query_text: Some("hello".into()), + query_vector: None, + columns: vec!["id".into(), "title".into()], + num_results: 10, + query_type: VsQueryType::Hybrid, + filters: None, + reranker_columns: None, + }); + assert_eq!(body["query_text"], JsonValue::String("hello".into())); + assert_eq!(body["num_results"], JsonValue::Number(10.into())); + assert_eq!(body["query_type"], JsonValue::String("HYBRID".into())); + assert_eq!(body["debug_level"], JsonValue::Number(1.into())); + assert_eq!( + body["columns"], + JsonValue::Array(vec![ + JsonValue::String("id".into()), + JsonValue::String("title".into()) + ]) + ); + // No query_vector, filters, or reranker when not provided. + assert!(body.get("query_vector").is_none()); + assert!(body.get("filters").is_none()); + assert!(body.get("reranker").is_none()); + } + + #[test] + fn test_build_query_body_with_reranker_and_filters() { + let body = build_query_body(&VsQueryParams { + index_name: "x".into(), + query_text: None, + query_vector: Some(vec![0.1, 0.2, 0.3]), + columns: vec!["id".into()], + num_results: 5, + query_type: VsQueryType::Ann, + filters: Some(vec![ + ("region".into(), FilterValue::String("us-west".into())), + ( + "tags".into(), + FilterValue::Array(vec![ + FilterValue::String("a".into()), + FilterValue::String("b".into()), + ]), + ), + ]), + reranker_columns: Some(vec!["title".into(), "body".into()]), + }); + assert_eq!(body["query_type"], JsonValue::String("ANN".into())); + let vec_arr = body["query_vector"].as_array().unwrap(); + assert_eq!(vec_arr.len(), 3); + assert_eq!(body["filters"]["region"], JsonValue::String("us-west".into())); + assert_eq!( + body["filters"]["tags"], + JsonValue::Array(vec![ + JsonValue::String("a".into()), + JsonValue::String("b".into()) + ]) + ); + assert_eq!( + body["reranker"]["model"], + JsonValue::String("databricks_reranker".into()) + ); + assert_eq!( + body["reranker"]["parameters"]["columns_to_rerank"], + JsonValue::Array(vec![ + JsonValue::String("title".into()), + JsonValue::String("body".into()) + ]) + ); + } + + #[test] + fn test_build_next_page_body() { + let body = build_next_page_body(&VsNextPageParams { + index_name: "x".into(), + endpoint_name: "ep".into(), + page_token: "tok".into(), + }); + assert_eq!(body["endpoint_name"], JsonValue::String("ep".into())); + assert_eq!(body["page_token"], JsonValue::String("tok".into())); + } + + #[test] + fn test_query_type_parse_and_roundtrip() { + assert_eq!(VsQueryType::parse("ann"), Some(VsQueryType::Ann)); + assert_eq!(VsQueryType::parse("Hybrid"), Some(VsQueryType::Hybrid)); + assert_eq!(VsQueryType::parse("FULL_TEXT"), Some(VsQueryType::FullText)); + assert_eq!(VsQueryType::parse("bogus"), None); + + assert_eq!(VsQueryType::Ann.as_wire(), "ANN"); + assert_eq!(VsQueryType::Hybrid.as_str(), "hybrid"); + } + + #[test] + fn test_parse_filters_json() { + let f = parse_filters_json(r#"{"a":"x","b":1,"c":true,"d":["p","q"]}"#).unwrap(); + let map: std::collections::HashMap = f.into_iter().collect(); + assert_eq!(map["a"], FilterValue::String("x".into())); + assert_eq!(map["b"], FilterValue::Number(1.0)); + assert_eq!(map["c"], FilterValue::Boolean(true)); + match &map["d"] { + FilterValue::Array(items) => assert_eq!(items.len(), 2), + _ => panic!("expected array"), + } + } + + #[test] + fn test_parse_filters_rejects_null_and_nested() { + assert!(parse_filters_json(r#"{"a":null}"#).is_err()); + assert!(parse_filters_json(r#"{"a":{"nested":"x"}}"#).is_err()); + assert!(parse_filters_json(r#"[1,2]"#).is_err()); // non-object + } + + #[test] + fn test_parse_vs_response_shapes_hits() { + let raw = r#"{ + "manifest": {"columns": [{"name":"id"},{"name":"title"},{"name":"score"}]}, + "result": {"row_count": 2, "data_array": [["1","hello",0.9],["2","world",0.8]]}, + "next_page_token": "tok", + "debug_info": {"response_time": 42} + }"#; + let resp = parse_vs_response(raw, VsQueryType::Hybrid).unwrap(); + assert_eq!(resp.total_count, 2); + assert_eq!(resp.query_type, "hybrid"); + assert_eq!(resp.query_time_ms, 42); + assert_eq!(resp.next_page_token.as_deref(), Some("tok")); + assert_eq!(resp.results.len(), 2); + assert_eq!(resp.results[0].score, 0.9); + assert_eq!(resp.results[0].data["id"], JsonValue::String("1".into())); + assert_eq!( + resp.results[0].data["title"], + JsonValue::String("hello".into()) + ); + assert!(resp.results[0].data.get("score").is_none()); + } + + #[test] + fn test_parse_vs_response_missing_fields() { + let raw = r#"{}"#; + let resp = parse_vs_response(raw, VsQueryType::Ann).unwrap(); + assert_eq!(resp.total_count, 0); + assert!(resp.results.is_empty()); + assert_eq!(resp.query_time_ms, 0); + } +} diff --git a/packages/appkit-rs/src/context.rs b/packages/appkit-rs/src/context.rs new file mode 100644 index 00000000..eb9ec333 --- /dev/null +++ b/packages/appkit-rs/src/context.rs @@ -0,0 +1,157 @@ +//! Python `contextvars`-based execution context. +//! +//! Mirrors the TypeScript `AsyncLocalStorage`-based execution context in +//! `packages/appkit/src/context/execution-context.ts`. +//! +//! Provides: +//! - `_USER_CONTEXT_VAR`: a module-level `contextvars.ContextVar` holding the +//! current `UserContext` (or `None` when running as service principal). +//! - `run_in_user_context(user_ctx, fn)`: run a sync callable with user context. +//! - `as_user(user_ctx, async_fn)`: run an async callable with user context. +//! - `get_current_user()`: retrieve the current `UserContext`, or `None`. +//! - `is_in_user_context()`: check whether a user context is active. + +use pyo3::prelude::*; + +use crate::auth::UserContext; + +// --------------------------------------------------------------------------- +// Module-level ContextVar helpers +// --------------------------------------------------------------------------- + +/// Create the `_USER_CONTEXT_VAR` ContextVar and register it on a module. +pub fn create_context_var(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + let contextvars = py.import("contextvars")?; + let cv = + contextvars.call_method1("ContextVar", ("appkit_user_context",))?; + m.setattr("_USER_CONTEXT_VAR", cv)?; + Ok(()) +} + +/// Retrieve the `_USER_CONTEXT_VAR` from the native appkit module. +fn get_context_var(py: Python<'_>) -> PyResult { + let module = py.import("appkit.appkit")?; + let cv = module.getattr("_USER_CONTEXT_VAR")?; + Ok(cv.into()) +} + +// --------------------------------------------------------------------------- +// Public Python functions +// --------------------------------------------------------------------------- + +/// Run a synchronous callable with the given `UserContext` set as the current +/// execution context for the duration of the call. +/// +/// Mirrors TypeScript's `runInUserContext(userContext, fn)`. +/// +/// ```python +/// result = run_in_user_context(user_ctx, lambda: do_work()) +/// ``` +#[pyfunction] +#[pyo3(signature = (user_context, func))] +pub fn run_in_user_context( + py: Python<'_>, + user_context: UserContext, + func: PyObject, +) -> PyResult { + let cv = get_context_var(py)?; + let token = cv.call_method1(py, "set", (user_context,))?; + let result = func.call0(py); + // Always reset, even on error. + let _ = cv.call_method1(py, "reset", (token,)); + result +} + +/// Run an async callable with the given `UserContext` set for the duration. +/// +/// Returns an awaitable coroutine. The context variable is set before calling +/// `async_fn()` and reset after it completes (or raises). +/// +/// ```python +/// result = await as_user(user_ctx, my_async_fn) +/// ``` +#[pyfunction] +#[pyo3(signature = (user_context, func))] +pub fn as_user<'py>( + py: Python<'py>, + user_context: UserContext, + func: PyObject, +) -> PyResult> { + let user_ctx = Py::new(py, user_context)?; + + // Import the Python-side wrapper that sets the context var *inside* + // a native coroutine, ensuring the value propagates correctly across + // the PyO3-tokio bridge. + let context_mod = py.import("appkit._context")?; + let wrapper_fn = context_mod.getattr("_as_user_wrapper")?; + + let cv = get_context_var(py)?; + let coroutine = wrapper_fn.call1((cv, user_ctx, func))?.unbind(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let future = Python::with_gil(|py| { + pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py)) + })?; + future.await + }) +} + +/// Get the current `UserContext` from the execution context, or `None` if +/// running as service principal (no user context set). +/// +/// ```python +/// user = get_current_user() +/// if user is not None: +/// print(user.user_id) +/// ``` +#[pyfunction] +pub fn get_current_user(py: Python<'_>) -> PyResult> { + let cv = get_context_var(py)?; + // Use sentinel to detect unset: cv.get() + let none = py.None(); + let val = cv.call_method1(py, "get", (none.clone_ref(py),))?; + + if val.is_none(py) { + return Ok(None); + } + + match val.extract::(py) { + Ok(ctx) => Ok(Some(ctx)), + Err(_) => Ok(None), + } +} + +/// Check whether the current execution is running in a user context. +/// +/// ```python +/// if is_in_user_context(): +/// user = get_current_user() +/// ``` +#[pyfunction] +pub fn is_in_user_context(py: Python<'_>) -> PyResult { + Ok(get_current_user(py)?.is_some()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_user_context_round_trip() { + // Verify UserContext can be created in Rust — Python interop is + // tested in the Python integration tests. + let ctx = UserContext::new( + "tok".into(), + "u1".into(), + Some("Alice".into()), + "ws-1".into(), + None, + ); + assert_eq!(ctx.user_id, "u1"); + assert_eq!(ctx.workspace_id, "ws-1"); + } +} diff --git a/packages/appkit-rs/src/errors.rs b/packages/appkit-rs/src/errors.rs new file mode 100644 index 00000000..a1b2dbd5 --- /dev/null +++ b/packages/appkit-rs/src/errors.rs @@ -0,0 +1,272 @@ +//! Shared error hierarchy — TS-style typed errors exposed as Python exception +//! classes and a Rust `AppKitErrorKind` enum for internal classification. +//! +//! Mirrors `packages/appkit/src/errors/*.ts`: +//! `AppKitError` (base) → `ValidationError`, `AuthenticationError`, +//! `NotFoundError`, `PayloadTooLargeError`, `UpstreamError`, `TimeoutError`, +//! `ConnectionError`, `ConfigurationError`, `InternalError`. +//! +//! Each Python class inherits from `appkit.AppKitError` (which inherits from +//! `Exception`). The Rust `AppKitErrorKind` enum maps them to HTTP status codes +//! and a stable string code. Plugin execute() and route handlers classify any +//! raised exception into this hierarchy so callers see consistent HTTP +//! responses and `ExecutionResult.status` values. + +use pyo3::create_exception; +use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::PyType; +use pyo3::PyTypeInfo; + +// --------------------------------------------------------------------------- +// Python exception classes +// --------------------------------------------------------------------------- + +// Base: appkit.AppKitError(Exception) +create_exception!(appkit, AppKitError, PyException); + +// Subclasses of AppKitError — kept in registration order so later lookups +// (is_instance_of) are stable across calls. +create_exception!(appkit, ValidationError, AppKitError); +create_exception!(appkit, AuthenticationError, AppKitError); +create_exception!(appkit, NotFoundError, AppKitError); +create_exception!(appkit, PayloadTooLargeError, AppKitError); +create_exception!(appkit, UpstreamError, AppKitError); +create_exception!(appkit, TimeoutError, AppKitError); +create_exception!(appkit, ConnectionError, AppKitError); +create_exception!(appkit, ConfigurationError, AppKitError); +create_exception!(appkit, InternalError, AppKitError); + +// --------------------------------------------------------------------------- +// AppKitErrorKind — Rust-side classification +// --------------------------------------------------------------------------- + +/// Typed error kind matching the TS error hierarchy. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AppKitErrorKind { + Validation, + Authentication, + NotFound, + PayloadTooLarge, + Upstream, + Timeout, + Connection, + Configuration, + Internal, +} + +impl AppKitErrorKind { + /// HTTP status code matching the TS `statusCode` field. + pub fn status(&self) -> u16 { + match self { + Self::Validation => 400, + Self::Authentication => 401, + Self::NotFound => 404, + Self::PayloadTooLarge => 413, + Self::Timeout => 408, + Self::Upstream => 502, + Self::Connection => 503, + Self::Configuration => 500, + Self::Internal => 500, + } + } + + /// Stable string code matching the TS `code` field. + pub fn code(&self) -> &'static str { + match self { + Self::Validation => "VALIDATION_ERROR", + Self::Authentication => "AUTHENTICATION_ERROR", + Self::NotFound => "NOT_FOUND", + Self::PayloadTooLarge => "PAYLOAD_TOO_LARGE", + Self::Timeout => "TIMEOUT", + Self::Upstream => "UPSTREAM_ERROR", + Self::Connection => "CONNECTION_ERROR", + Self::Configuration => "CONFIGURATION_ERROR", + Self::Internal => "INTERNAL_ERROR", + } + } + + /// Classify an HTTP status code from an upstream response. + pub fn from_http_status(status: u16) -> Self { + match status { + 400 => Self::Validation, + 401 | 403 => Self::Authentication, + 404 => Self::NotFound, + 408 => Self::Timeout, + 413 => Self::PayloadTooLarge, + 500..=599 => Self::Upstream, + _ => Self::Internal, + } + } + + /// Build a Python exception of the matching class with `message`. + pub fn to_py_err(&self, message: impl Into) -> PyErr { + let msg = message.into(); + match self { + Self::Validation => ValidationError::new_err(msg), + Self::Authentication => AuthenticationError::new_err(msg), + Self::NotFound => NotFoundError::new_err(msg), + Self::PayloadTooLarge => PayloadTooLargeError::new_err(msg), + Self::Timeout => TimeoutError::new_err(msg), + Self::Upstream => UpstreamError::new_err(msg), + Self::Connection => ConnectionError::new_err(msg), + Self::Configuration => ConfigurationError::new_err(msg), + Self::Internal => InternalError::new_err(msg), + } + } +} + +// --------------------------------------------------------------------------- +// Classification helpers +// --------------------------------------------------------------------------- + +/// Classify an arbitrary `PyErr` into the AppKit error hierarchy. +/// Returns `(status, code, message)` suitable for interceptor/route responses. +pub fn classify_pyerr(py: Python<'_>, err: &PyErr) -> (u16, &'static str, String) { + let kind = classify_kind(py, err); + (kind.status(), kind.code(), err.to_string()) +} + +fn classify_kind(py: Python<'_>, err: &PyErr) -> AppKitErrorKind { + // Order matters: check most specific (subclasses) first. + let cases: &[(fn(Python<'_>) -> Bound<'_, PyType>, AppKitErrorKind)] = &[ + (ValidationError::type_object, AppKitErrorKind::Validation), + ( + AuthenticationError::type_object, + AppKitErrorKind::Authentication, + ), + (NotFoundError::type_object, AppKitErrorKind::NotFound), + ( + PayloadTooLargeError::type_object, + AppKitErrorKind::PayloadTooLarge, + ), + (UpstreamError::type_object, AppKitErrorKind::Upstream), + (TimeoutError::type_object, AppKitErrorKind::Timeout), + (ConnectionError::type_object, AppKitErrorKind::Connection), + ( + ConfigurationError::type_object, + AppKitErrorKind::Configuration, + ), + (InternalError::type_object, AppKitErrorKind::Internal), + (AppKitError::type_object, AppKitErrorKind::Internal), + ]; + + for (type_fn, kind) in cases { + let ty = type_fn(py); + if err.matches(py, &ty).unwrap_or(false) { + return *kind; + } + } + + // Common stdlib exceptions → sensible fallbacks. + if err.is_instance_of::(py) { + return AppKitErrorKind::Validation; + } + if err.is_instance_of::(py) { + return AppKitErrorKind::Internal; + } + AppKitErrorKind::Internal +} + +// --------------------------------------------------------------------------- +// Module registration +// --------------------------------------------------------------------------- + +pub fn register(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add("AppKitError", py.get_type::())?; + m.add("ValidationError", py.get_type::())?; + m.add("AuthenticationError", py.get_type::())?; + m.add("NotFoundError", py.get_type::())?; + m.add("PayloadTooLargeError", py.get_type::())?; + m.add("UpstreamError", py.get_type::())?; + m.add("TimeoutError", py.get_type::())?; + m.add("ConnectionError", py.get_type::())?; + m.add("ConfigurationError", py.get_type::())?; + m.add("InternalError", py.get_type::())?; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_status_codes() { + assert_eq!(AppKitErrorKind::Validation.status(), 400); + assert_eq!(AppKitErrorKind::Authentication.status(), 401); + assert_eq!(AppKitErrorKind::NotFound.status(), 404); + assert_eq!(AppKitErrorKind::Timeout.status(), 408); + assert_eq!(AppKitErrorKind::PayloadTooLarge.status(), 413); + assert_eq!(AppKitErrorKind::Upstream.status(), 502); + assert_eq!(AppKitErrorKind::Connection.status(), 503); + assert_eq!(AppKitErrorKind::Configuration.status(), 500); + assert_eq!(AppKitErrorKind::Internal.status(), 500); + } + + #[test] + fn test_codes_are_stable_strings() { + // Downstream telemetry relies on these; if they change, update docs. + assert_eq!(AppKitErrorKind::Validation.code(), "VALIDATION_ERROR"); + assert_eq!(AppKitErrorKind::NotFound.code(), "NOT_FOUND"); + assert_eq!(AppKitErrorKind::Timeout.code(), "TIMEOUT"); + } + + #[test] + fn test_from_http_status() { + assert_eq!(AppKitErrorKind::from_http_status(400), AppKitErrorKind::Validation); + assert_eq!(AppKitErrorKind::from_http_status(401), AppKitErrorKind::Authentication); + assert_eq!(AppKitErrorKind::from_http_status(403), AppKitErrorKind::Authentication); + assert_eq!(AppKitErrorKind::from_http_status(404), AppKitErrorKind::NotFound); + assert_eq!(AppKitErrorKind::from_http_status(408), AppKitErrorKind::Timeout); + assert_eq!(AppKitErrorKind::from_http_status(413), AppKitErrorKind::PayloadTooLarge); + assert_eq!(AppKitErrorKind::from_http_status(500), AppKitErrorKind::Upstream); + assert_eq!(AppKitErrorKind::from_http_status(502), AppKitErrorKind::Upstream); + assert_eq!(AppKitErrorKind::from_http_status(418), AppKitErrorKind::Internal); + } + + #[test] + fn test_classify_pyerr_validation() { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { + let err = ValidationError::new_err("bad input"); + let (status, code, msg) = classify_pyerr(py, &err); + assert_eq!(status, 400); + assert_eq!(code, "VALIDATION_ERROR"); + assert!(msg.contains("bad input")); + }); + } + + #[test] + fn test_classify_pyerr_authentication() { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { + let err = AuthenticationError::new_err("bad token"); + let (status, _code, _msg) = classify_pyerr(py, &err); + assert_eq!(status, 401); + }); + } + + #[test] + fn test_classify_pyerr_value_error_falls_back_to_validation() { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { + let err = PyValueError::new_err("bad"); + let (status, _code, _msg) = classify_pyerr(py, &err); + assert_eq!(status, 400); + }); + } + + #[test] + fn test_classify_pyerr_generic_runtime_falls_back_to_internal() { + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { + let err = PyRuntimeError::new_err("unexpected"); + let (status, _code, _msg) = classify_pyerr(py, &err); + assert_eq!(status, 500); + }); + } +} diff --git a/packages/appkit-rs/src/interceptor.rs b/packages/appkit-rs/src/interceptor.rs new file mode 100644 index 00000000..00f09774 --- /dev/null +++ b/packages/appkit-rs/src/interceptor.rs @@ -0,0 +1,751 @@ +//! Async middleware chain implementing the AppKit interceptor pattern. +//! +//! Execution order (outermost to innermost): +//! **Telemetry → Timeout → Retry → Cache → user function** +//! +//! Each interceptor is a wrapping function that takes the "next" callable and +//! returns a new callable. The chain is built bottom-up so that the outermost +//! interceptor runs first. + +use crate::cache::CacheManager; +use crate::telemetry::TelemetryProvider; +use serde_json::Value as JsonValue; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; + +// --------------------------------------------------------------------------- +// Core types +// --------------------------------------------------------------------------- + +/// Error produced by the interceptor chain or user function. +#[derive(Clone, Debug)] +pub struct ExecutionError { + pub status: u16, + pub message: String, +} + +impl std::fmt::Display for ExecutionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}] {}", self.status, self.message) + } +} + +impl std::error::Error for ExecutionError {} + +/// Context passed through the interceptor chain. +pub struct InterceptorContext { + pub user_key: String, + pub cancelled: Arc, + pub metadata: HashMap, +} + +/// A callable that can be invoked multiple times (needed for retry). +/// Each invocation produces a new future that resolves to the execution result. +pub type BoxFuture = Pin + Send>>; +pub type ExecuteFn = + Arc BoxFuture> + Send + Sync>; + +// --------------------------------------------------------------------------- +// Interceptor config types +// --------------------------------------------------------------------------- + +/// Retry configuration matching TS `RetryConfig`. +#[derive(Clone, Debug)] +pub struct RetryConfig { + pub enabled: bool, + pub attempts: u32, + pub initial_delay_ms: u64, + pub max_delay_ms: u64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + enabled: false, + attempts: 3, + initial_delay_ms: 1000, + max_delay_ms: 30_000, + } + } +} + +/// Cache interceptor configuration. +#[derive(Clone, Debug)] +pub struct CacheInterceptorConfig { + pub enabled: bool, + pub cache_key: Vec, + pub ttl: Option, +} + +/// Telemetry interceptor configuration. +#[derive(Clone, Debug)] +pub struct TelemetryInterceptorConfig { + pub enabled: bool, + pub span_name: Option, + pub attributes: Vec<(String, String)>, +} + +/// Combined configuration for a single `execute()` call. +#[derive(Clone, Debug, Default)] +pub struct PluginExecuteConfig { + pub cache: Option, + pub retry: Option, + pub telemetry: Option, + pub timeout_ms: Option, +} + +// --------------------------------------------------------------------------- +// Interceptor wrappers +// --------------------------------------------------------------------------- + +/// Wrap with telemetry span (outermost interceptor). +pub fn wrap_with_telemetry( + next: ExecuteFn, + telemetry: Arc, + span_name: String, + attributes: Vec<(String, String)>, + cancelled: Arc, +) -> ExecuteFn { + Arc::new(move || { + let next = next.clone(); + let telemetry = telemetry.clone(); + let span_name = span_name.clone(); + let attributes = attributes.clone(); + let cancelled = cancelled.clone(); + + Box::pin(async move { + if cancelled.load(Ordering::Relaxed) { + return Err(ExecutionError { + status: 499, + message: "Operation aborted before execution".into(), + }); + } + + if !telemetry.traces_enabled() { + return next().await; + } + + use opentelemetry::trace::{Span, Status, Tracer}; + let tracer = telemetry.tracer(); + let mut span = tracer.start(span_name); + + for (k, v) in &attributes { + span.set_attribute(opentelemetry::KeyValue::new(k.clone(), v.clone())); + } + + let result = next().await; + + match &result { + Ok(_) => { + span.set_status(Status::Ok); + } + Err(e) => { + span.set_status(Status::error(e.message.clone())); + } + } + span.end(); + + result + }) + }) +} + +/// Wrap with timeout (second outermost). +pub fn wrap_with_timeout( + next: ExecuteFn, + timeout_ms: u64, + cancelled: Arc, +) -> ExecuteFn { + Arc::new(move || { + let next = next.clone(); + let cancelled = cancelled.clone(); + + Box::pin(async move { + let timeout = Duration::from_millis(timeout_ms); + match tokio::time::timeout(timeout, next()).await { + Ok(result) => result, + Err(_) => { + cancelled.store(true, Ordering::Relaxed); + Err(ExecutionError { + status: 408, + message: format!("Operation timed out after {timeout_ms} ms"), + }) + } + } + }) + }) +} + +/// Wrap with retry + exponential backoff with full jitter (third layer). +pub fn wrap_with_retry( + next: ExecuteFn, + config: RetryConfig, + cancelled: Arc, +) -> ExecuteFn { + Arc::new(move || { + let next = next.clone(); + let cancelled = cancelled.clone(); + let attempts = config.attempts; + let initial_delay = config.initial_delay_ms; + let max_delay = config.max_delay_ms; + + Box::pin(async move { + let mut last_error = None; + + for attempt in 1..=attempts { + match next().await { + Ok(value) => return Ok(value), + Err(e) => { + if attempt == attempts || cancelled.load(Ordering::Relaxed) { + return Err(e); + } + last_error = Some(e); + let delay = calculate_delay(attempt, initial_delay, max_delay); + tokio::time::sleep(Duration::from_millis(delay)).await; + } + } + } + + Err(last_error.unwrap_or_else(|| ExecutionError { + status: 500, + message: "Retry exhausted with no error".into(), + })) + }) + }) +} + +/// Exponential backoff with full jitter: `min(initial * 2^(attempt-1), max) * rand()`. +fn calculate_delay(attempt: u32, initial_delay_ms: u64, max_delay_ms: u64) -> u64 { + use rand::Rng; + let exp = initial_delay_ms.saturating_mul(1u64 << (attempt - 1).min(30)); + let capped = exp.min(max_delay_ms); + let jitter: f64 = rand::thread_rng().gen(); + (capped as f64 * jitter) as u64 +} + +/// Wrap with cache (innermost interceptor). +pub fn wrap_with_cache( + next: ExecuteFn, + cache: Arc, + cache_key: Vec, + user_key: String, + ttl: Option, + enabled: bool, +) -> ExecuteFn { + if !enabled || cache_key.is_empty() { + return next; + } + + // Pre-compute the deterministic cache key. + let refs: Vec<&str> = cache_key.iter().map(|s| s.as_str()).collect(); + let key = CacheManager::generate_key_from_parts(&refs, &user_key); + + Arc::new(move || { + let next = next.clone(); + let cache = cache.clone(); + let key = key.clone(); + + Box::pin(async move { + cache + .get_or_execute_internal( + key, + move || { + let fut = next(); + async move { fut.await.map_err(|e| e.message) } + }, + ttl, + ) + .await + .map_err(|msg| ExecutionError { + status: 500, + message: msg, + }) + }) + }) +} + +// --------------------------------------------------------------------------- +// Chain builder +// --------------------------------------------------------------------------- + +/// Build the full interceptor chain around `base_fn`. +/// +/// Wrapping order (innermost first): +/// Cache → Retry → Timeout → Telemetry +/// +/// Each layer is conditionally applied based on the config. +pub fn build_interceptor_chain( + base_fn: ExecuteFn, + config: &PluginExecuteConfig, + context: &InterceptorContext, + cache: Option>, + telemetry: Option>, +) -> ExecuteFn { + let mut current = base_fn; + + // 1. Innermost: Cache + if let (Some(cache), Some(ref cc)) = (cache, &config.cache) { + current = wrap_with_cache( + current, + cache, + cc.cache_key.clone(), + context.user_key.clone(), + cc.ttl, + cc.enabled, + ); + } + + // 2. Retry + if let Some(ref rc) = config.retry { + if rc.enabled { + current = wrap_with_retry(current, rc.clone(), context.cancelled.clone()); + } + } + + // 3. Timeout + if let Some(timeout_ms) = config.timeout_ms { + current = wrap_with_timeout(current, timeout_ms, context.cancelled.clone()); + } + + // 4. Outermost: Telemetry + if let Some(telemetry) = telemetry { + if let Some(ref tc) = config.telemetry { + if tc.enabled { + current = wrap_with_telemetry( + current, + telemetry, + tc.span_name + .clone() + .unwrap_or_else(|| "plugin.execute".into()), + tc.attributes.clone(), + context.cancelled.clone(), + ); + } + } + } + + current +} + +// --------------------------------------------------------------------------- +// Stream interceptors +// --------------------------------------------------------------------------- + +/// Item type for streaming through interceptors. +/// Each item is either a JSON string payload or an execution error. +pub type StreamItem = Result; + +/// Wrap a stream with a timeout on the full stream lifetime. +/// Sends a final timeout error if the deadline is exceeded. +pub fn wrap_stream_with_timeout( + mut rx: mpsc::Receiver, + timeout_ms: u64, +) -> mpsc::Receiver { + let (tx, out_rx) = mpsc::channel(32); + tokio::spawn(async move { + let deadline = tokio::time::Instant::now() + Duration::from_millis(timeout_ms); + loop { + tokio::select! { + item = rx.recv() => { + match item { + Some(val) => { + if tx.send(val).await.is_err() { break; } + } + None => break, + } + } + _ = tokio::time::sleep_until(deadline) => { + let _ = tx.send(Err(ExecutionError { + status: 408, + message: format!("Stream timed out after {timeout_ms} ms"), + })).await; + break; + } + } + } + }); + out_rx +} + +/// Wrap a stream with telemetry: a span covers the full stream lifetime. +/// The span is started immediately and ended when the stream completes or errors. +pub fn wrap_stream_with_telemetry( + mut rx: mpsc::Receiver, + telemetry: Arc, + span_name: String, + attributes: Vec<(String, String)>, +) -> mpsc::Receiver { + if !telemetry.traces_enabled() { + return rx; + } + + let (tx, out_rx) = mpsc::channel(32); + tokio::spawn(async move { + use opentelemetry::trace::{Span, Status, Tracer}; + let tracer = telemetry.tracer(); + let mut span = tracer.start(span_name); + for (k, v) in &attributes { + span.set_attribute(opentelemetry::KeyValue::new(k.clone(), v.clone())); + } + + let mut had_error = false; + while let Some(item) = rx.recv().await { + if let Err(ref e) = item { + span.set_status(Status::error(e.message.clone())); + had_error = true; + } + if tx.send(item).await.is_err() { + break; + } + } + + if !had_error { + span.set_status(Status::Ok); + } + span.end(); + }); + out_rx +} + +/// Build the stream interceptor chain. +/// +/// For streams, only telemetry and timeout are applied: +/// - **Timeout** (inner) — caps the total stream lifetime +/// - **Telemetry** (outer) — spans the full stream lifetime +/// - **Retry/Cache** are intentionally skipped (streams are non-repeatable) +pub fn build_stream_interceptor_chain( + rx: mpsc::Receiver, + config: &PluginExecuteConfig, + telemetry: Option>, +) -> mpsc::Receiver { + let mut current = rx; + + // 1. Timeout (inner — fires first) + if let Some(timeout_ms) = config.timeout_ms { + current = wrap_stream_with_timeout(current, timeout_ms); + } + + // 2. Telemetry (outer — spans the full lifetime including timeout) + if let Some(telemetry) = telemetry { + if let Some(ref tc) = config.telemetry { + if tc.enabled { + current = wrap_stream_with_telemetry( + current, + telemetry, + tc.span_name + .clone() + .unwrap_or_else(|| "plugin.execute_stream".into()), + tc.attributes.clone(), + ); + } + } + } + + current +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn ok_fn(value: JsonValue) -> ExecuteFn { + Arc::new(move || { + let v = value.clone(); + Box::pin(async move { Ok(v) }) + }) + } + + fn err_fn(status: u16, msg: &str) -> ExecuteFn { + let msg = msg.to_string(); + Arc::new(move || { + let msg = msg.clone(); + Box::pin(async move { + Err(ExecutionError { + status, + message: msg, + }) + }) + }) + } + + fn counting_fn( + counter: Arc, + value: JsonValue, + ) -> ExecuteFn { + Arc::new(move || { + let counter = counter.clone(); + let v = value.clone(); + Box::pin(async move { + counter.fetch_add(1, Ordering::SeqCst); + Ok(v) + }) + }) + } + + fn make_context(user_key: &str) -> InterceptorContext { + InterceptorContext { + user_key: user_key.to_string(), + cancelled: Arc::new(AtomicBool::new(false)), + metadata: HashMap::new(), + } + } + + #[tokio::test] + async fn test_no_interceptors_passthrough() { + let f = ok_fn(JsonValue::String("hello".into())); + let ctx = make_context("u1"); + let config = PluginExecuteConfig::default(); + let chain = build_interceptor_chain(f, &config, &ctx, None, None); + let result = chain().await.unwrap(); + assert_eq!(result, JsonValue::String("hello".into())); + } + + #[tokio::test] + async fn test_timeout_passes_when_fast() { + let f = ok_fn(JsonValue::Bool(true)); + let cancelled = Arc::new(AtomicBool::new(false)); + let wrapped = wrap_with_timeout(f, 5000, cancelled); + let result = wrapped().await.unwrap(); + assert_eq!(result, JsonValue::Bool(true)); + } + + #[tokio::test] + async fn test_timeout_fires() { + let slow_fn: ExecuteFn = Arc::new(|| { + Box::pin(async { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok(JsonValue::Null) + }) + }); + let cancelled = Arc::new(AtomicBool::new(false)); + let wrapped = wrap_with_timeout(slow_fn, 50, cancelled.clone()); + let result = wrapped().await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.status, 408); + assert!(cancelled.load(Ordering::Relaxed)); + } + + #[tokio::test] + async fn test_retry_succeeds_on_first_attempt() { + let counter = Arc::new(std::sync::atomic::AtomicU32::new(0)); + let f = counting_fn(counter.clone(), JsonValue::Number(42.into())); + let cancelled = Arc::new(AtomicBool::new(false)); + let config = RetryConfig { + enabled: true, + attempts: 3, + initial_delay_ms: 10, + max_delay_ms: 100, + }; + let wrapped = wrap_with_retry(f, config, cancelled); + let result = wrapped().await.unwrap(); + assert_eq!(result, JsonValue::Number(42.into())); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_retry_retries_on_failure() { + let attempt = Arc::new(std::sync::atomic::AtomicU32::new(0)); + let attempt_c = attempt.clone(); + let f: ExecuteFn = Arc::new(move || { + let attempt = attempt_c.clone(); + Box::pin(async move { + let n = attempt.fetch_add(1, Ordering::SeqCst); + if n < 2 { + Err(ExecutionError { + status: 500, + message: "transient".into(), + }) + } else { + Ok(JsonValue::String("recovered".into())) + } + }) + }); + let config = RetryConfig { + enabled: true, + attempts: 5, + initial_delay_ms: 1, + max_delay_ms: 10, + }; + let cancelled = Arc::new(AtomicBool::new(false)); + let wrapped = wrap_with_retry(f, config, cancelled); + let result = wrapped().await.unwrap(); + assert_eq!(result, JsonValue::String("recovered".into())); + assert_eq!(attempt.load(Ordering::SeqCst), 3); + } + + #[tokio::test] + async fn test_retry_exhausted() { + let f = err_fn(503, "down"); + let config = RetryConfig { + enabled: true, + attempts: 2, + initial_delay_ms: 1, + max_delay_ms: 10, + }; + let cancelled = Arc::new(AtomicBool::new(false)); + let wrapped = wrap_with_retry(f, config, cancelled); + let result = wrapped().await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().status, 503); + } + + #[tokio::test] + async fn test_retry_skips_when_cancelled() { + let counter = Arc::new(std::sync::atomic::AtomicU32::new(0)); + let counter_c = counter.clone(); + let f: ExecuteFn = Arc::new(move || { + let counter = counter_c.clone(); + Box::pin(async move { + counter.fetch_add(1, Ordering::SeqCst); + Err(ExecutionError { + status: 500, + message: "fail".into(), + }) + }) + }); + let cancelled = Arc::new(AtomicBool::new(true)); // pre-cancelled + let config = RetryConfig { + enabled: true, + attempts: 5, + initial_delay_ms: 1, + max_delay_ms: 10, + }; + let wrapped = wrap_with_retry(f, config, cancelled); + let result = wrapped().await; + assert!(result.is_err()); + // Should stop after first attempt because cancelled + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_full_chain_cache_then_retry() { + let cache = Arc::new(CacheManager::new_internal( + crate::cache::CacheConfig::default(), + )); + let counter = Arc::new(std::sync::atomic::AtomicU32::new(0)); + let counter_c = counter.clone(); + let f: ExecuteFn = Arc::new(move || { + let counter = counter_c.clone(); + Box::pin(async move { + counter.fetch_add(1, Ordering::SeqCst); + Ok(JsonValue::String("computed".into())) + }) + }); + + let ctx = make_context("user-1"); + let config = PluginExecuteConfig { + cache: Some(CacheInterceptorConfig { + enabled: true, + cache_key: vec!["query".into()], + ttl: Some(60), + }), + retry: Some(RetryConfig { + enabled: true, + attempts: 3, + initial_delay_ms: 1, + max_delay_ms: 10, + }), + ..Default::default() + }; + + let chain = build_interceptor_chain(f, &config, &ctx, Some(cache), None); + + // First call computes + let r1 = chain().await.unwrap(); + assert_eq!(r1, JsonValue::String("computed".into())); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + // Second call hits cache — function not called again + let r2 = chain().await.unwrap(); + assert_eq!(r2, JsonValue::String("computed".into())); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[test] + fn test_calculate_delay_bounded() { + for _ in 0..100 { + let d = calculate_delay(1, 1000, 30_000); + assert!(d <= 1000); + let d = calculate_delay(5, 1000, 30_000); + assert!(d <= 30_000); + } + } + + #[tokio::test] + async fn test_telemetry_disabled_passthrough() { + let f = ok_fn(JsonValue::Number(7.into())); + let provider = Arc::new(TelemetryProvider::new_disabled("test")); + let cancelled = Arc::new(AtomicBool::new(false)); + let wrapped = wrap_with_telemetry(f, provider, "span".into(), vec![], cancelled); + let result = wrapped().await.unwrap(); + assert_eq!(result, JsonValue::Number(7.into())); + } + + // -- Stream interceptors -- + + #[tokio::test] + async fn test_stream_timeout_fires() { + let (tx, rx) = mpsc::channel(32); + let mut wrapped = wrap_stream_with_timeout(rx, 50); + + // Keep tx alive but don't send — timeout should fire. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(10)).await; + drop(tx); + }); + + let item = wrapped.recv().await.unwrap(); + assert!(item.is_err()); + assert_eq!(item.unwrap_err().status, 408); + } + + #[tokio::test] + async fn test_stream_timeout_passes_when_fast() { + let (tx, rx) = mpsc::channel(32); + let mut wrapped = wrap_stream_with_timeout(rx, 5000); + + tx.send(Ok("hello".into())).await.unwrap(); + drop(tx); + + let item = wrapped.recv().await.unwrap(); + assert_eq!(item.unwrap(), "hello"); + assert!(wrapped.recv().await.is_none()); + } + + #[tokio::test] + async fn test_stream_telemetry_disabled_passthrough() { + let provider = Arc::new(TelemetryProvider::new_disabled("test")); + let (tx, rx) = mpsc::channel(32); + let mut wrapped = wrap_stream_with_telemetry(rx, provider, "test".into(), vec![]); + + tx.send(Ok("data".into())).await.unwrap(); + drop(tx); + + let item = wrapped.recv().await.unwrap(); + assert_eq!(item.unwrap(), "data"); + assert!(wrapped.recv().await.is_none()); + } + + #[tokio::test] + async fn test_build_stream_chain_no_config() { + let (tx, rx) = mpsc::channel(32); + let config = PluginExecuteConfig::default(); + let mut wrapped = build_stream_interceptor_chain(rx, &config, None); + + tx.send(Ok("item".into())).await.unwrap(); + drop(tx); + + let item = wrapped.recv().await.unwrap(); + assert_eq!(item.unwrap(), "item"); + } +} diff --git a/packages/appkit-rs/src/lib.rs b/packages/appkit-rs/src/lib.rs new file mode 100644 index 00000000..08d610d4 --- /dev/null +++ b/packages/appkit-rs/src/lib.rs @@ -0,0 +1,173 @@ +use pyo3::prelude::*; + +pub mod auth; +pub mod cache; +pub mod config; +pub mod connectors; +pub mod context; +pub mod errors; +pub mod interceptor; +pub mod plugin; +pub mod plugins; +pub mod server; +pub mod stream; +pub mod telemetry; + +// --------------------------------------------------------------------------- +// Top-level create_app() convenience function +// --------------------------------------------------------------------------- + +/// Create and initialize an AppKit instance in one call. +/// +/// This is the primary public API — mirrors TypeScript's `createApp(...)`. +/// +/// ```python +/// app = await create_app( +/// config=AppConfig.from_env(), +/// plugins=[my_plugin], +/// cache_config=CacheConfig(ttl=600), +/// auto_start=True, +/// ) +/// ``` +/// +/// Steps: +/// 1. Creates an `AppKit` instance +/// 2. Registers all provided plugins +/// 3. Initializes (telemetry, cache, phase-ordered plugin setup) +/// 4. Optionally starts the HTTP server (when `auto_start=True`) +/// +/// Returns the initialized `AppKit` instance. +#[pyfunction] +#[pyo3(signature = (*, config, plugins = vec![], cache_config = None, server_config = None, auto_start = true))] +fn create_app<'py>( + py: Python<'py>, + config: config::AppConfig, + plugins: Vec, + cache_config: Option, + server_config: Option, + auto_start: bool, +) -> PyResult> { + let server_cfg = server_config.unwrap_or_else(|| server::PyServerConfig { + host: config.host.clone(), + port: config.app_port, + auto_start, + static_path: None, + }); + let should_start = server_cfg.auto_start; + + // Build the AppKit, register plugins synchronously. + let mut app = plugin::PyAppKit::new(); + for p in plugins { + app.register(p)?; + } + + let app_obj = Py::new(py, app)?; + + // Call initialize (returns an awaitable coroutine). + let init_coro: PyObject = { + let mut app_mut = app_obj.borrow_mut(py); + let coro = app_mut.initialize(py, config, cache_config)?; + coro.unbind() + }; + + let app_clone = app_obj.clone_ref(py); + let server_cfg = Py::new(py, server_cfg)?; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // Await initialization. + let init_future = Python::with_gil(|py| { + pyo3_async_runtimes::tokio::into_future(init_coro.into_bound(py)) + })?; + init_future.await?; + + // Start server if auto_start is enabled. + if should_start { + let server_future = Python::with_gil(|py| -> PyResult<_> { + let cfg = server_cfg.extract::(py)?; + let app = app_clone.borrow(py); + let coro = app.start_server(py, cfg)?; + pyo3_async_runtimes::tokio::into_future(coro) + })?; + server_future.await?; + } + + Ok(app_clone) + }) +} + +/// Python module entry point for `appkit`. +/// +/// Exposes config, auth, cache, telemetry, plugin, server, streaming, +/// connector, and context types to Python. Async methods are bridged via +/// pyo3-async-runtimes + tokio so they can be awaited from Python's asyncio +/// event loop. +#[pymodule] +fn appkit(m: &Bound<'_, PyModule>) -> PyResult<()> { + // Initialize the Tokio runtime for pyo3-async-runtimes so that + // future_into_py-backed async methods work when called from Python. + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.enable_all(); + pyo3_async_runtimes::tokio::init(builder); + + // Config + m.add_class::()?; + + // Auth / context + m.add_class::()?; + m.add_class::()?; + + // Cache + m.add_class::()?; + m.add_class::()?; + + // Plugin system + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Server / routing + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Connectors + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Top-level create_app function + m.add_function(wrap_pyfunction!(create_app, m)?)?; + + // Context helpers + m.add_function(wrap_pyfunction!(context::run_in_user_context, m)?)?; + m.add_function(wrap_pyfunction!(context::as_user, m)?)?; + m.add_function(wrap_pyfunction!(context::get_current_user, m)?)?; + m.add_function(wrap_pyfunction!(context::is_in_user_context, m)?)?; + + // Error hierarchy — Python exception classes (AppKitError + subclasses). + errors::register(m.py(), m)?; + + // Initialize the contextvars.ContextVar on the module. + context::create_context_var(m.py(), m)?; + + Ok(()) +} diff --git a/packages/appkit-rs/src/plugin.rs b/packages/appkit-rs/src/plugin.rs new file mode 100644 index 00000000..e4e11f6f --- /dev/null +++ b/packages/appkit-rs/src/plugin.rs @@ -0,0 +1,1332 @@ +//! Plugin system — trait, manifest, phase ordering, execution runtime, and +//! Python base class. +//! +//! Mirrors the TypeScript plugin architecture: +//! - Three-phase init: Core → Normal → Deferred +//! - Plugin trait with `setup()`, `exports()`, `client_config()` +//! - `PluginRuntime` providing `execute()` / `execute_stream()` through the +//! interceptor chain +//! - `PyPlugin` subclassable Python base class +//! - `PyAppKit` orchestrator for plugin registration and initialization + +use pyo3::prelude::*; +use serde_json::Value as JsonValue; +use std::collections::HashMap; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use crate::cache::{CacheConfig, CacheManager}; +use crate::errors::classify_pyerr; +use crate::interceptor::{ + build_interceptor_chain, build_stream_interceptor_chain, ExecuteFn, ExecutionError, + InterceptorContext, PluginExecuteConfig, StreamItem, +}; +use tokio::sync::mpsc; +use crate::telemetry::{TelemetryManager, TelemetryOptions, TelemetryProvider}; + +// --------------------------------------------------------------------------- +// Plugin types +// --------------------------------------------------------------------------- + +/// Plugin initialization phase ordering. +/// Core plugins initialize first, then Normal, then Deferred. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub enum PluginPhase { + Core, + #[default] + Normal, + Deferred, +} + +impl std::str::FromStr for PluginPhase { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "core" => Ok(Self::Core), + "normal" => Ok(Self::Normal), + "deferred" => Ok(Self::Deferred), + _ => Err(()), + } + } +} + +impl PluginPhase { + + pub fn as_str(&self) -> &'static str { + match self { + Self::Core => "core", + Self::Normal => "normal", + Self::Deferred => "deferred", + } + } + + fn order(&self) -> u8 { + match self { + Self::Core => 0, + Self::Normal => 1, + Self::Deferred => 2, + } + } +} + +/// Resource requirement declared in a plugin manifest. +#[derive(Clone, Debug)] +pub struct ResourceRequirement { + pub resource_type: String, + pub required: bool, +} + +/// Plugin manifest — metadata and resource declarations. +#[derive(Clone, Debug)] +pub struct PluginManifest { + pub name: String, + pub display_name: Option, + pub description: Option, + pub required_resources: Vec, + pub optional_resources: Vec, +} + +// --------------------------------------------------------------------------- +// ExecutionResult +// --------------------------------------------------------------------------- + +/// Discriminated result type matching TypeScript's `ExecutionResult`. +/// Plugin execute() never throws; it returns `Ok` or `Err` variant. +#[derive(Clone, Debug)] +pub enum ExecutionResult { + Ok { data: JsonValue }, + Err { status: u16, message: String }, +} + +impl ExecutionResult { + pub fn is_ok(&self) -> bool { + matches!(self, Self::Ok { .. }) + } + + pub fn data(&self) -> Option<&JsonValue> { + match self { + Self::Ok { data } => Some(data), + _ => None, + } + } +} + +// --------------------------------------------------------------------------- +// Plugin trait (Rust-internal) +// --------------------------------------------------------------------------- + +/// Trait for Rust-implemented plugins. +pub trait Plugin: Send + Sync { + fn name(&self) -> &str; + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest; + + /// Called during AppKit initialization after plugin construction. + fn setup( + &self, + ) -> std::pin::Pin> + Send + '_>> { + Box::pin(async { Ok(()) }) + } + + /// Return the public API surface of this plugin. + fn exports(&self) -> HashMap { + HashMap::new() + } + + /// Return startup config that is sent to the client. + fn client_config(&self) -> HashMap { + HashMap::new() + } +} + +// --------------------------------------------------------------------------- +// PluginRuntime — shared execution infrastructure +// --------------------------------------------------------------------------- + +/// Shared execution infrastructure available to every plugin. +/// Provides `execute()` that runs user functions through the interceptor chain. +pub struct PluginRuntime { + pub name: String, + pub cache: Arc, + pub telemetry: Arc, +} + +impl PluginRuntime { + pub fn new( + name: &str, + cache: Arc, + telemetry_options: Option, + ) -> Self { + Self { + name: name.to_string(), + cache, + telemetry: Arc::new(TelemetryManager::get_provider(name, telemetry_options)), + } + } + + /// Execute a function through the full interceptor chain. + /// + /// Never panics — all errors are captured into `ExecutionResult::Err`. + pub async fn execute( + &self, + f: F, + config: PluginExecuteConfig, + user_key: &str, + ) -> ExecutionResult + where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + let context = InterceptorContext { + user_key: user_key.to_string(), + cancelled: Arc::new(AtomicBool::new(false)), + metadata: HashMap::new(), + }; + + let base_fn: ExecuteFn = Arc::new(move || Box::pin(f())); + + let chain = build_interceptor_chain( + base_fn, + &config, + &context, + Some(self.cache.clone()), + Some(self.telemetry.clone()), + ); + + match chain().await { + Ok(data) => ExecutionResult::Ok { data }, + Err(e) => ExecutionResult::Err { + status: e.status, + message: e.message, + }, + } + } + + /// Apply stream interceptors to an item receiver. + /// + /// For streams, only telemetry and timeout are applied: + /// - Telemetry spans the full stream lifetime + /// - Timeout applies to the total stream duration + /// - Retry and cache are intentionally skipped (streams are non-repeatable) + pub fn wrap_stream( + &self, + input: mpsc::Receiver, + config: &PluginExecuteConfig, + ) -> mpsc::Receiver { + build_stream_interceptor_chain(input, config, Some(self.telemetry.clone())) + } +} + +// =========================================================================== +// Python bindings +// =========================================================================== + +// --------------------------------------------------------------------------- +// PyPluginPhase — class attributes for phase constants +// --------------------------------------------------------------------------- + +/// Phase ordering constants for Python plugins. +#[pyclass(frozen, name = "PluginPhase", module = "appkit")] +pub struct PyPluginPhase; + +#[pymethods] +impl PyPluginPhase { + #[classattr] + const CORE: &'static str = "core"; + #[classattr] + const NORMAL: &'static str = "normal"; + #[classattr] + const DEFERRED: &'static str = "deferred"; +} + +// --------------------------------------------------------------------------- +// PyExecutionResult +// --------------------------------------------------------------------------- + +/// Python-facing execution result (frozen, immutable). +#[pyclass(frozen, name = "ExecutionResult", module = "appkit")] +#[derive(Clone)] +pub struct PyExecutionResult { + #[pyo3(get)] + pub ok: bool, + /// JSON string of the result data (only set when ok=True). + #[pyo3(get)] + pub data: Option, + /// HTTP status code (only set when ok=False). + #[pyo3(get)] + pub status: Option, + /// Error message (only set when ok=False). + #[pyo3(get)] + pub message: Option, +} + +#[pymethods] +impl PyExecutionResult { + fn __repr__(&self) -> String { + if self.ok { + format!("ExecutionResult(ok=True, data={:?})", self.data) + } else { + format!( + "ExecutionResult(ok=False, status={:?}, message={:?})", + self.status, self.message + ) + } + } + + fn __bool__(&self) -> bool { + self.ok + } + + fn __eq__(&self, other: &Self) -> bool { + self.ok == other.ok + && self.data == other.data + && self.status == other.status + && self.message == other.message + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.ok.hash(&mut hasher); + self.data.hash(&mut hasher); + self.status.hash(&mut hasher); + hasher.finish() + } +} + +impl From for PyExecutionResult { + fn from(result: ExecutionResult) -> Self { + match result { + ExecutionResult::Ok { data } => PyExecutionResult { + ok: true, + data: Some(data.to_string()), + status: None, + message: None, + }, + ExecutionResult::Err { status, message } => PyExecutionResult { + ok: false, + data: None, + status: Some(status), + message: Some(message), + }, + } + } +} + +// --------------------------------------------------------------------------- +// PyPluginManifest +// --------------------------------------------------------------------------- + +/// Plugin manifest exposed to Python. +#[pyclass(frozen, name = "PluginManifest", module = "appkit")] +#[derive(Clone)] +pub struct PyPluginManifest { + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub display_name: Option, + #[pyo3(get)] + pub description: Option, +} + +#[pymethods] +impl PyPluginManifest { + #[new] + #[pyo3(signature = (name, *, display_name = None, description = None))] + fn new(name: String, display_name: Option, description: Option) -> Self { + Self { + name, + display_name, + description, + } + } + + fn __repr__(&self) -> String { + format!("PluginManifest(name={:?})", self.name) + } + + fn __eq__(&self, other: &Self) -> bool { + self.name == other.name + && self.display_name == other.display_name + && self.description == other.description + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.name.hash(&mut hasher); + hasher.finish() + } +} + +// --------------------------------------------------------------------------- +// PyPlugin — subclassable base class +// --------------------------------------------------------------------------- + +/// Base class for Python plugins. Subclass this to create custom plugins. +/// +/// ```python +/// from appkit import Plugin, PluginManifest +/// +/// class MyPlugin(Plugin): +/// def __init__(self): +/// super().__init__("my-plugin", manifest=PluginManifest("my-plugin")) +/// +/// async def setup(self): +/// pass # initialization logic +/// ``` +#[pyclass(subclass, name = "Plugin", module = "appkit")] +pub struct PyPlugin { + #[pyo3(get)] + name: String, + #[pyo3(get)] + phase: String, + #[pyo3(get)] + manifest: PyPluginManifest, + #[pyo3(get)] + is_ready: bool, + /// Rust-internal runtime — set by AppKit during initialization. + runtime: Option, +} + +#[pymethods] +impl PyPlugin { + /// Construct a `PyPlugin`. Accepts any `*args`/`**kwargs` so that + /// Python subclasses with arbitrary constructor signatures + /// (e.g. `AnalyticsPlugin(config)`) can inherit `__new__` without + /// type errors. Best-effort extracts `name`/`phase` from positional + /// args and `manifest` from kwargs so direct construction + /// (`appkit.Plugin("name", manifest=m)`) still fully initializes + /// fields without requiring a separate `__init__` call. + #[new] + #[pyo3(signature = (*args, **kwargs))] + fn new( + args: &Bound<'_, pyo3::types::PyTuple>, + kwargs: Option<&Bound<'_, pyo3::types::PyDict>>, + ) -> PyResult { + let name = args + .get_item(0) + .ok() + .and_then(|a| a.extract::().ok()) + .unwrap_or_default(); + let phase = kwargs + .and_then(|k| k.get_item("phase").ok().flatten()) + .and_then(|v| v.extract::().ok()) + .or_else(|| { + args.get_item(1) + .ok() + .and_then(|a| a.extract::().ok()) + }) + .unwrap_or_else(|| "normal".to_string()); + let manifest = kwargs + .and_then(|k| k.get_item("manifest").ok().flatten()) + .and_then(|m| m.extract::().ok()) + .unwrap_or_else(|| PyPluginManifest { + name: name.clone(), + display_name: None, + description: None, + }); + if !name.is_empty() && phase.parse::().is_err() { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Invalid phase: {phase}. Must be 'core', 'normal', or 'deferred'" + ))); + } + Ok(Self { + name, + phase, + manifest, + is_ready: false, + runtime: None, + }) + } + + /// Re-initialize fields from Python `super().__init__(...)`. + /// This enables the standard Python subclassing pattern: + /// ```python + /// class MyPlugin(Plugin): + /// def __init__(self): + /// super().__init__("my-plugin", manifest=PluginManifest("my-plugin")) + /// ``` + #[pyo3(signature = (name, *, phase = "normal".to_string(), manifest))] + fn __init__(&mut self, name: String, phase: String, manifest: PyPluginManifest) -> PyResult<()> { + if phase.parse::().is_err() { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Invalid phase: {phase}. Must be 'core', 'normal', or 'deferred'" + ))); + } + self.name = name; + self.phase = phase; + self.manifest = manifest; + Ok(()) + } + + /// Called by AppKit during initialization. Override in subclass. + fn setup<'py>(&self, py: Python<'py>) -> PyResult> { + // Default: return a coroutine that resolves immediately. + pyo3_async_runtimes::tokio::future_into_py(py, async { Ok(()) }) + } + + /// Return export dict for this plugin. Override in subclass. + fn exports(&self) -> HashMap { + HashMap::new() + } + + /// Return client config dict. Override in subclass. + fn client_config(&self) -> HashMap { + HashMap::new() + } + + /// Override in subclass to register HTTP routes with the server. + /// + /// ```python + /// def inject_routes(self, router): + /// router.get("/items", self.get_items) + /// router.post("/items", self.create_item) + /// ``` + fn inject_routes(&self, _router: PyObject) -> PyResult<()> { + Ok(()) + } + + /// Execute an async Python callable through the interceptor chain. + /// + /// ```python + /// result = await plugin.execute(my_async_fn, user_key="user-1") + /// ``` + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (func, *, user_key = "".to_string(), timeout_ms = None, retry_attempts = None, cache_key = None, cache_ttl = None))] + fn execute<'py>( + &self, + py: Python<'py>, + func: PyObject, + user_key: String, + timeout_ms: Option, + retry_attempts: Option, + cache_key: Option>, + cache_ttl: Option, + ) -> PyResult> { + let runtime = self.runtime.as_ref().ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err( + "Plugin not initialized — register with AppKit first", + ) + })?; + + let cache = runtime.cache.clone(); + let telemetry = runtime.telemetry.clone(); + let name = runtime.name.clone(); + + // Build execution config from keyword arguments. + let config = PluginExecuteConfig { + timeout_ms, + retry: retry_attempts.map(|attempts| crate::interceptor::RetryConfig { + enabled: true, + attempts, + ..Default::default() + }), + cache: cache_key.map(|keys| crate::interceptor::CacheInterceptorConfig { + enabled: true, + cache_key: keys, + ttl: cache_ttl, + }), + telemetry: Some(crate::interceptor::TelemetryInterceptorConfig { + enabled: true, + span_name: Some(format!("{name}.execute")), + attributes: vec![], + }), + }; + + let context = InterceptorContext { + user_key: user_key.clone(), + cancelled: Arc::new(AtomicBool::new(false)), + metadata: HashMap::new(), + }; + + // Wrap the Python callable as an ExecuteFn. Any Python exception is + // classified through the AppKit error hierarchy so `ExecutionResult` + // carries a meaningful HTTP status instead of always 500. + let py_fn = Arc::new(func); + let base_fn: ExecuteFn = Arc::new(move || { + let py_fn = py_fn.clone(); + Box::pin(async move { + let future = Python::with_gil(|py| { + let coroutine = py_fn.call0(py).map_err(|e| { + let (status, _code, msg) = classify_pyerr(py, &e); + ExecutionError { status, message: msg } + })?; + pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py)) + .map_err(|e| { + let (status, _code, msg) = classify_pyerr(py, &e); + ExecutionError { status, message: msg } + }) + })?; + + let result = future.await.map_err(|e| { + Python::with_gil(|py| { + let (status, _code, msg) = classify_pyerr(py, &e); + ExecutionError { status, message: msg } + }) + })?; + + let json_str: String = Python::with_gil(|py| { + result.extract::(py).map_err(|e| ExecutionError { + status: 500, + message: format!("Execute callable must return a JSON string: {e}"), + }) + })?; + + serde_json::from_str(&json_str).map_err(|e| ExecutionError { + status: 500, + message: format!("Invalid JSON from execute callable: {e}"), + }) + }) + }); + + let chain = build_interceptor_chain( + base_fn, + &config, + &context, + Some(cache), + Some(telemetry), + ); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let result = match chain().await { + Ok(data) => ExecutionResult::Ok { data }, + Err(e) => ExecutionResult::Err { + status: e.status, + message: e.message, + }, + }; + Ok(PyExecutionResult::from(result)) + }) + } + + /// Execute a streaming function through the interceptor chain. + /// + /// The callable should return a Python async generator that yields + /// JSON strings. Returns a `StreamIterator` for async iteration. + /// + /// Retry and cache are not supported for streams. + /// + /// ```python + /// stream = await plugin.execute_stream(my_async_gen_fn) + /// async for item in stream: + /// data = json.loads(item) + /// ``` + #[pyo3(signature = (func, *, user_key = "".to_string(), timeout_ms = None))] + fn execute_stream<'py>( + &self, + py: Python<'py>, + func: PyObject, + user_key: String, + timeout_ms: Option, + ) -> PyResult> { + let _ = user_key; + let runtime = self.runtime.as_ref().ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err( + "Plugin not initialized — register with AppKit first", + ) + })?; + + let telemetry = runtime.telemetry.clone(); + let name = runtime.name.clone(); + + // Call the Python callable to get the async generator (requires GIL). + let py_gen = func.call0(py).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Failed to call stream function: {e}" + )) + })?; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let config = PluginExecuteConfig { + timeout_ms, + retry: None, + cache: None, + telemetry: Some(crate::interceptor::TelemetryInterceptorConfig { + enabled: true, + span_name: Some(format!("{name}.execute_stream")), + attributes: vec![], + }), + }; + + let (tx, rx) = mpsc::channel::(32); + + // Spawn task to drive the Python async generator. + spawn_stream_generator(py_gen, tx); + + // Apply stream interceptors (telemetry + timeout). + let output_rx = build_stream_interceptor_chain(rx, &config, Some(telemetry)); + + Ok(PyStreamIterator::new(output_rx)) + }) + } + + fn __repr__(&self) -> String { + format!( + "Plugin(name={:?}, phase={:?}, ready={})", + self.name, self.phase, self.is_ready + ) + } +} + +impl PyPlugin { + /// Inject the shared runtime (called by PyAppKit during initialization). + fn inject_runtime(&mut self, runtime: PluginRuntime) { + self.runtime = Some(runtime); + self.is_ready = true; + } +} + +// --------------------------------------------------------------------------- +// PyStreamIterator — async iterator for streaming execution results +// --------------------------------------------------------------------------- + +/// Python async iterator backed by a tokio mpsc channel. +/// +/// Yields JSON string items from a streaming execution. Used by +/// `Plugin.execute_stream()` and `ServingConnector.stream()`. +/// +/// ```python +/// stream = await plugin.execute_stream(my_async_gen_fn) +/// async for item in stream: +/// data = json.loads(item) +/// ``` +#[pyclass(name = "StreamIterator", module = "appkit")] +pub struct PyStreamIterator { + rx: Arc>>, +} + +impl PyStreamIterator { + pub fn new(rx: mpsc::Receiver) -> Self { + Self { + rx: Arc::new(tokio::sync::Mutex::new(rx)), + } + } +} + +#[pymethods] +impl PyStreamIterator { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__<'py>(&self, py: Python<'py>) -> PyResult> { + let rx = self.rx.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = rx.lock().await; + match guard.recv().await { + Some(Ok(data)) => Ok(data), + Some(Err(e)) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "[{}] {}", + e.status, e.message + ))), + None => Err(pyo3::exceptions::PyStopAsyncIteration::new_err(())), + } + }) + } + + fn __repr__(&self) -> String { + "StreamIterator(...)".to_string() + } +} + +// --------------------------------------------------------------------------- +// Stream generator driver +// --------------------------------------------------------------------------- + +/// Spawn a task that drives a Python async generator, forwarding items +/// as `StreamItem` values to the given sender. +fn spawn_stream_generator(py_gen: PyObject, tx: mpsc::Sender) { + tokio::spawn(async move { + loop { + // Step 1: acquire GIL, call __anext__, get a future, release GIL. + let future_result: Result, ExecutionError> = Python::with_gil(|py| { + match py_gen.call_method0(py, "__anext__") { + Ok(coroutine) => pyo3_async_runtimes::tokio::into_future( + coroutine.into_bound(py), + ) + .map(Some) + .map_err(|e| ExecutionError { + status: 500, + message: e.to_string(), + }), + Err(e) => { + if e.is_instance_of::(py) { + Ok(None) + } else { + Err(ExecutionError { + status: 500, + message: e.to_string(), + }) + } + } + } + }); + + match future_result { + Ok(None) => break, // Generator exhausted. + Err(e) => { + let _ = tx.send(Err(e)).await; + break; + } + Ok(Some(future)) => { + match future.await { + Ok(value) => { + let data = Python::with_gil(|py| { + value + .extract::(py) + .unwrap_or_else(|_| "null".to_string()) + }); + if tx.send(Ok(data)).await.is_err() { + break; // Receiver dropped. + } + } + Err(e) => { + let is_stop = Python::with_gil(|py| { + e.is_instance_of::(py) + }); + if is_stop { + break; + } + let _ = tx + .send(Err(ExecutionError { + status: 500, + message: e.to_string(), + })) + .await; + break; + } + } + } + } + } + }); +} + +// --------------------------------------------------------------------------- +// PyAppKit — plugin orchestrator +// --------------------------------------------------------------------------- + +/// AppKit orchestrator — registers plugins, manages phase-ordered +/// initialization, and provides access to registered plugins. +/// +/// ```python +/// from appkit import AppKit, AppConfig +/// +/// app = AppKit() +/// app.register(my_plugin) +/// await app.initialize(config) +/// ``` +#[pyclass(name = "AppKit", module = "appkit")] +pub struct PyAppKit { + plugins: Vec, + initialized: bool, + /// Shutdown sender for the running server (set by `start_server`). + shutdown_tx: Arc>>>, +} + +impl Default for PyAppKit { + fn default() -> Self { + Self { + plugins: Vec::new(), + initialized: false, + shutdown_tx: Arc::new(std::sync::Mutex::new(None)), + } + } +} + +#[pymethods] +impl PyAppKit { + #[new] + pub fn new() -> Self { + Self::default() + } + + /// Register a plugin instance. Must be called before `initialize()`. + pub fn register(&mut self, plugin: PyObject) -> PyResult<()> { + if self.initialized { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "Cannot register plugins after initialization", + )); + } + self.plugins.push(plugin); + Ok(()) + } + + /// Initialize all registered plugins in phase order (core → normal → deferred). + /// + /// This: + /// 1. Initializes telemetry from the AppConfig + /// 2. Creates the shared CacheManager + /// 3. Injects PluginRuntime into each plugin + /// 4. Calls `setup()` on each plugin in phase order + #[pyo3(signature = (config, *, cache_config = None))] + pub fn initialize<'py>( + &mut self, + py: Python<'py>, + config: crate::config::AppConfig, + cache_config: Option, + ) -> PyResult> { + if self.initialized { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "AppKit already initialized", + )); + } + + // Initialize telemetry. + let telem_config = crate::telemetry::TelemetryConfig::from_app_config(&config); + TelemetryManager::initialize(&telem_config); + + // Create shared cache. + let cache = Arc::new(CacheManager::new_internal( + cache_config.unwrap_or_default(), + )); + + // Sort plugins by phase and inject runtime. + let mut indexed: Vec<(u8, usize)> = Vec::new(); + for (i, plugin_obj) in self.plugins.iter().enumerate() { + let phase_str: String = plugin_obj + .getattr(py, "phase") + .and_then(|a| a.extract(py)) + .unwrap_or_else(|_| "normal".to_string()); + let phase = phase_str.parse::().unwrap_or_default(); + indexed.push((phase.order(), i)); + } + indexed.sort_by_key(|(order, _)| *order); + + // Inject runtime into each plugin. Python subclasses of `Plugin` + // (PyPlugin) inherit the parent's storage, so `PyRefMut<'_, PyPlugin>` + // is obtainable from subclass instances. Fail loudly if extraction + // fails — otherwise the plugin would silently stay uninitialized and + // later `execute()` / `execute_stream()` calls would error with a + // cryptic "Plugin not initialized" message. + for &(_, i) in &indexed { + let plugin_obj = &self.plugins[i]; + let name: String = plugin_obj + .getattr(py, "name") + .and_then(|a| a.extract(py)) + .unwrap_or_else(|_| format!("plugin-{i}")); + let runtime = PluginRuntime::new(&name, cache.clone(), None); + + let mut py_plugin = plugin_obj + .extract::>(py) + .map_err(|e| { + pyo3::exceptions::PyTypeError::new_err(format!( + "Plugin '{name}' is not a subclass of appkit.Plugin (cannot inject runtime): {e}" + )) + })?; + py_plugin.inject_runtime(runtime); + } + + // Call setup() on each plugin in phase order. + let ordered_plugins: Vec = + indexed.iter().map(|&(_, i)| self.plugins[i].clone_ref(py)).collect(); + + self.initialized = true; + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + for plugin_obj in ordered_plugins { + let future = Python::with_gil(|py| { + let setup_result = plugin_obj.call_method0(py, "setup")?; + pyo3_async_runtimes::tokio::into_future(setup_result.into_bound(py)) + })?; + future.await?; + } + Ok(()) + }) + } + + /// Get a registered plugin by name. + fn get_plugin(&self, py: Python<'_>, name: &str) -> PyResult> { + for plugin_obj in &self.plugins { + let plugin_name: String = plugin_obj + .getattr(py, "name") + .and_then(|a| a.extract(py)) + .unwrap_or_default(); + if plugin_name == name { + return Ok(Some(plugin_obj.clone_ref(py))); + } + } + Ok(None) + } + + /// List all registered plugin names. + fn plugin_names(&self, py: Python<'_>) -> Vec { + self.plugins + .iter() + .filter_map(|obj| { + obj.getattr(py, "name") + .and_then(|a| a.extract(py)) + .ok() + }) + .collect() + } + + /// Start the HTTP server. Collects routes from all registered plugins via + /// `inject_routes()`, aggregates client configs, and starts an axum server. + /// + /// ```python + /// await app.start_server(ServerConfig(host="0.0.0.0", port=8000)) + /// ``` + #[pyo3(signature = (server_config))] + pub fn start_server<'py>( + &self, + py: Python<'py>, + server_config: crate::server::PyServerConfig, + ) -> PyResult> { + if !self.initialized { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "AppKit must be initialized before starting the server", + )); + } + { + let guard = self.shutdown_tx.lock().unwrap(); + if guard.is_some() { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "Server already started", + )); + } + } + + // Collect routes and client configs from all plugins. + let mut all_routes: Vec<(String, Vec)> = Vec::new(); + let mut plugin_configs: HashMap = HashMap::new(); + + for plugin_obj in &self.plugins { + let name: String = plugin_obj + .getattr(py, "name") + .and_then(|a| a.extract(py)) + .unwrap_or_default(); + + // Call inject_routes if the plugin defines it. + let router = crate::server::PyRouter::new(&name); + let router_obj = Py::new(py, router)?; + let _ = plugin_obj.call_method1(py, "inject_routes", (router_obj.clone_ref(py),)); + let routes = router_obj.borrow(py).take_routes(); + if !routes.is_empty() { + all_routes.push((name.clone(), routes)); + } + + // Collect client_config (expects JSON string or dict). + if let Ok(config_result) = plugin_obj.call_method0(py, "client_config") { + if let Ok(config_dict) = config_result.extract::>(py) { + if !config_dict.is_empty() { + let json_val = serde_json::to_value(&config_dict).unwrap_or_default(); + plugin_configs.insert(name.clone(), json_val); + } + } + } + } + + let task_locals = pyo3_async_runtimes::tokio::get_current_locals(py)?; + let stream_manager = crate::stream::StreamManager::new(crate::stream::StreamConfig::default()); + let static_path = crate::server::detect_static_path(server_config.static_path.as_deref()); + let router = crate::server::build_router( + all_routes, + plugin_configs, + stream_manager.clone(), + static_path, + task_locals, + ); + let host = server_config.host.clone(); + let port = server_config.port; + let shutdown_slot = self.shutdown_tx.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let handle = + crate::server::start_server(router, &host, port, stream_manager) + .await + .map_err(pyo3::exceptions::PyRuntimeError::new_err)?; + *shutdown_slot.lock().unwrap() = Some(handle.shutdown_tx); + Ok(()) + }) + } + + /// Trigger graceful server shutdown. + fn shutdown(&self) -> PyResult<()> { + let guard = self.shutdown_tx.lock().unwrap(); + match guard.as_ref() { + Some(tx) => { + let _ = tx.send(true); + Ok(()) + } + None => Err(pyo3::exceptions::PyRuntimeError::new_err( + "Server is not running", + )), + } + } + + fn __repr__(&self) -> String { + format!( + "AppKit(plugins={}, initialized={})", + self.plugins.len(), + self.initialized + ) + } + + fn __len__(&self) -> usize { + self.plugins.len() + } + + fn __bool__(&self) -> bool { + self.initialized + } + + fn __contains__(&self, py: Python<'_>, name: &str) -> bool { + self.plugins.iter().any(|obj| { + obj.getattr(py, "name") + .and_then(|a| a.extract::(py)) + .map(|n| n == name) + .unwrap_or(false) + }) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::Ordering; + + fn test_manifest(name: &str) -> PluginManifest { + PluginManifest { + name: name.to_string(), + display_name: None, + description: None, + required_resources: vec![], + optional_resources: vec![], + } + } + + // -- PluginPhase -- + + #[test] + fn test_phase_ordering() { + assert!(PluginPhase::Core.order() < PluginPhase::Normal.order()); + assert!(PluginPhase::Normal.order() < PluginPhase::Deferred.order()); + } + + #[test] + fn test_phase_from_str() { + assert_eq!("core".parse(), Ok(PluginPhase::Core)); + assert_eq!("normal".parse(), Ok(PluginPhase::Normal)); + assert_eq!("deferred".parse(), Ok(PluginPhase::Deferred)); + assert!("invalid".parse::().is_err()); + } + + #[test] + fn test_phase_roundtrip() { + for phase in [PluginPhase::Core, PluginPhase::Normal, PluginPhase::Deferred] { + assert_eq!(phase.as_str().parse(), Ok(phase)); + } + } + + // -- ExecutionResult -- + + #[test] + fn test_execution_result_ok() { + let r = ExecutionResult::Ok { + data: JsonValue::String("hello".into()), + }; + assert!(r.is_ok()); + assert_eq!(r.data(), Some(&JsonValue::String("hello".into()))); + } + + #[test] + fn test_execution_result_err() { + let r = ExecutionResult::Err { + status: 404, + message: "not found".into(), + }; + assert!(!r.is_ok()); + assert!(r.data().is_none()); + } + + // -- PluginRuntime -- + + #[tokio::test] + async fn test_runtime_execute_success() { + let cache = Arc::new(CacheManager::new_internal(CacheConfig::default())); + let runtime = PluginRuntime { + name: "test".into(), + cache, + telemetry: Arc::new(TelemetryProvider::new_disabled("test")), + }; + + let result = runtime + .execute( + || async { Ok(JsonValue::Number(42.into())) }, + PluginExecuteConfig::default(), + "user-1", + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.data(), Some(&JsonValue::Number(42.into()))); + } + + #[tokio::test] + async fn test_runtime_execute_error() { + let cache = Arc::new(CacheManager::new_internal(CacheConfig::default())); + let runtime = PluginRuntime { + name: "test".into(), + cache, + telemetry: Arc::new(TelemetryProvider::new_disabled("test")), + }; + + let result = runtime + .execute( + || async { + Err(ExecutionError { + status: 503, + message: "service unavailable".into(), + }) + }, + PluginExecuteConfig::default(), + "user-1", + ) + .await; + + assert!(!result.is_ok()); + match result { + ExecutionResult::Err { status, message } => { + assert_eq!(status, 503); + assert_eq!(message, "service unavailable"); + } + _ => panic!("expected Err"), + } + } + + #[tokio::test] + async fn test_runtime_execute_with_timeout() { + let cache = Arc::new(CacheManager::new_internal(CacheConfig::default())); + let runtime = PluginRuntime { + name: "test".into(), + cache, + telemetry: Arc::new(TelemetryProvider::new_disabled("test")), + }; + + let result = runtime + .execute( + || async { + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + Ok(JsonValue::Null) + }, + PluginExecuteConfig { + timeout_ms: Some(50), + ..Default::default() + }, + "user-1", + ) + .await; + + assert!(!result.is_ok()); + match result { + ExecutionResult::Err { status, .. } => assert_eq!(status, 408), + _ => panic!("expected timeout error"), + } + } + + #[tokio::test] + async fn test_runtime_execute_with_cache() { + let cache = Arc::new(CacheManager::new_internal(CacheConfig::default())); + let runtime = PluginRuntime { + name: "test".into(), + cache, + telemetry: Arc::new(TelemetryProvider::new_disabled("test")), + }; + + let counter = Arc::new(std::sync::atomic::AtomicU32::new(0)); + let counter_c = counter.clone(); + let f = move || { + let counter = counter_c.clone(); + async move { + counter.fetch_add(1, Ordering::SeqCst); + Ok(JsonValue::String("value".into())) + } + }; + + let config = PluginExecuteConfig { + cache: Some(crate::interceptor::CacheInterceptorConfig { + enabled: true, + cache_key: vec!["test-key".into()], + ttl: Some(60), + }), + ..Default::default() + }; + + // First call computes + let r1 = runtime.execute(f.clone(), config.clone(), "user-1").await; + assert!(r1.is_ok()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + + // Second call hits cache + let r2 = runtime.execute(f, config, "user-1").await; + assert!(r2.is_ok()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + // -- Plugin trait -- + + struct TestPlugin { + manifest: PluginManifest, + } + + impl Plugin for TestPlugin { + fn name(&self) -> &str { + &self.manifest.name + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + } + + #[test] + fn test_plugin_trait_defaults() { + let p = TestPlugin { + manifest: test_manifest("my-plugin"), + }; + assert_eq!(p.name(), "my-plugin"); + assert_eq!(p.phase(), PluginPhase::Normal); + assert!(p.exports().is_empty()); + assert!(p.client_config().is_empty()); + } + + // -- PyExecutionResult -- + + #[test] + fn test_py_execution_result_from_ok() { + let r = ExecutionResult::Ok { + data: JsonValue::Bool(true), + }; + let py_r = PyExecutionResult::from(r); + assert!(py_r.ok); + assert_eq!(py_r.data, Some("true".to_string())); + assert!(py_r.status.is_none()); + } + + #[test] + fn test_py_execution_result_from_err() { + let r = ExecutionResult::Err { + status: 400, + message: "bad request".into(), + }; + let py_r = PyExecutionResult::from(r); + assert!(!py_r.ok); + assert!(py_r.data.is_none()); + assert_eq!(py_r.status, Some(400)); + assert_eq!(py_r.message, Some("bad request".into())); + } +} diff --git a/packages/appkit-rs/src/plugins/analytics.rs b/packages/appkit-rs/src/plugins/analytics.rs new file mode 100644 index 00000000..3a68abf9 --- /dev/null +++ b/packages/appkit-rs/src/plugins/analytics.rs @@ -0,0 +1,702 @@ +//! Analytics plugin — Rust `Plugin` trait implementation with the core +//! query-processing behaviors that the TypeScript analytics plugin relies on. +//! +//! The Rust side owns: +//! - Plugin manifest (`sql_warehouse` required resource) +//! - Query parameter extraction + validation +//! - SQL parameter list construction from user-supplied typed values +//! - Query file discovery from `config/queries/` (`.sql` vs `.obo.sql`) +//! - Stable query hashing for cache-key disambiguation +//! - Cache-key parts helper matching TS `["analytics:query", query_key, ...]` +//! +//! HTTP route wiring lives on the Python subclass in +//! `appkit/plugins/analytics.py`. This module is the source of truth for +//! parsing, validation, and cache-key behavior so the Python layer stays +//! thin and Rust/Python cannot drift on these rules. + +use sha2::{Digest, Sha256}; +use std::collections::{BTreeMap, HashMap}; +use std::fs; +use std::path::{Path, PathBuf}; + +use crate::plugin::{Plugin, PluginManifest, PluginPhase, ResourceRequirement}; + +// --------------------------------------------------------------------------- +// Config +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug, Default)] +pub struct AnalyticsPluginConfig { + /// Query execution timeout in milliseconds (matches TS `IAnalyticsConfig.timeout`). + pub timeout_ms: Option, + /// Override directory for query file loading. Defaults to `config/queries`. + pub queries_dir: Option, +} + +// --------------------------------------------------------------------------- +// QueryProcessor — TS `QueryProcessor` parity +// --------------------------------------------------------------------------- + +/// Typed SQL parameter matching the TS `sql.*` helpers — the TS code tags +/// values with `__sql_type` (STRING / BIGINT / DATE / TIMESTAMP / BOOLEAN). +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum SqlType { + String, + Number, + Date, + Timestamp, + Boolean, +} + +impl SqlType { + pub fn as_str(&self) -> &'static str { + match self { + Self::String => "STRING", + Self::Number => "BIGINT", + Self::Date => "DATE", + Self::Timestamp => "TIMESTAMP", + Self::Boolean => "BOOLEAN", + } + } +} + +/// Typed value passed to `QueryProcessor::convert_to_sql_parameters`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SqlValue { + pub value: String, + pub sql_type: SqlType, +} + +impl SqlValue { + pub fn string>(v: S) -> Self { + Self { value: v.into(), sql_type: SqlType::String } + } + + pub fn number(v: N) -> Self { + Self { value: v.to_string(), sql_type: SqlType::Number } + } + + pub fn date>(v: S) -> Self { + Self { value: v.into(), sql_type: SqlType::Date } + } + + pub fn timestamp>(v: S) -> Self { + Self { value: v.into(), sql_type: SqlType::Timestamp } + } + + pub fn boolean(v: bool) -> Self { + Self { value: v.to_string(), sql_type: SqlType::Boolean } + } +} + +/// Outgoing SQL statement parameter sent to the SQL Statement Execution API. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct StatementParameter { + pub name: String, + pub value: String, + pub type_name: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ValidationError { + pub field: String, + pub message: String, +} + +impl std::fmt::Display for ValidationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Invalid value for '{}': {}", self.field, self.message) + } +} + +impl std::error::Error for ValidationError {} + +/// QueryProcessor — mirrors `packages/appkit/src/plugins/analytics/query.ts`. +#[derive(Default, Clone, Debug)] +pub struct QueryProcessor; + +impl QueryProcessor { + pub fn new() -> Self { + Self + } + + /// Extract all `:param_name` placeholders from a query (`/:([a-zA-Z_]\w*)/g`). + /// Returns a set-like ordered, deduplicated list in first-seen order. + /// + /// Skips colons that appear inside SQL string literals (`'...'`), + /// quoted identifiers (`"..."`), line comments (`-- ...`), block comments + /// (`/* ... */`, nestable), and dollar-quoted strings (`$tag$...$tag$`). + /// Also skips `::TYPE` cast operators. + pub fn extract_param_names(&self, query: &str) -> Vec { + let bytes = query.as_bytes(); + let mut out: Vec = Vec::new(); + let mut seen = std::collections::HashSet::new(); + let mut i = 0; + while i < bytes.len() { + let b = bytes[i]; + + // Line comment: -- ... until end of line. + if b == b'-' && i + 1 < bytes.len() && bytes[i + 1] == b'-' { + i += 2; + while i < bytes.len() && bytes[i] != b'\n' { + i += 1; + } + continue; + } + + // Block comment: /* ... */ — PostgreSQL allows nesting. + if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' { + i += 2; + let mut depth: u32 = 1; + while i < bytes.len() && depth > 0 { + if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' { + depth += 1; + i += 2; + } else if i + 1 < bytes.len() && bytes[i] == b'*' && bytes[i + 1] == b'/' { + depth -= 1; + i += 2; + } else { + i += 1; + } + } + continue; + } + + // Single-quoted string literal: '...'. Doubled '' is an escape. + if b == b'\'' { + i += 1; + while i < bytes.len() { + if bytes[i] == b'\'' { + if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { + i += 2; + } else { + i += 1; + break; + } + } else { + i += 1; + } + } + continue; + } + + // Double-quoted identifier: "...". Doubled "" is an escape. + if b == b'"' { + i += 1; + while i < bytes.len() { + if bytes[i] == b'"' { + if i + 1 < bytes.len() && bytes[i + 1] == b'"' { + i += 2; + } else { + i += 1; + break; + } + } else { + i += 1; + } + } + continue; + } + + // Dollar-quoted string: $tag$...$tag$ (tag may be empty: $$...$$). + if b == b'$' { + let mut tag_end = i + 1; + while tag_end < bytes.len() && is_ident_continue(bytes[tag_end]) { + tag_end += 1; + } + if tag_end < bytes.len() && bytes[tag_end] == b'$' { + // The delimiter is `bytes[i..=tag_end]` — both `$` included. + let delim_len = tag_end - i + 1; + let mut j = tag_end + 1; + let mut closed = false; + while j + delim_len <= bytes.len() { + if bytes[j..j + delim_len] == bytes[i..i + delim_len] { + j += delim_len; + closed = true; + break; + } + j += 1; + } + i = if closed { j } else { bytes.len() }; + continue; + } + // Not a dollar-quote opener — fall through. + } + + if b == b':' { + // Skip `::TYPE` casts — the PostgreSQL cast operator consumes + // two colons and the identifier that follows, so neither colon + // starts a named parameter. + if i + 1 < bytes.len() && bytes[i + 1] == b':' { + i += 2; + while i < bytes.len() && is_ident_continue(bytes[i]) { + i += 1; + } + continue; + } + if i + 1 < bytes.len() && is_ident_start(bytes[i + 1]) { + let start = i + 1; + let mut end = start; + while end < bytes.len() && is_ident_continue(bytes[end]) { + end += 1; + } + let name = std::str::from_utf8(&bytes[start..end]) + .unwrap_or("") + .to_string(); + if !name.is_empty() && seen.insert(name.clone()) { + out.push(name); + } + i = end; + continue; + } + } + i += 1; + } + out + } + + /// Stable hash of the query text — used for cache-key disambiguation. + /// + /// TS uses MD5; we use SHA-256 hex here because `sha2` is already a + /// workspace dependency and there is no shared cache namespace between + /// TS and Rust. The contract is "stable per query string", which both + /// satisfy. + pub fn hash_query(&self, query: &str) -> String { + let digest = Sha256::digest(query.as_bytes()); + digest.iter().map(|b| format!("{b:02x}")).collect() + } + + /// Validate + transform user-supplied parameters into the wire format + /// consumed by the SQL Statement Execution API. + /// + /// Rules (mirroring TS): + /// - Every key in `parameters` MUST appear as a `:name` placeholder + /// in `query`. Extraneous keys → `ValidationError`. + /// - `None`/missing values are dropped (not sent). + pub fn convert_to_sql_parameters( + &self, + query: &str, + parameters: &BTreeMap>, + ) -> Result, ValidationError> { + let query_params: std::collections::HashSet = + self.extract_param_names(query).into_iter().collect(); + + for key in parameters.keys() { + if !query_params.contains(key) { + let valid = { + let mut v: Vec<&str> = query_params.iter().map(|s| s.as_str()).collect(); + v.sort(); + if v.is_empty() { "none".to_string() } else { v.join(", ") } + }; + return Err(ValidationError { + field: key.clone(), + message: format!( + "expected a parameter defined in the query (valid: {valid})", + ), + }); + } + } + + let mut out = Vec::new(); + for (name, value) in parameters.iter() { + if let Some(v) = value { + out.push(StatementParameter { + name: name.clone(), + value: v.value.clone(), + type_name: v.sql_type.as_str().to_string(), + }); + } + } + Ok(out) + } + + /// Compute the TS-parity cache-key parts: + /// `["analytics:query", query_key, JSON.stringify(parameters), + /// JSON.stringify(format), hashed_query, executor_key]`. + pub fn cache_key_parts( + &self, + query_key: &str, + parameters_json: &str, + format: &str, + hashed_query: &str, + executor_key: &str, + ) -> Vec { + vec![ + "analytics:query".to_string(), + query_key.to_string(), + parameters_json.to_string(), + serde_json::to_string(format).unwrap_or_else(|_| format!("\"{format}\"")), + hashed_query.to_string(), + executor_key.to_string(), + ] + } +} + +fn is_ident_start(b: u8) -> bool { + b.is_ascii_alphabetic() || b == b'_' +} + +fn is_ident_continue(b: u8) -> bool { + b.is_ascii_alphanumeric() || b == b'_' +} + +// --------------------------------------------------------------------------- +// Query file loading — `config/queries/*.sql` vs `*.obo.sql` +// --------------------------------------------------------------------------- + +/// A query loaded from disk. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LoadedQuery { + pub query_key: String, + pub query: String, + /// True when the source file was `.obo.sql` (executes + /// on-behalf-of the user). False for service-principal-scoped files. + pub is_as_user: bool, +} + +/// Load a query by key from a queries directory. Prefers `.obo.sql` (user +/// context) over `.sql` (service principal) when both exist — matches the TS +/// precedence where OBO queries are picked up first by `app.getAppQuery()`. +pub fn load_query(queries_dir: &Path, query_key: &str) -> Option { + if !is_valid_query_key(query_key) { + return None; + } + let obo = queries_dir.join(format!("{query_key}.obo.sql")); + let sp = queries_dir.join(format!("{query_key}.sql")); + if let Ok(text) = fs::read_to_string(&obo) { + return Some(LoadedQuery { + query_key: query_key.to_string(), + query: text, + is_as_user: true, + }); + } + if let Ok(text) = fs::read_to_string(&sp) { + return Some(LoadedQuery { + query_key: query_key.to_string(), + query: text, + is_as_user: false, + }); + } + None +} + +/// Allow only safe query keys — alphanumeric, underscore, dash. Prevents +/// path traversal like `../secrets`. +pub fn is_valid_query_key(key: &str) -> bool { + !key.is_empty() + && key.len() <= 128 + && key + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') +} + +// --------------------------------------------------------------------------- +// AnalyticsPluginCore — Plugin trait impl +// --------------------------------------------------------------------------- + +pub struct AnalyticsPluginCore { + manifest: PluginManifest, + #[allow(dead_code)] + config: AnalyticsPluginConfig, + processor: QueryProcessor, +} + +impl AnalyticsPluginCore { + pub const NAME: &'static str = "analytics"; + + pub fn new(config: AnalyticsPluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Analytics Plugin".into()), + description: Some( + "SQL query execution against Databricks SQL Warehouses".into(), + ), + required_resources: vec![ResourceRequirement { + resource_type: "sql_warehouse".into(), + required: true, + }], + optional_resources: vec![], + }, + config, + processor: QueryProcessor::new(), + } + } + + pub fn processor(&self) -> &QueryProcessor { + &self.processor + } + + pub fn queries_dir(&self) -> PathBuf { + self.config + .queries_dir + .clone() + .unwrap_or_else(|| PathBuf::from("config").join("queries")) + } + + pub fn exports_map(&self) -> HashMap { + let mut out = HashMap::new(); + out.insert("query".into(), "analytics.query".into()); + out + } +} + +impl Plugin for AnalyticsPluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + #[test] + fn test_extract_param_names_basic() { + let qp = QueryProcessor::new(); + let params = qp.extract_param_names( + "SELECT * FROM t WHERE id = :user_id AND region = :region", + ); + assert_eq!(params, vec!["user_id".to_string(), "region".to_string()]); + } + + #[test] + fn test_extract_param_names_dedup_preserves_first_order() { + let qp = QueryProcessor::new(); + let params = qp.extract_param_names( + "SELECT :a, :b, :a FROM t WHERE x = :c OR y = :b", + ); + assert_eq!(params, vec!["a".to_string(), "b".to_string(), "c".to_string()]); + } + + #[test] + fn test_extract_param_names_ignores_non_params() { + let qp = QueryProcessor::new(); + // `::TYPE` casts, standalone colons, and digits-first are all rejected. + let params = qp.extract_param_names( + "SELECT x::BIGINT, :user, :: FROM t WHERE n = :1bad OR m = :_ok", + ); + assert!(params.contains(&"user".to_string())); + assert!(params.contains(&"_ok".to_string())); + assert!(!params.contains(&"BIGINT".to_string())); + assert!(!params.contains(&"1bad".to_string())); + } + + #[test] + fn test_extract_param_names_skips_string_literals() { + let qp = QueryProcessor::new(); + // Colons inside single-quoted literals must not be treated as params. + let params = qp.extract_param_names( + "SELECT ':not_a_param', 'foo''s :also_skipped' FROM t WHERE id = :real_id", + ); + assert_eq!(params, vec!["real_id".to_string()]); + } + + #[test] + fn test_extract_param_names_skips_quoted_identifiers() { + let qp = QueryProcessor::new(); + let params = qp.extract_param_names( + "SELECT \"col:not_param\", \"esc\"\":also_skipped\" FROM t WHERE id = :real", + ); + assert_eq!(params, vec!["real".to_string()]); + } + + #[test] + fn test_extract_param_names_skips_line_comments() { + let qp = QueryProcessor::new(); + let params = qp.extract_param_names( + "SELECT 1 -- :fake_param\nFROM t WHERE id = :real_id", + ); + assert_eq!(params, vec!["real_id".to_string()]); + } + + #[test] + fn test_extract_param_names_skips_block_comments() { + let qp = QueryProcessor::new(); + let params = qp.extract_param_names( + "SELECT /* :fake1 /* nested :fake2 */ :fake3 */ :real FROM t", + ); + assert_eq!(params, vec!["real".to_string()]); + } + + #[test] + fn test_extract_param_names_skips_dollar_quoted_strings() { + let qp = QueryProcessor::new(); + let params = qp.extract_param_names( + "SELECT $$:not_a_param$$, $tag$:also_not$tag$, :real FROM t", + ); + assert_eq!(params, vec!["real".to_string()]); + } + + #[test] + fn test_convert_parameters_rejects_extra_when_colon_is_in_literal() { + let qp = QueryProcessor::new(); + // A raw colon inside a literal must not make `fake` look "defined" — + // the only real param is `id`, so `fake` must be rejected. + let mut params = BTreeMap::new(); + params.insert("fake".to_string(), Some(SqlValue::string("x"))); + let err = qp + .convert_to_sql_parameters( + "SELECT ':fake' FROM t WHERE id = :id", + ¶ms, + ) + .unwrap_err(); + assert_eq!(err.field, "fake"); + assert!(err.message.contains("valid: id")); + } + + #[test] + fn test_hash_query_is_stable_and_differs() { + let qp = QueryProcessor::new(); + let h1 = qp.hash_query("SELECT 1"); + let h2 = qp.hash_query("SELECT 1"); + let h3 = qp.hash_query("SELECT 2"); + assert_eq!(h1, h2); + assert_ne!(h1, h3); + assert_eq!(h1.len(), 64); + } + + #[test] + fn test_convert_parameters_success() { + let qp = QueryProcessor::new(); + let mut params = BTreeMap::new(); + params.insert("id".to_string(), Some(SqlValue::number(42))); + params.insert("name".to_string(), Some(SqlValue::string("alice"))); + let out = qp + .convert_to_sql_parameters( + "SELECT * FROM t WHERE id = :id AND name = :name", + ¶ms, + ) + .unwrap(); + assert_eq!(out.len(), 2); + let by_name: HashMap<&str, &StatementParameter> = + out.iter().map(|p| (p.name.as_str(), p)).collect(); + assert_eq!(by_name["id"].value, "42"); + assert_eq!(by_name["id"].type_name, "BIGINT"); + assert_eq!(by_name["name"].value, "alice"); + assert_eq!(by_name["name"].type_name, "STRING"); + } + + #[test] + fn test_convert_parameters_none_dropped() { + let qp = QueryProcessor::new(); + let mut params = BTreeMap::new(); + params.insert("id".to_string(), None); + let out = qp + .convert_to_sql_parameters("SELECT * FROM t WHERE id = :id", ¶ms) + .unwrap(); + assert!(out.is_empty()); + } + + #[test] + fn test_convert_parameters_rejects_extra_keys() { + let qp = QueryProcessor::new(); + let mut params = BTreeMap::new(); + params.insert("missing".to_string(), Some(SqlValue::string("x"))); + let err = qp + .convert_to_sql_parameters("SELECT * FROM t WHERE id = :id", ¶ms) + .unwrap_err(); + assert_eq!(err.field, "missing"); + assert!(err.message.contains("valid: id")); + } + + #[test] + fn test_cache_key_parts_shape() { + let qp = QueryProcessor::new(); + let parts = qp.cache_key_parts( + "trips_by_zone", + "{\"zone\":1}", + "JSON", + "deadbeef", + "user-42", + ); + assert_eq!(parts[0], "analytics:query"); + assert_eq!(parts[1], "trips_by_zone"); + assert_eq!(parts[2], "{\"zone\":1}"); + assert_eq!(parts[3], "\"JSON\""); + assert_eq!(parts[4], "deadbeef"); + assert_eq!(parts[5], "user-42"); + } + + #[test] + fn test_manifest_declares_sql_warehouse() { + let core = AnalyticsPluginCore::new(AnalyticsPluginConfig::default()); + assert_eq!(core.name(), "analytics"); + assert_eq!(core.manifest().required_resources.len(), 1); + assert_eq!( + core.manifest().required_resources[0].resource_type, + "sql_warehouse" + ); + assert_eq!(core.phase(), PluginPhase::Normal); + } + + #[test] + fn test_is_valid_query_key() { + assert!(is_valid_query_key("trips_by_zone")); + assert!(is_valid_query_key("abc-123_xyz")); + assert!(!is_valid_query_key("")); + assert!(!is_valid_query_key("../secrets")); + assert!(!is_valid_query_key("with space")); + assert!(!is_valid_query_key("weird$name")); + } + + #[test] + fn test_load_query_prefers_obo() { + let tmp = tempdir_in_target(); + let q_dir = tmp.join("queries"); + fs::create_dir_all(&q_dir).unwrap(); + let mut sp = fs::File::create(q_dir.join("foo.sql")).unwrap(); + sp.write_all(b"SELECT 1 AS service").unwrap(); + let mut obo = fs::File::create(q_dir.join("foo.obo.sql")).unwrap(); + obo.write_all(b"SELECT 1 AS user").unwrap(); + + let q = load_query(&q_dir, "foo").expect("load"); + assert!(q.is_as_user); + assert!(q.query.contains("user")); + } + + #[test] + fn test_load_query_falls_back_to_sp() { + let tmp = tempdir_in_target(); + let q_dir = tmp.join("queries_sp_only"); + fs::create_dir_all(&q_dir).unwrap(); + let mut sp = fs::File::create(q_dir.join("bar.sql")).unwrap(); + sp.write_all(b"SELECT 'sp'").unwrap(); + + let q = load_query(&q_dir, "bar").expect("load"); + assert!(!q.is_as_user); + assert!(q.query.contains("sp")); + } + + #[test] + fn test_load_query_rejects_bad_key() { + let tmp = tempdir_in_target(); + assert!(load_query(&tmp, "../etc/passwd").is_none()); + } + + /// Create a unique temp dir under the current target/ so tests don't pollute + /// the repo and clean up isn't needed between runs. + fn tempdir_in_target() -> PathBuf { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + let base = std::env::temp_dir().join("appkit-rs-analytics-tests"); + let n = COUNTER.fetch_add(1, Ordering::SeqCst); + let pid = std::process::id(); + let dir = base.join(format!("{pid}-{n}")); + fs::create_dir_all(&dir).unwrap(); + dir + } +} diff --git a/packages/appkit-rs/src/plugins/files.rs b/packages/appkit-rs/src/plugins/files.rs new file mode 100644 index 00000000..79dc71dc --- /dev/null +++ b/packages/appkit-rs/src/plugins/files.rs @@ -0,0 +1,94 @@ +//! Files plugin — Rust `Plugin` trait implementation. Wraps `FilesConnector` +//! with per-volume alias configuration and declares the `volume` resource +//! requirement. +//! +//! HTTP routes and OBO token extraction are owned by the Python subclass in +//! `appkit/plugins/files.py`. This struct supplies the plugin name, phase, +//! and manifest used by future Rust callers and the (upcoming) appkit CLI. + +use std::collections::HashMap; + +use crate::plugin::{Plugin, PluginManifest, PluginPhase, ResourceRequirement}; + +/// Per-volume configuration — alias → fully-qualified volume path. +/// Example: `{"files": "/Volumes/catalog/schema/volume"}`. +#[derive(Clone, Debug, Default)] +pub struct FilesPluginConfig { + pub volumes: HashMap, + pub max_upload_size_bytes: Option, + pub timeout_ms: Option, +} + +/// Rust wrapper around `FilesConnector` exposing the `Plugin` trait. +pub struct FilesPluginCore { + manifest: PluginManifest, + #[allow(dead_code)] + config: FilesPluginConfig, +} + +impl FilesPluginCore { + pub const NAME: &'static str = "files"; + + pub fn new(config: FilesPluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Files Plugin".into()), + description: Some("Unity Catalog Volumes file operations".into()), + required_resources: vec![ResourceRequirement { + resource_type: "volume".into(), + required: true, + }], + optional_resources: vec![], + }, + config, + } + } +} + +impl Plugin for FilesPluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_manifest_declares_volume_resource() { + let core = FilesPluginCore::new(FilesPluginConfig::default()); + assert_eq!(core.name(), "files"); + assert_eq!(core.manifest().required_resources.len(), 1); + assert_eq!(core.manifest().required_resources[0].resource_type, "volume"); + } + + #[test] + fn test_phase_is_normal() { + let core = FilesPluginCore::new(FilesPluginConfig::default()); + assert_eq!(core.phase(), PluginPhase::Normal); + } + + #[test] + fn test_config_round_trip() { + let mut volumes = HashMap::new(); + volumes.insert("files".to_string(), "/Volumes/a/b/c".to_string()); + let config = FilesPluginConfig { + volumes: volumes.clone(), + max_upload_size_bytes: Some(5_000_000_000), + timeout_ms: Some(30_000), + }; + let core = FilesPluginCore::new(config.clone()); + assert_eq!(core.config.volumes, volumes); + assert_eq!(core.config.max_upload_size_bytes, Some(5_000_000_000)); + } +} diff --git a/packages/appkit-rs/src/plugins/genie.rs b/packages/appkit-rs/src/plugins/genie.rs new file mode 100644 index 00000000..f599177c --- /dev/null +++ b/packages/appkit-rs/src/plugins/genie.rs @@ -0,0 +1,70 @@ +//! Genie plugin — Rust `Plugin` trait implementation. +//! +//! Declares the `genie_space` resource requirement. HTTP routes live in +//! `appkit/plugins/genie.py`. + +use std::collections::HashMap; + +use crate::plugin::{Plugin, PluginManifest, PluginPhase, ResourceRequirement}; + +/// Genie plugin configuration — alias → space_id map. +#[derive(Clone, Debug, Default)] +pub struct GeniePluginConfig { + pub spaces: HashMap, + pub timeout_ms: Option, +} + +pub struct GeniePluginCore { + manifest: PluginManifest, + #[allow(dead_code)] + config: GeniePluginConfig, +} + +impl GeniePluginCore { + pub const NAME: &'static str = "genie"; + + pub fn new(config: GeniePluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Genie Plugin".into()), + description: Some("Databricks Genie conversational analytics".into()), + required_resources: vec![ResourceRequirement { + resource_type: "genie_space".into(), + required: true, + }], + optional_resources: vec![], + }, + config, + } + } +} + +impl Plugin for GeniePluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_manifest_declares_genie_space() { + let core = GeniePluginCore::new(GeniePluginConfig::default()); + assert_eq!(core.name(), "genie"); + assert_eq!( + core.manifest().required_resources[0].resource_type, + "genie_space" + ); + } +} diff --git a/packages/appkit-rs/src/plugins/lakebase.rs b/packages/appkit-rs/src/plugins/lakebase.rs new file mode 100644 index 00000000..c735a7b3 --- /dev/null +++ b/packages/appkit-rs/src/plugins/lakebase.rs @@ -0,0 +1,70 @@ +//! Lakebase plugin — Rust `Plugin` trait implementation. +//! +//! Declares the `postgres` resource requirement. Lakebase exposes a +//! programmatic pool API rather than HTTP routes, so the Python subclass in +//! `appkit/plugins/lakebase.py` publishes connection helpers via exports() +//! and keeps `inject_routes()` empty. + +use crate::plugin::{Plugin, PluginManifest, PluginPhase, ResourceRequirement}; + +#[derive(Clone, Debug, Default)] +pub struct LakebasePluginConfig { + pub database: Option, + pub host: Option, + pub ssl_mode: Option, +} + +pub struct LakebasePluginCore { + manifest: PluginManifest, + #[allow(dead_code)] + config: LakebasePluginConfig, +} + +impl LakebasePluginCore { + pub const NAME: &'static str = "lakebase"; + + pub fn new(config: LakebasePluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Lakebase".into()), + description: Some("Databricks Lakebase PostgreSQL integration".into()), + required_resources: vec![ResourceRequirement { + resource_type: "postgres".into(), + required: true, + }], + optional_resources: vec![], + }, + config, + } + } +} + +impl Plugin for LakebasePluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_manifest_declares_postgres() { + let core = LakebasePluginCore::new(LakebasePluginConfig::default()); + assert_eq!(core.name(), "lakebase"); + assert_eq!( + core.manifest().required_resources[0].resource_type, + "postgres" + ); + } +} diff --git a/packages/appkit-rs/src/plugins/mod.rs b/packages/appkit-rs/src/plugins/mod.rs new file mode 100644 index 00000000..1feeec2e --- /dev/null +++ b/packages/appkit-rs/src/plugins/mod.rs @@ -0,0 +1,24 @@ +//! Shipped plugin wrappers — Rust `Plugin` trait implementations for the core +//! connectors. Each module here provides: +//! +//! - A per-plugin config struct +//! - A Rust `Plugin` trait implementation with a proper `PluginManifest` +//! (name, required/optional resources) +//! - A Python-facing subclass of `appkit.Plugin` lives in +//! `appkit/plugins/*.py`, which is where route injection and OBO-aware +//! handler logic lives. The Rust side is the source of truth for manifest +//! + resource declarations so a future `appkit plugin sync` CLI can read +//! them directly without parsing Python. +//! +//! Keeping the Rust trait implementations slim avoids duplicating the +//! connector's HTTP logic and prevents Rust/Python drift: both sides agree on +//! manifests and plugin names, while route injection is owned by a single +//! Python class per plugin. + +pub mod analytics; +pub mod files; +pub mod genie; +pub mod lakebase; +pub mod server; +pub mod serving; +pub mod vector_search; diff --git a/packages/appkit-rs/src/plugins/server.rs b/packages/appkit-rs/src/plugins/server.rs new file mode 100644 index 00000000..f60647c2 --- /dev/null +++ b/packages/appkit-rs/src/plugins/server.rs @@ -0,0 +1,187 @@ +//! Server plugin — Rust `Plugin` trait implementation. +//! +//! Promotes the existing axum-backed server into a first-class plugin with a +//! real manifest. Unlike other plugin cores here, the server plugin hosts +//! routes for every other registered plugin rather than owning routes of its +//! own, so `required_resources` is empty. +//! +//! Actual route hosting, SSE streaming, and graceful shutdown live in the +//! top-level `crate::server` module — this struct is the plugin-registry face +//! of that infrastructure. Python code registers a `ServerPlugin` via +//! `appkit/plugins/server.py`; the Rust side is the source of truth for the +//! plugin name (`"server"`), manifest, and phase (Core — it must be ready +//! before Normal plugins inject routes). + +use crate::plugin::{Plugin, PluginManifest, PluginPhase}; +use crate::server::PyServerConfig; + +// --------------------------------------------------------------------------- +// Config +// --------------------------------------------------------------------------- + +/// Server plugin configuration — mirrors the TS `IServerConfig` knobs that +/// actually alter Rust-side server behavior. For route-hosting-only fields we +/// reuse `PyServerConfig` to avoid duplicating the existing binding. +#[derive(Clone, Debug)] +pub struct ServerPluginConfig { + pub host: String, + pub port: u16, + pub auto_start: bool, + pub static_path: Option, +} + +impl Default for ServerPluginConfig { + fn default() -> Self { + Self { + host: "0.0.0.0".into(), + port: 8000, + auto_start: true, + static_path: None, + } + } +} + +impl From for ServerPluginConfig { + fn from(value: PyServerConfig) -> Self { + Self { + host: value.host, + port: value.port, + auto_start: value.auto_start, + static_path: value.static_path, + } + } +} + +impl From<&PyServerConfig> for ServerPluginConfig { + fn from(value: &PyServerConfig) -> Self { + Self { + host: value.host.clone(), + port: value.port, + auto_start: value.auto_start, + static_path: value.static_path.clone(), + } + } +} + +// --------------------------------------------------------------------------- +// ServerPluginCore — Plugin trait impl. +// --------------------------------------------------------------------------- + +pub struct ServerPluginCore { + manifest: PluginManifest, + config: ServerPluginConfig, +} + +impl ServerPluginCore { + pub const NAME: &'static str = "server"; + + pub fn new(config: ServerPluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Server Plugin".into()), + description: Some( + "HTTP server with axum route hosting, SSE streaming, and graceful shutdown" + .into(), + ), + required_resources: vec![], + optional_resources: vec![], + }, + config, + } + } + + pub fn config(&self) -> &ServerPluginConfig { + &self.config + } + + /// Convert the plugin config into the `PyServerConfig` consumed by + /// `PyAppKit::start_server`. This keeps the Rust/Python bindings in sync + /// without duplicating the server-config field set. + pub fn to_py_config(&self) -> PyServerConfig { + PyServerConfig { + host: self.config.host.clone(), + port: self.config.port, + auto_start: self.config.auto_start, + static_path: self.config.static_path.clone(), + } + } +} + +impl Plugin for ServerPluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + /// The server plugin initializes in the Core phase — it must be ready + /// before Normal-phase plugins inject their routes. + fn phase(&self) -> PluginPhase { + PluginPhase::Core + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_manifest_has_no_required_resources() { + let core = ServerPluginCore::new(ServerPluginConfig::default()); + assert_eq!(core.name(), "server"); + assert!(core.manifest().required_resources.is_empty()); + assert!(core.manifest().optional_resources.is_empty()); + } + + #[test] + fn test_server_plugin_runs_in_core_phase() { + let core = ServerPluginCore::new(ServerPluginConfig::default()); + assert_eq!(core.phase(), PluginPhase::Core); + } + + #[test] + fn test_default_config_matches_ts_defaults() { + let cfg = ServerPluginConfig::default(); + assert_eq!(cfg.host, "0.0.0.0"); + assert_eq!(cfg.port, 8000); + assert!(cfg.auto_start); + assert!(cfg.static_path.is_none()); + } + + #[test] + fn test_to_py_config_roundtrip() { + let cfg = ServerPluginConfig { + host: "127.0.0.1".into(), + port: 9090, + auto_start: false, + static_path: Some("dist".into()), + }; + let core = ServerPluginCore::new(cfg.clone()); + let py_cfg = core.to_py_config(); + assert_eq!(py_cfg.host, cfg.host); + assert_eq!(py_cfg.port, cfg.port); + assert_eq!(py_cfg.auto_start, cfg.auto_start); + assert_eq!(py_cfg.static_path, cfg.static_path); + } + + #[test] + fn test_from_py_server_config() { + let py_cfg = PyServerConfig { + host: "example.com".into(), + port: 3000, + auto_start: true, + static_path: Some("public".into()), + }; + let plugin_cfg: ServerPluginConfig = (&py_cfg).into(); + assert_eq!(plugin_cfg.host, "example.com"); + assert_eq!(plugin_cfg.port, 3000); + assert_eq!(plugin_cfg.static_path.as_deref(), Some("public")); + } +} diff --git a/packages/appkit-rs/src/plugins/serving.rs b/packages/appkit-rs/src/plugins/serving.rs new file mode 100644 index 00000000..ad297ea7 --- /dev/null +++ b/packages/appkit-rs/src/plugins/serving.rs @@ -0,0 +1,107 @@ +//! Serving plugin — Rust `Plugin` trait implementation. +//! +//! Declares the `serving_endpoint` resource requirement. HTTP routes live in +//! `appkit/plugins/serving.py`. + +use std::collections::HashMap; + +use crate::plugin::{Plugin, PluginManifest, PluginPhase, ResourceRequirement}; + +#[derive(Clone, Debug, Default)] +pub struct ServingEndpointConfig { + /// Environment variable name that resolves to the endpoint name. + pub env: String, + pub served_model: Option, +} + +/// Serving plugin configuration — alias → endpoint config. When `endpoints` +/// is empty, the plugin operates in "simple mode" and exposes a single +/// `invoke` / `stream` route under the alias "default". +#[derive(Clone, Debug, Default)] +pub struct ServingPluginConfig { + pub endpoints: HashMap, + pub timeout_ms: Option, +} + +impl ServingPluginConfig { + pub fn is_named_mode(&self) -> bool { + !self.endpoints.is_empty() + } +} + +pub struct ServingPluginCore { + manifest: PluginManifest, + #[allow(dead_code)] + config: ServingPluginConfig, +} + +impl ServingPluginCore { + pub const NAME: &'static str = "serving"; + + pub fn new(config: ServingPluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Model Serving Plugin".into()), + description: Some("Invoke and stream from Databricks serving endpoints".into()), + required_resources: vec![ResourceRequirement { + resource_type: "serving_endpoint".into(), + required: true, + }], + optional_resources: vec![], + }, + config, + } + } +} + +impl Plugin for ServingPluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_mode_is_default() { + let core = ServingPluginCore::new(ServingPluginConfig::default()); + assert!(!core.config.is_named_mode()); + } + + #[test] + fn test_named_mode_when_endpoints_present() { + let mut endpoints = HashMap::new(); + endpoints.insert( + "chat".into(), + ServingEndpointConfig { + env: "CHAT_ENDPOINT".into(), + served_model: None, + }, + ); + let config = ServingPluginConfig { + endpoints, + timeout_ms: None, + }; + assert!(config.is_named_mode()); + } + + #[test] + fn test_manifest_declares_serving_endpoint() { + let core = ServingPluginCore::new(ServingPluginConfig::default()); + assert_eq!( + core.manifest().required_resources[0].resource_type, + "serving_endpoint" + ); + } +} diff --git a/packages/appkit-rs/src/plugins/vector_search.rs b/packages/appkit-rs/src/plugins/vector_search.rs new file mode 100644 index 00000000..b72142a0 --- /dev/null +++ b/packages/appkit-rs/src/plugins/vector_search.rs @@ -0,0 +1,182 @@ +//! Vector Search plugin — Rust `Plugin` trait implementation. +//! +//! Declares the `vector_search_index` resource requirement. The HTTP routes +//! (`/api/vector-search/query`, `/api/vector-search/query-next-page`) live on +//! the Python subclass in `appkit/plugins/vector_search.py`; the Rust side +//! owns the manifest, per-index config, and request-shape helpers used by +//! `VectorSearchConnector`. +//! +//! The request/response builders live in +//! `crate::connectors::vector_search`. This module is deliberately thin: it +//! exists so Python apps can declare a `VectorSearchPluginCore` with the right +//! manifest and resource requirement, mirroring the TS `VectorSearchPlugin`. + +use std::collections::HashMap; + +use crate::connectors::vector_search::VsQueryType; +use crate::plugin::{Plugin, PluginManifest, PluginPhase, ResourceRequirement}; + +// --------------------------------------------------------------------------- +// Per-index configuration — alias → index settings. +// --------------------------------------------------------------------------- + +/// Per-index configuration matching the TS `VectorSearchPluginConfig.indexes` +/// shape. +#[derive(Clone, Debug)] +pub struct VectorSearchIndexConfig { + /// Three-level UC name: `catalog.schema.index_name`. + pub index_name: String, + /// Endpoint name — required for pagination. + pub endpoint_name: Option, + /// Columns to return in results. + pub columns: Vec, + /// Default query type — ann / hybrid / full_text. + pub query_type: VsQueryType, + /// Default max number of results. + pub num_results: u32, + /// Reranker columns (enables the `databricks_reranker` when non-empty). + pub reranker_columns: Option>, +} + +impl Default for VectorSearchIndexConfig { + fn default() -> Self { + Self { + index_name: String::new(), + endpoint_name: None, + columns: Vec::new(), + query_type: VsQueryType::Hybrid, + num_results: 20, + reranker_columns: None, + } + } +} + +/// Vector Search plugin configuration — alias → index config plus a default +/// timeout. +#[derive(Clone, Debug, Default)] +pub struct VectorSearchPluginConfig { + pub indexes: HashMap, + /// Per-query timeout in milliseconds. + pub timeout_ms: Option, +} + +impl VectorSearchPluginConfig { + /// Return the configured aliases in stable (sorted) order. Used by the + /// Python layer to advertise known indexes via `client_config`. + pub fn aliases(&self) -> Vec { + let mut keys: Vec = self.indexes.keys().cloned().collect(); + keys.sort(); + keys + } +} + +// --------------------------------------------------------------------------- +// VectorSearchPluginCore — Plugin trait impl. +// --------------------------------------------------------------------------- + +pub struct VectorSearchPluginCore { + manifest: PluginManifest, + #[allow(dead_code)] + config: VectorSearchPluginConfig, +} + +impl VectorSearchPluginCore { + pub const NAME: &'static str = "vector-search"; + + pub fn new(config: VectorSearchPluginConfig) -> Self { + Self { + manifest: PluginManifest { + name: Self::NAME.into(), + display_name: Some("Vector Search Plugin".into()), + description: Some( + "Query Databricks Vector Search indexes with hybrid search, reranking, and pagination" + .into(), + ), + required_resources: vec![ResourceRequirement { + resource_type: "vector_search_index".into(), + required: true, + }], + optional_resources: vec![], + }, + config, + } + } + + pub fn config(&self) -> &VectorSearchPluginConfig { + &self.config + } +} + +impl Plugin for VectorSearchPluginCore { + fn name(&self) -> &str { + &self.manifest.name + } + + fn phase(&self) -> PluginPhase { + PluginPhase::Normal + } + + fn manifest(&self) -> &PluginManifest { + &self.manifest + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_manifest_declares_vector_search_index() { + let core = VectorSearchPluginCore::new(VectorSearchPluginConfig::default()); + assert_eq!(core.name(), "vector-search"); + assert_eq!(core.manifest().required_resources.len(), 1); + assert_eq!( + core.manifest().required_resources[0].resource_type, + "vector_search_index" + ); + assert_eq!(core.phase(), PluginPhase::Normal); + } + + #[test] + fn test_default_index_config_matches_ts_defaults() { + let cfg = VectorSearchIndexConfig::default(); + assert_eq!(cfg.query_type, VsQueryType::Hybrid); + assert_eq!(cfg.num_results, 20); + assert!(cfg.columns.is_empty()); + assert!(cfg.endpoint_name.is_none()); + assert!(cfg.reranker_columns.is_none()); + } + + #[test] + fn test_aliases_are_sorted_and_stable() { + let mut indexes = HashMap::new(); + indexes.insert( + "docs".into(), + VectorSearchIndexConfig { + index_name: "cat.sch.docs".into(), + columns: vec!["id".into()], + ..Default::default() + }, + ); + indexes.insert( + "articles".into(), + VectorSearchIndexConfig { + index_name: "cat.sch.articles".into(), + columns: vec!["id".into()], + ..Default::default() + }, + ); + let config = VectorSearchPluginConfig { + indexes, + timeout_ms: Some(10_000), + }; + assert_eq!( + config.aliases(), + vec!["articles".to_string(), "docs".to_string()] + ); + } +} diff --git a/packages/appkit-rs/src/server.rs b/packages/appkit-rs/src/server.rs new file mode 100644 index 00000000..47b7b3be --- /dev/null +++ b/packages/appkit-rs/src/server.rs @@ -0,0 +1,894 @@ +//! Axum HTTP server with Python route injection, SSE streaming, static file +//! serving, and graceful shutdown. +//! +//! Ports `packages/appkit/src/plugins/server/index.ts`. + +use std::collections::HashMap; +use std::convert::Infallible; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use axum::body::Bytes; +use axum::http::{header, HeaderMap, Method, StatusCode, Uri}; +use axum::response::sse::{Event as AxumSseEvent, KeepAlive, Sse}; +use axum::response::{IntoResponse, Json, Response}; +use axum::routing::{delete, get, patch, post, put}; +use axum::Router; +use futures::StreamExt; +use pyo3::prelude::*; +use serde_json::Value as JsonValue; +use tokio::net::TcpListener; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +use tower_http::cors::CorsLayer; + +use pyo3_async_runtimes::TaskLocals; + +use crate::stream::{SseEvent, StreamConfig, StreamManager}; + +// --------------------------------------------------------------------------- +// Clonable PyObject wrapper (acquires GIL for Clone) +// --------------------------------------------------------------------------- + +/// Wrapper around `PyObject` that implements `Clone` by briefly acquiring the +/// GIL. This is needed because axum requires handler closures to be `Clone`. +/// +/// # GIL acquisition on clone +/// +/// Every `Clone::clone` call acquires the GIL via `Python::with_gil`. Under the +/// current GIL-based CPython runtime this is safe but adds contention when many +/// handler clones happen concurrently (e.g. during router initialization). If +/// PyO3 moves to a free-threaded build (`--disable-gil` / PEP 703), this +/// implementation must be revisited because `clone_ref` semantics may change. +#[derive(Debug)] +struct GilPyObject(PyObject); + +impl Clone for GilPyObject { + fn clone(&self) -> Self { + Python::with_gil(|py| Self(self.0.clone_ref(py))) + } +} + +impl GilPyObject { + fn new(obj: PyObject) -> Self { + Self(obj) + } + + fn into_inner(self) -> PyObject { + self.0 + } +} + +// --------------------------------------------------------------------------- +// Route types (crate-internal) +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug)] +pub enum HttpMethod { + Get, + Post, + Put, + Delete, + Patch, +} + +impl HttpMethod { + fn as_str(&self) -> &'static str { + match self { + Self::Get => "GET", + Self::Post => "POST", + Self::Put => "PUT", + Self::Delete => "DELETE", + Self::Patch => "PATCH", + } + } +} + +pub struct RouteDefinition { + pub method: HttpMethod, + pub path: String, + pub handler: PyObject, + pub is_stream: bool, +} + +// --------------------------------------------------------------------------- +// PyRouter — collects route registrations from Python plugins +// --------------------------------------------------------------------------- + +/// Router passed to `Plugin.inject_routes()` for route registration. +/// +/// ```python +/// def inject_routes(self, router): +/// router.get("/items", self.get_items) +/// router.post("/items", self.create_item) +/// router.get("/stream", self.handle_stream, stream=True) +/// ``` +#[pyclass(name = "Router", module = "appkit")] +pub struct PyRouter { + routes: std::sync::Mutex>, + #[pyo3(get)] + plugin_name: String, +} + +impl PyRouter { + pub fn new(plugin_name: &str) -> Self { + Self { + routes: std::sync::Mutex::new(Vec::new()), + plugin_name: plugin_name.to_string(), + } + } + + /// Take the collected routes out of the router (consumes them). + pub fn take_routes(&self) -> Vec { + std::mem::take(&mut *self.routes.lock().unwrap()) + } +} + +#[pymethods] +impl PyRouter { + #[pyo3(signature = (path, handler, *, stream = false))] + fn get(&self, path: String, handler: PyObject, stream: bool) { + self.routes.lock().unwrap().push(RouteDefinition { + method: HttpMethod::Get, + path, + handler, + is_stream: stream, + }); + } + + #[pyo3(signature = (path, handler, *, stream = false))] + fn post(&self, path: String, handler: PyObject, stream: bool) { + self.routes.lock().unwrap().push(RouteDefinition { + method: HttpMethod::Post, + path, + handler, + is_stream: stream, + }); + } + + #[pyo3(signature = (path, handler, *, stream = false))] + fn put(&self, path: String, handler: PyObject, stream: bool) { + self.routes.lock().unwrap().push(RouteDefinition { + method: HttpMethod::Put, + path, + handler, + is_stream: stream, + }); + } + + #[pyo3(signature = (path, handler, *, stream = false))] + fn delete(&self, path: String, handler: PyObject, stream: bool) { + self.routes.lock().unwrap().push(RouteDefinition { + method: HttpMethod::Delete, + path, + handler, + is_stream: stream, + }); + } + + #[pyo3(signature = (path, handler, *, stream = false))] + fn patch(&self, path: String, handler: PyObject, stream: bool) { + self.routes.lock().unwrap().push(RouteDefinition { + method: HttpMethod::Patch, + path, + handler, + is_stream: stream, + }); + } + + fn __repr__(&self) -> String { + let count = self.routes.lock().unwrap().len(); + format!("Router(plugin={:?}, routes={})", self.plugin_name, count) + } +} + +// --------------------------------------------------------------------------- +// PyRequest — immutable request object passed to Python handlers +// --------------------------------------------------------------------------- + +/// HTTP request data forwarded to Python route handlers. +/// +/// ```python +/// async def handle(self, request): +/// print(request.method, request.path) +/// body = json.loads(request.body) if request.body else {} +/// return json.dumps({"ok": True}) +/// ``` +#[pyclass(frozen, name = "Request", module = "appkit")] +#[derive(Clone)] +pub struct PyRequest { + #[pyo3(get)] + pub method: String, + #[pyo3(get)] + pub path: String, + #[pyo3(get)] + pub headers: HashMap, + #[pyo3(get)] + pub query: String, + #[pyo3(get)] + pub body: String, +} + +#[pymethods] +impl PyRequest { + /// Parse the request body as JSON and return Python-native data + /// (dict, list, str, int, float, bool, or None). + /// + /// Raises `ValueError` if the body is not valid JSON. + fn json(&self, py: Python<'_>) -> PyResult { + let json_mod = py.import("json")?; + match json_mod.call_method1("loads", (self.body.as_str(),)) { + Ok(result) => Ok(result.unbind()), + Err(e) => { + let msg = e.to_string(); + Err(pyo3::exceptions::PyValueError::new_err(msg)) + } + } + } + + fn __repr__(&self) -> String { + format!("Request({} {})", self.method, self.path) + } +} + +// --------------------------------------------------------------------------- +// PyServerConfig +// --------------------------------------------------------------------------- + +/// Server configuration. +/// +/// ```python +/// config = ServerConfig(host="0.0.0.0", port=8000, static_path="dist") +/// ``` +#[pyclass(frozen, name = "ServerConfig", module = "appkit")] +#[derive(Clone)] +pub struct PyServerConfig { + #[pyo3(get)] + pub host: String, + #[pyo3(get)] + pub port: u16, + #[pyo3(get)] + pub auto_start: bool, + #[pyo3(get)] + pub static_path: Option, +} + +#[pymethods] +impl PyServerConfig { + #[new] + #[pyo3(signature = (*, host = "0.0.0.0".to_string(), port = 8000, auto_start = true, static_path = None))] + fn new(host: String, port: u16, auto_start: bool, static_path: Option) -> Self { + Self { + host, + port, + auto_start, + static_path, + } + } + + fn __repr__(&self) -> String { + format!("ServerConfig(host={:?}, port={})", self.host, self.port) + } + + fn __eq__(&self, other: &Self) -> bool { + self.host == other.host + && self.port == other.port + && self.auto_start == other.auto_start + && self.static_path == other.static_path + } + + fn __hash__(&self) -> u64 { + use std::hash::{Hash, Hasher}; + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.host.hash(&mut hasher); + self.port.hash(&mut hasher); + self.auto_start.hash(&mut hasher); + hasher.finish() + } +} + +// --------------------------------------------------------------------------- +// Endpoint info (returned by /api/__config) +// --------------------------------------------------------------------------- + +#[derive(Clone, serde::Serialize)] +struct EndpointInfo { + method: String, + path: String, + plugin: String, +} + +// --------------------------------------------------------------------------- +// Python handler invocation helpers +// --------------------------------------------------------------------------- + +/// Call a Python async handler with a `PyRequest`, returning the JSON string. +/// +/// `task_locals` carries the Python asyncio event loop reference so that +/// `into_future` can bridge the coroutine even though we're on a bare +/// tokio task (no running Python event loop). +async fn call_python_handler( + handler: PyObject, + request: PyRequest, + task_locals: &TaskLocals, +) -> Result { + let future = Python::with_gil(|py| { + let req_obj = Py::new(py, request) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let coroutine = handler + .call1(py, (req_obj,)) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + pyo3_async_runtimes::into_future_with_locals( + &task_locals.clone_ref(py), + coroutine.into_bound(py), + ) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())) + })?; + + let result = future + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Python::with_gil(|py| { + result.extract::(py).map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Handler must return a JSON string: {e}"), + ) + }) + }) +} + +/// Spawn a task that drives a Python async generator, forwarding items to `tx`. +fn spawn_python_generator(py_gen: PyObject, tx: mpsc::Sender<(Option, String)>) { + tokio::spawn(async move { + loop { + // Step 1: acquire GIL, call __anext__, get a future, release GIL. + let future_result: Result, String> = Python::with_gil(|py| { + match py_gen.call_method0(py, "__anext__") { + Ok(coroutine) => pyo3_async_runtimes::tokio::into_future( + coroutine.into_bound(py), + ) + .map(Some) + .map_err(|e| e.to_string()), + Err(e) => { + if e.is_instance_of::(py) { + Ok(None) + } else { + Err(e.to_string()) + } + } + } + }); + + match future_result { + Ok(None) => break, // Generator exhausted. + Err(e) => { + let _ = tx.send((Some("error".to_string()), e)).await; + break; + } + Ok(Some(future)) => { + match future.await { + Ok(value) => { + let data = Python::with_gil(|py| { + value + .extract::(py) + .unwrap_or_else(|_| "null".to_string()) + }); + if tx.send((None, data)).await.is_err() { + break; // Receiver dropped. + } + } + Err(e) => { + let is_stop = Python::with_gil(|py| { + e.is_instance_of::(py) + }); + if is_stop { + break; + } + let _ = tx.send((Some("error".to_string()), e.to_string())).await; + break; + } + } + } + } + } + }); +} + +// --------------------------------------------------------------------------- +// Request data extraction +// --------------------------------------------------------------------------- + +fn extract_request_data(method: &Method, uri: &Uri, headers: &HeaderMap, body: &[u8]) -> PyRequest { + let header_map: HashMap = headers + .iter() + .filter_map(|(k, v)| v.to_str().ok().map(|val| (k.as_str().to_string(), val.to_string()))) + .collect(); + + PyRequest { + method: method.to_string(), + path: uri.path().to_string(), + headers: header_map, + query: uri.query().unwrap_or("").to_string(), + body: String::from_utf8_lossy(body).to_string(), + } +} + +/// Parse a query string into key-value pairs. +fn parse_query(query: &str) -> HashMap { + query + .split('&') + .filter(|s| !s.is_empty()) + .filter_map(|pair| { + let mut parts = pair.splitn(2, '='); + let key = parts.next()?; + let value = parts.next().unwrap_or(""); + Some((key.to_string(), value.to_string())) + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Router construction +// --------------------------------------------------------------------------- + +/// Build the full axum `Router` from plugin routes, health check, client config, +/// and optional static file serving. +pub fn build_router( + plugin_routes: Vec<(String, Vec)>, + plugin_configs: HashMap, + stream_manager: Arc, + static_path: Option, + task_locals: TaskLocals, +) -> Router { + let task_locals = Arc::new(task_locals); + let mut app = Router::new(); + + // GET /health + app = app.route( + "/health", + get(|| async { Json(serde_json::json!({"status": "ok"})) }), + ); + + // GET /api/__config — aggregated endpoint map + plugin client configs. + let endpoint_info = collect_endpoint_info(&plugin_routes); + let config_payload = Arc::new(serde_json::json!({ + "plugins": plugin_configs, + "endpoints": endpoint_info, + })); + let config_payload_clone = config_payload.clone(); + app = app.route( + "/api/__config", + get(move || { + let payload = config_payload_clone.clone(); + async move { Json((*payload).clone()) } + }), + ); + + // Plugin routes — each mounted under /api/{plugin_name}. + for (plugin_name, routes) in plugin_routes { + let mut plugin_router = Router::new(); + + for route in routes { + if route.is_stream { + plugin_router = + mount_stream_route(plugin_router, route, stream_manager.clone(), task_locals.clone()); + } else { + plugin_router = mount_handler_route(plugin_router, route, task_locals.clone()); + } + } + + app = app.nest(&format!("/api/{plugin_name}"), plugin_router); + } + + // Static file serving (fallback). + if let Some(ref static_dir) = static_path { + let serve = tower_http::services::ServeDir::new(static_dir) + .append_index_html_on_directories(true); + app = app.fallback_service(serve); + } + + // CORS — permissive dev-friendly configuration. + app.layer(CorsLayer::permissive()) +} + +/// Pick a response Content-Type by peeking at the handler's output. +/// +/// Handlers return strings. JSON is the default — but handlers that render +/// HTML (e.g. server-rendered pages) would otherwise be served as +/// `application/json` and not render in browsers. A leading `<` after +/// optional whitespace unambiguously signals markup because valid JSON +/// cannot start with `<`. +fn detect_content_type(body: &str) -> &'static str { + match body.trim_start().as_bytes().first() { + Some(b'<') => "text/html; charset=utf-8", + _ => "application/json", + } +} + +fn mount_handler_route(router: Router, route: RouteDefinition, task_locals: Arc) -> Router { + let py_handler = GilPyObject::new(route.handler); + + let handler_fn = move |method: Method, uri: Uri, headers: HeaderMap, body: Bytes| { + let py_handler = py_handler.clone().into_inner(); + let locals = Python::with_gil(|py| (*task_locals).clone_ref(py)); + async move { + let request = extract_request_data(&method, &uri, &headers, &body); + match call_python_handler(py_handler, request, &locals).await { + Ok(body_str) => Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, detect_content_type(&body_str)) + .body(axum::body::Body::from(body_str)) + .unwrap_or_else(|_| { + StatusCode::INTERNAL_SERVER_ERROR.into_response().into_response() + }), + Err((status, msg)) => { + let err = serde_json::json!({"error": msg}).to_string(); + Response::builder() + .status(status) + .header(header::CONTENT_TYPE, "application/json") + .body(axum::body::Body::from(err)) + .unwrap_or_else(|_| { + StatusCode::INTERNAL_SERVER_ERROR.into_response().into_response() + }) + } + } + } + }; + + match route.method { + HttpMethod::Get => router.route(&route.path, get(handler_fn)), + HttpMethod::Post => router.route(&route.path, post(handler_fn)), + HttpMethod::Put => router.route(&route.path, put(handler_fn)), + HttpMethod::Delete => router.route(&route.path, delete(handler_fn)), + HttpMethod::Patch => router.route(&route.path, patch(handler_fn)), + } +} + +fn mount_stream_route( + router: Router, + route: RouteDefinition, + stream_manager: Arc, + _task_locals: Arc, +) -> Router { + let py_handler = GilPyObject::new(route.handler); + let sm = stream_manager; + + let handler_fn = move |method: Method, uri: Uri, headers: HeaderMap, body: Bytes| { + let py_handler = py_handler.clone().into_inner(); + let sm = sm.clone(); + async move { + let request = extract_request_data(&method, &uri, &headers, &body); + + // Stream ID from query param or generate. + let query_params = parse_query(&request.query); + let stream_id = query_params + .get("stream_id") + .cloned() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + // Reconnection: Last-Event-ID header. + let last_event_id = headers + .get("last-event-id") + .and_then(|v| v.to_str().ok()) + .map(String::from); + + // Try to reconnect to an existing stream. + if let Some(ref last_id) = last_event_id { + if let Ok(rx) = sm.subscribe(&stream_id, Some(last_id)).await { + return sse_response(rx, &sm.config); + } + } + + // Create new stream: call Python handler to get async generator. + let py_gen = match Python::with_gil(|py| -> Result { + let req_obj = + Py::new(py, request).map_err(|e: PyErr| e.to_string())?; + py_handler + .call1(py, (req_obj,)) + .map_err(|e: PyErr| e.to_string()) + }) { + Ok(gen) => gen, + Err(msg) => return error_response(StatusCode::INTERNAL_SERVER_ERROR, &msg), + }; + + let (item_tx, item_rx) = mpsc::channel(32); + spawn_python_generator(py_gen, item_tx); + + if let Err(e) = sm.create_stream(stream_id.clone(), item_rx).await { + return error_response(StatusCode::INTERNAL_SERVER_ERROR, &e); + } + + match sm.subscribe(&stream_id, None).await { + Ok(rx) => sse_response(rx, &sm.config), + Err(e) => error_response(StatusCode::INTERNAL_SERVER_ERROR, &e), + } + } + }; + + match route.method { + HttpMethod::Get => router.route(&route.path, get(handler_fn)), + HttpMethod::Post => router.route(&route.path, post(handler_fn)), + HttpMethod::Put => router.route(&route.path, put(handler_fn)), + HttpMethod::Delete => router.route(&route.path, delete(handler_fn)), + HttpMethod::Patch => router.route(&route.path, patch(handler_fn)), + } +} + +/// Convert an SSE event receiver to an axum `Sse` response with keep-alive. +fn sse_response(rx: mpsc::Receiver, config: &StreamConfig) -> Response { + let stream = ReceiverStream::new(rx).map(|event| { + let mut e = AxumSseEvent::default().id(event.id).data(event.data); + if let Some(t) = event.event_type { + e = e.event(t); + } + Ok::<_, Infallible>(e) + }); + + Sse::new(stream) + .keep_alive( + KeepAlive::new() + .interval(config.heartbeat_interval) + .text("heartbeat"), + ) + .into_response() +} + +fn error_response(status: StatusCode, msg: &str) -> Response { + Response::builder() + .status(status) + .header(header::CONTENT_TYPE, "application/json") + .body(axum::body::Body::from( + serde_json::json!({"error": msg}).to_string(), + )) + .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response()) +} + +fn collect_endpoint_info(routes: &[(String, Vec)]) -> Vec { + routes + .iter() + .flat_map(|(plugin_name, routes)| { + routes.iter().map(move |route| EndpointInfo { + plugin: plugin_name.clone(), + method: route.method.as_str().to_string(), + path: format!("/api/{}{}", plugin_name, route.path), + }) + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Static path detection +// --------------------------------------------------------------------------- + +/// Detect the static file directory. If `explicit` is provided and exists, use +/// it. Otherwise search the standard paths (matches TS +/// `["dist", "client/dist", "build", "public", "out"]`). +pub fn detect_static_path(explicit: Option<&str>) -> Option { + if let Some(p) = explicit { + let path = PathBuf::from(p); + if path.exists() { + return Some(path); + } + return None; + } + + let candidates = ["dist", "client/dist", "build", "public", "out"]; + for dir in &candidates { + let path = PathBuf::from(dir); + if path.join("index.html").exists() { + return Some(path); + } + } + None +} + +// --------------------------------------------------------------------------- +// Server lifecycle +// --------------------------------------------------------------------------- + +/// Handle to a running server. Stores the shutdown sender so that +/// `PyAppKit.shutdown()` can trigger graceful shutdown. +pub struct ServerHandle { + pub shutdown_tx: tokio::sync::watch::Sender, + pub task: tokio::task::JoinHandle<()>, +} + +/// Graceful shutdown timeout (matches TS `15000ms`). +const SHUTDOWN_TIMEOUT_SECS: u64 = 15; + +/// Start the axum HTTP server. Returns a `ServerHandle` once the listener is +/// bound and the background server task is spawned. +pub async fn start_server( + router: Router, + host: &str, + port: u16, + stream_manager: Arc, +) -> Result { + let addr = format!("{host}:{port}"); + let listener = TcpListener::bind(&addr) + .await + .map_err(|e| format!("Failed to bind {addr}: {e}"))?; + + eprintln!("AppKit server listening on {addr}"); + + let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(false); + let sm = stream_manager; + + let task = tokio::spawn(async move { + let (shutdown_started_tx, shutdown_started_rx) = + tokio::sync::oneshot::channel::<()>(); + + let shutdown = async move { + // Wait for SIGTERM, SIGINT, or manual shutdown. + let ctrl_c = tokio::signal::ctrl_c(); + + #[cfg(unix)] + { + let mut sigterm = tokio::signal::unix::signal( + tokio::signal::unix::SignalKind::terminate(), + ) + .expect("Failed to install SIGTERM handler"); + + tokio::select! { + _ = ctrl_c => {}, + _ = sigterm.recv() => {}, + _ = async { + loop { + if shutdown_rx.changed().await.is_err() { break; } + if *shutdown_rx.borrow() { break; } + } + } => {}, + } + } + + #[cfg(not(unix))] + { + tokio::select! { + _ = ctrl_c => {}, + _ = async { + loop { + if shutdown_rx.changed().await.is_err() { break; } + if *shutdown_rx.borrow() { break; } + } + } => {}, + } + } + + eprintln!("Shutdown signal received, shutting down gracefully..."); + + // Abort all active streams. + sm.abort_all().await; + + // Signal that shutdown has started so the timeout can begin. + let _ = shutdown_started_tx.send(()); + }; + + let server = axum::serve(listener, router).with_graceful_shutdown(shutdown); + + // Race: graceful server drain vs. forced shutdown timeout. + // When the timeout fires, dropping the server future aborts in-flight + // connections — no std::process::exit needed, preserving normal + // Rust cleanup/destructor semantics. + tokio::select! { + result = server => { result.ok(); }, + _ = async { + let _ = shutdown_started_rx.await; + tokio::time::sleep(Duration::from_secs(SHUTDOWN_TIMEOUT_SECS)).await; + eprintln!("Force shutdown after {SHUTDOWN_TIMEOUT_SECS}s timeout"); + } => {} + } + }); + + Ok(ServerHandle { + shutdown_tx, + task, + }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_static_path_explicit_missing() { + assert!(detect_static_path(Some("/nonexistent/path")).is_none()); + } + + #[test] + fn test_detect_static_path_none() { + // In the test environment there's unlikely to be a dist/ with index.html. + // Just verify it doesn't panic. + let _ = detect_static_path(None); + } + + #[test] + fn test_detect_content_type() { + assert_eq!(detect_content_type("{\"ok\":true}"), "application/json"); + assert_eq!(detect_content_type("[]"), "application/json"); + assert_eq!(detect_content_type(""), "application/json"); + assert_eq!(detect_content_type("plain text"), "application/json"); + assert_eq!( + detect_content_type(""), + "text/html; charset=utf-8" + ); + assert_eq!( + detect_content_type(" \nhi"), + "text/html; charset=utf-8" + ); + } + + #[test] + fn test_parse_query() { + let q = parse_query("stream_id=abc&foo=bar"); + assert_eq!(q.get("stream_id").unwrap(), "abc"); + assert_eq!(q.get("foo").unwrap(), "bar"); + } + + #[test] + fn test_parse_query_empty() { + let q = parse_query(""); + assert!(q.is_empty()); + } + + #[test] + fn test_collect_endpoint_info() { + pyo3::prepare_freethreaded_python(); + let routes = vec![( + "my-plugin".to_string(), + vec![ + RouteDefinition { + method: HttpMethod::Get, + path: "/items".into(), + handler: Python::with_gil(|py| py.None().into()), + is_stream: false, + }, + RouteDefinition { + method: HttpMethod::Post, + path: "/items".into(), + handler: Python::with_gil(|py| py.None().into()), + is_stream: false, + }, + ], + )]; + let info = collect_endpoint_info(&routes); + assert_eq!(info.len(), 2); + assert_eq!(info[0].method, "GET"); + assert_eq!(info[0].path, "/api/my-plugin/items"); + assert_eq!(info[1].method, "POST"); + } + + #[test] + fn test_py_server_config_defaults() { + let cfg = PyServerConfig::new("0.0.0.0".into(), 8000, true, None); + assert_eq!(cfg.host, "0.0.0.0"); + assert_eq!(cfg.port, 8000); + assert!(cfg.auto_start); + assert!(cfg.static_path.is_none()); + } + + #[test] + fn test_extract_request_data() { + let method = Method::POST; + let uri: Uri = "/api/test?foo=bar".parse().unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + let body = b"{\"key\":\"val\"}"; + + let req = extract_request_data(&method, &uri, &headers, body); + assert_eq!(req.method, "POST"); + assert_eq!(req.path, "/api/test"); + assert_eq!(req.query, "foo=bar"); + assert_eq!(req.body, "{\"key\":\"val\"}"); + assert_eq!(req.headers.get("content-type").unwrap(), "application/json"); + } +} diff --git a/packages/appkit-rs/src/stream.rs b/packages/appkit-rs/src/stream.rs new file mode 100644 index 00000000..5a6546bc --- /dev/null +++ b/packages/appkit-rs/src/stream.rs @@ -0,0 +1,646 @@ +//! SSE stream manager — event ring buffer, stream lifecycle, reconnection, +//! heartbeat, and multi-client broadcasting. +//! +//! Ports `packages/appkit/src/stream/stream-manager.ts` and +//! `packages/appkit/src/stream/buffers.ts`. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::{broadcast, mpsc, Mutex, RwLock}; + +// --------------------------------------------------------------------------- +// Ring buffer +// --------------------------------------------------------------------------- + +/// Generic fixed-capacity ring buffer with key-based O(1) lookup. +/// Matches the TypeScript `RingBuffer` in `buffers.ts`. +pub struct RingBuffer { + buffer: Vec>, + capacity: usize, + write_index: usize, + size: usize, + key_index: HashMap, + key_fn: Box String + Send + Sync>, +} + +impl RingBuffer { + pub fn new(capacity: usize, key_fn: impl Fn(&T) -> String + Send + Sync + 'static) -> Self { + assert!(capacity > 0, "capacity must be > 0"); + Self { + buffer: (0..capacity).map(|_| None).collect(), + capacity, + write_index: 0, + size: 0, + key_index: HashMap::new(), + key_fn: Box::new(key_fn), + } + } + + /// Insert an item. If the key already exists, update in place. + /// Otherwise, write at the current position (evicting the oldest if full). + pub fn add(&mut self, item: T) { + let key = (self.key_fn)(&item); + + // Update in-place if key already exists. + if let Some(&idx) = self.key_index.get(&key) { + self.buffer[idx] = Some(item); + return; + } + + // Evict old occupant at write position. + if let Some(ref old) = self.buffer[self.write_index] { + let old_key = (self.key_fn)(old); + self.key_index.remove(&old_key); + } + + self.key_index.insert(key, self.write_index); + self.buffer[self.write_index] = Some(item); + self.write_index = (self.write_index + 1) % self.capacity; + if self.size < self.capacity { + self.size += 1; + } + } + + pub fn get(&self, key: &str) -> Option<&T> { + self.key_index.get(key).and_then(|&i| self.buffer[i].as_ref()) + } + + pub fn has(&self, key: &str) -> bool { + self.key_index.contains_key(key) + } + + /// All items in insertion order (oldest first). + pub fn get_all(&self) -> Vec<&T> { + (0..self.size) + .filter_map(|i| { + let idx = (self.write_index + self.capacity - self.size + i) % self.capacity; + self.buffer[idx].as_ref() + }) + .collect() + } + + pub fn clear(&mut self) { + self.buffer.iter_mut().for_each(|s| *s = None); + self.key_index.clear(); + self.write_index = 0; + self.size = 0; + } + + pub fn len(&self) -> usize { + self.size + } + + pub fn is_empty(&self) -> bool { + self.size == 0 + } +} + +// --------------------------------------------------------------------------- +// Buffered event +// --------------------------------------------------------------------------- + +/// An event stored in the ring buffer for replay on reconnection. +#[derive(Clone, Debug)] +pub struct BufferedEvent { + pub id: String, + pub event_type: Option, + pub data: String, + pub timestamp: Instant, +} + +// --------------------------------------------------------------------------- +// Event ring buffer +// --------------------------------------------------------------------------- + +/// Event-specific ring buffer. Default capacity: 100 (matches TS). +pub struct EventRingBuffer { + inner: RingBuffer, +} + +impl EventRingBuffer { + pub fn new(capacity: usize) -> Self { + Self { + inner: RingBuffer::new(capacity, |e: &BufferedEvent| e.id.clone()), + } + } + + pub fn add(&mut self, event: BufferedEvent) { + self.inner.add(event); + } + + pub fn has(&self, event_id: &str) -> bool { + self.inner.has(event_id) + } + + /// Events after `last_event_id`, oldest first. + /// Returns empty vec if `last_event_id` is not found in the buffer. + pub fn get_events_since(&self, last_event_id: &str) -> Vec { + let all = self.inner.get_all(); + let mut found = false; + all.into_iter() + .filter(|e| { + if found { + return true; + } + if e.id == last_event_id { + found = true; + } + false + }) + .cloned() + .collect() + } + + pub fn clear(&mut self) { + self.inner.clear(); + } +} + +// --------------------------------------------------------------------------- +// SSE event (broadcast payload) +// --------------------------------------------------------------------------- + +/// Event broadcast to subscribers. Converted to SSE wire format by the server. +#[derive(Clone, Debug)] +pub struct SseEvent { + pub id: String, + pub event_type: Option, + pub data: String, +} + +// --------------------------------------------------------------------------- +// Stream config +// --------------------------------------------------------------------------- + +/// Configuration for the stream manager. Defaults match the TS implementation. +pub struct StreamConfig { + pub max_active_streams: usize, + pub max_event_size: usize, + /// How long to keep the buffer after a stream completes (for reconnection). + pub buffer_ttl: Duration, + /// Ring buffer capacity per stream. + pub buffer_size: usize, + /// Keep-alive comment interval (used by the server's SSE response). + pub heartbeat_interval: Duration, +} + +impl Default for StreamConfig { + fn default() -> Self { + Self { + max_active_streams: 1000, + max_event_size: 1024 * 1024, // 1 MB + buffer_ttl: Duration::from_secs(600), // 10 minutes + buffer_size: 100, + heartbeat_interval: Duration::from_secs(10), + } + } +} + +// --------------------------------------------------------------------------- +// Stream entry (internal state per stream) +// --------------------------------------------------------------------------- + +struct StreamEntry { + buffer: Mutex, + tx: broadcast::Sender, + completed: AtomicBool, + cancel_tx: tokio::sync::watch::Sender, +} + +// --------------------------------------------------------------------------- +// Stream manager +// --------------------------------------------------------------------------- + +/// Manages SSE stream lifecycle: creation, event buffering, multi-client +/// subscription, reconnection replay, and cancellation. +/// +/// Mirrors the TypeScript `StreamManager` in `stream-manager.ts`. +pub struct StreamManager { + pub config: StreamConfig, + streams: RwLock>>, +} + +impl StreamManager { + pub fn new(config: StreamConfig) -> Arc { + Arc::new(Self { + config, + streams: RwLock::new(HashMap::new()), + }) + } + + /// Create a new stream. Events arrive via `item_rx` as `(event_type, data)` + /// tuples. They are buffered for reconnection and broadcast to all + /// subscribers. The stream is automatically cleaned up after `buffer_ttl` + /// once the item source is exhausted or cancelled. + pub async fn create_stream( + self: &Arc, + stream_id: String, + mut item_rx: mpsc::Receiver<(Option, String)>, + ) -> Result<(), String> { + { + let streams = self.streams.read().await; + if streams.len() >= self.config.max_active_streams { + return Err("Maximum active streams exceeded".into()); + } + if streams.contains_key(&stream_id) { + return Err(format!("Stream {stream_id} already exists")); + } + } + + let (tx, _) = broadcast::channel::(256); + let (cancel_tx, mut cancel_rx) = tokio::sync::watch::channel(false); + + let entry = Arc::new(StreamEntry { + buffer: Mutex::new(EventRingBuffer::new(self.config.buffer_size)), + tx: tx.clone(), + completed: AtomicBool::new(false), + cancel_tx, + }); + + self.streams + .write() + .await + .insert(stream_id.clone(), entry.clone()); + + let max_event_size = self.config.max_event_size; + let buffer_ttl = self.config.buffer_ttl; + let manager = Arc::clone(self); + let sid = stream_id; + + tokio::spawn(async move { + loop { + tokio::select! { + _ = async { + loop { + if cancel_rx.changed().await.is_err() { break; } + if *cancel_rx.borrow() { break; } + } + } => break, + item = item_rx.recv() => { + match item { + Some((event_type, data)) => { + if data.len() > max_event_size { + let _ = tx.send(SseEvent { + id: uuid::Uuid::new_v4().to_string(), + event_type: Some("error".into()), + data: "Event exceeds maximum size".into(), + }); + break; + } + + let event = SseEvent { + id: uuid::Uuid::new_v4().to_string(), + event_type, + data, + }; + + // Buffer for replay, then broadcast. + { + let mut buf = entry.buffer.lock().await; + buf.add(BufferedEvent { + id: event.id.clone(), + event_type: event.event_type.clone(), + data: event.data.clone(), + timestamp: Instant::now(), + }); + } + + let _ = tx.send(event); + } + None => break, // Source exhausted. + } + } + } + } + + entry.completed.store(true, Ordering::SeqCst); + + // Delayed cleanup — keep buffer available for reconnection. + tokio::spawn(async move { + tokio::time::sleep(buffer_ttl).await; + manager.streams.write().await.remove(&sid); + }); + }); + + Ok(()) + } + + /// Subscribe to a stream's events. Returns an `mpsc::Receiver` that yields + /// replay events (if reconnecting) followed by live broadcast events. + pub async fn subscribe( + &self, + stream_id: &str, + last_event_id: Option<&str>, + ) -> Result, String> { + let entry = { + let streams = self.streams.read().await; + streams + .get(stream_id) + .cloned() + .ok_or_else(|| format!("Stream {stream_id} not found"))? + }; + + let (tx, rx) = mpsc::channel::(256); + + // Subscribe to broadcast BEFORE reading buffer to avoid missing events. + let mut broadcast_rx = entry.tx.subscribe(); + + // Replay missed events if reconnecting. + let mut last_replayed_id: Option = None; + if let Some(last_id) = last_event_id { + let buffer = entry.buffer.lock().await; + if !buffer.has(last_id) { + let _ = tx + .send(SseEvent { + id: uuid::Uuid::new_v4().to_string(), + event_type: Some("warning".into()), + data: "Buffer overflow: some events may have been missed".into(), + }) + .await; + } + let missed = buffer.get_events_since(last_id); + for event in &missed { + let _ = tx + .send(SseEvent { + id: event.id.clone(), + event_type: event.event_type.clone(), + data: event.data.clone(), + }) + .await; + } + last_replayed_id = missed.last().map(|e| e.id.clone()); + } + + // Forward live broadcast events, skipping any that overlap with replay. + tokio::spawn(async move { + let mut past_replay = last_replayed_id.is_none(); + loop { + match broadcast_rx.recv().await { + Ok(event) => { + if !past_replay { + if Some(&event.id) == last_replayed_id.as_ref() { + past_replay = true; + } + continue; + } + if tx.send(event).await.is_err() { + break; // Subscriber disconnected. + } + } + Err(broadcast::error::RecvError::Closed) => break, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + } + } + }); + + Ok(rx) + } + + /// Cancel all active streams. + pub async fn abort_all(&self) { + let streams = self.streams.read().await; + for entry in streams.values() { + let _ = entry.cancel_tx.send(true); + } + } + + /// Number of active (not-yet-cleaned-up) streams. + pub async fn active_count(&self) -> usize { + self.streams.read().await.len() + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // -- RingBuffer -- + + #[test] + fn test_ring_buffer_basic_ops() { + let mut rb = RingBuffer::new(3, |s: &String| s.clone()); + + rb.add("a".into()); + rb.add("b".into()); + assert_eq!(rb.len(), 2); + assert!(rb.has("a")); + assert!(rb.has("b")); + assert_eq!(rb.get("a"), Some(&"a".into())); + } + + #[test] + fn test_ring_buffer_overflow_evicts_oldest() { + let mut rb = RingBuffer::new(2, |s: &String| s.clone()); + + rb.add("a".into()); + rb.add("b".into()); + rb.add("c".into()); // evicts "a" + + assert!(!rb.has("a")); + assert!(rb.has("b")); + assert!(rb.has("c")); + assert_eq!(rb.len(), 2); + } + + #[test] + fn test_ring_buffer_update_in_place() { + let mut rb = RingBuffer::new(3, |s: &(String, i32)| s.0.clone()); + + rb.add(("x".into(), 1)); + rb.add(("y".into(), 2)); + rb.add(("x".into(), 3)); // update, not new insert + + assert_eq!(rb.len(), 2); + let val = rb.get("x").unwrap(); + assert_eq!(val.1, 3); + } + + #[test] + fn test_ring_buffer_get_all_order() { + let mut rb = RingBuffer::new(4, |s: &String| s.clone()); + + rb.add("a".into()); + rb.add("b".into()); + rb.add("c".into()); + + let all: Vec<&str> = rb.get_all().into_iter().map(|s| s.as_str()).collect(); + assert_eq!(all, vec!["a", "b", "c"]); + } + + #[test] + fn test_ring_buffer_get_all_after_wrap() { + let mut rb = RingBuffer::new(3, |s: &String| s.clone()); + + rb.add("a".into()); + rb.add("b".into()); + rb.add("c".into()); + rb.add("d".into()); // evicts "a" + + let all: Vec<&str> = rb.get_all().into_iter().map(|s| s.as_str()).collect(); + assert_eq!(all, vec!["b", "c", "d"]); + } + + #[test] + fn test_ring_buffer_clear() { + let mut rb = RingBuffer::new(3, |s: &String| s.clone()); + rb.add("a".into()); + rb.add("b".into()); + rb.clear(); + assert!(rb.is_empty()); + assert!(!rb.has("a")); + } + + // -- EventRingBuffer -- + + fn make_event(id: &str, data: &str) -> BufferedEvent { + BufferedEvent { + id: id.into(), + event_type: None, + data: data.into(), + timestamp: Instant::now(), + } + } + + #[test] + fn test_event_ring_buffer_get_events_since() { + let mut erb = EventRingBuffer::new(10); + erb.add(make_event("1", "a")); + erb.add(make_event("2", "b")); + erb.add(make_event("3", "c")); + erb.add(make_event("4", "d")); + + let since = erb.get_events_since("2"); + let ids: Vec<&str> = since.iter().map(|e| e.id.as_str()).collect(); + assert_eq!(ids, vec!["3", "4"]); + } + + #[test] + fn test_event_ring_buffer_get_events_since_not_found() { + let mut erb = EventRingBuffer::new(10); + erb.add(make_event("1", "a")); + erb.add(make_event("2", "b")); + + let since = erb.get_events_since("99"); + assert!(since.is_empty()); + } + + #[test] + fn test_event_ring_buffer_has() { + let mut erb = EventRingBuffer::new(5); + erb.add(make_event("x", "data")); + assert!(erb.has("x")); + assert!(!erb.has("y")); + } + + // -- StreamManager -- + + #[tokio::test] + async fn test_stream_manager_create_and_subscribe() { + let sm = StreamManager::new(StreamConfig::default()); + let (item_tx, item_rx) = mpsc::channel(32); + + sm.create_stream("s1".into(), item_rx).await.unwrap(); + assert_eq!(sm.active_count().await, 1); + + let mut rx = sm.subscribe("s1", None).await.unwrap(); + + item_tx.send((None, "hello".into())).await.unwrap(); + + let event = tokio::time::timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(event.data, "hello"); + } + + #[tokio::test] + async fn test_stream_manager_reconnection_replay() { + let sm = StreamManager::new(StreamConfig::default()); + let (item_tx, item_rx) = mpsc::channel(32); + + sm.create_stream("s1".into(), item_rx).await.unwrap(); + + // First subscriber to capture event IDs. + let mut rx1 = sm.subscribe("s1", None).await.unwrap(); + + item_tx.send((None, "event-a".into())).await.unwrap(); + item_tx.send((None, "event-b".into())).await.unwrap(); + + let ev1 = tokio::time::timeout(Duration::from_secs(1), rx1.recv()) + .await + .unwrap() + .unwrap(); + let ev2 = tokio::time::timeout(Duration::from_secs(1), rx1.recv()) + .await + .unwrap() + .unwrap(); + + // Reconnect with last_event_id = ev1.id → should replay ev2. + let mut rx2 = sm.subscribe("s1", Some(&ev1.id)).await.unwrap(); + let replayed = tokio::time::timeout(Duration::from_secs(1), rx2.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(replayed.data, "event-b"); + assert_eq!(replayed.id, ev2.id); + } + + #[tokio::test] + async fn test_stream_manager_abort_all() { + let sm = StreamManager::new(StreamConfig { + buffer_ttl: Duration::from_millis(50), + ..Default::default() + }); + let (_item_tx, item_rx) = mpsc::channel::<(Option, String)>(32); + sm.create_stream("s1".into(), item_rx).await.unwrap(); + + sm.abort_all().await; + + // Wait for cleanup. + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(sm.active_count().await, 0); + } + + #[tokio::test] + async fn test_stream_manager_duplicate_stream_id_rejected() { + let sm = StreamManager::new(StreamConfig::default()); + let (_tx1, rx1) = mpsc::channel(1); + let (_tx2, rx2) = mpsc::channel(1); + + sm.create_stream("dup".into(), rx1).await.unwrap(); + let result = sm.create_stream("dup".into(), rx2).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_stream_manager_oversized_event_terminates_stream() { + let sm = StreamManager::new(StreamConfig { + max_event_size: 10, + buffer_ttl: Duration::from_millis(50), + ..Default::default() + }); + let (item_tx, item_rx) = mpsc::channel(32); + sm.create_stream("s1".into(), item_rx).await.unwrap(); + + let mut rx = sm.subscribe("s1", None).await.unwrap(); + + // Send oversized event. + item_tx + .send((None, "x".repeat(100))) + .await + .unwrap(); + + let event = tokio::time::timeout(Duration::from_secs(1), rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(event.event_type.as_deref(), Some("error")); + } +} diff --git a/packages/appkit-rs/src/telemetry.rs b/packages/appkit-rs/src/telemetry.rs new file mode 100644 index 00000000..327d2d6c --- /dev/null +++ b/packages/appkit-rs/src/telemetry.rs @@ -0,0 +1,256 @@ +//! OpenTelemetry integration — TelemetryManager (singleton) and per-plugin +//! TelemetryProvider. +//! +//! Mirrors the TypeScript `TelemetryManager` / `TelemetryProvider` pattern: +//! - Global singleton initializes OTLP exporters when an endpoint is configured +//! - Per-plugin providers scope tracers/meters by plugin name +//! - When no endpoint is configured, the global API returns noop implementations + +use opentelemetry::KeyValue; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::trace::TracerProvider; +use std::collections::HashMap; +use std::sync::{Arc, Mutex, OnceLock}; + +use crate::config::AppConfig; + +// --------------------------------------------------------------------------- +// Configuration types +// --------------------------------------------------------------------------- + +/// Global telemetry configuration. +#[derive(Clone, Debug, Default)] +pub struct TelemetryConfig { + /// OTLP exporter endpoint. If `None`, telemetry is disabled (noop). + pub endpoint: Option, + /// Service name for resource attributes. + pub service_name: Option, +} + +impl TelemetryConfig { + pub fn from_app_config(config: &AppConfig) -> Self { + Self { + endpoint: config.otel_endpoint.clone(), + service_name: None, + } + } +} + +/// Per-plugin telemetry options — controls which signals are active. +#[derive(Clone, Debug)] +pub struct TelemetryOptions { + pub traces: bool, + pub metrics: bool, + pub logs: bool, +} + +impl Default for TelemetryOptions { + fn default() -> Self { + Self { + traces: true, + metrics: true, + logs: true, + } + } +} + +// --------------------------------------------------------------------------- +// TelemetryManager — singleton +// --------------------------------------------------------------------------- + +const DEFAULT_SERVICE_NAME: &str = "databricks-app"; + +/// Global telemetry manager. Initializes the OpenTelemetry SDK when an OTLP +/// endpoint is configured; otherwise all providers return noop implementations. +pub struct TelemetryManager { + active: bool, +} + +static INSTANCE: OnceLock> = OnceLock::new(); + +static INTERNED_NAMES: OnceLock>> = OnceLock::new(); + +/// Intern a plugin name so it can be used as `&'static str` without leaking +/// memory on every call. Each unique name is leaked exactly once. +fn intern_name(name: &str) -> &'static str { + let map = INTERNED_NAMES.get_or_init(|| Mutex::new(HashMap::new())); + let mut guard = map.lock().unwrap(); + if let Some(&existing) = guard.get(name) { + existing + } else { + let leaked: &'static str = Box::leak(name.to_string().into_boxed_str()); + guard.insert(name.to_string(), leaked); + leaked + } +} + +impl TelemetryManager { + /// Initialize the global singleton. Idempotent — subsequent calls return + /// the existing instance. + pub fn initialize(config: &TelemetryConfig) -> Arc { + INSTANCE + .get_or_init(|| { + if let Some(ref endpoint) = config.endpoint { + Self::init_with_endpoint(endpoint, config.service_name.as_deref()); + Arc::new(TelemetryManager { active: true }) + } else { + Arc::new(TelemetryManager { active: false }) + } + }) + .clone() + } + + /// Get the singleton instance. Returns `None` if not yet initialized. + pub fn get_instance() -> Option> { + INSTANCE.get().cloned() + } + + /// Create a per-plugin scoped `TelemetryProvider`. + pub fn get_provider( + plugin_name: &str, + options: Option, + ) -> TelemetryProvider { + TelemetryProvider { + plugin_name: plugin_name.to_string(), + plugin_name_static: intern_name(plugin_name), + options: options.unwrap_or_default(), + } + } + + /// Whether the OTLP exporter was successfully initialized. + pub fn is_active(&self) -> bool { + self.active + } + + fn init_with_endpoint(endpoint: &str, service_name: Option<&str>) { + let service = service_name.unwrap_or(DEFAULT_SERVICE_NAME); + let resource = opentelemetry_sdk::Resource::new(vec![KeyValue::new( + "service.name", + service.to_string(), + )]); + + // Trace exporter via OTLP/gRPC + if let Ok(exporter) = opentelemetry_otlp::SpanExporter::builder() + .with_tonic() + .with_endpoint(endpoint) + .build() + { + let tracer_provider = TracerProvider::builder() + .with_resource(resource) + .with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio) + .build(); + + // Register as the global tracer provider so that + // `opentelemetry::global::tracer(name)` returns real tracers. + opentelemetry::global::set_tracer_provider(tracer_provider); + } + } +} + +// --------------------------------------------------------------------------- +// TelemetryProvider — per-plugin scoped +// --------------------------------------------------------------------------- + +/// Per-plugin telemetry provider. When the global manager is not active or a +/// particular signal is disabled, the OpenTelemetry global API transparently +/// returns noop implementations at zero cost. +pub struct TelemetryProvider { + plugin_name: String, + /// Leaked static str for APIs requiring `&'static str`. + plugin_name_static: &'static str, + options: TelemetryOptions, +} + +impl TelemetryProvider { + /// Create a provider with all signals disabled (for testing). + pub fn new_disabled(plugin_name: &str) -> Self { + Self { + plugin_name: plugin_name.to_string(), + plugin_name_static: intern_name(plugin_name), + options: TelemetryOptions { + traces: false, + metrics: false, + logs: false, + }, + } + } + + pub fn plugin_name(&self) -> &str { + &self.plugin_name + } + + pub fn traces_enabled(&self) -> bool { + self.options.traces + } + + pub fn metrics_enabled(&self) -> bool { + self.options.metrics + } + + /// Get a tracer scoped to this plugin. Returns a noop tracer when the + /// global provider is not configured. + pub fn tracer(&self) -> opentelemetry::global::BoxedTracer { + opentelemetry::global::tracer(self.plugin_name_static) + } + + /// Get a meter scoped to this plugin. + pub fn meter(&self) -> opentelemetry::metrics::Meter { + opentelemetry::global::meter(self.plugin_name_static) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_telemetry_config_default() { + let config = TelemetryConfig::default(); + assert!(config.endpoint.is_none()); + assert!(config.service_name.is_none()); + } + + #[test] + fn test_telemetry_config_from_app_config() { + let app = AppConfig::new( + "https://host.databricks.com".into(), + None, + None, + None, + 8000, + "0.0.0.0".into(), + Some("http://otel:4317".into()), + ); + let tc = TelemetryConfig::from_app_config(&app); + assert_eq!(tc.endpoint.as_deref(), Some("http://otel:4317")); + } + + #[test] + fn test_telemetry_options_default() { + let opts = TelemetryOptions::default(); + assert!(opts.traces); + assert!(opts.metrics); + assert!(opts.logs); + } + + #[test] + fn test_provider_traces_enabled() { + let provider = TelemetryManager::get_provider("test-plugin", None); + assert!(provider.traces_enabled()); + assert_eq!(provider.plugin_name(), "test-plugin"); + + let disabled = TelemetryManager::get_provider( + "quiet", + Some(TelemetryOptions { + traces: false, + metrics: false, + logs: false, + }), + ); + assert!(!disabled.traces_enabled()); + } +} diff --git a/packages/appkit-rs/tests/conftest.py b/packages/appkit-rs/tests/conftest.py new file mode 100644 index 00000000..83e89ff3 --- /dev/null +++ b/packages/appkit-rs/tests/conftest.py @@ -0,0 +1,41 @@ +"""Shared fixtures for appkit integration tests. + +Run with: maturin develop && pytest tests/ +""" + +import pytest + + +@pytest.fixture +def app_config(): + """Minimal AppConfig for testing (no real Databricks connection).""" + import appkit + + return appkit.AppConfig( + "https://test.databricks.com", + client_id="test-client-id", + client_secret="test-client-secret", + warehouse_id="test-warehouse-id", + ) + + +@pytest.fixture +def user_context(): + """A sample UserContext for testing.""" + import appkit + + return appkit.UserContext( + "test-token", + "user-42", + user_name="Alice", + workspace_id="ws-123", + warehouse_id="wh-456", + ) + + +@pytest.fixture +def cache_config(): + """CacheConfig with short TTL for testing.""" + import appkit + + return appkit.CacheConfig(ttl=10, max_size=50) diff --git a/packages/appkit-rs/tests/test_auth.py b/packages/appkit-rs/tests/test_auth.py new file mode 100644 index 00000000..bd185d53 --- /dev/null +++ b/packages/appkit-rs/tests/test_auth.py @@ -0,0 +1,72 @@ +"""Integration tests for ServiceContext and UserContext.""" + +import pytest + +import appkit + + +class TestUserContext: + def test_constructor(self, user_context): + assert user_context.token == "test-token" + assert user_context.user_id == "user-42" + assert user_context.user_name == "Alice" + assert user_context.workspace_id == "ws-123" + assert user_context.warehouse_id == "wh-456" + + def test_is_user_context_property(self, user_context): + assert user_context.is_user_context is True + + def test_keyword_only_workspace_id(self): + ctx = appkit.UserContext( + "tok", + "uid", + workspace_id="ws", + ) + assert ctx.workspace_id == "ws" + assert ctx.warehouse_id is None + assert ctx.user_name is None + + def test_repr(self, user_context): + r = repr(user_context) + assert "UserContext" in r + assert "user-42" in r + + def test_equality(self): + a = appkit.UserContext("tok", "u1", workspace_id="ws") + b = appkit.UserContext("tok", "u1", workspace_id="ws") + c = appkit.UserContext("tok2", "u1", workspace_id="ws") + assert a == b + assert a != c + + def test_hashable(self): + ctx = appkit.UserContext("tok", "u1", workspace_id="ws") + s = {ctx} + assert len(s) == 1 + + def test_frozen(self, user_context): + with pytest.raises(AttributeError): + user_context.token = "new" + + +class TestServiceContext: + def test_constructor(self, app_config): + svc = appkit.ServiceContext(app_config) + assert svc.config == app_config + + def test_missing_client_id(self): + cfg = appkit.AppConfig("https://host.databricks.com") + with pytest.raises(ValueError, match="CLIENT_ID"): + appkit.ServiceContext(cfg) + + def test_missing_client_secret(self): + cfg = appkit.AppConfig( + "https://host.databricks.com", client_id="cid" + ) + with pytest.raises(ValueError, match="CLIENT_SECRET"): + appkit.ServiceContext(cfg) + + def test_repr(self, app_config): + svc = appkit.ServiceContext(app_config) + r = repr(svc) + assert "ServiceContext" in r + assert "test.databricks.com" in r diff --git a/packages/appkit-rs/tests/test_cache.py b/packages/appkit-rs/tests/test_cache.py new file mode 100644 index 00000000..d75cf762 --- /dev/null +++ b/packages/appkit-rs/tests/test_cache.py @@ -0,0 +1,129 @@ +"""Integration tests for CacheConfig and CacheManager.""" + +import asyncio + +import pytest + +import appkit + + +class TestCacheConfig: + def test_defaults(self): + cfg = appkit.CacheConfig() + assert cfg.enabled is True + assert cfg.ttl == 3600 + assert cfg.max_size == 1000 + assert cfg.cleanup_probability == pytest.approx(0.01) + + def test_custom_values(self, cache_config): + assert cache_config.ttl == 10 + assert cache_config.max_size == 50 + + def test_repr(self): + r = repr(appkit.CacheConfig()) + assert "CacheConfig" in r + assert "3600" in r + + def test_equality(self): + a = appkit.CacheConfig(ttl=60) + b = appkit.CacheConfig(ttl=60) + c = appkit.CacheConfig(ttl=120) + assert a == b + assert a != c + + def test_hashable(self): + cfg = appkit.CacheConfig() + s = {cfg} + assert len(s) == 1 + + def test_frozen(self): + cfg = appkit.CacheConfig() + with pytest.raises(AttributeError): + cfg.ttl = 999 + + +class TestCacheManager: + def test_constructor(self): + cm = appkit.CacheManager() + assert repr(cm).startswith("CacheManager") + + def test_with_config(self, cache_config): + cm = appkit.CacheManager(cache_config) + assert bool(cm) is True + + def test_disabled(self): + cfg = appkit.CacheConfig(enabled=False) + cm = appkit.CacheManager(cfg) + assert bool(cm) is False + + def test_generate_key_deterministic(self): + k1 = appkit.CacheManager.generate_key(["q", "p"], "user-1") + k2 = appkit.CacheManager.generate_key(["q", "p"], "user-1") + assert k1 == k2 + assert len(k1) == 64 # SHA256 hex + + def test_generate_key_varies_by_user(self): + k1 = appkit.CacheManager.generate_key(["q"], "alice") + k2 = appkit.CacheManager.generate_key(["q"], "bob") + assert k1 != k2 + + @pytest.mark.asyncio + async def test_set_get_delete(self): + cm = appkit.CacheManager() + await cm.set("key1", '{"value": 42}') + result = await cm.get("key1") + assert result is not None + assert "42" in result + + await cm.delete("key1") + result = await cm.get("key1") + assert result is None + + @pytest.mark.asyncio + async def test_has_and_size(self): + cm = appkit.CacheManager() + assert await cm.has("nonexistent") is False + assert await cm.size() == 0 + + await cm.set("k", '"v"') + assert await cm.has("k") is True + assert await cm.size() == 1 + + @pytest.mark.asyncio + async def test_clear(self): + cm = appkit.CacheManager() + await cm.set("a", '"1"') + await cm.set("b", '"2"') + assert await cm.size() == 2 + + await cm.clear() + assert await cm.size() == 0 + + @pytest.mark.asyncio + async def test_get_or_execute(self): + cm = appkit.CacheManager() + call_count = 0 + + async def compute(): + nonlocal call_count + call_count += 1 + return '{"computed": true}' + + # First call executes + result = await cm.get_or_execute("k", compute) + assert "computed" in result + assert call_count == 1 + + # Second call hits cache + result = await cm.get_or_execute("k", compute) + assert "computed" in result + assert call_count == 1 + + @pytest.mark.asyncio + async def test_set_with_custom_ttl(self): + cm = appkit.CacheManager(appkit.CacheConfig(ttl=3600)) + await cm.set("k", '"v"', ttl=0) + # With ttl=0, entry expires almost immediately + await asyncio.sleep(0.01) + result = await cm.get("k") + assert result is None diff --git a/packages/appkit-rs/tests/test_config.py b/packages/appkit-rs/tests/test_config.py new file mode 100644 index 00000000..663413fc --- /dev/null +++ b/packages/appkit-rs/tests/test_config.py @@ -0,0 +1,78 @@ +"""Integration tests for AppConfig.""" + +import os + +import pytest + +import appkit + + +class TestAppConfig: + def test_constructor_defaults(self): + cfg = appkit.AppConfig("https://host.databricks.com") + assert cfg.databricks_host == "https://host.databricks.com" + assert cfg.client_id is None + assert cfg.client_secret is None + assert cfg.warehouse_id is None + assert cfg.app_port == 8000 + assert cfg.host == "0.0.0.0" + assert cfg.otel_endpoint is None + + def test_constructor_keyword_only(self): + cfg = appkit.AppConfig( + "https://host.databricks.com", + client_id="cid", + client_secret="secret", + warehouse_id="wh-1", + app_port=9090, + host="127.0.0.1", + otel_endpoint="http://otel:4317", + ) + assert cfg.client_id == "cid" + assert cfg.client_secret == "secret" + assert cfg.warehouse_id == "wh-1" + assert cfg.app_port == 9090 + assert cfg.host == "127.0.0.1" + assert cfg.otel_endpoint == "http://otel:4317" + + def test_from_env(self): + os.environ["DATABRICKS_HOST"] = "https://env.databricks.com" + os.environ["DATABRICKS_CLIENT_ID"] = "env-cid" + os.environ["DATABRICKS_APP_PORT"] = "7070" + try: + cfg = appkit.AppConfig.from_env() + assert cfg.databricks_host == "https://env.databricks.com" + assert cfg.client_id == "env-cid" + assert cfg.app_port == 7070 + finally: + os.environ.pop("DATABRICKS_HOST", None) + os.environ.pop("DATABRICKS_CLIENT_ID", None) + os.environ.pop("DATABRICKS_APP_PORT", None) + + def test_from_env_missing_host(self): + os.environ.pop("DATABRICKS_HOST", None) + with pytest.raises(ValueError, match="DATABRICKS_HOST"): + appkit.AppConfig.from_env() + + def test_repr(self): + cfg = appkit.AppConfig("https://host.databricks.com", app_port=8080) + r = repr(cfg) + assert "AppConfig" in r + assert "host.databricks.com" in r + + def test_equality(self): + a = appkit.AppConfig("https://host.databricks.com", client_id="cid") + b = appkit.AppConfig("https://host.databricks.com", client_id="cid") + c = appkit.AppConfig("https://other.databricks.com") + assert a == b + assert a != c + + def test_hashable(self): + cfg = appkit.AppConfig("https://host.databricks.com") + s = {cfg} # should be hashable + assert len(s) == 1 + + def test_frozen(self): + cfg = appkit.AppConfig("https://host.databricks.com") + with pytest.raises(AttributeError): + cfg.databricks_host = "changed" diff --git a/packages/appkit-rs/tests/test_connectors.py b/packages/appkit-rs/tests/test_connectors.py new file mode 100644 index 00000000..6ba2213c --- /dev/null +++ b/packages/appkit-rs/tests/test_connectors.py @@ -0,0 +1,161 @@ +"""Integration tests for connector types and construction. + +These tests verify Python-side construction and type behavior. +Actual HTTP calls are not made (no live Databricks workspace). +""" + +import os + +import pytest + +import appkit + + +class TestFilesConnector: + def test_constructor(self): + fc = appkit.FilesConnector("https://host.databricks.com") + assert "FilesConnector" in repr(fc) + + def test_with_default_volume(self): + fc = appkit.FilesConnector( + "https://host.databricks.com", + default_volume="/Volumes/cat/sch/vol", + ) + resolved = fc.resolve_path("file.txt") + assert resolved == "/Volumes/cat/sch/vol/file.txt" + + def test_resolve_absolute_path(self): + fc = appkit.FilesConnector("https://host.databricks.com") + resolved = fc.resolve_path("/Volumes/cat/sch/vol/file.txt") + assert resolved == "/Volumes/cat/sch/vol/file.txt" + + def test_path_traversal_rejected(self): + fc = appkit.FilesConnector("https://host.databricks.com") + with pytest.raises(ValueError, match="traversal"): + fc.resolve_path("/Volumes/cat/sch/vol/../../../etc/passwd") + + def test_non_volumes_absolute_rejected(self): + fc = appkit.FilesConnector("https://host.databricks.com") + with pytest.raises(ValueError, match="/Volumes/"): + fc.resolve_path("/etc/passwd") + + def test_relative_without_default_volume(self): + fc = appkit.FilesConnector("https://host.databricks.com") + with pytest.raises(ValueError, match="default volume"): + fc.resolve_path("file.txt") + + +class TestFileDirectoryEntry: + def test_repr(self): + # FileDirectoryEntry is created by the connector, not directly. + # Verify type exists and is importable. + assert hasattr(appkit, "FileDirectoryEntry") + + +class TestSqlWarehouseConnector: + def test_constructor(self): + sw = appkit.SqlWarehouseConnector("https://host.databricks.com") + assert "SqlWarehouseConnector" in repr(sw) + + def test_with_timeout(self): + sw = appkit.SqlWarehouseConnector( + "https://host.databricks.com", timeout_ms=30000 + ) + assert "30000" in repr(sw) + + +class TestSqlColumn: + def test_type_exists(self): + assert hasattr(appkit, "SqlColumn") + + +class TestSqlStatementResult: + def test_type_exists(self): + assert hasattr(appkit, "SqlStatementResult") + + +class TestGenieConnector: + def test_constructor(self): + gc = appkit.GenieConnector("https://host.databricks.com") + assert "GenieConnector" in repr(gc) + + def test_with_options(self): + gc = appkit.GenieConnector( + "https://host.databricks.com", + timeout_ms=60000, + max_messages=100, + ) + r = repr(gc) + assert "60000" in r + assert "100" in r + + +class TestServingConnector: + def test_constructor(self): + sc = appkit.ServingConnector("https://host.databricks.com") + assert "ServingConnector" in repr(sc) + + +class TestServingResponse: + def test_type_exists(self): + assert hasattr(appkit, "ServingResponse") + + +class TestLakebaseConnector: + def test_constructor(self): + lc = appkit.LakebaseConnector("https://host.databricks.com") + assert "LakebaseConnector" in repr(lc) + + +class TestLakebasePgConfig: + def test_explicit_values(self): + cfg = appkit.LakebasePgConfig( + host="db.example.com", + database="mydb", + port=5433, + ssl_mode="prefer", + app_name="myapp", + ) + assert cfg.host == "db.example.com" + assert cfg.database == "mydb" + assert cfg.port == 5433 + assert cfg.ssl_mode == "prefer" + assert cfg.app_name == "myapp" + + def test_from_env(self): + os.environ["PGHOST"] = "env-host.example.com" + os.environ["PGDATABASE"] = "envdb" + try: + cfg = appkit.LakebasePgConfig.from_env() + assert cfg.host == "env-host.example.com" + assert cfg.database == "envdb" + assert cfg.port == 5432 # default + assert cfg.ssl_mode == "require" # default + finally: + os.environ.pop("PGHOST", None) + os.environ.pop("PGDATABASE", None) + + def test_missing_host_raises(self): + os.environ.pop("PGHOST", None) + os.environ.pop("LAKEBASE_ENDPOINT", None) + os.environ["PGDATABASE"] = "db" + try: + with pytest.raises(ValueError, match="host"): + appkit.LakebasePgConfig.from_env() + finally: + os.environ.pop("PGDATABASE", None) + + def test_equality(self): + a = appkit.LakebasePgConfig(host="h", database="d") + b = appkit.LakebasePgConfig(host="h", database="d") + assert a == b + + def test_hashable(self): + cfg = appkit.LakebasePgConfig(host="h", database="d") + s = {cfg} + assert len(s) == 1 + + +class TestDatabaseCredential: + def test_type_exists(self): + assert hasattr(appkit, "DatabaseCredential") diff --git a/packages/appkit-rs/tests/test_context.py b/packages/appkit-rs/tests/test_context.py new file mode 100644 index 00000000..241aba85 --- /dev/null +++ b/packages/appkit-rs/tests/test_context.py @@ -0,0 +1,72 @@ +"""Integration tests for contextvars-based execution context.""" + +import pytest + +import appkit + + +class TestContextVars: + def test_get_current_user_none_by_default(self): + user = appkit.get_current_user() + assert user is None + + def test_is_in_user_context_false_by_default(self): + assert appkit.is_in_user_context() is False + + def test_run_in_user_context(self, user_context): + def check(): + user = appkit.get_current_user() + assert user is not None + assert user.user_id == "user-42" + assert user.workspace_id == "ws-123" + assert appkit.is_in_user_context() is True + return user.user_id + + result = appkit.run_in_user_context(user_context, check) + assert result == "user-42" + + # After the call, context should be reset + assert appkit.get_current_user() is None + assert appkit.is_in_user_context() is False + + def test_run_in_user_context_exception_resets(self, user_context): + def raise_error(): + assert appkit.is_in_user_context() is True + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + appkit.run_in_user_context(user_context, raise_error) + + # Context should be reset even after error + assert appkit.get_current_user() is None + + @pytest.mark.asyncio + async def test_as_user(self, user_context): + async def async_check(): + user = appkit.get_current_user() + assert user is not None + assert user.user_id == "user-42" + return user.user_id + + result = await appkit.as_user(user_context, async_check) + assert result == "user-42" + + def test_nested_contexts(self): + outer = appkit.UserContext("tok-outer", "outer-user", workspace_id="ws-1") + inner = appkit.UserContext("tok-inner", "inner-user", workspace_id="ws-2") + + def outer_fn(): + assert appkit.get_current_user().user_id == "outer-user" + + def inner_fn(): + assert appkit.get_current_user().user_id == "inner-user" + return "inner-done" + + result = appkit.run_in_user_context(inner, inner_fn) + # After inner context completes, we should be back to outer + # Note: contextvars.ContextVar.reset() restores the previous token, + # so outer context is restored. + return result + + result = appkit.run_in_user_context(outer, outer_fn) + assert result == "inner-done" diff --git a/packages/appkit-rs/tests/test_create_app.py b/packages/appkit-rs/tests/test_create_app.py new file mode 100644 index 00000000..d8d2c697 --- /dev/null +++ b/packages/appkit-rs/tests/test_create_app.py @@ -0,0 +1,131 @@ +"""Integration tests for the top-level create_app() orchestrator.""" + +import pytest + +import appkit + + +class TestCreateApp: + @pytest.mark.asyncio + async def test_create_app_with_plugin(self, app_config): + """Full lifecycle through create_app: register, init, plugin ready.""" + + class TestPlugin(appkit.Plugin): + def __init__(self): + super().__init__( + "test-plugin", + manifest=appkit.PluginManifest("test-plugin"), + ) + self.setup_called = False + + async def setup(self): + self.setup_called = True + + plugin = TestPlugin() + app = await appkit.create_app( + config=app_config, + plugins=[plugin], + auto_start=False, + ) + assert isinstance(app, appkit.AppKit) + assert bool(app) is True + assert plugin.is_ready is True + assert plugin.setup_called is True + + @pytest.mark.asyncio + async def test_create_app_no_plugins(self, app_config): + app = await appkit.create_app( + config=app_config, + auto_start=False, + ) + assert isinstance(app, appkit.AppKit) + assert len(app) == 0 + + @pytest.mark.asyncio + async def test_create_app_with_cache_config(self, app_config, cache_config): + app = await appkit.create_app( + config=app_config, + cache_config=cache_config, + auto_start=False, + ) + assert isinstance(app, appkit.AppKit) + + @pytest.mark.asyncio + async def test_create_app_with_server_config(self, app_config): + server_cfg = appkit.ServerConfig( + port=9999, + auto_start=False, + ) + app = await appkit.create_app( + config=app_config, + server_config=server_cfg, + auto_start=False, + ) + assert isinstance(app, appkit.AppKit) + + @pytest.mark.asyncio + async def test_create_app_multiple_plugins_phase_order(self, app_config): + """Verify create_app respects phase ordering for multiple plugins.""" + order = [] + + class OrderPlugin(appkit.Plugin): + def __init__(self, name, phase): + super().__init__( + name, + phase=phase, + manifest=appkit.PluginManifest(name), + ) + + async def setup(self): + order.append(self.name) + + plugins = [ + OrderPlugin("deferred-p", "deferred"), + OrderPlugin("core-p", "core"), + OrderPlugin("normal-p", "normal"), + ] + + await appkit.create_app( + config=app_config, + plugins=plugins, + auto_start=False, + ) + + assert order.index("core-p") < order.index("normal-p") + assert order.index("normal-p") < order.index("deferred-p") + + @pytest.mark.asyncio + async def test_create_app_execute_through_plugin(self, app_config): + """End-to-end: create_app → execute function through interceptor chain.""" + + class ApiPlugin(appkit.Plugin): + def __init__(self): + super().__init__( + "api", + manifest=appkit.PluginManifest("api"), + ) + + plugin = ApiPlugin() + app = await appkit.create_app( + config=app_config, + plugins=[plugin], + auto_start=False, + ) + + async def compute(): + return '{"status": "ok", "count": 7}' + + result = await plugin.execute(compute, user_key="test-user") + assert result.ok is True + assert "count" in result.data + assert "7" in result.data + + @pytest.mark.asyncio + async def test_keyword_only_signature(self, app_config): + """Verify create_app requires keyword arguments.""" + # This should work + await appkit.create_app(config=app_config, auto_start=False) + + # Positional args should fail + with pytest.raises(TypeError): + await appkit.create_app(app_config, [], None, None, False) diff --git a/packages/appkit-rs/tests/test_plugin.py b/packages/appkit-rs/tests/test_plugin.py new file mode 100644 index 00000000..71f2627e --- /dev/null +++ b/packages/appkit-rs/tests/test_plugin.py @@ -0,0 +1,321 @@ +"""Integration tests for Plugin, PluginManifest, ExecutionResult, and AppKit.""" + +import pytest + +import appkit + + +class TestPluginPhase: + def test_constants(self): + assert appkit.PluginPhase.CORE == "core" + assert appkit.PluginPhase.NORMAL == "normal" + assert appkit.PluginPhase.DEFERRED == "deferred" + + +class TestPluginManifest: + def test_constructor(self): + m = appkit.PluginManifest("my-plugin") + assert m.name == "my-plugin" + assert m.display_name is None + assert m.description is None + + def test_with_all_fields(self): + m = appkit.PluginManifest( + "my-plugin", + display_name="My Plugin", + description="A test plugin", + ) + assert m.display_name == "My Plugin" + assert m.description == "A test plugin" + + def test_repr(self): + m = appkit.PluginManifest("my-plugin") + assert "my-plugin" in repr(m) + + def test_equality(self): + a = appkit.PluginManifest("p1") + b = appkit.PluginManifest("p1") + c = appkit.PluginManifest("p2") + assert a == b + assert a != c + + def test_hashable(self): + m = appkit.PluginManifest("p1") + s = {m} + assert len(s) == 1 + + +class TestExecutionResult: + def test_ok_result(self): + # ExecutionResult is only created by the framework, not directly. + # Test via Plugin.execute() in the full lifecycle tests. + pass + + def test_repr_format(self): + # ExecutionResult instances come from execute(), tested in lifecycle. + pass + + +class TestPlugin: + def test_constructor(self): + m = appkit.PluginManifest("test") + p = appkit.Plugin("test", manifest=m) + assert p.name == "test" + assert p.phase == "normal" + assert p.is_ready is False + + def test_phase_validation(self): + m = appkit.PluginManifest("test") + with pytest.raises(ValueError, match="Invalid phase"): + appkit.Plugin("test", phase="invalid", manifest=m) + + def test_core_phase(self): + m = appkit.PluginManifest("test") + p = appkit.Plugin("test", phase="core", manifest=m) + assert p.phase == "core" + + def test_repr(self): + m = appkit.PluginManifest("test") + p = appkit.Plugin("test", manifest=m) + r = repr(p) + assert "Plugin" in r + assert "test" in r + + def test_subclassing(self): + class MyPlugin(appkit.Plugin): + def __init__(self): + super().__init__( + "my-custom", + manifest=appkit.PluginManifest("my-custom"), + ) + + p = MyPlugin() + assert p.name == "my-custom" + assert isinstance(p, appkit.Plugin) + + def test_exports_default_empty(self): + m = appkit.PluginManifest("test") + p = appkit.Plugin("test", manifest=m) + assert p.exports() == {} + + def test_client_config_default_empty(self): + m = appkit.PluginManifest("test") + p = appkit.Plugin("test", manifest=m) + assert p.client_config() == {} + + def test_execute_before_init_raises(self): + m = appkit.PluginManifest("test") + p = appkit.Plugin("test", manifest=m) + with pytest.raises(RuntimeError, match="not initialized"): + import asyncio + + asyncio.get_event_loop().run_until_complete( + p.execute(lambda: None) + ) + + +class TestAppKit: + def test_constructor(self): + app = appkit.AppKit() + assert repr(app).startswith("AppKit") + assert len(app) == 0 + assert bool(app) is False # not initialized + + def test_register(self): + app = appkit.AppKit() + m = appkit.PluginManifest("p1") + p = appkit.Plugin("p1", manifest=m) + app.register(p) + assert len(app) == 1 + assert "p1" in app + + def test_contains(self): + app = appkit.AppKit() + m = appkit.PluginManifest("p1") + p = appkit.Plugin("p1", manifest=m) + app.register(p) + assert "p1" in app + assert "nonexistent" not in app + + @pytest.mark.asyncio + async def test_initialize(self, app_config): + app = appkit.AppKit() + m = appkit.PluginManifest("p1") + p = appkit.Plugin("p1", manifest=m) + app.register(p) + + await app.initialize(app_config) + assert bool(app) is True # initialized + assert p.is_ready is True + + @pytest.mark.asyncio + async def test_double_initialize_raises(self, app_config): + app = appkit.AppKit() + await app.initialize(app_config) + with pytest.raises(RuntimeError, match="already initialized"): + await app.initialize(app_config) + + @pytest.mark.asyncio + async def test_register_after_init_raises(self, app_config): + app = appkit.AppKit() + await app.initialize(app_config) + m = appkit.PluginManifest("late") + p = appkit.Plugin("late", manifest=m) + with pytest.raises(RuntimeError, match="Cannot register"): + app.register(p) + + @pytest.mark.asyncio + async def test_get_plugin(self, app_config): + app = appkit.AppKit() + m = appkit.PluginManifest("p1") + p = appkit.Plugin("p1", manifest=m) + app.register(p) + await app.initialize(app_config) + + found = app.get_plugin("p1") + assert found is not None + assert found.name == "p1" + + assert app.get_plugin("nonexistent") is None + + @pytest.mark.asyncio + async def test_plugin_names(self, app_config): + app = appkit.AppKit() + for name in ["alpha", "beta", "gamma"]: + m = appkit.PluginManifest(name) + p = appkit.Plugin(name, manifest=m) + app.register(p) + await app.initialize(app_config) + + names = app.plugin_names() + assert set(names) == {"alpha", "beta", "gamma"} + + @pytest.mark.asyncio + async def test_phase_ordering(self, app_config): + """Verify plugins are initialized in phase order.""" + order = [] + + class TrackingPlugin(appkit.Plugin): + def __init__(self, name, phase): + super().__init__( + name, + phase=phase, + manifest=appkit.PluginManifest(name), + ) + + async def setup(self): + order.append(self.name) + + app = appkit.AppKit() + app.register(TrackingPlugin("deferred-1", "deferred")) + app.register(TrackingPlugin("core-1", "core")) + app.register(TrackingPlugin("normal-1", "normal")) + app.register(TrackingPlugin("core-2", "core")) + await app.initialize(app_config) + + # Core plugins first, then normal, then deferred + core_indices = [order.index(n) for n in order if n.startswith("core")] + normal_indices = [order.index(n) for n in order if n.startswith("normal")] + deferred_indices = [order.index(n) for n in order if n.startswith("deferred")] + assert all(c < n for c in core_indices for n in normal_indices) + assert all(n < d for n in normal_indices for d in deferred_indices) + + @pytest.mark.asyncio + async def test_execute_through_plugin(self, app_config): + """Full lifecycle: register, init, execute through interceptor chain.""" + + class ComputePlugin(appkit.Plugin): + def __init__(self): + super().__init__( + "compute", + manifest=appkit.PluginManifest("compute"), + ) + + app = appkit.AppKit() + plugin = ComputePlugin() + app.register(plugin) + await app.initialize(app_config) + + async def my_func(): + return '{"result": 42}' + + result = await plugin.execute(my_func, user_key="user-1") + assert result.ok is True + assert result.data is not None + assert "42" in result.data + assert bool(result) is True + + @pytest.mark.asyncio + async def test_execute_error(self, app_config): + app = appkit.AppKit() + m = appkit.PluginManifest("err-plugin") + p = appkit.Plugin("err-plugin", manifest=m) + app.register(p) + await app.initialize(app_config) + + async def failing_func(): + raise RuntimeError("something broke") + + result = await p.execute(failing_func) + assert result.ok is False + assert result.status == 500 + assert "something broke" in result.message + assert bool(result) is False + + @pytest.mark.asyncio + async def test_execute_with_cache(self, app_config): + """Verify cache interceptor deduplicates calls.""" + call_count = 0 + + app = appkit.AppKit() + m = appkit.PluginManifest("cached") + p = appkit.Plugin("cached", manifest=m) + app.register(p) + await app.initialize(app_config) + + async def compute(): + nonlocal call_count + call_count += 1 + return '{"val": "computed"}' + + r1 = await p.execute( + compute, + user_key="u1", + cache_key=["test-key"], + cache_ttl=60, + ) + assert r1.ok + assert call_count == 1 + + r2 = await p.execute( + compute, + user_key="u1", + cache_key=["test-key"], + cache_ttl=60, + ) + assert r2.ok + assert call_count == 1 # cache hit + + @pytest.mark.asyncio + async def test_execute_with_timeout(self, app_config): + import asyncio + + app = appkit.AppKit() + m = appkit.PluginManifest("timeout-test") + p = appkit.Plugin("timeout-test", manifest=m) + app.register(p) + await app.initialize(app_config) + + async def slow_func(): + await asyncio.sleep(10) + return '"never"' + + result = await p.execute(slow_func, timeout_ms=50) + assert result.ok is False + assert result.status == 408 + assert "timed out" in result.message.lower() + + def test_shutdown_before_start_raises(self): + app = appkit.AppKit() + with pytest.raises(RuntimeError, match="not running"): + app.shutdown() diff --git a/packages/appkit-rs/tests/test_plugins.py b/packages/appkit-rs/tests/test_plugins.py new file mode 100644 index 00000000..b7582f0d --- /dev/null +++ b/packages/appkit-rs/tests/test_plugins.py @@ -0,0 +1,406 @@ +"""Integration tests for shipped Python plugin wrappers. + +Covers construction, client_config, route registration, and OBO auth +extraction. Tests use a lightweight ``_MockRouter`` that records +``inject_routes`` calls so we can assert on method/path/stream flags +without booting the Rust HTTP server. +""" + +from __future__ import annotations + +import os +from typing import Any + +import pytest + +import appkit +from appkit.plugins import ( + AnalyticsPlugin, + AnalyticsPluginConfig, + FilesPlugin, + FilesPluginConfig, + GeniePlugin, + GeniePluginConfig, + LakebasePlugin, + LakebasePluginConfig, + ServerPlugin, + ServerPluginConfig, + ServingEndpointConfig, + ServingPlugin, + ServingPluginConfig, + VectorSearchIndexConfig, + VectorSearchPlugin, + VectorSearchPluginConfig, + VolumeConfig, + analytics, + files, + genie, + lakebase, + server, + serving, + vector_search, +) +from appkit.plugins.analytics import _extract_param_names + + +HOST = "https://test.databricks.com" + + +class _MockRouter: + """Minimal router replacement that records route registrations.""" + + def __init__(self) -> None: + self.routes: list[tuple[str, str, bool]] = [] + + def _add(self, method: str, path: str, _handler: Any, stream: bool) -> None: + self.routes.append((method, path, stream)) + + def get(self, path, handler, *, stream=False): + self._add("GET", path, handler, stream) + + def post(self, path, handler, *, stream=False): + self._add("POST", path, handler, stream) + + def put(self, path, handler, *, stream=False): + self._add("PUT", path, handler, stream) + + def delete(self, path, handler, *, stream=False): + self._add("DELETE", path, handler, stream) + + def patch(self, path, handler, *, stream=False): + self._add("PATCH", path, handler, stream) + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + """Ensure tests do not leak DATABRICKS_HOST across cases.""" + monkeypatch.setenv("DATABRICKS_HOST", HOST) + yield + + +# --------------------------------------------------------------------------- +# Analytics +# --------------------------------------------------------------------------- + + +class TestAnalyticsPlugin: + def test_requires_warehouse(self, monkeypatch): + monkeypatch.delenv("DATABRICKS_WAREHOUSE_ID", raising=False) + with pytest.raises(ValueError, match="warehouse_id"): + AnalyticsPlugin(AnalyticsPluginConfig()) + + def test_construction(self): + plugin = AnalyticsPlugin( + AnalyticsPluginConfig(warehouse_id="wh-1", queries_dir="/tmp/q") + ) + assert plugin.name == "analytics" + assert plugin.warehouse_id == "wh-1" + assert plugin.client_config() == {"warehouse_id": "wh-1"} + + def test_routes_registered(self): + plugin = AnalyticsPlugin(AnalyticsPluginConfig(warehouse_id="wh-1")) + router = _MockRouter() + plugin.inject_routes(router) + assert ("POST", "/query/:query_key", False) in router.routes + assert ("GET", "/queries", False) in router.routes + + def test_entry_point(self): + assert isinstance( + analytics(AnalyticsPluginConfig(warehouse_id="wh-1")), + AnalyticsPlugin, + ) + + +class TestParamExtraction: + def test_basic(self): + assert _extract_param_names("SELECT :foo, :bar") == ["foo", "bar"] + + def test_skips_single_quoted(self): + assert _extract_param_names("SELECT ':not_a_param'") == [] + + def test_skips_line_comment(self): + assert _extract_param_names("-- :debug\nSELECT :real") == ["real"] + + def test_skips_block_comment(self): + assert _extract_param_names("/* :debug */ SELECT :real") == ["real"] + + def test_skips_nested_block_comment(self): + q = "/* outer /* :inner */ still_out */ SELECT :real" + assert _extract_param_names(q) == ["real"] + + def test_cast_ignored(self): + assert _extract_param_names("SELECT 1::INT") == [] + + def test_double_quoted_identifier(self): + assert _extract_param_names('SELECT ":not_a_param" FROM t') == [] + + def test_dedup(self): + assert _extract_param_names("SELECT :x, :x, :y") == ["x", "y"] + + +# --------------------------------------------------------------------------- +# Vector Search +# --------------------------------------------------------------------------- + + +class TestVectorSearchPlugin: + def _cfg(self): + return VectorSearchPluginConfig( + indexes={ + "docs": VectorSearchIndexConfig( + index_name="catalog.schema.docs_idx", + columns=["id", "content"], + ), + } + ) + + def test_construction(self): + plugin = VectorSearchPlugin(self._cfg()) + assert plugin.name == "vector-search" + assert plugin.client_config() == {"indexes": "docs"} + + def test_routes(self): + plugin = VectorSearchPlugin(self._cfg()) + router = _MockRouter() + plugin.inject_routes(router) + assert ("POST", "/query", False) in router.routes + assert ("POST", "/query-next-page", False) in router.routes + + def test_invalid_query_type_rejected(self): + with pytest.raises(ValueError, match="query_type"): + VectorSearchIndexConfig(index_name="a.b.c", query_type="bogus") + + def test_entry_point(self): + assert isinstance(vector_search(self._cfg()), VectorSearchPlugin) + + +# --------------------------------------------------------------------------- +# Server +# --------------------------------------------------------------------------- + + +class TestServerPlugin: + def test_defaults(self): + plugin = ServerPlugin() + assert plugin.name == "server" + assert plugin.phase == appkit.PluginPhase.CORE + cfg = plugin.to_server_config() + assert cfg.host == "0.0.0.0" + assert cfg.port == 8000 + assert cfg.auto_start is True + + def test_custom_config(self): + plugin = ServerPlugin( + ServerPluginConfig(host="127.0.0.1", port=9090, auto_start=False) + ) + cfg = plugin.to_server_config() + assert cfg.host == "127.0.0.1" + assert cfg.port == 9090 + assert cfg.auto_start is False + + def test_inject_routes_noop(self): + plugin = ServerPlugin() + router = _MockRouter() + plugin.inject_routes(router) + assert router.routes == [] + + def test_entry_point(self): + assert isinstance(server(), ServerPlugin) + + +# --------------------------------------------------------------------------- +# Files +# --------------------------------------------------------------------------- + + +class TestFilesPlugin: + def _cfg(self): + return FilesPluginConfig( + volumes={"uploads": VolumeConfig(path="/Volumes/c/s/uploads")} + ) + + def test_construction(self): + plugin = FilesPlugin(self._cfg()) + assert plugin.name == "files" + assert plugin.client_config() == {"volumes": "uploads"} + + def test_routes(self): + plugin = FilesPlugin(self._cfg()) + router = _MockRouter() + plugin.inject_routes(router) + methods = {(m, p) for m, p, _ in router.routes} + assert ("GET", "/list") in methods + assert ("POST", "/mkdir") in methods + assert ("DELETE", "/delete") in methods + + def test_entry_point(self): + assert isinstance(files(self._cfg()), FilesPlugin) + + +# --------------------------------------------------------------------------- +# Genie +# --------------------------------------------------------------------------- + + +class TestGeniePlugin: + def _cfg(self): + return GeniePluginConfig(spaces={"sales": "space-123"}) + + def test_construction(self): + plugin = GeniePlugin(self._cfg()) + assert plugin.name == "genie" + assert plugin.client_config() == {"spaces": "sales"} + + def test_routes(self): + plugin = GeniePlugin(self._cfg()) + router = _MockRouter() + plugin.inject_routes(router) + methods = {(m, p) for m, p, _ in router.routes} + assert ("POST", "/message") in methods + assert ("GET", "/conversation") in methods + assert ("GET", "/query-result") in methods + + def test_entry_point(self): + assert isinstance(genie(self._cfg()), GeniePlugin) + + +# --------------------------------------------------------------------------- +# Serving +# --------------------------------------------------------------------------- + + +class TestServingPlugin: + def _cfg(self): + return ServingPluginConfig( + endpoints={ + "chat": ServingEndpointConfig(env="CHAT_ENDPOINT"), + } + ) + + def test_requires_endpoint(self): + with pytest.raises(ValueError, match="at least one endpoint"): + ServingPluginConfig(endpoints={}) + + def test_construction(self): + plugin = ServingPlugin(self._cfg()) + assert plugin.name == "serving" + assert plugin.client_config() == {"endpoints": "chat"} + + def test_routes(self): + plugin = ServingPlugin(self._cfg()) + router = _MockRouter() + plugin.inject_routes(router) + assert ("POST", "/invoke/:endpoint", False) in router.routes + assert ("POST", "/stream/:endpoint", True) in router.routes + + def test_resolve_endpoint_reads_env(self, monkeypatch): + plugin = ServingPlugin(self._cfg()) + monkeypatch.setenv("CHAT_ENDPOINT", "databricks-dbrx") + assert plugin.resolve_endpoint("chat") == "databricks-dbrx" + + def test_resolve_endpoint_missing_env(self, monkeypatch): + plugin = ServingPlugin(self._cfg()) + monkeypatch.delenv("CHAT_ENDPOINT", raising=False) + with pytest.raises(appkit.ValidationError, match="not set"): + plugin.resolve_endpoint("chat") + + def test_resolve_endpoint_unknown_alias(self): + plugin = ServingPlugin(self._cfg()) + with pytest.raises(appkit.ValidationError, match="Unknown endpoint"): + plugin.resolve_endpoint("does-not-exist") + + def test_entry_point(self): + assert isinstance(serving(self._cfg()), ServingPlugin) + + +# --------------------------------------------------------------------------- +# Lakebase +# --------------------------------------------------------------------------- + + +class TestLakebasePlugin: + def _cfg(self): + return LakebasePluginConfig( + pg_config=appkit.LakebasePgConfig(host="db.example.com", database="mydb") + ) + + def test_construction(self): + plugin = LakebasePlugin(self._cfg()) + assert plugin.name == "lakebase" + assert plugin.pg_config.host == "db.example.com" + assert plugin.pg_config.database == "mydb" + + def test_client_config(self): + plugin = LakebasePlugin(self._cfg()) + cfg = plugin.client_config() + assert cfg["database"] == "mydb" + assert "ssl_mode" in cfg + + def test_exports(self): + plugin = LakebasePlugin(self._cfg()) + exports = plugin.exports() + assert exports["pg_host"] == "db.example.com" + assert exports["pg_database"] == "mydb" + assert exports["pg_port"] == "5432" + + def test_inject_routes_noop(self): + plugin = LakebasePlugin(self._cfg()) + router = _MockRouter() + plugin.inject_routes(router) + assert router.routes == [] + + def test_entry_point(self): + # PGHOST/PGDATABASE required to default-construct LakebasePgConfig. + os.environ["PGHOST"] = "h" + os.environ["PGDATABASE"] = "d" + try: + assert isinstance(lakebase(), LakebasePlugin) + finally: + os.environ.pop("PGHOST", None) + os.environ.pop("PGDATABASE", None) + + +# --------------------------------------------------------------------------- +# End-to-end: plugins subclass appkit.Plugin and register cleanly +# --------------------------------------------------------------------------- + + +class TestPluginRegistration: + @pytest.mark.asyncio + async def test_register_and_initialize(self, app_config): + os.environ["CHAT_ENDPOINT"] = "dbrx-chat" + try: + plugins = [ + server(), + analytics(AnalyticsPluginConfig(warehouse_id="wh-1")), + vector_search( + VectorSearchPluginConfig( + indexes={ + "docs": VectorSearchIndexConfig( + index_name="c.s.i", columns=["id"] + ) + } + ) + ), + serving( + ServingPluginConfig( + endpoints={ + "chat": ServingEndpointConfig(env="CHAT_ENDPOINT") + } + ) + ), + ] + app = appkit.AppKit() + for p in plugins: + app.register(p) + await app.initialize(app_config) + assert set(app.plugin_names()) == { + "server", + "analytics", + "vector-search", + "serving", + } + for p in plugins: + assert p.is_ready is True + finally: + os.environ.pop("CHAT_ENDPOINT", None) diff --git a/packages/appkit-rs/tests/test_server.py b/packages/appkit-rs/tests/test_server.py new file mode 100644 index 00000000..f05b0c4e --- /dev/null +++ b/packages/appkit-rs/tests/test_server.py @@ -0,0 +1,59 @@ +"""Integration tests for Server types.""" + +import pytest + +import appkit + + +class TestServerConfig: + def test_defaults(self): + cfg = appkit.ServerConfig() + assert cfg.host == "0.0.0.0" + assert cfg.port == 8000 + assert cfg.auto_start is True + assert cfg.static_path is None + + def test_custom_values(self): + cfg = appkit.ServerConfig( + host="127.0.0.1", + port=9090, + auto_start=False, + static_path="/dist", + ) + assert cfg.host == "127.0.0.1" + assert cfg.port == 9090 + assert cfg.auto_start is False + assert cfg.static_path == "/dist" + + def test_repr(self): + cfg = appkit.ServerConfig() + r = repr(cfg) + assert "ServerConfig" in r + assert "8000" in r + + def test_equality(self): + a = appkit.ServerConfig(port=8000) + b = appkit.ServerConfig(port=8000) + c = appkit.ServerConfig(port=9090) + assert a == b + assert a != c + + def test_hashable(self): + cfg = appkit.ServerConfig() + s = {cfg} + assert len(s) == 1 + + def test_frozen(self): + cfg = appkit.ServerConfig() + with pytest.raises(AttributeError): + cfg.port = 9999 + + +class TestRouter: + def test_type_exists(self): + assert hasattr(appkit, "Router") + + +class TestRequest: + def test_type_exists(self): + assert hasattr(appkit, "Request") diff --git a/template-python/.env.example b/template-python/.env.example new file mode 100644 index 00000000..cf28e14e --- /dev/null +++ b/template-python/.env.example @@ -0,0 +1,15 @@ +# Required — your Databricks workspace URL +DATABRICKS_HOST=https://your-workspace.cloud.databricks.com + +# Required for service-principal auth (not needed when running as a Databricks App) +# DATABRICKS_CLIENT_ID=your-client-id +# DATABRICKS_CLIENT_SECRET=your-client-secret + +# Optional — SQL warehouse for analytics queries +# DATABRICKS_WAREHOUSE_ID=your-warehouse-id + +# Optional — OpenTelemetry collector endpoint +# OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 + +# Server settings +DATABRICKS_APP_PORT=8000 diff --git a/template-python/README.md b/template-python/README.md new file mode 100644 index 00000000..98e134e1 --- /dev/null +++ b/template-python/README.md @@ -0,0 +1,97 @@ +# AppKit Python Template + +A minimal, runnable scaffold for Python backend applications built on [appkit](../packages/appkit-rs/). + +## Prerequisites + +- Python 3.11+ +- Rust toolchain (for building from source) — install via [rustup](https://rustup.rs/) +- [maturin](https://www.maturin.rs/) (`pip install maturin`) + +## Setup + +```bash +# 1. Create and activate a virtual environment +python -m venv .venv +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# 2. Install appkit (build from source) +cd ../packages/appkit-rs +maturin develop +cd ../../template-python + +# 3. Install remaining dependencies +pip install -r requirements.txt + +# 4. Configure environment +cp .env.example .env +# Edit .env with your Databricks workspace URL and credentials +``` + +## Run + +```bash +python server/main.py +``` + +The server starts on `http://0.0.0.0:8000` by default. Test it: + +```bash +# Health check +curl http://localhost:8000/api/example/health + +# Greeting (with caching and timeout via the interceptor chain) +curl -X POST http://localhost:8000/api/example/greet \ + -H "Content-Type: application/json" \ + -d '{"name": "Alice"}' + +# Streaming +curl http://localhost:8000/api/example/stream +``` + +## Test + +```bash +pytest +``` + +Tests exercise plugin registration, the interceptor chain (`execute`), streaming (`execute_stream`), and error handling — all without requiring a live Databricks workspace. + +## Project Structure + +``` +template-python/ +├── server/ +│ ├── main.py # Entry point — creates AppConfig, registers plugins, starts server +│ └── example_plugin.py # Sample Plugin subclass with routes +├── tests/ +│ └── test_example.py # pytest tests for the example plugin +├── pyproject.toml # Python project metadata +├── requirements.txt # Pinned dependencies +├── .env.example # Environment variable documentation +├── app.yaml.tmpl # Databricks Apps deployment config +├── databricks.yml.tmpl # Databricks Bundle config +└── README.md # This file +``` + +## Deploy to Databricks + +1. Install the [Databricks CLI](https://docs.databricks.com/dev-tools/cli/index.html). +2. Copy the deployment templates: + ```bash + cp app.yaml.tmpl app.yaml + cp databricks.yml.tmpl databricks.yml + ``` +3. Edit `databricks.yml` — set your project name, workspace host, and any resource references. +4. Deploy: + ```bash + databricks bundle deploy + ``` + +## Customizing + +- **Add plugins:** Create a new file in `server/`, subclass `appkit.Plugin`, and register it in `main.py`. +- **Add routes:** Override `inject_routes(self, router)` in your plugin. Routes are automatically namespaced under `/api/{plugin-name}/`. +- **Use connectors:** Access Databricks services via `appkit.SqlWarehouseConnector`, `appkit.FilesConnector`, `appkit.GenieConnector`, `appkit.ServingConnector`, or `appkit.LakebaseConnector`. +- **Caching/retry/timeout:** Pass options to `self.execute()` to leverage the built-in interceptor chain. diff --git a/template-python/_gitignore b/template-python/_gitignore new file mode 100644 index 00000000..a3714379 --- /dev/null +++ b/template-python/_gitignore @@ -0,0 +1,27 @@ +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +*.egg +dist/ +build/ +.venv/ +venv/ + +# Environment +.env + +# Testing +.pytest_cache/ +htmlcov/ +.coverage + +# IDE +.vscode/ +.idea/ + +# OS +.DS_Store + +# Databricks +.databricks/ diff --git a/template-python/app.yaml.tmpl b/template-python/app.yaml.tmpl new file mode 100644 index 00000000..f4227095 --- /dev/null +++ b/template-python/app.yaml.tmpl @@ -0,0 +1,5 @@ +command: ['python', 'server/main.py'] +{{- if .appEnv}} +env: +{{.appEnv}} +{{- end}} diff --git a/template-python/databricks.yml.tmpl b/template-python/databricks.yml.tmpl new file mode 100644 index 00000000..a800feb4 --- /dev/null +++ b/template-python/databricks.yml.tmpl @@ -0,0 +1,35 @@ +bundle: + name: {{.projectName}} +{{- if .bundle.variables}} + +variables: +{{.bundle.variables}} +{{- end}} + +resources: + apps: + app: + name: "{{.projectName}}" + description: "{{.appDescription}}" + source_code_path: ./ + + # Uncomment to enable on behalf of user API scopes. + # Available scopes: sql, dashboards.genie, files.files, serving.serving-endpoints + # user_api_scopes: + # - sql +{{- if .bundle.resources}} + + resources: +{{.bundle.resources}} +{{- end}} + +targets: + default: + default: true + workspace: + host: {{.workspaceHost}} +{{- if .bundle.targetVariables}} + + variables: +{{.bundle.targetVariables}} +{{- end}} diff --git a/template-python/pyproject.toml b/template-python/pyproject.toml new file mode 100644 index 00000000..b43ab89f --- /dev/null +++ b/template-python/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.backends._legacy:_Backend" + +[project] +name = "my-appkit-app" +version = "0.1.0" +requires-python = ">=3.11" +description = "A Databricks AppKit application" +dependencies = [ + "databricks-appkit>=0.1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/template-python/requirements.txt b/template-python/requirements.txt new file mode 100644 index 00000000..df4ee825 --- /dev/null +++ b/template-python/requirements.txt @@ -0,0 +1,9 @@ +# Core dependency — install from local wheel or PyPI once published +databricks-appkit>=0.1.0 + +# .env file loading +python-dotenv>=1.0 + +# Dev / test +pytest>=8.0 +pytest-asyncio>=0.23 diff --git a/template-python/server/example_plugin.py b/template-python/server/example_plugin.py new file mode 100644 index 00000000..584655b0 --- /dev/null +++ b/template-python/server/example_plugin.py @@ -0,0 +1,85 @@ +"""Example plugin demonstrating the AppKit Python SDK. + +Shows how to subclass Plugin, register routes, use the interceptor +chain (execute), and stream results (execute_stream). +""" + +import json + +import appkit + + +class ExamplePlugin(appkit.Plugin): + """A sample plugin with health-check, greeting, and streaming endpoints.""" + + def __init__(self): + super().__init__( + "example", + manifest=appkit.PluginManifest( + "example", + display_name="Example Plugin", + description="Demonstrates setup, routes, execute, and streaming", + ), + ) + + async def setup(self): + """Called once during app initialization (after registration).""" + print(f"[{self.name}] plugin initialized") + + def inject_routes(self, router: appkit.Router): + """Register HTTP endpoints under /api/example/...""" + + router.get("/health", self._health_handler) + router.post("/greet", self._greet_handler) + router.get("/stream", self._stream_handler, stream=True) + + # ------------------------------------------------------------------ + # Route handlers + # ------------------------------------------------------------------ + + async def _health_handler(self, request: appkit.Request) -> str: + """GET /api/example/health — simple health check.""" + return json.dumps({"status": "ok", "plugin": self.name}) + + async def _greet_handler(self, request: appkit.Request) -> str: + """POST /api/example/greet — runs through the interceptor chain. + + Expects JSON body: {"name": "Alice"} + Uses execute() for caching, timeout, and retry support. + """ + body = request.json() + name = body.get("name", "World") + + async def build_greeting(): + return json.dumps({"message": f"Hello, {name}!"}) + + result = await self.execute( + build_greeting, + user_key=name, + cache_key=["greet", name], + cache_ttl=60, + timeout_ms=5000, + ) + + if result.ok: + return result.data + return json.dumps( + {"error": result.message}, + ) + + async def _stream_handler(self, request: appkit.Request): + """GET /api/example/stream — demonstrates execute_stream(). + + Yields numbered items as SSE events. The framework detects this + async generator (stream=True route) and bridges each yielded + string to a Server-Sent Event automatically. + """ + + async def generate_items(): + for i in range(5): + yield json.dumps({"item": i, "total": 5}) + + stream = await self.execute_stream(generate_items, timeout_ms=10000) + + async for chunk in stream: + yield chunk diff --git a/template-python/server/main.py b/template-python/server/main.py new file mode 100644 index 00000000..4b0c3fef --- /dev/null +++ b/template-python/server/main.py @@ -0,0 +1,36 @@ +"""Entry point for the AppKit Python application.""" + +import asyncio +from pathlib import Path + +import appkit +from dotenv import load_dotenv + +from example_plugin import ExamplePlugin + +# Load .env from the template-python root (one level up from server/). +load_dotenv(Path(__file__).resolve().parent.parent / ".env") + + +async def main(): + config = appkit.AppConfig.from_env() + + app = await appkit.create_app( + config=config, + plugins=[ExamplePlugin()], + cache_config=appkit.CacheConfig(ttl=3600), + ) + + # The server starts automatically (auto_start=True by default). + # To keep the process alive, await shutdown or Ctrl-C. + print(f"AppKit running on http://{config.host}:{config.app_port}") + try: + await asyncio.Event().wait() # Block until interrupted + except asyncio.CancelledError: + pass + finally: + app.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/template-python/tests/test_example.py b/template-python/tests/test_example.py new file mode 100644 index 00000000..9adba36b --- /dev/null +++ b/template-python/tests/test_example.py @@ -0,0 +1,118 @@ +"""Tests for the example plugin. + +Run with: pytest +(Requires appkit-rs to be installed: maturin develop or pip install appkit-rs) +""" + +import json +import sys +from pathlib import Path + +import pytest + +import appkit + +# Allow importing from server/ +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "server")) + +from example_plugin import ExamplePlugin + + +@pytest.fixture +def app_config(): + """Minimal AppConfig for testing (no real Databricks connection).""" + return appkit.AppConfig( + "https://test.databricks.com", + client_id="test-client-id", + client_secret="test-client-secret", + ) + + +@pytest.fixture +def example_plugin(): + return ExamplePlugin() + + +class TestExamplePluginRegistration: + def test_plugin_name(self, example_plugin): + assert example_plugin.name == "example" + + def test_plugin_phase(self, example_plugin): + assert example_plugin.phase == "normal" + + def test_manifest_fields(self, example_plugin): + m = example_plugin.manifest + assert m.name == "example" + assert m.display_name == "Example Plugin" + assert m.description is not None + + +class TestExamplePluginLifecycle: + @pytest.mark.asyncio + async def test_setup_marks_ready(self, app_config, example_plugin): + app = await appkit.create_app( + config=app_config, + plugins=[example_plugin], + auto_start=False, + ) + assert example_plugin.is_ready is True + assert "example" in app + + @pytest.mark.asyncio + async def test_execute_greet(self, app_config, example_plugin): + """Verify the greeting logic works through the interceptor chain.""" + await appkit.create_app( + config=app_config, + plugins=[example_plugin], + auto_start=False, + ) + + async def greet(): + return json.dumps({"message": "Hello, Test!"}) + + result = await example_plugin.execute( + greet, + user_key="test", + cache_key=["greet", "test"], + cache_ttl=60, + ) + assert result.ok is True + data = json.loads(result.data) + assert data["message"] == "Hello, Test!" + + @pytest.mark.asyncio + async def test_execute_stream(self, app_config, example_plugin): + """Verify streaming execution collects all chunks.""" + await appkit.create_app( + config=app_config, + plugins=[example_plugin], + auto_start=False, + ) + + async def generate(): + for i in range(3): + yield json.dumps({"i": i}) + + stream = await example_plugin.execute_stream(generate) + items = [] + async for chunk in stream: + items.append(json.loads(chunk)) + assert len(items) == 3 + assert items[0]["i"] == 0 + assert items[2]["i"] == 2 + + @pytest.mark.asyncio + async def test_execute_error_handling(self, app_config, example_plugin): + """Verify errors are captured as failed ExecutionResult.""" + await appkit.create_app( + config=app_config, + plugins=[example_plugin], + auto_start=False, + ) + + async def fail(): + raise ValueError("boom") + + result = await example_plugin.execute(fail) + assert result.ok is False + assert "boom" in result.message