diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index e1f18de74..bddd4e325 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -69,7 +69,7 @@ jobs: - name: Install dependencies run: | - python -m pip install -e ".[tests,tracking-client,graphviz]" + python -m pip install -e ".[tests,tracking-client,graphviz,tracking-server-s3]" - name: Run tests run: | diff --git a/.gitignore b/.gitignore index a92fd35b4..1982de051 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,11 @@ burr/tracking/server/build examples/*/statemachine examples/*/*/statemachine .vscode + +# Terraform (see also examples/deployment/aws/terraform/.gitignore) +**/.terraform.lock.hcl +examples/deployment/aws/terraform/.terraform/ +examples/deployment/aws/terraform/*.tfstate +examples/deployment/aws/terraform/*.tfstate.* +examples/deployment/aws/terraform/.terraform.tfstate.lock.info +examples/deployment/aws/terraform/*.tfplan diff --git a/burr/integrations/__init__.py b/burr/integrations/__init__.py index 13a83393a..956579056 100644 --- a/burr/integrations/__init__.py +++ b/burr/integrations/__init__.py @@ -14,3 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + + +def __getattr__(name: str): + """Lazy load Bedrock integration to avoid requiring boto3 unless used.""" + if name == "BedrockAction": + from burr.integrations.bedrock import BedrockAction + return BedrockAction + if name == "BedrockStreamingAction": + from burr.integrations.bedrock import BedrockStreamingAction + return BedrockStreamingAction + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/burr/integrations/bedrock.py b/burr/integrations/bedrock.py new file mode 100644 index 000000000..36969feaf --- /dev/null +++ b/burr/integrations/bedrock.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Amazon Bedrock integration for Burr. + +This module provides Action classes for invoking Amazon Bedrock models +within Burr applications. + +Example usage: + from burr.integrations.bedrock import BedrockAction + + def prompt_mapper(state): + return { + "messages": [{"role": "user", "content": state["user_input"]}], + "system": [{"text": "You are a helpful assistant."}], + } + + # With default client (created lazily on first use): + action = BedrockAction( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + input_mapper=prompt_mapper, + reads=["user_input"], + writes=["response"], + ) + + # With injected client (for tests or distributed execution): + # client = boto3.client("bedrock-runtime", region_name="us-east-1") + # action = BedrockAction(..., client=client) +""" + +import logging +from typing import Any, Generator, Optional, Protocol + +from burr.core.action import SingleStepAction, StreamingAction +from burr.core.state import State +from burr.integrations.base import require_plugin + +logger = logging.getLogger(__name__) + +# Type for injected Bedrock client (avoids boto3 import at type-check time) +BedrockClient = Any + +try: + import boto3 + from botocore.config import Config + from botocore.exceptions import ClientError +except ImportError as e: + require_plugin(e, "bedrock") + + +class StateToPromptMapper(Protocol): + """Protocol for mapping Burr state to Bedrock prompt format.""" + + def __call__(self, state: State) -> dict[str, Any]: + ... + + +class BedrockAction(SingleStepAction): + """Action that invokes Amazon Bedrock models using the Converse API.""" + + def __init__( + self, + model_id: str, + input_mapper: StateToPromptMapper, + reads: list[str], + writes: list[str], + name: str = "bedrock_invoke", + region: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + inference_config: Optional[dict[str, Any]] = None, + max_retries: int = 3, + client: Optional[BedrockClient] = None, + ): + super().__init__() + self._model_id = model_id + self._input_mapper = input_mapper + self._reads = reads + self._writes = writes + self._name = name + self._region = region + self._guardrail_id = guardrail_id + self._guardrail_version = guardrail_version or "DRAFT" + self._inference_config = inference_config or {"maxTokens": 4096} + self._max_retries = max_retries + self._client = client + + def _get_client(self) -> BedrockClient: + """Return the Bedrock client, creating lazily if not injected.""" + if self._client is not None: + return self._client + config = Config( + retries={"max_attempts": self._max_retries, "mode": "adaptive"} + ) + self._client = boto3.client( + "bedrock-runtime", region_name=self._region, config=config + ) + return self._client + + @property + def reads(self) -> list[str]: + return self._reads + + @property + def writes(self) -> list[str]: + return self._writes + + @property + def name(self) -> str: + return self._name + + def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: + prompt = self._input_mapper(state) + + request: dict[str, Any] = { + "modelId": self._model_id, + "messages": prompt["messages"], + "inferenceConfig": self._inference_config, + } + + if "system" in prompt: + request["system"] = prompt["system"] + + if self._guardrail_id: + request["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + } + + try: + response = self._get_client().converse(**request) + except ClientError as e: + logger.error("Bedrock API error: %s", e) + raise + + output_message = response["output"]["message"] + content_blocks = output_message.get("content", []) + text = content_blocks[0]["text"] if content_blocks else "" + + result: dict[str, Any] = { + "response": text, + "usage": response.get("usage", {}), + "stop_reason": response.get("stopReason"), + } + + updates = {key: result[key] for key in self._writes if key in result} + new_state = state.update(**updates) + + return result, new_state + + +class BedrockStreamingAction(StreamingAction): + """Streaming variant of BedrockAction using Converse Stream API.""" + + def __init__( + self, + model_id: str, + input_mapper: StateToPromptMapper, + reads: list[str], + writes: list[str], + name: str = "bedrock_stream", + region: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + inference_config: Optional[dict[str, Any]] = None, + max_retries: int = 3, + client: Optional[BedrockClient] = None, + ): + super().__init__() + self._model_id = model_id + self._input_mapper = input_mapper + self._reads = reads + self._writes = writes + self._name = name + self._region = region + self._guardrail_id = guardrail_id + self._guardrail_version = guardrail_version or "DRAFT" + self._inference_config = inference_config or {"maxTokens": 4096} + self._max_retries = max_retries + self._client = client + + def _get_client(self) -> BedrockClient: + """Return the Bedrock client, creating lazily if not injected.""" + if self._client is not None: + return self._client + config = Config( + retries={"max_attempts": self._max_retries, "mode": "adaptive"} + ) + self._client = boto3.client( + "bedrock-runtime", region_name=self._region, config=config + ) + return self._client + + @property + def reads(self) -> list[str]: + return self._reads + + @property + def writes(self) -> list[str]: + return self._writes + + @property + def name(self) -> str: + return self._name + + def stream_run( + self, state: State, **run_kwargs + ) -> Generator[dict, None, None]: + prompt = self._input_mapper(state) + + request: dict[str, Any] = { + "modelId": self._model_id, + "messages": prompt["messages"], + "inferenceConfig": self._inference_config, + } + + if "system" in prompt: + request["system"] = prompt["system"] + + if self._guardrail_id: + request["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + } + + try: + response = self._get_client().converse_stream(**request) + except ClientError as e: + logger.error("Bedrock streaming API error: %s", e) + raise + + full_response = "" + stream = response.get("stream", []) + for event in stream: + if "contentBlockDelta" in event: + chunk = event["contentBlockDelta"]["delta"].get("text", "") + full_response += chunk + yield {"chunk": chunk, "response": full_response} + + yield {"chunk": "", "response": full_response, "complete": True} + + def update(self, result: dict, state: State) -> State: + if result.get("complete"): + updates = {"response": result.get("response", "")} + filtered = {k: v for k, v in updates.items() if k in self._writes} + return state.update(**filtered) + return state diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index e33cab9b8..904fe0019 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -162,6 +162,31 @@ def snapshot_interval_milliseconds(self) -> Optional[int]: pass +class EventDrivenBackendMixin(abc.ABC): + """Mixin for backends that support event-driven updates. + + Enables backends to receive real-time notifications instead of polling + for new files. + """ + + @abc.abstractmethod + async def start_event_consumer(self): + """Start the event consumer for event-driven tracking. + + This method should run indefinitely, processing event notifications + from the configured message queue. + """ + pass + + @abc.abstractmethod + def is_event_driven(self) -> bool: + """Check if this backend is configured for event-driven updates. + + :return: True if event-driven mode is enabled and configured, False otherwise + """ + pass + + class BackendBase(abc.ABC): async def lifespan(self, app: FastAPI): """Quick tool to allow plugin to the app's lifecycle. diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index 0e5ce62b0..95ef427ff 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import asyncio import importlib import logging import os @@ -29,6 +30,7 @@ from burr.tracking.server.backend import ( AnnotationsBackendMixin, BackendBase, + EventDrivenBackendMixin, IndexingBackendMixin, SnapshottingBackendMixin, ) @@ -135,9 +137,20 @@ async def lifespan(app: FastAPI): await backend.lifespan(app).__anext__() await sync_index() # this will trigger the repeat every N seconds await save_snapshot() # this will trigger the repeat every N seconds + # Start event consumer for event-driven tracking when configured + event_consumer_task = None + if isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven(): + event_consumer_task = asyncio.create_task(backend.start_event_consumer()) global initialized initialized = True yield + # Graceful shutdown: cancel event consumer task + if event_consumer_task is not None: + event_consumer_task.cancel() + try: + await event_consumer_task + except asyncio.CancelledError: + pass await backend.lifespan(app).__anext__() @@ -172,12 +185,18 @@ def get_app_spec(): logger = logging.getLogger(__name__) if app_spec.indexing: - update_interval = backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None - sync_index = repeat_every( - seconds=backend.update_interval_milliseconds() / 1000, - wait_first=True, - logger=logger, - )(sync_index) + # Only use polling when not in event-driven mode + if not ( + isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven() + ): + update_interval = ( + backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None + ) + sync_index = repeat_every( + seconds=backend.update_interval_milliseconds() / 1000, + wait_first=True, + logger=logger, + )(sync_index) if app_spec.snapshotting: snapshot_interval = ( diff --git a/burr/tracking/server/s3/README.md b/burr/tracking/server/s3/README.md index 0dbd7608f..62a6035b2 100644 --- a/burr/tracking/server/s3/README.md +++ b/burr/tracking/server/s3/README.md @@ -59,8 +59,9 @@ This will immediately start indexing your s3 bucket (or pick up from the last sn To track your data, you use the S3TrackingClient. You pass the tracker to the `ApplicationBuilder`: - ```python +from burr.tracking.s3client import S3TrackingClient + app = ( ApplicationBuilder() .with_graph(graph) diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index 7f48f88ad..04c1051cd 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -15,21 +15,25 @@ # specific language governing permissions and limitations # under the License. +import asyncio import dataclasses import datetime +import enum import functools import itertools import json import logging import operator import os.path +import tempfile import uuid from collections import Counter -from typing import List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Literal, Optional, Sequence, Type, TypeVar, Union import fastapi import pydantic from aiobotocore import session +from pydantic import field_validator from fastapi import FastAPI from pydantic_settings import BaseSettings from tortoise import functions, transactions @@ -42,6 +46,7 @@ from burr.tracking.server.backend import ( BackendBase, BurrSettings, + EventDrivenBackendMixin, IndexingBackendMixin, SnapshottingBackendMixin, ) @@ -67,10 +72,33 @@ async def _query_s3_file( bucket: str, key: str, client: session.AioBaseClient, -) -> Union[ContentsModel, List[ContentsModel]]: + buffer_size_mb: int = 10, +) -> bytes: + """Query S3 file with buffering to handle large files. + + BIP-0042: Uses SpooledTemporaryFile to buffer content, spilling to disk + if the file exceeds buffer_size_mb. This ensures the returned bytes object + is seekable for pickle/json deserialization, fixing the UnsupportedOperation + error on large state files. + + :param bucket: S3 bucket name + :param key: S3 object key + :param client: aiobotocore S3 client + :param buffer_size_mb: Max MB to hold in RAM before spilling to disk (default 10MB) + :return: File contents as bytes + """ response = await client.get_object(Bucket=bucket, Key=key) - body = await response["Body"].read() - return body + buffer_size = buffer_size_mb * 1024 * 1024 + + with tempfile.SpooledTemporaryFile(max_size=buffer_size, mode="w+b") as tmp: + async with response["Body"] as stream: + while True: + chunk = await stream.read(8192) + if not chunk: + break + tmp.write(chunk) + tmp.seek(0) + return tmp.read() @dataclasses.dataclass @@ -133,6 +161,13 @@ def from_path(cls, path: str, created_date: datetime.datetime) -> "DataFile": ) +class TrackingMode(str, enum.Enum): + """Tracking mode for S3 backend: polling or event-driven.""" + + POLLING = "POLLING" + EVENT_DRIVEN = "EVENT_DRIVEN" + + class S3Settings(BurrSettings): s3_bucket: str update_interval_milliseconds: int = 120_000 @@ -140,6 +175,20 @@ class S3Settings(BurrSettings): snapshot_interval_milliseconds: int = 3_600_000 load_snapshot_on_start: bool = True prior_snapshots_to_keep: int = 5 + # BIP-0042: Event-driven tracking settings + tracking_mode: TrackingMode = TrackingMode.POLLING + sqs_queue_url: Optional[str] = None + sqs_region: Optional[str] = None + sqs_wait_time_seconds: int = 20 # SQS long polling timeout + s3_buffer_size_mb: int = 10 # RAM buffer before spilling to disk + + @field_validator("tracking_mode", mode="before") + @classmethod + def coerce_tracking_mode(cls, v: object) -> object: + """Coerce legacy 'SQS' string to EVENT_DRIVEN for backward compatibility.""" + if v == "SQS": + return TrackingMode.EVENT_DRIVEN + return v def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: @@ -156,7 +205,7 @@ def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: return inverted_str + "-" + timestamp.isoformat() -class SQLiteS3Backend(BackendBase, IndexingBackendMixin, SnapshottingBackendMixin): +class SQLiteS3Backend(BackendBase, IndexingBackendMixin, SnapshottingBackendMixin, EventDrivenBackendMixin): def __init__( self, s3_bucket: str, @@ -165,6 +214,12 @@ def __init__( snapshot_interval_milliseconds: int, load_snapshot_on_start: bool, prior_snapshots_to_keep: int, + # BIP-0042: New parameters for event-driven tracking + tracking_mode: Union[TrackingMode, str] = TrackingMode.POLLING, + sqs_queue_url: Optional[str] = None, + sqs_region: Optional[str] = None, + sqs_wait_time_seconds: int = 20, + s3_buffer_size_mb: int = 10, ): self._backend_id = system.now().isoformat() + str(uuid.uuid4()) self._bucket = s3_bucket @@ -177,6 +232,17 @@ def __init__( self._load_snapshot_on_start = load_snapshot_on_start self._snapshot_key_history = [] self._prior_snapshots_to_keep = prior_snapshots_to_keep + # BIP-0042: Store event-driven tracking settings (normalize str to enum) + if isinstance(tracking_mode, TrackingMode): + self._tracking_mode = tracking_mode + elif tracking_mode == "SQS": + self._tracking_mode = TrackingMode.EVENT_DRIVEN + else: + self._tracking_mode = TrackingMode(tracking_mode) + self._sqs_queue_url = sqs_queue_url + self._sqs_region = sqs_region + self._sqs_wait_time_seconds = sqs_wait_time_seconds + self._s3_buffer_size_mb = s3_buffer_size_mb async def load_snapshot(self): if not self._load_snapshot_on_start: @@ -215,7 +281,7 @@ async def snapshot(self): s3_key = f"{self._snapshot_prefix}/{timestamp}/{self._backend_id}/snapshot.db" # TODO -- copy the path at snapshot_path to s3 using aiobotocore session = self._session - logger.info(f"Saving db snapshot at: {s3_key}") + logger.info("Saving db snapshot at: %s", s3_key) async with session.create_client("s3") as s3_client: with open(path, "rb") as file_data: await s3_client.put_object(Bucket=self._bucket, Key=s3_key, Body=file_data) @@ -223,7 +289,7 @@ async def snapshot(self): self._snapshot_key_history.append(s3_key) if len(self._snapshot_key_history) > 5: old_snapshot_to_remove = self._snapshot_key_history.pop(0) - logger.info(f"Removing old snapshot: {old_snapshot_to_remove}") + logger.info("Removing old snapshot: %s", old_snapshot_to_remove) async with session.create_client("s3") as s3_client: await s3_client.delete_object(Bucket=self._bucket, Key=old_snapshot_to_remove) @@ -247,7 +313,7 @@ async def _s3_get_first_write_date(self, project_id: str): async def _update_projects(self): current_projects = await Project.all() project_names = {project.name for project in current_projects} - logger.info(f"Current projects: {project_names}") + logger.info("Current projects: %s", project_names) async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( @@ -257,7 +323,7 @@ async def _update_projects(self): project_name = prefix.get("Prefix").split("/")[-2] if project_name not in project_names: now = system.now() - logger.info(f"Creating project: {project_name}") + logger.info("Creating project: %s", project_name) await Project.create( name=project_name, uri=None, @@ -283,7 +349,7 @@ async def query_applications_by_key( async def _gather_metadata_files( self, - metadata_files: List[DataFile], + metadata_files: list[DataFile], ) -> Sequence[dict]: """Gives a list of metadata files so we can update the application""" @@ -312,7 +378,7 @@ async def _query_metadata_file(metadata_file: DataFile) -> dict: ) return out - async def _gather_log_file_data(self, log_files: List[DataFile]) -> Sequence[dict]: + async def _gather_log_file_data(self, log_files: list[DataFile]) -> Sequence[dict]: """Gives a list of log files so we can update the application""" async def _query_log_file(log_file: DataFile) -> dict: @@ -344,9 +410,9 @@ async def _gather_paths_to_update( :return: list of paths to update """ - logger.info(f"Scanning db with highwatermark: {high_watermark_s3_path}") + logger.info("Scanning db with highwatermark: %s", high_watermark_s3_path) paths_to_update = [] - logger.info(f"Scanning log data for project: {project.name}") + logger.info("Scanning log data for project: %s", project.name) async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( @@ -358,11 +424,11 @@ async def _gather_paths_to_update( key = content["Key"] last_modified = content["LastModified"] # Created == last_modified as we have an immutable data model - logger.info(f"Found new file: {key}") + logger.info("Found new file: %s", key) paths_to_update.append(DataFile.from_path(key, created_date=last_modified)) if len(paths_to_update) >= max_paths: break - logger.info(f"Found {len(paths_to_update)} new files to index") + logger.info("Found %s new files to index", len(paths_to_update)) return paths_to_update async def _ensure_applications_exist( @@ -379,10 +445,13 @@ async def _ensure_applications_exist( ) counter = Counter([path.file_type for path in paths_to_update]) logger.info( - f"Found {len(all_application_keys)} applications in the scan, " - f"including: {counter['log']} log files, " - f"{counter['metadata']} metadata files, and {counter['graph']} graph files, " - f"and {len(paths_to_update) - len(all_application_keys)} other files." + "Found %s applications in the scan, including: %s log files, " + "%s metadata files, and %s graph files, and %s other files.", + len(all_application_keys), + counter["log"], + counter["metadata"], + counter["graph"], + len(paths_to_update) - len(all_application_keys), ) # First, let's create all applications, ignoring them if they exist @@ -406,7 +475,9 @@ async def _ensure_applications_exist( ] logger.info( - f"Creating {len(apps_to_create)} new applications, with keys: {[(app.name, app.partition_key) for app in apps_to_create]}" + "Creating %s new applications, with keys: %s", + len(apps_to_create), + [(app.name, app.partition_key) for app in apps_to_create], ) await Application.bulk_create(apps_to_create) all_applications = await self.query_applications_by_key(all_application_keys) @@ -421,7 +492,7 @@ async def _update_all_applications( :param paths_to_update: All paths to update :return: """ - logger.info(f"found: {len(all_applications)} applications to update in the db") + logger.info("found: %s applications to update in the db", len(all_applications)) metadata_data = [path for path in paths_to_update if path.file_type == "metadata"] graph_data = [path for path in paths_to_update if path.file_type == "graph"] metadata_objects = await self._gather_metadata_files(metadata_data) @@ -484,7 +555,7 @@ async def _update_high_watermark( async def _scan_and_update_db_for_project( self, project: Project, indexing_job: IndexingJob - ) -> Tuple[IndexStatus, int]: + ) -> tuple[IndexStatus, int]: """Scans and updates the database for a project. TODO -- break this up into functions @@ -499,7 +570,7 @@ async def _scan_and_update_db_for_project( ) # This way we can sort by the latest captured time high_watermark = current_status.s3_highwatermark if current_status is not None else "" - logger.info(f"Scanning db with highwatermark: {high_watermark}") + logger.info("Scanning db with highwatermark: %s", high_watermark) paths_to_update = await self._gather_paths_to_update( project=project, high_watermark_s3_path=high_watermark ) @@ -524,7 +595,7 @@ async def _scan_and_update_db(self): # TODO -- add error catching status, num_files = await self._scan_and_update_db_for_project(project, indexing_job) - logger.info(f"Scanned: {num_files} files with status stored at ID={status.id}") + logger.info("Scanned: %s files with status stored at ID=%s", num_files, status.id) indexing_job.records_processed = num_files indexing_job.end_time = system.now() @@ -571,7 +642,7 @@ async def list_apps( partition_key: Optional[str], limit: int = 100, offset: int = 0, - ) -> Tuple[Sequence[schema.ApplicationSummary], int]: + ) -> tuple[Sequence[schema.ApplicationSummary], int]: # TODO -- distinctify between project name and project ID # Currently they're the same in the UI but we'll want to have them decoupled app_query = ( @@ -631,13 +702,22 @@ async def get_application_logs( "-created_at" ) async with self._session.create_client("s3") as client: - # Get all the files + # Get all the files (BIP-0042: use buffered reading for large files) files = await utils.gather_with_concurrency( 1, - _query_s3_file(self._bucket, application.graph_file_pointer, client), - # _query_s3_files(self.bucket, application.metadata_file_pointer, client), + _query_s3_file( + self._bucket, + application.graph_file_pointer, + client, + self._s3_buffer_size_mb, + ), *itertools.chain( - _query_s3_file(self._bucket, log_file.s3_path, client) + _query_s3_file( + self._bucket, + log_file.s3_path, + client, + self._s3_buffer_size_mb, + ) for log_file in application_logs ), ) @@ -656,6 +736,114 @@ async def get_application_logs( application=graph_data, ) + # BIP-0042: Event-driven tracking methods + async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> None: + """Handle a single S3 event notification - index the file immediately. + + :param s3_key: The S3 object key from the event + :param event_time: When the event occurred + """ + try: + data_file = DataFile.from_path(s3_key, created_date=event_time) + # Path structure: data/{project}/yyyy/mm/dd/hh/minutes/pk/app_id/filename + project_name = s3_key.split("/")[1] + + project = await Project.filter(name=project_name).first() + if project is None: + logger.info("Creating project %s from S3 event", project_name) + project = await Project.create( + name=project_name, + uri=None, + created_at=event_time, + indexed_at=event_time, + updated_at=event_time, + ) + + all_applications = await self._ensure_applications_exist([data_file], project) + await self._update_all_applications(all_applications, [data_file]) + await self.update_log_files([data_file], all_applications) + + logger.info("Indexed S3 event: %s", s3_key) + except Exception as e: + logger.error("Failed to handle S3 event %s: %s", s3_key, e) + raise # Re-raise so message stays in queue for retry / DLQ + + async def start_event_consumer(self) -> None: + """Start the event consumer for event-driven tracking. + + Runs indefinitely, processing S3 event notifications from the configured + message queue. Handles both EventBridge and direct S3 notification formats. + """ + if self._tracking_mode != TrackingMode.EVENT_DRIVEN or not self._sqs_queue_url: + logger.info("Event consumer not configured, skipping") + return + + logger.info("Starting event consumer for queue: %s", self._sqs_queue_url) + + async with self._session.create_client("sqs", region_name=self._sqs_region) as sqs_client: + try: + while True: + try: + response = await sqs_client.receive_message( + QueueUrl=self._sqs_queue_url, + MaxNumberOfMessages=10, + WaitTimeSeconds=self._sqs_wait_time_seconds, + VisibilityTimeout=300, + ) + + messages = response.get("Messages", []) + for message in messages: + try: + body = json.loads(message["Body"]) + s3_key = None + event_time = None + + # Handle EventBridge wrapped S3 events (one event per message) + if "detail" in body: + s3_keys_with_times = [ + ( + body["detail"]["object"]["key"], + datetime.datetime.fromisoformat( + body["time"].replace("Z", "+00:00") + ), + ) + ] + elif "Records" in body: + s3_keys_with_times = [ + ( + record["s3"]["object"]["key"], + datetime.datetime.fromisoformat( + record["eventTime"].replace("Z", "+00:00") + ), + ) + for record in body["Records"] + ] + else: + logger.warning("Unknown message format: %s", body) + continue + + for s3_key, event_time in s3_keys_with_times: + if s3_key and s3_key.endswith(".jsonl"): + await self._handle_s3_event(s3_key, event_time) + + await sqs_client.delete_message( + QueueUrl=self._sqs_queue_url, + ReceiptHandle=message["ReceiptHandle"], + ) + except Exception as e: + logger.error("Failed to process SQS message: %s", e) + + except Exception as e: + logger.error("Event consumer error: %s", e) + await asyncio.sleep(5) + except asyncio.CancelledError: + logger.info("Event consumer shutting down") + raise + + def is_event_driven(self) -> bool: + """Check if this backend is configured for event-driven updates.""" + return self._tracking_mode == TrackingMode.EVENT_DRIVEN and self._sqs_queue_url is not None + async def indexing_jobs( self, offset: int = 0, limit: int = 100, filter_empty: bool = True ) -> Sequence[schema.IndexingJob]: diff --git a/burr/version.py b/burr/version.py index 8555b4cce..bc7c5aa75 100644 --- a/burr/version.py +++ b/burr/version.py @@ -20,5 +20,9 @@ try: __version__ = importlib.metadata.version("apache-burr") except importlib.metadata.PackageNotFoundError: - # Fallback for older installations or development - __version__ = importlib.metadata.version("burr") + try: + # Fallback for older installations + __version__ = importlib.metadata.version("burr") + except importlib.metadata.PackageNotFoundError: + # Development / source tree: no package metadata + __version__ = "0.0.0.dev" diff --git a/examples/deployment/aws/terraform/.gitignore b/examples/deployment/aws/terraform/.gitignore new file mode 100644 index 000000000..00a20986e --- /dev/null +++ b/examples/deployment/aws/terraform/.gitignore @@ -0,0 +1,11 @@ +# Terraform +.terraform/ +.terraform.lock.hcl +*.tfstate +*.tfstate.* +.terraform.tfstate.lock.info +*.tfplan +crash.log +override.tf +override.tf.json +*.tfvars.backup diff --git a/examples/deployment/aws/terraform/dev.tfvars b/examples/deployment/aws/terraform/dev.tfvars new file mode 100644 index 000000000..86378ba96 --- /dev/null +++ b/examples/deployment/aws/terraform/dev.tfvars @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Development environment configuration +# Bucket name is auto-generated: burr-tracking-{env}-{region}-{account_id}-{random} +# account_id: leave empty to auto-fetch from AWS credentials, or set explicitly + +aws_region = "us-east-1" +environment = "dev" + +# account_id = "" # Optional. Empty = auto-fetch. Or set: account_id = "123456789012" + +sqs_queue_name = "burr-s3-events-dev" + +# S3 only (polling mode) - simpler for dev; set to true for event-driven +enable_sqs = false + +log_retention_days = 30 +snapshot_retention_days = 14 + +sqs_message_retention_seconds = 86400 +sqs_visibility_timeout_seconds = 120 +sqs_receive_wait_time_seconds = 20 +sqs_max_receive_count = 3 diff --git a/examples/deployment/aws/terraform/main.tf b/examples/deployment/aws/terraform/main.tf new file mode 100644 index 000000000..7c6b5bbc2 --- /dev/null +++ b/examples/deployment/aws/terraform/main.tf @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +terraform { + required_version = ">= 1.0" + + required_providers { + aws = { + source = "hashicorp/aws" + version = ">= 5.0" + } + random = { + source = "hashicorp/random" + version = ">= 3.0" + } + } +} + +provider "aws" { + region = var.aws_region +} + +data "aws_caller_identity" "current" {} +data "aws_region" "current" {} + +resource "random_id" "bucket_suffix" { + byte_length = 4 +} + +locals { + region_short = replace(data.aws_region.current.name, "-", "") + account_id = var.account_id != "" ? var.account_id : data.aws_caller_identity.current.account_id + auto_bucket = "burr-tracking-${var.environment}-${local.region_short}-${local.account_id}-${random_id.bucket_suffix.hex}" + bucket_name = var.s3_bucket_name != "" ? var.s3_bucket_name : local.auto_bucket +} + +module "s3" { + source = "./modules/s3" + + bucket_name = local.bucket_name + tags = local.common_tags + + lifecycle_rules = [ + { + id = "expire-old-logs" + prefix = "data/" + enabled = true + expiration_days = var.log_retention_days + noncurrent_days = 7 + }, + { + id = "expire-old-snapshots" + prefix = "snapshots/" + enabled = true + expiration_days = var.snapshot_retention_days + noncurrent_days = null + } + ] +} + +module "sqs" { + source = "./modules/sqs" + count = var.enable_sqs ? 1 : 0 + + queue_name = var.sqs_queue_name + message_retention_seconds = var.sqs_message_retention_seconds + visibility_timeout_seconds = var.sqs_visibility_timeout_seconds + receive_wait_time_seconds = var.sqs_receive_wait_time_seconds + max_receive_count = var.sqs_max_receive_count + tags = local.common_tags +} + +resource "aws_sqs_queue_policy" "s3_notifications" { + count = var.enable_sqs ? 1 : 0 + + queue_url = module.sqs[0].queue_id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Sid = "AllowS3Notifications" + Effect = "Allow" + Principal = { + Service = "s3.amazonaws.com" + } + Action = "sqs:SendMessage" + Resource = module.sqs[0].queue_arn + Condition = { + ArnLike = { + "aws:SourceArn" = module.s3.bucket_arn + } + } + } + ] + }) +} + +resource "aws_s3_bucket_notification" "burr_logs" { + count = var.enable_sqs ? 1 : 0 + + bucket = module.s3.bucket_id + + queue { + queue_arn = module.sqs[0].queue_arn + events = ["s3:ObjectCreated:*"] + filter_prefix = "data/" + filter_suffix = ".jsonl" + } + + depends_on = [aws_sqs_queue_policy.s3_notifications] +} + +resource "aws_sns_topic" "dlq_alarm" { + count = var.enable_sqs ? 1 : 0 + + name = "${var.environment}-burr-dlq-alarm" + display_name = "Burr DLQ Alarm - ${var.environment}" + tags = local.common_tags +} + +resource "aws_sns_topic_subscription" "dlq_alarm_email" { + count = var.enable_sqs && length(var.dlq_alarm_notification_emails) > 0 ? length(var.dlq_alarm_notification_emails) : 0 + + topic_arn = aws_sns_topic.dlq_alarm[0].arn + protocol = "email" + endpoint = var.dlq_alarm_notification_emails[count.index] +} + +resource "aws_cloudwatch_metric_alarm" "dlq_messages" { + count = var.enable_sqs ? 1 : 0 + + alarm_name = "${var.environment}-burr-dlq-messages" + alarm_description = "Alarm when messages appear in Burr SQS dead letter queue" + comparison_operator = "GreaterThanThreshold" + evaluation_periods = 1 + metric_name = "ApproximateNumberOfMessagesVisible" + namespace = "AWS/SQS" + period = 60 + statistic = "Sum" + threshold = 0 + + alarm_actions = [aws_sns_topic.dlq_alarm[0].arn] + ok_actions = [aws_sns_topic.dlq_alarm[0].arn] + + dimensions = { + QueueName = module.sqs[0].dlq_name + } + + tags = local.common_tags +} + +module "iam" { + source = "./modules/iam" + + role_name = "${var.environment}-burr-server-role" + s3_bucket_arn = module.s3.bucket_arn + sqs_queue_arn = var.enable_sqs ? module.sqs[0].queue_arn : "" + enable_sqs = var.enable_sqs + tags = local.common_tags +} + +locals { + common_tags = { + Environment = var.environment + Project = "burr-tracking" + ManagedBy = "terraform" + } +} diff --git a/examples/deployment/aws/terraform/modules/iam/main.tf b/examples/deployment/aws/terraform/modules/iam/main.tf new file mode 100644 index 000000000..b63284f19 --- /dev/null +++ b/examples/deployment/aws/terraform/modules/iam/main.tf @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +data "aws_iam_policy_document" "assume_role" { + statement { + effect = "Allow" + actions = ["sts:AssumeRole"] + principals { + type = "Service" + identifiers = var.trusted_services + } + } +} + +resource "aws_iam_role" "burr_server" { + name = var.role_name + assume_role_policy = data.aws_iam_policy_document.assume_role.json + + tags = merge(var.tags, { + Name = var.role_name + }) +} + +data "aws_iam_policy_document" "s3_least_privilege" { + statement { + sid = "S3ListBucket" + effect = "Allow" + actions = [ + "s3:ListBucket", + "s3:GetBucketLocation" + ] + resources = [var.s3_bucket_arn] + } + + statement { + sid = "S3ObjectOperations" + effect = "Allow" + actions = [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:HeadObject" + ] + resources = ["${var.s3_bucket_arn}/*"] + } +} + +resource "aws_iam_role_policy" "s3" { + name = "${var.role_name}-s3" + role = aws_iam_role.burr_server.id + policy = data.aws_iam_policy_document.s3_least_privilege.json +} + +data "aws_iam_policy_document" "sqs_least_privilege" { + count = var.enable_sqs ? 1 : 0 + + statement { + sid = "SQSConsume" + effect = "Allow" + actions = [ + "sqs:ReceiveMessage", + "sqs:DeleteMessage", + "sqs:GetQueueAttributes" + ] + resources = [var.sqs_queue_arn] + } +} + +resource "aws_iam_role_policy" "sqs" { + count = var.enable_sqs ? 1 : 0 + name = "${var.role_name}-sqs" + role = aws_iam_role.burr_server.id + policy = data.aws_iam_policy_document.sqs_least_privilege[0].json +} + diff --git a/examples/deployment/aws/terraform/modules/iam/outputs.tf b/examples/deployment/aws/terraform/modules/iam/outputs.tf new file mode 100644 index 000000000..ccf3003e6 --- /dev/null +++ b/examples/deployment/aws/terraform/modules/iam/outputs.tf @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "role_arn" { + description = "ARN of the IAM role" + value = aws_iam_role.burr_server.arn +} + +output "role_name" { + description = "Name of the IAM role" + value = aws_iam_role.burr_server.name +} diff --git a/examples/deployment/aws/terraform/modules/iam/variables.tf b/examples/deployment/aws/terraform/modules/iam/variables.tf new file mode 100644 index 000000000..9a2e83cc9 --- /dev/null +++ b/examples/deployment/aws/terraform/modules/iam/variables.tf @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "role_name" { + description = "Name of the IAM role for Burr server" + type = string +} + +variable "trusted_services" { + description = "List of AWS services that can assume this role" + type = list(string) + default = ["ecs-tasks.amazonaws.com", "ec2.amazonaws.com", "lambda.amazonaws.com"] +} + +variable "s3_bucket_arn" { + description = "ARN of the S3 bucket for least privilege access" + type = string +} + +variable "enable_sqs" { + description = "Enable SQS IAM permissions" + type = bool + default = true +} + +variable "sqs_queue_arn" { + description = "ARN of the SQS queue for least privilege access" + type = string + default = "" +} + +variable "tags" { + description = "Tags to apply to resources" + type = map(string) + default = {} +} diff --git a/examples/deployment/aws/terraform/modules/s3/main.tf b/examples/deployment/aws/terraform/modules/s3/main.tf new file mode 100644 index 000000000..67163ee09 --- /dev/null +++ b/examples/deployment/aws/terraform/modules/s3/main.tf @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +resource "aws_s3_bucket" "this" { + bucket = var.bucket_name + + tags = merge(var.tags, { + Name = var.bucket_name + }) +} + +resource "aws_s3_bucket_versioning" "this" { + bucket = aws_s3_bucket.this.id + + versioning_configuration { + status = "Enabled" + } +} + +resource "aws_s3_bucket_server_side_encryption_configuration" "this" { + bucket = aws_s3_bucket.this.id + + rule { + apply_server_side_encryption_by_default { + sse_algorithm = "AES256" + } + } +} + +resource "aws_s3_bucket_lifecycle_configuration" "this" { + bucket = aws_s3_bucket.this.id + + dynamic "rule" { + for_each = var.lifecycle_rules + content { + id = rule.value.id + status = rule.value.enabled ? "Enabled" : "Disabled" + + filter { + prefix = rule.value.prefix + } + + expiration { + days = rule.value.expiration_days + } + + dynamic "noncurrent_version_expiration" { + for_each = try(rule.value.noncurrent_days, null) != null ? [1] : [] + content { + noncurrent_days = rule.value.noncurrent_days + } + } + } + } +} + +resource "aws_s3_bucket_public_access_block" "this" { + bucket = aws_s3_bucket.this.id + + block_public_acls = true + block_public_policy = true + ignore_public_acls = true + restrict_public_buckets = true +} diff --git a/examples/deployment/aws/terraform/modules/s3/outputs.tf b/examples/deployment/aws/terraform/modules/s3/outputs.tf new file mode 100644 index 000000000..5ffc964b6 --- /dev/null +++ b/examples/deployment/aws/terraform/modules/s3/outputs.tf @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "bucket_id" { + description = "ID of the S3 bucket" + value = aws_s3_bucket.this.id +} + +output "bucket_arn" { + description = "ARN of the S3 bucket" + value = aws_s3_bucket.this.arn +} diff --git a/examples/deployment/aws/terraform/modules/s3/variables.tf b/examples/deployment/aws/terraform/modules/s3/variables.tf new file mode 100644 index 000000000..580cc967d --- /dev/null +++ b/examples/deployment/aws/terraform/modules/s3/variables.tf @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "bucket_name" { + description = "Name of the S3 bucket" + type = string +} + +variable "lifecycle_rules" { + description = "List of lifecycle rules for the bucket" + type = list(object({ + id = string + prefix = string + enabled = bool + expiration_days = number + noncurrent_days = optional(number) + })) +} + +variable "tags" { + description = "Tags to apply to resources" + type = map(string) + default = {} +} diff --git a/examples/deployment/aws/terraform/modules/sqs/main.tf b/examples/deployment/aws/terraform/modules/sqs/main.tf new file mode 100644 index 000000000..4eb3bea9d --- /dev/null +++ b/examples/deployment/aws/terraform/modules/sqs/main.tf @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +resource "aws_sqs_queue" "main" { + name = var.queue_name + message_retention_seconds = var.message_retention_seconds + visibility_timeout_seconds = var.visibility_timeout_seconds + receive_wait_time_seconds = var.receive_wait_time_seconds + + tags = merge(var.tags, { + Name = var.queue_name + }) +} + +resource "aws_sqs_queue" "dlq" { + name = "${var.queue_name}-dlq" + message_retention_seconds = var.dlq_message_retention_seconds + + tags = merge(var.tags, { + Name = "${var.queue_name}-dlq" + }) +} + +resource "aws_sqs_queue_redrive_policy" "main" { + queue_url = aws_sqs_queue.main.id + redrive_policy = jsonencode({ + deadLetterTargetArn = aws_sqs_queue.dlq.arn + maxReceiveCount = var.max_receive_count + }) +} diff --git a/examples/deployment/aws/terraform/modules/sqs/outputs.tf b/examples/deployment/aws/terraform/modules/sqs/outputs.tf new file mode 100644 index 000000000..5b7ccd098 --- /dev/null +++ b/examples/deployment/aws/terraform/modules/sqs/outputs.tf @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "queue_id" { + description = "URL of the SQS queue" + value = aws_sqs_queue.main.id +} + +output "queue_url" { + description = "URL of the SQS queue" + value = aws_sqs_queue.main.url +} + +output "queue_arn" { + description = "ARN of the SQS queue" + value = aws_sqs_queue.main.arn +} + +output "dlq_url" { + description = "URL of the dead letter queue" + value = aws_sqs_queue.dlq.url +} + +output "dlq_arn" { + description = "ARN of the dead letter queue" + value = aws_sqs_queue.dlq.arn +} + +output "dlq_name" { + description = "Name of the dead letter queue (for CloudWatch dimensions)" + value = aws_sqs_queue.dlq.name +} diff --git a/examples/deployment/aws/terraform/modules/sqs/variables.tf b/examples/deployment/aws/terraform/modules/sqs/variables.tf new file mode 100644 index 000000000..47e67f3ba --- /dev/null +++ b/examples/deployment/aws/terraform/modules/sqs/variables.tf @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "queue_name" { + description = "Name of the SQS queue" + type = string +} + +variable "message_retention_seconds" { + description = "Message retention period in seconds" + type = number + default = 1209600 +} + +variable "visibility_timeout_seconds" { + description = "Visibility timeout for messages in seconds" + type = number + default = 300 +} + +variable "receive_wait_time_seconds" { + description = "Long polling wait time in seconds" + type = number + default = 20 +} + +variable "dlq_message_retention_seconds" { + description = "DLQ message retention period in seconds" + type = number + default = 1209600 +} + +variable "max_receive_count" { + description = "Max receive count before message moves to DLQ" + type = number + default = 3 +} + +variable "tags" { + description = "Tags to apply to resources" + type = map(string) + default = {} +} diff --git a/examples/deployment/aws/terraform/outputs.tf b/examples/deployment/aws/terraform/outputs.tf new file mode 100644 index 000000000..627a98bc0 --- /dev/null +++ b/examples/deployment/aws/terraform/outputs.tf @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "s3_bucket_name" { + description = "Name of the S3 bucket for Burr logs" + value = module.s3.bucket_id +} + +output "s3_bucket_arn" { + description = "ARN of the S3 bucket" + value = module.s3.bucket_arn +} + +output "sqs_queue_url" { + description = "URL of the SQS queue for S3 events" + value = var.enable_sqs ? module.sqs[0].queue_url : null +} + +output "sqs_queue_arn" { + description = "ARN of the SQS queue" + value = var.enable_sqs ? module.sqs[0].queue_arn : null +} + +output "sqs_dlq_url" { + description = "URL of the dead letter queue" + value = var.enable_sqs ? module.sqs[0].dlq_url : null +} + +output "dlq_alarm_arn" { + description = "ARN of the CloudWatch alarm for DLQ messages" + value = var.enable_sqs ? aws_cloudwatch_metric_alarm.dlq_messages[0].arn : null +} + +output "dlq_alarm_sns_topic_arn" { + description = "ARN of the SNS topic for DLQ alarm notifications" + value = var.enable_sqs ? aws_sns_topic.dlq_alarm[0].arn : null +} + +output "iam_role_arn" { + description = "ARN of the IAM role for Burr server" + value = module.iam.role_arn +} + +output "iam_role_name" { + description = "Name of the IAM role for Burr server" + value = module.iam.role_name +} + +output "burr_environment_variables" { + description = "Environment variables to configure Burr server" + value = var.enable_sqs ? { + BURR_S3_BUCKET = module.s3.bucket_id + BURR_TRACKING_MODE = "EVENT_DRIVEN" + BURR_SQS_QUEUE_URL = module.sqs[0].queue_url + BURR_SQS_REGION = data.aws_region.current.name + BURR_SQS_WAIT_TIME_SECONDS = "20" + BURR_S3_BUFFER_SIZE_MB = "10" + } : { + BURR_S3_BUCKET = module.s3.bucket_id + BURR_TRACKING_MODE = "POLLING" + BURR_SQS_QUEUE_URL = "" + BURR_SQS_REGION = data.aws_region.current.name + BURR_SQS_WAIT_TIME_SECONDS = "20" + BURR_S3_BUFFER_SIZE_MB = "10" + } +} diff --git a/examples/deployment/aws/terraform/prod.tfvars b/examples/deployment/aws/terraform/prod.tfvars new file mode 100644 index 000000000..43b9e8a95 --- /dev/null +++ b/examples/deployment/aws/terraform/prod.tfvars @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Production environment configuration +# Bucket name is auto-generated: burr-tracking-{env}-{region}-{account_id}-{random} +# account_id: leave empty to auto-fetch from AWS credentials, or set explicitly + +aws_region = "us-east-1" +environment = "prod" + +# account_id = "" # Optional. Empty = auto-fetch. Or set: account_id = "123456789012" + +sqs_queue_name = "burr-s3-events-prod" + +enable_sqs = true + +log_retention_days = 90 +snapshot_retention_days = 30 + +sqs_message_retention_seconds = 1209600 +sqs_visibility_timeout_seconds = 300 +sqs_receive_wait_time_seconds = 20 +sqs_max_receive_count = 3 + +# Optional: receive email when messages land in DLQ +# dlq_alarm_notification_emails = ["ops@example.com"] diff --git a/examples/deployment/aws/terraform/tutorial.md b/examples/deployment/aws/terraform/tutorial.md new file mode 100644 index 000000000..93883f90d --- /dev/null +++ b/examples/deployment/aws/terraform/tutorial.md @@ -0,0 +1,212 @@ +# Apache Burr AWS Tracking Infrastructure Tutorial + +This tutorial explains how to deploy Apache Burr tracking infrastructure on AWS using Terraform. All Terraform code lives in `examples/deployment/aws/terraform/`. It covers deployment with S3 only (polling mode), with S3 and SQS (event-driven mode), and local development without AWS. + +## Quick Start + +```bash +cd examples/deployment/aws/terraform +terraform init +terraform apply -var-file=dev.tfvars # S3 only, polling mode +# or +terraform apply -var-file=prod.tfvars # S3 + SQS, event-driven + DLQ alarm +``` + +Bucket names are auto-generated. After apply, run `terraform output burr_environment_variables` and set those on your Burr server. + +## Overview + +The Terraform configuration provisions: + +- **S3 bucket**: Stores Burr application logs and database snapshots. Name is auto-generated (`burr-tracking-{env}-{region}-{account_id}-{random}`) when not specified. +- **SQS queue** (optional): Receives S3 event notifications for real-time tracking; controlled by `enable_sqs` +- **CloudWatch alarm + SNS**: Alerts when messages land in the dead letter queue; optional email subscriptions +- **IAM role**: Least-privilege permissions for the Burr server + +## Directory Structure + +All code is in `examples/deployment/aws/terraform/`: + +``` +examples/deployment/aws/terraform/ +├── main.tf # Root module: S3, SQS, CloudWatch alarm, SNS, IAM +├── variables.tf # Input variables +├── outputs.tf # Output values +├── dev.tfvars # Development: S3 only (enable_sqs = false) +├── prod.tfvars # Production: S3 + SQS + DLQ alarm (enable_sqs = true) +├── tutorial.md # This file +└── modules/ + ├── s3/ # S3 bucket with versioning, encryption, lifecycle + ├── sqs/ # SQS queue with DLQ and redrive policy + └── iam/ # IAM role with least-privilege policies +``` + +## Prerequisites + +- Terraform >= 1.0 +- AWS CLI configured with credentials + +No manual bucket naming required; names are auto-generated. `account_id` is fetched from AWS credentials when not set. For a custom bucket name, set `s3_bucket_name` in your tfvars. + +## Using tfvars Files + +| File | Mode | enable_sqs | Resources created | +|-------------|-------------------|------------|--------------------------------------------------------| +| dev.tfvars | S3 only (polling) | false | S3 bucket, IAM role | +| prod.tfvars | S3 + SQS (event) | true | S3 bucket, SQS queue, DLQ, CloudWatch alarm, SNS, IAM | + +### Development (dev.tfvars) - S3 Only + +Uses S3 polling mode (no SQS). Bucket name is auto-generated (`burr-tracking-{env}-{region}-{account_id}-{random}`). Override with `s3_bucket_name = "my-bucket"` in tfvars if needed. + +Deploy: + +```bash +cd examples/deployment/aws/terraform +terraform init +terraform plan -var-file=dev.tfvars +terraform apply -var-file=dev.tfvars +``` + +### Production (prod.tfvars) - S3 + SQS + +Uses event-driven mode with SQS. Bucket name is auto-generated (`burr-tracking-{env}-{region}-{account_id}-{random}`). A CloudWatch alarm fires when messages land in the DLQ. + +Deploy: + +```bash +terraform plan -var-file=prod.tfvars +terraform apply -var-file=prod.tfvars +``` + +### Override Mode in Any tfvars + +To deploy with SQS using dev.tfvars, override: `terraform apply -var-file=dev.tfvars -var="enable_sqs=true"`. To deploy S3-only with prod.tfvars: `terraform apply -var-file=prod.tfvars -var="enable_sqs=false"`. + +## Deployment Modes + +### With S3 and SQS (Event-Driven Mode) + +Default configuration. Provides near-instant telemetry updates (~200ms latency). + +1. Set `enable_sqs = true` in your tfvars (e.g. prod.tfvars). +2. Deploy with `terraform apply -var-file=prod.tfvars`. +3. Configure the Burr server with the output environment variables: + +```bash +terraform output burr_environment_variables +``` + +4. Set these on your Burr server (ECS task, EC2, etc.): + +- BURR_S3_BUCKET +- BURR_TRACKING_MODE=EVENT_DRIVEN +- BURR_SQS_QUEUE_URL +- BURR_SQS_REGION +- BURR_SQS_WAIT_TIME_SECONDS +- BURR_S3_BUFFER_SIZE_MB + +### With S3 Only (Polling Mode) + +Use when you prefer simpler infrastructure or cannot use SQS. Burr polls S3 periodically (default 120 seconds). + +1. Set `enable_sqs = false` in your tfvars. +2. Deploy: + +```bash +terraform apply -var-file=dev.tfvars +``` + +3. Configure the Burr server: + +- BURR_S3_BUCKET +- BURR_TRACKING_MODE=POLLING +- BURR_SQS_QUEUE_URL="" (leave empty) +- BURR_SQS_REGION +- BURR_S3_BUFFER_SIZE_MB + +The Terraform will create only the S3 bucket and IAM role. No SQS queue or S3 event notifications. + +### Without S3 and SQS (Local Mode) + +For local development, no Terraform deployment is needed. Burr uses the local filesystem for tracking. + +1. Run the Burr server locally: + +```bash +burr --no-open +``` + +2. Use `LocalTrackingClient` in your application instead of `S3TrackingClient`. + +3. Data is stored in `~/.burr` by default. + +## Key Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| aws_region | AWS region | us-east-1 | +| environment | Environment name (dev, prod) | dev | +| account_id | AWS account ID. Empty = auto-fetch from credentials | "" | +| s3_bucket_name | S3 bucket name. Empty = auto-generated (env, region, account_id, random) | "" | +| enable_sqs | Create SQS for event-driven tracking | true | +| sqs_queue_name | Name of the SQS queue | burr-s3-events | +| log_retention_days | Days to retain logs in S3 | 90 | +| snapshot_retention_days | Days to retain DB snapshots | 30 | +| dlq_alarm_notification_emails | Emails to notify when DLQ has messages (confirm via AWS email) | [] | + +## CloudWatch DLQ Alarm and SNS Notifications + +When SQS is enabled, a CloudWatch alarm fires when messages appear in the dead letter queue. An SNS topic is created for notifications. To receive email alerts, add your addresses to `dlq_alarm_notification_emails` in your tfvars: + +``` +dlq_alarm_notification_emails = ["ops@example.com", "oncall@example.com"] +``` + +Each email will receive a confirmation request from AWS; you must confirm the subscription before alerts are delivered. To use Slack or other endpoints, subscribe them to the SNS topic ARN (see `terraform output dlq_alarm_sns_topic_arn`) after apply. + +## Outputs + +After apply, useful outputs: + +```bash +terraform output s3_bucket_name +terraform output sqs_queue_url +terraform output sqs_dlq_url +terraform output dlq_alarm_arn +terraform output dlq_alarm_sns_topic_arn +terraform output burr_environment_variables +``` + +## IAM Least Privilege + +The IAM role grants only: + +- **S3**: ListBucket, GetBucketLocation, GetObject, PutObject, DeleteObject, HeadObject on the specific bucket +- **SQS** (when enabled): ReceiveMessage, DeleteMessage, GetQueueAttributes on the specific queue + +## Cleanup + +To destroy all resources: + +```bash +terraform destroy -var-file=dev.tfvars +``` + +For S3 buckets with versioning, you may need to empty the bucket first: + +```bash +aws s3api list-object-versions --bucket BUCKET_NAME --output json | jq -r '.Versions[],.DeleteMarkers[]|.Key+" "+.VersionId' | while read key vid; do aws s3api delete-object --bucket BUCKET_NAME --key "$key" --version-id "$vid"; done +``` + +## Troubleshooting + +**S3 bucket name already exists**: S3 bucket names are globally unique. With auto-generation, each apply gets a new random suffix. For a fixed name, set `s3_bucket_name` explicitly. + +**SQS policy errors**: Ensure the S3 bucket notification depends on the queue policy. The Terraform handles this with `depends_on`. + +**Burr server not receiving events**: Verify BURR_SQS_QUEUE_URL is set and the IAM role has sqs:ReceiveMessage. Check CloudWatch for the SQS consumer. + +**DLQ alarm firing**: Messages in the DLQ mean the Burr server failed to process S3 events (e.g. crashed, timeout). Check the DLQ in the AWS Console, inspect failed messages, and fix the root cause. Confirm SNS email subscriptions via the link AWS sends. + +**No email from DLQ alarm**: Check your spam folder for the SNS confirmation email. Subscriptions are pending until confirmed. diff --git a/examples/deployment/aws/terraform/variables.tf b/examples/deployment/aws/terraform/variables.tf new file mode 100644 index 000000000..0af4960ab --- /dev/null +++ b/examples/deployment/aws/terraform/variables.tf @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "aws_region" { + description = "AWS region for resources" + type = string + default = "us-east-1" +} + +variable "environment" { + description = "Environment name (dev, staging, prod)" + type = string + default = "dev" +} + +variable "account_id" { + description = "AWS account ID for bucket name. Leave empty to auto-fetch from AWS credentials." + type = string + default = "" +} + +variable "s3_bucket_name" { + description = "Name of the S3 bucket for Burr logs. If empty, auto-generated from environment, region, and random suffix." + type = string + default = "" +} + +variable "enable_sqs" { + description = "Enable SQS for event-driven tracking. When false, Burr uses S3 polling mode." + type = bool + default = true +} + +variable "sqs_queue_name" { + description = "Name of the SQS queue for S3 events" + type = string + default = "burr-s3-events" +} + +variable "log_retention_days" { + description = "Days to retain log files in S3" + type = number + default = 90 +} + +variable "snapshot_retention_days" { + description = "Days to retain database snapshots in S3" + type = number + default = 30 +} + +variable "sqs_message_retention_seconds" { + description = "SQS message retention period in seconds" + type = number + default = 1209600 +} + +variable "sqs_visibility_timeout_seconds" { + description = "SQS visibility timeout in seconds" + type = number + default = 300 +} + +variable "sqs_receive_wait_time_seconds" { + description = "SQS long polling wait time in seconds" + type = number + default = 20 +} + +variable "sqs_max_receive_count" { + description = "Max receive count before message moves to DLQ" + type = number + default = 3 +} + +variable "dlq_alarm_notification_emails" { + description = "Email addresses to notify when messages land in the DLQ. Empty = no email subscriptions." + type = list(string) + default = [] +} diff --git a/pyproject.toml b/pyproject.toml index 3cb698ede..4704cfb2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,8 @@ tests = [ "apache-burr[redis]", "apache-burr[opentelemetry]", "apache-burr[haystack]", - "apache-burr[ray]" + "apache-burr[ray]", + "apache-burr[bedrock]" ] documentation = [ @@ -128,6 +129,10 @@ tracking-client-s3 = [ "boto3" ] +bedrock = [ + "boto3" +] + tracking-server-s3 = [ "aerich", "aiobotocore", diff --git a/tests/integrations/test_bip0042_bedrock.py b/tests/integrations/test_bip0042_bedrock.py new file mode 100644 index 000000000..c0c0522d0 --- /dev/null +++ b/tests/integrations/test_bip0042_bedrock.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for Bedrock integration.""" + +import inspect +from unittest.mock import MagicMock + +import pytest + +boto3 = pytest.importorskip("boto3", reason="boto3 required for Bedrock tests") + + +class TestBedrockImports: + """Test that Bedrock classes can be imported via lazy loading.""" + + def test_lazy_import_bedrock_action(self): + """Verify BedrockAction can be imported from burr.integrations.""" + from burr.integrations import BedrockAction + + assert BedrockAction is not None + + def test_lazy_import_bedrock_streaming_action(self): + """Verify BedrockStreamingAction can be imported from burr.integrations.""" + from burr.integrations import BedrockStreamingAction + + assert BedrockStreamingAction is not None + + def test_direct_import_bedrock_module(self): + """Verify bedrock.py module exists and has expected classes.""" + from burr.integrations.bedrock import ( + BedrockAction, + BedrockStreamingAction, + StateToPromptMapper, + ) + + assert BedrockAction is not None + assert BedrockStreamingAction is not None + assert StateToPromptMapper is not None + + +class TestBedrockActionInterface: + """Test BedrockAction class interface with mocked boto3.""" + + def test_bedrock_action_extends_single_step_action(self): + """Verify BedrockAction extends SingleStepAction.""" + from burr.core.action import SingleStepAction + from burr.integrations.bedrock import BedrockAction + + assert issubclass(BedrockAction, SingleStepAction) + + def test_bedrock_streaming_action_extends_streaming_action(self): + """Verify BedrockStreamingAction extends StreamingAction.""" + from burr.core.action import StreamingAction + from burr.integrations.bedrock import BedrockStreamingAction + + assert issubclass(BedrockStreamingAction, StreamingAction) + + def test_bedrock_action_has_required_properties(self): + """Verify BedrockAction has reads, writes, name properties.""" + from burr.integrations.bedrock import BedrockAction + + action = BedrockAction( + model_id="test-model", + input_mapper=lambda s: {"messages": []}, + reads=["input"], + writes=["output"], + ) + assert action.reads == ["input"] + assert action.writes == ["output"] + assert action.name == "bedrock_invoke" + + def test_bedrock_action_accepts_all_parameters(self): + """Verify BedrockAction accepts all specified parameters.""" + from burr.integrations.bedrock import BedrockAction + + sig = inspect.signature(BedrockAction.__init__) + params = list(sig.parameters.keys()) + assert "model_id" in params + assert "input_mapper" in params + assert "reads" in params + assert "writes" in params + assert "name" in params + assert "region" in params + assert "guardrail_id" in params + assert "guardrail_version" in params + assert "inference_config" in params + assert "max_retries" in params + assert "client" in params + + def test_bedrock_action_uses_injected_client(self): + """Verify BedrockAction uses injected client when provided.""" + from burr.integrations.bedrock import BedrockAction + + mock_client = MagicMock() + mock_client.converse.return_value = { + "output": {"message": {"content": [{"text": "hi"}]}}, + "usage": {}, + "stopReason": "end_turn", + } + + action = BedrockAction( + model_id="test-model", + input_mapper=lambda s: {"messages": [{"role": "user", "content": "hi"}]}, + reads=[], + writes=["response"], + client=mock_client, + ) + + result, _ = action.run_and_update({}) + assert result["response"] == "hi" + mock_client.converse.assert_called_once() + + +class TestBedrockStreamingActionInterface: + """Test BedrockStreamingAction class interface with mocked boto3.""" + + def test_bedrock_streaming_action_uses_injected_client(self): + """Verify BedrockStreamingAction uses injected client when provided.""" + from burr.integrations.bedrock import BedrockStreamingAction + + mock_client = MagicMock() + mock_client.converse_stream.return_value = { + "stream": [ + {"contentBlockDelta": {"delta": {"text": "hello "}}}, + {"contentBlockDelta": {"delta": {"text": "world"}}}, + ] + } + + action = BedrockStreamingAction( + model_id="test-model", + input_mapper=lambda s: {"messages": [{"role": "user", "content": "hi"}]}, + reads=[], + writes=["response"], + client=mock_client, + ) + + chunks = list(action.stream_run({})) + assert len(chunks) == 3 # 2 content chunks + 1 complete + assert chunks[0]["chunk"] == "hello " + assert chunks[1]["chunk"] == "world" + assert chunks[2]["complete"] is True + assert chunks[2]["response"] == "hello world" + mock_client.converse_stream.assert_called_once() + + +class TestStateToPromptMapperProtocol: + """Test StateToPromptMapper Protocol exists.""" + + def test_protocol_exists(self): + """Verify StateToPromptMapper Protocol is defined.""" + from burr.integrations.bedrock import StateToPromptMapper + + assert StateToPromptMapper is not None diff --git a/tests/tracking/test_bip0042_s3_buffering.py b/tests/tracking/test_bip0042_s3_buffering.py new file mode 100644 index 000000000..bbe7eadd6 --- /dev/null +++ b/tests/tracking/test_bip0042_s3_buffering.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""BIP-0042: Tests for S3 buffering fix and settings.""" + +import inspect + +import pytest + + +class TestS3Settings: + """Test that S3Settings has all BIP-0042 fields with correct defaults.""" + + def test_s3_settings_has_tracking_mode(self): + """Verify tracking_mode field exists with POLLING default.""" + from burr.tracking.server.s3.backend import S3Settings, TrackingMode + + assert "tracking_mode" in S3Settings.model_fields + assert S3Settings.model_fields["tracking_mode"].default == TrackingMode.POLLING + + def test_s3_settings_has_sqs_queue_url(self): + """Verify sqs_queue_url field exists with None default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "sqs_queue_url" in S3Settings.model_fields + assert S3Settings.model_fields["sqs_queue_url"].default is None + + def test_s3_settings_has_sqs_region(self): + """Verify sqs_region field exists with None default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "sqs_region" in S3Settings.model_fields + assert S3Settings.model_fields["sqs_region"].default is None + + def test_s3_settings_has_sqs_wait_time_seconds(self): + """Verify sqs_wait_time_seconds field exists with 20 default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "sqs_wait_time_seconds" in S3Settings.model_fields + assert S3Settings.model_fields["sqs_wait_time_seconds"].default == 20 + + def test_s3_settings_has_s3_buffer_size_mb(self): + """Verify s3_buffer_size_mb field exists with 10 default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "s3_buffer_size_mb" in S3Settings.model_fields + assert S3Settings.model_fields["s3_buffer_size_mb"].default == 10 + + def test_s3_settings_coerces_sqs_to_event_driven(self): + """Verify legacy 'SQS' string coerces to EVENT_DRIVEN for backward compatibility.""" + from burr.tracking.server.s3.backend import S3Settings, TrackingMode + + settings = S3Settings(s3_bucket="test", tracking_mode="SQS") + assert settings.tracking_mode == TrackingMode.EVENT_DRIVEN + + +class TestSQLiteS3BackendInit: + """Test SQLiteS3Backend accepts and stores BIP-0042 parameters.""" + + def test_backend_accepts_new_parameters(self): + """Verify __init__ accepts all 5 new BIP-0042 parameters.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + sig = inspect.signature(SQLiteS3Backend.__init__) + params = list(sig.parameters.keys()) + + assert "tracking_mode" in params + assert "sqs_queue_url" in params + assert "sqs_region" in params + assert "sqs_wait_time_seconds" in params + assert "s3_buffer_size_mb" in params + + def test_backend_has_event_driven_methods(self): + """Verify SQLiteS3Backend has event-driven methods.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + assert hasattr(SQLiteS3Backend, "_handle_s3_event") + assert hasattr(SQLiteS3Backend, "start_event_consumer") + assert hasattr(SQLiteS3Backend, "is_event_driven") + assert callable(getattr(SQLiteS3Backend, "_handle_s3_event")) + assert callable(getattr(SQLiteS3Backend, "start_event_consumer")) + assert callable(getattr(SQLiteS3Backend, "is_event_driven")) + + +class TestEventDrivenBackendMixin: + """Test EventDrivenBackendMixin exists and has correct interface.""" + + def test_mixin_exists(self): + """Verify EventDrivenBackendMixin exists in backend.py.""" + from burr.tracking.server.backend import EventDrivenBackendMixin + + assert EventDrivenBackendMixin is not None + + def test_mixin_has_abstract_methods(self): + """Verify mixin has abstract start_event_consumer and is_event_driven.""" + import abc + + from burr.tracking.server.backend import EventDrivenBackendMixin + + assert issubclass(EventDrivenBackendMixin, abc.ABC) + assert hasattr(EventDrivenBackendMixin, "start_event_consumer") + assert hasattr(EventDrivenBackendMixin, "is_event_driven") + + def test_sqlite_s3_backend_inherits_mixin(self): + """Verify SQLiteS3Backend inherits from EventDrivenBackendMixin.""" + from burr.tracking.server.backend import EventDrivenBackendMixin + from burr.tracking.server.s3.backend import SQLiteS3Backend + + assert issubclass(SQLiteS3Backend, EventDrivenBackendMixin) + + +class TestQueryS3FileBuffering: + """Test _query_s3_file function signature includes buffer_size_mb.""" + + def test_query_s3_file_has_buffer_param(self): + """Verify _query_s3_file accepts buffer_size_mb parameter.""" + from burr.tracking.server.s3.backend import _query_s3_file + + sig = inspect.signature(_query_s3_file) + params = list(sig.parameters.keys()) + + assert "buffer_size_mb" in params + assert sig.parameters["buffer_size_mb"].default == 10 + + +class TestHandleS3Event: + """Test _handle_s3_event creates project if it doesn't exist.""" + + def test_handle_s3_event_method_exists(self): + """Verify _handle_s3_event method exists and is async.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + assert hasattr(SQLiteS3Backend, "_handle_s3_event") + method = getattr(SQLiteS3Backend, "_handle_s3_event") + assert inspect.iscoroutinefunction(method) + + def test_handle_s3_event_signature(self): + """Verify _handle_s3_event accepts s3_key and event_time parameters.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + sig = inspect.signature(SQLiteS3Backend._handle_s3_event) + params = list(sig.parameters.keys()) + + assert "self" in params + assert "s3_key" in params + assert "event_time" in params