From ccebb1baee8ff3e0b53ae375765cd58657fc83cf Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Sat, 28 Feb 2026 13:19:33 -0600 Subject: [PATCH 1/8] feat: BIP-0042 Cloud-Native Architecture for Apache Burr on AWS (#664) - Event-driven SQS telemetry: S3 notifications to SQS, near-instant updates - Buffered S3 persistence: SpooledTemporaryFile fixes seek errors on large files - Native BedrockAction and BedrockStreamingAction for Bedrock integration - Terraform module: S3, SQS, IAM with dev/prod tfvars and tutorial --- .gitignore | 7 + burr/integrations/__init__.py | 11 + burr/integrations/bedrock.py | 227 ++++++++++++++++++++ burr/tracking/__init__.py | 3 +- burr/tracking/server/backend.py | 25 +++ burr/tracking/server/run.py | 24 ++- burr/tracking/server/s3/backend.py | 167 +++++++++++++- terraform/.gitignore | 10 + terraform/.terraform.lock.hcl | 25 +++ terraform/dev.tfvars | 38 ++++ terraform/main.tf | 132 ++++++++++++ terraform/modules/iam/main.tf | 109 ++++++++++ terraform/modules/iam/outputs.tf | 26 +++ terraform/modules/iam/variables.tf | 67 ++++++ terraform/modules/s3/main.tf | 78 +++++++ terraform/modules/s3/outputs.tf | 26 +++ terraform/modules/s3/variables.tf | 38 ++++ terraform/modules/sqs/main.tf | 44 ++++ terraform/modules/sqs/outputs.tf | 41 ++++ terraform/modules/sqs/variables.tf | 57 +++++ terraform/outputs.tf | 70 ++++++ terraform/prod.tfvars | 37 ++++ terraform/tutorial.md | 194 +++++++++++++++++ terraform/variables.tf | 93 ++++++++ tests/integrations/test_bip0042_bedrock.py | 168 +++++++++++++++ tests/tracking/test_bip0042_s3_buffering.py | 153 +++++++++++++ 26 files changed, 1855 insertions(+), 15 deletions(-) create mode 100644 burr/integrations/bedrock.py create mode 100644 terraform/.gitignore create mode 100644 terraform/.terraform.lock.hcl create mode 100644 terraform/dev.tfvars create mode 100644 terraform/main.tf create mode 100644 terraform/modules/iam/main.tf create mode 100644 terraform/modules/iam/outputs.tf create mode 100644 terraform/modules/iam/variables.tf create mode 100644 terraform/modules/s3/main.tf create mode 100644 terraform/modules/s3/outputs.tf create mode 100644 terraform/modules/s3/variables.tf create mode 100644 terraform/modules/sqs/main.tf create mode 100644 terraform/modules/sqs/outputs.tf create mode 100644 terraform/modules/sqs/variables.tf create mode 100644 terraform/outputs.tf create mode 100644 terraform/prod.tfvars create mode 100644 terraform/tutorial.md create mode 100644 terraform/variables.tf create mode 100644 tests/integrations/test_bip0042_bedrock.py create mode 100644 tests/tracking/test_bip0042_s3_buffering.py diff --git a/.gitignore b/.gitignore index a92fd35b4..c56ee23ed 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,10 @@ burr/tracking/server/build examples/*/statemachine examples/*/*/statemachine .vscode + +# Terraform (see also terraform/.gitignore) +terraform/.terraform/ +terraform/*.tfstate +terraform/*.tfstate.* +terraform/.terraform.tfstate.lock.info +terraform/*.tfplan diff --git a/burr/integrations/__init__.py b/burr/integrations/__init__.py index 13a83393a..956579056 100644 --- a/burr/integrations/__init__.py +++ b/burr/integrations/__init__.py @@ -14,3 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + + +def __getattr__(name: str): + """Lazy load Bedrock integration to avoid requiring boto3 unless used.""" + if name == "BedrockAction": + from burr.integrations.bedrock import BedrockAction + return BedrockAction + if name == "BedrockStreamingAction": + from burr.integrations.bedrock import BedrockStreamingAction + return BedrockStreamingAction + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/burr/integrations/bedrock.py b/burr/integrations/bedrock.py new file mode 100644 index 000000000..9a6353869 --- /dev/null +++ b/burr/integrations/bedrock.py @@ -0,0 +1,227 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Amazon Bedrock integration for Burr. + +BIP-0042: This module provides Action classes for invoking Amazon Bedrock models +within Burr applications. + +Example usage: + from burr.integrations.bedrock import BedrockAction + + def prompt_mapper(state): + return { + "messages": [{"role": "user", "content": state["user_input"]}], + "system": [{"text": "You are a helpful assistant."}], + } + + action = BedrockAction( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + input_mapper=prompt_mapper, + reads=["user_input"], + writes=["response"], + ) +""" + +import logging +from typing import Any, Dict, Generator, List, Optional, Protocol, Tuple + +from burr.core.action import SingleStepAction, StreamingAction +from burr.core.state import State +from burr.integrations.base import require_plugin + +logger = logging.getLogger(__name__) + +try: + import boto3 + from botocore.config import Config + from botocore.exceptions import ClientError +except ImportError as e: + require_plugin(e, "bedrock") + + +class StateToPromptMapper(Protocol): + """Protocol for mapping Burr state to Bedrock prompt format.""" + + def __call__(self, state: State) -> Dict[str, Any]: + ... + + +class BedrockAction(SingleStepAction): + """Action that invokes Amazon Bedrock models using the Converse API.""" + + def __init__( + self, + model_id: str, + input_mapper: StateToPromptMapper, + reads: List[str], + writes: List[str], + name: str = "bedrock_invoke", + region: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + inference_config: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + ): + super().__init__() + self._model_id = model_id + self._input_mapper = input_mapper + self._reads = reads + self._writes = writes + self._name = name + self._region = region + self._guardrail_id = guardrail_id + self._guardrail_version = guardrail_version or "DRAFT" + self._inference_config = inference_config or {"maxTokens": 4096} + + config = Config(retries={"max_attempts": max_retries, "mode": "adaptive"}) + self._client = boto3.client("bedrock-runtime", region_name=region, config=config) + + @property + def reads(self) -> List[str]: + return self._reads + + @property + def writes(self) -> List[str]: + return self._writes + + @property + def name(self) -> str: + return self._name + + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + prompt = self._input_mapper(state) + + request: Dict[str, Any] = { + "modelId": self._model_id, + "messages": prompt["messages"], + "inferenceConfig": self._inference_config, + } + + if "system" in prompt: + request["system"] = prompt["system"] + + if self._guardrail_id: + request["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + } + + try: + response = self._client.converse(**request) + except ClientError as e: + logger.error(f"Bedrock API error: {e}") + raise + + output_message = response["output"]["message"] + content_blocks = output_message.get("content", []) + text = content_blocks[0]["text"] if content_blocks else "" + + result: Dict[str, Any] = { + "response": text, + "usage": response.get("usage", {}), + "stop_reason": response.get("stopReason"), + } + + updates = {key: result[key] for key in self._writes if key in result} + new_state = state.update(**updates) + + return result, new_state + + +class BedrockStreamingAction(StreamingAction): + """Streaming variant of BedrockAction using Converse Stream API.""" + + def __init__( + self, + model_id: str, + input_mapper: StateToPromptMapper, + reads: List[str], + writes: List[str], + name: str = "bedrock_stream", + region: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + inference_config: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + ): + super().__init__() + self._model_id = model_id + self._input_mapper = input_mapper + self._reads = reads + self._writes = writes + self._name = name + self._region = region + self._guardrail_id = guardrail_id + self._guardrail_version = guardrail_version or "DRAFT" + self._inference_config = inference_config or {"maxTokens": 4096} + + config = Config(retries={"max_attempts": max_retries, "mode": "adaptive"}) + self._client = boto3.client("bedrock-runtime", region_name=region, config=config) + + @property + def reads(self) -> List[str]: + return self._reads + + @property + def writes(self) -> List[str]: + return self._writes + + @property + def name(self) -> str: + return self._name + + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + prompt = self._input_mapper(state) + + request: Dict[str, Any] = { + "modelId": self._model_id, + "messages": prompt["messages"], + "inferenceConfig": self._inference_config, + } + + if "system" in prompt: + request["system"] = prompt["system"] + + if self._guardrail_id: + request["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + } + + try: + response = self._client.converse_stream(**request) + except ClientError as e: + logger.error(f"Bedrock streaming API error: {e}") + raise + + full_response = "" + stream = response.get("stream", []) + for event in stream: + if "contentBlockDelta" in event: + chunk = event["contentBlockDelta"]["delta"].get("text", "") + full_response += chunk + yield {"chunk": chunk, "response": full_response} + + yield {"chunk": "", "response": full_response, "complete": True} + + def update(self, result: dict, state: State) -> State: + if result.get("complete"): + updates = {"response": result.get("response", "")} + filtered = {k: v for k, v in updates.items() if k in self._writes} + return state.update(**filtered) + return state diff --git a/burr/tracking/__init__.py b/burr/tracking/__init__.py index bc2581ee6..a62e68fb1 100644 --- a/burr/tracking/__init__.py +++ b/burr/tracking/__init__.py @@ -16,5 +16,6 @@ # under the License. from .client import LocalTrackingClient +from .s3client import S3TrackingClient -__all__ = ["LocalTrackingClient"] +__all__ = ["LocalTrackingClient", "S3TrackingClient"] diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index e33cab9b8..bf422c369 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -162,6 +162,31 @@ def snapshot_interval_milliseconds(self) -> Optional[int]: pass +class EventDrivenBackendMixin(abc.ABC): + """Mixin for backends that support event-driven updates via SQS. + + BIP-0042: This mixin enables backends to receive real-time notifications + from SQS instead of polling S3 for new files. + """ + + @abc.abstractmethod + async def start_sqs_consumer(self): + """Start the SQS consumer for event-driven tracking. + + This method should run indefinitely, processing S3 event notifications + from the configured SQS queue. + """ + pass + + @abc.abstractmethod + def is_event_driven(self) -> bool: + """Check if this backend is configured for event-driven updates. + + :return: True if SQS mode is enabled and configured, False otherwise + """ + pass + + class BackendBase(abc.ABC): async def lifespan(self, app: FastAPI): """Quick tool to allow plugin to the app's lifecycle. diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index 0e5ce62b0..cc9627006 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -29,6 +29,7 @@ from burr.tracking.server.backend import ( AnnotationsBackendMixin, BackendBase, + EventDrivenBackendMixin, IndexingBackendMixin, SnapshottingBackendMixin, ) @@ -128,6 +129,8 @@ async def save_snapshot(): @asynccontextmanager async def lifespan(app: FastAPI): + import asyncio + # Download if it does it # For now we do this before the lifespan await download_snapshot() @@ -135,6 +138,9 @@ async def lifespan(app: FastAPI): await backend.lifespan(app).__anext__() await sync_index() # this will trigger the repeat every N seconds await save_snapshot() # this will trigger the repeat every N seconds + # BIP-0042: Start SQS consumer for event-driven tracking when configured + if isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven(): + asyncio.create_task(backend.start_sqs_consumer()) global initialized initialized = True yield @@ -172,12 +178,18 @@ def get_app_spec(): logger = logging.getLogger(__name__) if app_spec.indexing: - update_interval = backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None - sync_index = repeat_every( - seconds=backend.update_interval_milliseconds() / 1000, - wait_first=True, - logger=logger, - )(sync_index) + # BIP-0042: Only use polling when not in event-driven (SQS) mode + if not ( + isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven() + ): + update_interval = ( + backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None + ) + sync_index = repeat_every( + seconds=backend.update_interval_milliseconds() / 1000, + wait_first=True, + logger=logger, + )(sync_index) if app_spec.snapshotting: snapshot_interval = ( diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index 7f48f88ad..d72ecf0f3 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import asyncio import dataclasses import datetime import functools @@ -23,6 +24,7 @@ import logging import operator import os.path +import tempfile import uuid from collections import Counter from typing import List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union @@ -42,6 +44,7 @@ from burr.tracking.server.backend import ( BackendBase, BurrSettings, + EventDrivenBackendMixin, IndexingBackendMixin, SnapshottingBackendMixin, ) @@ -67,10 +70,33 @@ async def _query_s3_file( bucket: str, key: str, client: session.AioBaseClient, -) -> Union[ContentsModel, List[ContentsModel]]: + buffer_size_mb: int = 10, +) -> bytes: + """Query S3 file with buffering to handle large files. + + BIP-0042: Uses SpooledTemporaryFile to buffer content, spilling to disk + if the file exceeds buffer_size_mb. This ensures the returned bytes object + is seekable for pickle/json deserialization, fixing the UnsupportedOperation + error on large state files. + + :param bucket: S3 bucket name + :param key: S3 object key + :param client: aiobotocore S3 client + :param buffer_size_mb: Max MB to hold in RAM before spilling to disk (default 10MB) + :return: File contents as bytes + """ response = await client.get_object(Bucket=bucket, Key=key) - body = await response["Body"].read() - return body + buffer_size = buffer_size_mb * 1024 * 1024 + + with tempfile.SpooledTemporaryFile(max_size=buffer_size, mode="w+b") as tmp: + async with response["Body"] as stream: + while True: + chunk = await stream.read(8192) + if not chunk: + break + tmp.write(chunk) + tmp.seek(0) + return tmp.read() @dataclasses.dataclass @@ -140,6 +166,12 @@ class S3Settings(BurrSettings): snapshot_interval_milliseconds: int = 3_600_000 load_snapshot_on_start: bool = True prior_snapshots_to_keep: int = 5 + # BIP-0042: Event-driven tracking settings + tracking_mode: str = "POLLING" # "POLLING" or "SQS" - POLLING is default for backward compatibility + sqs_queue_url: Optional[str] = None + sqs_region: Optional[str] = None + sqs_wait_time_seconds: int = 20 # SQS long polling timeout + s3_buffer_size_mb: int = 10 # RAM buffer before spilling to disk def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: @@ -156,7 +188,7 @@ def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: return inverted_str + "-" + timestamp.isoformat() -class SQLiteS3Backend(BackendBase, IndexingBackendMixin, SnapshottingBackendMixin): +class SQLiteS3Backend(BackendBase, IndexingBackendMixin, SnapshottingBackendMixin, EventDrivenBackendMixin): def __init__( self, s3_bucket: str, @@ -165,6 +197,12 @@ def __init__( snapshot_interval_milliseconds: int, load_snapshot_on_start: bool, prior_snapshots_to_keep: int, + # BIP-0042: New parameters for event-driven tracking + tracking_mode: str = "POLLING", + sqs_queue_url: Optional[str] = None, + sqs_region: Optional[str] = None, + sqs_wait_time_seconds: int = 20, + s3_buffer_size_mb: int = 10, ): self._backend_id = system.now().isoformat() + str(uuid.uuid4()) self._bucket = s3_bucket @@ -177,6 +215,12 @@ def __init__( self._load_snapshot_on_start = load_snapshot_on_start self._snapshot_key_history = [] self._prior_snapshots_to_keep = prior_snapshots_to_keep + # BIP-0042: Store event-driven tracking settings + self._tracking_mode = tracking_mode + self._sqs_queue_url = sqs_queue_url + self._sqs_region = sqs_region + self._sqs_wait_time_seconds = sqs_wait_time_seconds + self._s3_buffer_size_mb = s3_buffer_size_mb async def load_snapshot(self): if not self._load_snapshot_on_start: @@ -631,13 +675,22 @@ async def get_application_logs( "-created_at" ) async with self._session.create_client("s3") as client: - # Get all the files + # Get all the files (BIP-0042: use buffered reading for large files) files = await utils.gather_with_concurrency( 1, - _query_s3_file(self._bucket, application.graph_file_pointer, client), - # _query_s3_files(self.bucket, application.metadata_file_pointer, client), + _query_s3_file( + self._bucket, + application.graph_file_pointer, + client, + self._s3_buffer_size_mb, + ), *itertools.chain( - _query_s3_file(self._bucket, log_file.s3_path, client) + _query_s3_file( + self._bucket, + log_file.s3_path, + client, + self._s3_buffer_size_mb, + ) for log_file in application_logs ), ) @@ -656,6 +709,104 @@ async def get_application_logs( application=graph_data, ) + # BIP-0042: Event-driven tracking methods + async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> None: + """Handle a single S3 event notification - index the file immediately. + + :param s3_key: The S3 object key from the event + :param event_time: When the event occurred + """ + try: + data_file = DataFile.from_path(s3_key, created_date=event_time) + # Path structure: data/{project}/yyyy/mm/dd/hh/minutes/pk/app_id/filename + project_name = s3_key.split("/")[1] + + project = await Project.filter(name=project_name).first() + if project is None: + logger.info(f"Creating project {project_name} from S3 event") + project = await Project.create( + name=project_name, + uri=None, + created_at=event_time, + indexed_at=event_time, + updated_at=event_time, + ) + + all_applications = await self._ensure_applications_exist([data_file], project) + await self._update_all_applications(all_applications, [data_file]) + await self.update_log_files([data_file], all_applications) + + logger.info(f"Indexed S3 event: {s3_key}") + except Exception as e: + logger.error(f"Failed to handle S3 event {s3_key}: {e}") + + async def start_sqs_consumer(self) -> None: + """Start the SQS consumer for event-driven tracking. + + BIP-0042: This method runs indefinitely, processing S3 event notifications + from the configured SQS queue. It handles both EventBridge and direct S3 + notification formats. + """ + if self._tracking_mode != "SQS" or not self._sqs_queue_url: + logger.info("SQS consumer not configured, skipping") + return + + logger.info(f"Starting SQS consumer for queue: {self._sqs_queue_url}") + + async with self._session.create_client("sqs", region_name=self._sqs_region) as sqs_client: + while True: + try: + response = await sqs_client.receive_message( + QueueUrl=self._sqs_queue_url, + MaxNumberOfMessages=10, + WaitTimeSeconds=self._sqs_wait_time_seconds, + VisibilityTimeout=300, + ) + + messages = response.get("Messages", []) + for message in messages: + try: + body = json.loads(message["Body"]) + s3_key = None + event_time = None + + # Handle EventBridge wrapped S3 events + if "detail" in body: + s3_key = body["detail"]["object"]["key"] + event_time = datetime.datetime.fromisoformat( + body["time"].replace("Z", "+00:00") + ) + elif "Records" in body: + record = body["Records"][0] + s3_key = record["s3"]["object"]["key"] + event_time = datetime.datetime.fromisoformat( + record["eventTime"].replace("Z", "+00:00") + ) + else: + logger.warning(f"Unknown message format: {body}") + continue + + if s3_key and s3_key.endswith(".jsonl"): + await self._handle_s3_event(s3_key, event_time) + + await sqs_client.delete_message( + QueueUrl=self._sqs_queue_url, + ReceiptHandle=message["ReceiptHandle"], + ) + except Exception as e: + logger.error(f"Failed to process SQS message: {e}") + + except Exception as e: + logger.error(f"SQS consumer error: {e}") + await asyncio.sleep(5) + + def is_event_driven(self) -> bool: + """Check if this backend is configured for event-driven updates. + + BIP-0042: Returns True if tracking_mode is SQS and queue URL is configured. + """ + return self._tracking_mode == "SQS" and self._sqs_queue_url is not None + async def indexing_jobs( self, offset: int = 0, limit: int = 100, filter_empty: bool = True ) -> Sequence[schema.IndexingJob]: diff --git a/terraform/.gitignore b/terraform/.gitignore new file mode 100644 index 000000000..5c111d166 --- /dev/null +++ b/terraform/.gitignore @@ -0,0 +1,10 @@ +# Terraform +.terraform/ +*.tfstate +*.tfstate.* +.terraform.tfstate.lock.info +*.tfplan +crash.log +override.tf +override.tf.json +*.tfvars.backup diff --git a/terraform/.terraform.lock.hcl b/terraform/.terraform.lock.hcl new file mode 100644 index 000000000..abbb80149 --- /dev/null +++ b/terraform/.terraform.lock.hcl @@ -0,0 +1,25 @@ +# This file is maintained automatically by "terraform init". +# Manual edits may be lost in future updates. + +provider "registry.terraform.io/hashicorp/aws" { + version = "6.34.0" + constraints = ">= 5.0.0" + hashes = [ + "h1:Qzr5C24XLiHmkJVuao/Kb+jFLPaxGE/D5GUgko5VdWg=", + "zh:1e49dc96bf50633583e3cbe23bb357642e7e9afe135f54e061e26af6310e50d2", + "zh:45651bb4dad681f17782d99d9324de182a7bb9fbe9dd22f120fdb7fe42969cc9", + "zh:5880c306a427128124585b460c53bbcab9fb3767f26f796eae204f65f111a927", + "zh:71fa9170989b3a1a6913c369bd4a792f4a3e2aab4024c2aff0911e704020b058", + "zh:8d48628fb30f11b04215e06f4dd8a3b32f5f9ea2ed116d0c81c686bf678f9185", + "zh:9b12af85486a96aedd8d7984b0ff811a4b42e3d88dad1a3fb4c0b580d04fa425", + "zh:a6885766588fcad887bdac8c3665e048480eda028e492759a1ea29d22b98d509", + "zh:a6ce9f5e7edc2258733e978bba147600b42a979e18575ce2c7d7dcb6d0b9911f", + "zh:c88d8b7d344e745b191509c29ca773d696da8ca3443f62b20f97982d2d33ea00", + "zh:cae90d6641728ad0219b6a84746bf86dd1dda3e31560d6495a202213ef0258b6", + "zh:cc35927d9d41878049c4221beb1d580a3dbadaca7ba39fb267e001ef9c59ccb3", + "zh:d9e1cb00dc33998e1242fb844e4e3e6cf95e57c664dc1eb55bb7d24f8324bad3", + "zh:f3dbf4a1b7020722145312eb4425f3ea356276d741e3f60fb703fc59a1e2d9fd", + "zh:faba832cc9d99a83f42aaf5a27a4c7309401200169ef04643104cfc8f522d007", + "zh:fcd3f30b91dbcc7db67d5d39268741ffa46696a230a1f2aef32d245ace54bf65", + ] +} diff --git a/terraform/dev.tfvars b/terraform/dev.tfvars new file mode 100644 index 000000000..b35fd2a92 --- /dev/null +++ b/terraform/dev.tfvars @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Development environment configuration +# Replace ACCOUNT_ID with your AWS account ID for s3_bucket_name + +aws_region = "us-east-1" +environment = "dev" + +s3_bucket_name = "burr-tracking-logs-dev-ACCOUNT_ID" +sqs_queue_name = "burr-s3-events-dev" + +# S3 only (polling mode) - simpler for dev; set to true for event-driven +enable_sqs = false + +log_retention_days = 30 +snapshot_retention_days = 14 + +sqs_message_retention_seconds = 86400 +sqs_visibility_timeout_seconds = 120 +sqs_receive_wait_time_seconds = 20 +sqs_max_receive_count = 3 + +enable_bedrock = false diff --git a/terraform/main.tf b/terraform/main.tf new file mode 100644 index 000000000..9ca90dba8 --- /dev/null +++ b/terraform/main.tf @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +terraform { + required_version = ">= 1.0" + + required_providers { + aws = { + source = "hashicorp/aws" + version = ">= 5.0" + } + } +} + +provider "aws" { + region = var.aws_region +} + +data "aws_caller_identity" "current" {} +data "aws_region" "current" {} + +module "s3" { + source = "./modules/s3" + + bucket_name = var.s3_bucket_name + tags = local.common_tags + + lifecycle_rules = [ + { + id = "expire-old-logs" + prefix = "data/" + enabled = true + expiration_days = var.log_retention_days + noncurrent_days = 7 + }, + { + id = "expire-old-snapshots" + prefix = "snapshots/" + enabled = true + expiration_days = var.snapshot_retention_days + noncurrent_days = null + } + ] +} + +module "sqs" { + source = "./modules/sqs" + count = var.enable_sqs ? 1 : 0 + + queue_name = var.sqs_queue_name + message_retention_seconds = var.sqs_message_retention_seconds + visibility_timeout_seconds = var.sqs_visibility_timeout_seconds + receive_wait_time_seconds = var.sqs_receive_wait_time_seconds + max_receive_count = var.sqs_max_receive_count + tags = local.common_tags +} + +resource "aws_sqs_queue_policy" "s3_notifications" { + count = var.enable_sqs ? 1 : 0 + + queue_url = module.sqs[0].queue_id + + policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Sid = "AllowS3Notifications" + Effect = "Allow" + Principal = { + Service = "s3.amazonaws.com" + } + Action = "sqs:SendMessage" + Resource = module.sqs[0].queue_arn + Condition = { + ArnLike = { + "aws:SourceArn" = module.s3.bucket_arn + } + } + } + ] + }) +} + +resource "aws_s3_bucket_notification" "burr_logs" { + count = var.enable_sqs ? 1 : 0 + + bucket = module.s3.bucket_id + + queue { + queue_arn = module.sqs[0].queue_arn + events = ["s3:ObjectCreated:*"] + filter_prefix = "data/" + filter_suffix = ".jsonl" + } + + depends_on = [aws_sqs_queue_policy.s3_notifications] +} + +module "iam" { + source = "./modules/iam" + + role_name = "${var.environment}-burr-server-role" + s3_bucket_arn = module.s3.bucket_arn + sqs_queue_arn = var.enable_sqs ? module.sqs[0].queue_arn : "" + enable_sqs = var.enable_sqs + enable_bedrock = var.enable_bedrock + bedrock_model_arns = var.bedrock_model_arns + bedrock_foundation_model_arn = "arn:aws:bedrock:${data.aws_region.current.name}::foundation-model/*" + tags = local.common_tags +} + +locals { + common_tags = { + Environment = var.environment + Project = "burr-tracking" + ManagedBy = "terraform" + } +} diff --git a/terraform/modules/iam/main.tf b/terraform/modules/iam/main.tf new file mode 100644 index 000000000..40d7943c9 --- /dev/null +++ b/terraform/modules/iam/main.tf @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +data "aws_iam_policy_document" "assume_role" { + statement { + effect = "Allow" + actions = ["sts:AssumeRole"] + principals { + type = "Service" + identifiers = var.trusted_services + } + } +} + +resource "aws_iam_role" "burr_server" { + name = var.role_name + assume_role_policy = data.aws_iam_policy_document.assume_role.json + + tags = merge(var.tags, { + Name = var.role_name + }) +} + +data "aws_iam_policy_document" "s3_least_privilege" { + statement { + sid = "S3ListBucket" + effect = "Allow" + actions = [ + "s3:ListBucket", + "s3:GetBucketLocation" + ] + resources = [var.s3_bucket_arn] + } + + statement { + sid = "S3ObjectOperations" + effect = "Allow" + actions = [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:HeadObject" + ] + resources = ["${var.s3_bucket_arn}/*"] + } +} + +resource "aws_iam_role_policy" "s3" { + name = "${var.role_name}-s3" + role = aws_iam_role.burr_server.id + policy = data.aws_iam_policy_document.s3_least_privilege.json +} + +data "aws_iam_policy_document" "sqs_least_privilege" { + count = var.enable_sqs && var.sqs_queue_arn != "" ? 1 : 0 + + statement { + sid = "SQSConsume" + effect = "Allow" + actions = [ + "sqs:ReceiveMessage", + "sqs:DeleteMessage", + "sqs:GetQueueAttributes" + ] + resources = [var.sqs_queue_arn] + } +} + +resource "aws_iam_role_policy" "sqs" { + count = var.enable_sqs && var.sqs_queue_arn != "" ? 1 : 0 + name = "${var.role_name}-sqs" + role = aws_iam_role.burr_server.id + policy = data.aws_iam_policy_document.sqs_least_privilege[0].json +} + +data "aws_iam_policy_document" "bedrock_least_privilege" { + count = var.enable_bedrock ? 1 : 0 + + statement { + sid = "BedrockInvokeModels" + effect = "Allow" + actions = [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream" + ] + resources = length(var.bedrock_model_arns) > 0 ? var.bedrock_model_arns : [var.bedrock_foundation_model_arn] + } +} + +resource "aws_iam_role_policy" "bedrock" { + count = var.enable_bedrock ? 1 : 0 + name = "${var.role_name}-bedrock" + role = aws_iam_role.burr_server.id + policy = data.aws_iam_policy_document.bedrock_least_privilege[0].json +} diff --git a/terraform/modules/iam/outputs.tf b/terraform/modules/iam/outputs.tf new file mode 100644 index 000000000..ccf3003e6 --- /dev/null +++ b/terraform/modules/iam/outputs.tf @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "role_arn" { + description = "ARN of the IAM role" + value = aws_iam_role.burr_server.arn +} + +output "role_name" { + description = "Name of the IAM role" + value = aws_iam_role.burr_server.name +} diff --git a/terraform/modules/iam/variables.tf b/terraform/modules/iam/variables.tf new file mode 100644 index 000000000..304680676 --- /dev/null +++ b/terraform/modules/iam/variables.tf @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "role_name" { + description = "Name of the IAM role for Burr server" + type = string +} + +variable "trusted_services" { + description = "List of AWS services that can assume this role" + type = list(string) + default = ["ecs-tasks.amazonaws.com", "ec2.amazonaws.com", "lambda.amazonaws.com"] +} + +variable "s3_bucket_arn" { + description = "ARN of the S3 bucket for least privilege access" + type = string +} + +variable "enable_sqs" { + description = "Enable SQS IAM permissions" + type = bool + default = true +} + +variable "sqs_queue_arn" { + description = "ARN of the SQS queue for least privilege access" + type = string + default = "" +} + +variable "enable_bedrock" { + description = "Enable Bedrock IAM permissions" + type = bool + default = false +} + +variable "bedrock_model_arns" { + description = "List of specific Bedrock model ARNs for least privilege. If empty, uses foundation model wildcard." + type = list(string) + default = [] +} + +variable "bedrock_foundation_model_arn" { + description = "Bedrock foundation model ARN wildcard when bedrock_model_arns is empty" + type = string +} + +variable "tags" { + description = "Tags to apply to resources" + type = map(string) + default = {} +} diff --git a/terraform/modules/s3/main.tf b/terraform/modules/s3/main.tf new file mode 100644 index 000000000..67163ee09 --- /dev/null +++ b/terraform/modules/s3/main.tf @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +resource "aws_s3_bucket" "this" { + bucket = var.bucket_name + + tags = merge(var.tags, { + Name = var.bucket_name + }) +} + +resource "aws_s3_bucket_versioning" "this" { + bucket = aws_s3_bucket.this.id + + versioning_configuration { + status = "Enabled" + } +} + +resource "aws_s3_bucket_server_side_encryption_configuration" "this" { + bucket = aws_s3_bucket.this.id + + rule { + apply_server_side_encryption_by_default { + sse_algorithm = "AES256" + } + } +} + +resource "aws_s3_bucket_lifecycle_configuration" "this" { + bucket = aws_s3_bucket.this.id + + dynamic "rule" { + for_each = var.lifecycle_rules + content { + id = rule.value.id + status = rule.value.enabled ? "Enabled" : "Disabled" + + filter { + prefix = rule.value.prefix + } + + expiration { + days = rule.value.expiration_days + } + + dynamic "noncurrent_version_expiration" { + for_each = try(rule.value.noncurrent_days, null) != null ? [1] : [] + content { + noncurrent_days = rule.value.noncurrent_days + } + } + } + } +} + +resource "aws_s3_bucket_public_access_block" "this" { + bucket = aws_s3_bucket.this.id + + block_public_acls = true + block_public_policy = true + ignore_public_acls = true + restrict_public_buckets = true +} diff --git a/terraform/modules/s3/outputs.tf b/terraform/modules/s3/outputs.tf new file mode 100644 index 000000000..5ffc964b6 --- /dev/null +++ b/terraform/modules/s3/outputs.tf @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "bucket_id" { + description = "ID of the S3 bucket" + value = aws_s3_bucket.this.id +} + +output "bucket_arn" { + description = "ARN of the S3 bucket" + value = aws_s3_bucket.this.arn +} diff --git a/terraform/modules/s3/variables.tf b/terraform/modules/s3/variables.tf new file mode 100644 index 000000000..580cc967d --- /dev/null +++ b/terraform/modules/s3/variables.tf @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "bucket_name" { + description = "Name of the S3 bucket" + type = string +} + +variable "lifecycle_rules" { + description = "List of lifecycle rules for the bucket" + type = list(object({ + id = string + prefix = string + enabled = bool + expiration_days = number + noncurrent_days = optional(number) + })) +} + +variable "tags" { + description = "Tags to apply to resources" + type = map(string) + default = {} +} diff --git a/terraform/modules/sqs/main.tf b/terraform/modules/sqs/main.tf new file mode 100644 index 000000000..4eb3bea9d --- /dev/null +++ b/terraform/modules/sqs/main.tf @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +resource "aws_sqs_queue" "main" { + name = var.queue_name + message_retention_seconds = var.message_retention_seconds + visibility_timeout_seconds = var.visibility_timeout_seconds + receive_wait_time_seconds = var.receive_wait_time_seconds + + tags = merge(var.tags, { + Name = var.queue_name + }) +} + +resource "aws_sqs_queue" "dlq" { + name = "${var.queue_name}-dlq" + message_retention_seconds = var.dlq_message_retention_seconds + + tags = merge(var.tags, { + Name = "${var.queue_name}-dlq" + }) +} + +resource "aws_sqs_queue_redrive_policy" "main" { + queue_url = aws_sqs_queue.main.id + redrive_policy = jsonencode({ + deadLetterTargetArn = aws_sqs_queue.dlq.arn + maxReceiveCount = var.max_receive_count + }) +} diff --git a/terraform/modules/sqs/outputs.tf b/terraform/modules/sqs/outputs.tf new file mode 100644 index 000000000..ef5c86e88 --- /dev/null +++ b/terraform/modules/sqs/outputs.tf @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "queue_id" { + description = "URL of the SQS queue" + value = aws_sqs_queue.main.id +} + +output "queue_url" { + description = "URL of the SQS queue" + value = aws_sqs_queue.main.url +} + +output "queue_arn" { + description = "ARN of the SQS queue" + value = aws_sqs_queue.main.arn +} + +output "dlq_url" { + description = "URL of the dead letter queue" + value = aws_sqs_queue.dlq.url +} + +output "dlq_arn" { + description = "ARN of the dead letter queue" + value = aws_sqs_queue.dlq.arn +} diff --git a/terraform/modules/sqs/variables.tf b/terraform/modules/sqs/variables.tf new file mode 100644 index 000000000..47e67f3ba --- /dev/null +++ b/terraform/modules/sqs/variables.tf @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "queue_name" { + description = "Name of the SQS queue" + type = string +} + +variable "message_retention_seconds" { + description = "Message retention period in seconds" + type = number + default = 1209600 +} + +variable "visibility_timeout_seconds" { + description = "Visibility timeout for messages in seconds" + type = number + default = 300 +} + +variable "receive_wait_time_seconds" { + description = "Long polling wait time in seconds" + type = number + default = 20 +} + +variable "dlq_message_retention_seconds" { + description = "DLQ message retention period in seconds" + type = number + default = 1209600 +} + +variable "max_receive_count" { + description = "Max receive count before message moves to DLQ" + type = number + default = 3 +} + +variable "tags" { + description = "Tags to apply to resources" + type = map(string) + default = {} +} diff --git a/terraform/outputs.tf b/terraform/outputs.tf new file mode 100644 index 000000000..a14b0ef99 --- /dev/null +++ b/terraform/outputs.tf @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +output "s3_bucket_name" { + description = "Name of the S3 bucket for Burr logs" + value = module.s3.bucket_id +} + +output "s3_bucket_arn" { + description = "ARN of the S3 bucket" + value = module.s3.bucket_arn +} + +output "sqs_queue_url" { + description = "URL of the SQS queue for S3 events" + value = var.enable_sqs ? module.sqs[0].queue_url : null +} + +output "sqs_queue_arn" { + description = "ARN of the SQS queue" + value = var.enable_sqs ? module.sqs[0].queue_arn : null +} + +output "sqs_dlq_url" { + description = "URL of the dead letter queue" + value = var.enable_sqs ? module.sqs[0].dlq_url : null +} + +output "iam_role_arn" { + description = "ARN of the IAM role for Burr server" + value = module.iam.role_arn +} + +output "iam_role_name" { + description = "Name of the IAM role for Burr server" + value = module.iam.role_name +} + +output "burr_environment_variables" { + description = "Environment variables to configure Burr server" + value = var.enable_sqs ? { + BURR_S3_BUCKET = module.s3.bucket_id + BURR_TRACKING_MODE = "SQS" + BURR_SQS_QUEUE_URL = module.sqs[0].queue_url + BURR_SQS_REGION = data.aws_region.current.name + BURR_SQS_WAIT_TIME_SECONDS = "20" + BURR_S3_BUFFER_SIZE_MB = "10" + } : { + BURR_S3_BUCKET = module.s3.bucket_id + BURR_TRACKING_MODE = "POLLING" + BURR_SQS_QUEUE_URL = "" + BURR_SQS_REGION = data.aws_region.current.name + BURR_SQS_WAIT_TIME_SECONDS = "20" + BURR_S3_BUFFER_SIZE_MB = "10" + } +} diff --git a/terraform/prod.tfvars b/terraform/prod.tfvars new file mode 100644 index 000000000..9dc6410a5 --- /dev/null +++ b/terraform/prod.tfvars @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Production environment configuration +# Replace ACCOUNT_ID with your AWS account ID for s3_bucket_name + +aws_region = "us-east-1" +environment = "prod" + +s3_bucket_name = "burr-tracking-logs-prod-ACCOUNT_ID" +sqs_queue_name = "burr-s3-events-prod" + +enable_sqs = true + +log_retention_days = 90 +snapshot_retention_days = 30 + +sqs_message_retention_seconds = 1209600 +sqs_visibility_timeout_seconds = 300 +sqs_receive_wait_time_seconds = 20 +sqs_max_receive_count = 3 + +enable_bedrock = false diff --git a/terraform/tutorial.md b/terraform/tutorial.md new file mode 100644 index 000000000..90e389937 --- /dev/null +++ b/terraform/tutorial.md @@ -0,0 +1,194 @@ +# Apache Burr AWS Tracking Infrastructure Tutorial + +This tutorial explains how to deploy Apache Burr tracking infrastructure on AWS using Terraform. All Terraform code lives in the `terraform/` folder. It covers deployment with S3 only (polling mode), with S3 and SQS (event-driven mode), and local development without AWS. + +## Overview + +The Terraform configuration provisions: + +- **S3 bucket**: Stores Burr application logs and database snapshots (always created for AWS deployment) +- **SQS queue** (optional): Receives S3 event notifications for real-time tracking; controlled by `enable_sqs` +- **IAM role**: Least-privilege permissions for the Burr server + +## Directory Structure + +All code is in `terraform/`: + +``` +terraform/ +├── main.tf # Root module wiring S3, SQS, IAM +├── variables.tf # Input variables +├── outputs.tf # Output values +├── dev.tfvars # Development: S3 only (enable_sqs = false) +├── prod.tfvars # Production: S3 + SQS (enable_sqs = true) +├── tutorial.md # This file +└── modules/ + ├── s3/ # S3 bucket with versioning, encryption, lifecycle + ├── sqs/ # SQS queue with DLQ and redrive policy + └── iam/ # IAM role with least-privilege policies +``` + +## Prerequisites + +- Terraform >= 1.0 +- AWS CLI configured with credentials +- AWS account ID (for unique S3 bucket names) + +Get your AWS account ID: + +```bash +aws sts get-caller-identity --query Account --output text +``` + +## Using tfvars Files + +| File | Mode | enable_sqs | Resources created | +|-------------|-------------------|------------|--------------------------| +| dev.tfvars | S3 only (polling) | false | S3 bucket, IAM role | +| prod.tfvars | S3 + SQS (event) | true | S3 bucket, SQS queue, IAM | + +### Development (dev.tfvars) - S3 Only + +Uses S3 polling mode (no SQS). Edit and replace `ACCOUNT_ID` in `s3_bucket_name`: + +``` +s3_bucket_name = "burr-tracking-logs-dev-123456789012" +``` + +Deploy: + +```bash +cd terraform +terraform init +terraform plan -var-file=dev.tfvars +terraform apply -var-file=dev.tfvars +``` + +### Production (prod.tfvars) - S3 + SQS + +Uses event-driven mode with SQS. Edit and replace `ACCOUNT_ID`: + +``` +s3_bucket_name = "burr-tracking-logs-prod-123456789012" +``` + +Deploy: + +```bash +terraform plan -var-file=prod.tfvars +terraform apply -var-file=prod.tfvars +``` + +### Override Mode in Any tfvars + +To deploy with SQS using dev.tfvars, override: `terraform apply -var-file=dev.tfvars -var="enable_sqs=true"`. To deploy S3-only with prod.tfvars: `terraform apply -var-file=prod.tfvars -var="enable_sqs=false"`. + +## Deployment Modes + +### With S3 and SQS (Event-Driven Mode) + +Default configuration. Provides near-instant telemetry updates (~200ms latency). + +1. Set `enable_sqs = true` in your tfvars (e.g. prod.tfvars). +2. Deploy with `terraform apply -var-file=prod.tfvars`. +3. Configure the Burr server with the output environment variables: + +```bash +terraform output burr_environment_variables +``` + +4. Set these on your Burr server (ECS task, EC2, etc.): + +- BURR_S3_BUCKET +- BURR_TRACKING_MODE=SQS +- BURR_SQS_QUEUE_URL +- BURR_SQS_REGION +- BURR_SQS_WAIT_TIME_SECONDS +- BURR_S3_BUFFER_SIZE_MB + +### With S3 Only (Polling Mode) + +Use when you prefer simpler infrastructure or cannot use SQS. Burr polls S3 periodically (default 120 seconds). + +1. Set `enable_sqs = false` in your tfvars. +2. Deploy: + +```bash +terraform apply -var-file=dev.tfvars +``` + +3. Configure the Burr server: + +- BURR_S3_BUCKET +- BURR_TRACKING_MODE=POLLING +- BURR_SQS_QUEUE_URL="" (leave empty) +- BURR_SQS_REGION +- BURR_S3_BUFFER_SIZE_MB + +The Terraform will create only the S3 bucket and IAM role. No SQS queue or S3 event notifications. + +### Without S3 and SQS (Local Mode) + +For local development, no Terraform deployment is needed. Burr uses the local filesystem for tracking. + +1. Run the Burr server locally: + +```bash +burr --no-open +``` + +2. Use `LocalTrackingClient` in your application instead of `S3TrackingClient`. + +3. Data is stored in `~/.burr` by default. + +## Key Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| aws_region | AWS region | us-east-1 | +| environment | Environment name (dev, prod) | dev | +| s3_bucket_name | S3 bucket name (must be globally unique) | (required) | +| enable_sqs | Create SQS for event-driven tracking | true | +| log_retention_days | Days to retain logs in S3 | 90 | +| snapshot_retention_days | Days to retain DB snapshots | 30 | +| enable_bedrock | Add Bedrock IAM permissions | false | + +## Outputs + +After apply, useful outputs: + +```bash +terraform output s3_bucket_name +terraform output sqs_queue_url +terraform output burr_environment_variables +``` + +## IAM Least Privilege + +The IAM role grants only: + +- **S3**: ListBucket, GetBucketLocation, GetObject, PutObject, DeleteObject, HeadObject on the specific bucket +- **SQS** (when enabled): ReceiveMessage, DeleteMessage, GetQueueAttributes on the specific queue +- **Bedrock** (when enabled): InvokeModel, InvokeModelWithResponseStream on specified model ARNs + +## Cleanup + +To destroy all resources: + +```bash +terraform destroy -var-file=dev.tfvars +``` + +For S3 buckets with versioning, you may need to empty the bucket first: + +```bash +aws s3api list-object-versions --bucket BUCKET_NAME --output json | jq -r '.Versions[],.DeleteMarkers[]|.Key+" "+.VersionId' | while read key vid; do aws s3api delete-object --bucket BUCKET_NAME --key "$key" --version-id "$vid"; done +``` + +## Troubleshooting + +**S3 bucket name already exists**: S3 bucket names are globally unique. Use your account ID or a random suffix. + +**SQS policy errors**: Ensure the S3 bucket notification depends on the queue policy. The Terraform handles this with `depends_on`. + +**Burr server not receiving events**: Verify BURR_SQS_QUEUE_URL is set and the IAM role has sqs:ReceiveMessage. Check CloudWatch for the SQS consumer. diff --git a/terraform/variables.tf b/terraform/variables.tf new file mode 100644 index 000000000..29fd36aef --- /dev/null +++ b/terraform/variables.tf @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +variable "aws_region" { + description = "AWS region for resources" + type = string + default = "us-east-1" +} + +variable "environment" { + description = "Environment name (dev, staging, prod)" + type = string + default = "dev" +} + +variable "s3_bucket_name" { + description = "Name of the S3 bucket for Burr logs" + type = string +} + +variable "enable_sqs" { + description = "Enable SQS for event-driven tracking. When false, Burr uses S3 polling mode." + type = bool + default = true +} + +variable "sqs_queue_name" { + description = "Name of the SQS queue for S3 events" + type = string + default = "burr-s3-events" +} + +variable "log_retention_days" { + description = "Days to retain log files in S3" + type = number + default = 90 +} + +variable "snapshot_retention_days" { + description = "Days to retain database snapshots in S3" + type = number + default = 30 +} + +variable "sqs_message_retention_seconds" { + description = "SQS message retention period in seconds" + type = number + default = 1209600 +} + +variable "sqs_visibility_timeout_seconds" { + description = "SQS visibility timeout in seconds" + type = number + default = 300 +} + +variable "sqs_receive_wait_time_seconds" { + description = "SQS long polling wait time in seconds" + type = number + default = 20 +} + +variable "sqs_max_receive_count" { + description = "Max receive count before message moves to DLQ" + type = number + default = 3 +} + +variable "enable_bedrock" { + description = "Enable Bedrock IAM permissions for BedrockAction" + type = bool + default = false +} + +variable "bedrock_model_arns" { + description = "List of specific Bedrock model ARNs for least privilege. Empty uses foundation-model/*" + type = list(string) + default = [] +} diff --git a/tests/integrations/test_bip0042_bedrock.py b/tests/integrations/test_bip0042_bedrock.py new file mode 100644 index 000000000..45087e013 --- /dev/null +++ b/tests/integrations/test_bip0042_bedrock.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""BIP-0042: Tests for Bedrock integration.""" + +import inspect + +import pytest + + +class TestBedrockImports: + """Test that Bedrock classes can be imported via lazy loading.""" + + def test_lazy_import_bedrock_action(self): + """Verify BedrockAction can be imported from burr.integrations.""" + try: + from burr.integrations import BedrockAction + + assert BedrockAction is not None + except ImportError as e: + assert "bedrock" in str(e).lower() or "boto3" in str(e).lower() + + def test_lazy_import_bedrock_streaming_action(self): + """Verify BedrockStreamingAction can be imported from burr.integrations.""" + try: + from burr.integrations import BedrockStreamingAction + + assert BedrockStreamingAction is not None + except ImportError as e: + assert "bedrock" in str(e).lower() or "boto3" in str(e).lower() + + def test_direct_import_bedrock_module(self): + """Verify bedrock.py module exists and has expected classes.""" + try: + from burr.integrations.bedrock import ( + BedrockAction, + BedrockStreamingAction, + StateToPromptMapper, + ) + + assert BedrockAction is not None + assert BedrockStreamingAction is not None + assert StateToPromptMapper is not None + except ImportError as e: + assert "bedrock" in str(e).lower() or "boto3" in str(e).lower() + + +class TestBedrockActionInterface: + """Test BedrockAction class interface (without boto3).""" + + @pytest.fixture + def mock_boto3(self, monkeypatch): + """Mock boto3 to allow testing without AWS credentials.""" + import sys + from unittest.mock import MagicMock + + mock_boto = MagicMock() + mock_client = MagicMock() + mock_boto.client.return_value = mock_client + + mock_botocore = MagicMock() + mock_botocore.config.Config = MagicMock + mock_botocore.exceptions.ClientError = Exception + + monkeypatch.setitem(sys.modules, "boto3", mock_boto) + monkeypatch.setitem(sys.modules, "botocore", mock_botocore) + monkeypatch.setitem(sys.modules, "botocore.config", mock_botocore.config) + monkeypatch.setitem(sys.modules, "botocore.exceptions", mock_botocore.exceptions) + + return mock_boto, mock_client + + def test_bedrock_action_extends_single_step_action(self, mock_boto3): + """Verify BedrockAction extends SingleStepAction.""" + import importlib + + import burr.integrations.bedrock as bedrock_module + + importlib.reload(bedrock_module) + + from burr.core.action import SingleStepAction + from burr.integrations.bedrock import BedrockAction + + assert issubclass(BedrockAction, SingleStepAction) + + def test_bedrock_streaming_action_extends_streaming_action(self, mock_boto3): + """Verify BedrockStreamingAction extends StreamingAction.""" + import importlib + + import burr.integrations.bedrock as bedrock_module + + importlib.reload(bedrock_module) + + from burr.core.action import StreamingAction + from burr.integrations.bedrock import BedrockStreamingAction + + assert issubclass(BedrockStreamingAction, StreamingAction) + + def test_bedrock_action_has_required_properties(self, mock_boto3): + """Verify BedrockAction has reads, writes, name properties.""" + import importlib + + import burr.integrations.bedrock as bedrock_module + + importlib.reload(bedrock_module) + + from burr.integrations.bedrock import BedrockAction + + action = BedrockAction( + model_id="test-model", + input_mapper=lambda s: {"messages": []}, + reads=["input"], + writes=["output"], + ) + + assert action.reads == ["input"] + assert action.writes == ["output"] + assert action.name == "bedrock_invoke" + + def test_bedrock_action_accepts_all_parameters(self, mock_boto3): + """Verify BedrockAction accepts all BIP-0042 specified parameters.""" + import importlib + + import burr.integrations.bedrock as bedrock_module + + importlib.reload(bedrock_module) + + from burr.integrations.bedrock import BedrockAction + + sig = inspect.signature(BedrockAction.__init__) + params = list(sig.parameters.keys()) + + assert "model_id" in params + assert "input_mapper" in params + assert "reads" in params + assert "writes" in params + assert "name" in params + assert "region" in params + assert "guardrail_id" in params + assert "guardrail_version" in params + assert "inference_config" in params + assert "max_retries" in params + + +class TestStateToPromptMapperProtocol: + """Test StateToPromptMapper Protocol exists.""" + + def test_protocol_exists(self): + """Verify StateToPromptMapper Protocol is defined.""" + try: + from burr.integrations.bedrock import StateToPromptMapper + + assert StateToPromptMapper is not None + except ImportError: + pytest.skip("boto3 not installed") diff --git a/tests/tracking/test_bip0042_s3_buffering.py b/tests/tracking/test_bip0042_s3_buffering.py new file mode 100644 index 000000000..f445b0233 --- /dev/null +++ b/tests/tracking/test_bip0042_s3_buffering.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""BIP-0042: Tests for S3 buffering fix and settings.""" + +import inspect + +import pytest + + +class TestS3Settings: + """Test that S3Settings has all BIP-0042 fields with correct defaults.""" + + def test_s3_settings_has_tracking_mode(self): + """Verify tracking_mode field exists with POLLING default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "tracking_mode" in S3Settings.model_fields + assert S3Settings.model_fields["tracking_mode"].default == "POLLING" + + def test_s3_settings_has_sqs_queue_url(self): + """Verify sqs_queue_url field exists with None default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "sqs_queue_url" in S3Settings.model_fields + assert S3Settings.model_fields["sqs_queue_url"].default is None + + def test_s3_settings_has_sqs_region(self): + """Verify sqs_region field exists with None default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "sqs_region" in S3Settings.model_fields + assert S3Settings.model_fields["sqs_region"].default is None + + def test_s3_settings_has_sqs_wait_time_seconds(self): + """Verify sqs_wait_time_seconds field exists with 20 default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "sqs_wait_time_seconds" in S3Settings.model_fields + assert S3Settings.model_fields["sqs_wait_time_seconds"].default == 20 + + def test_s3_settings_has_s3_buffer_size_mb(self): + """Verify s3_buffer_size_mb field exists with 10 default.""" + from burr.tracking.server.s3.backend import S3Settings + + assert "s3_buffer_size_mb" in S3Settings.model_fields + assert S3Settings.model_fields["s3_buffer_size_mb"].default == 10 + + +class TestSQLiteS3BackendInit: + """Test SQLiteS3Backend accepts and stores BIP-0042 parameters.""" + + def test_backend_accepts_new_parameters(self): + """Verify __init__ accepts all 5 new BIP-0042 parameters.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + sig = inspect.signature(SQLiteS3Backend.__init__) + params = list(sig.parameters.keys()) + + assert "tracking_mode" in params + assert "sqs_queue_url" in params + assert "sqs_region" in params + assert "sqs_wait_time_seconds" in params + assert "s3_buffer_size_mb" in params + + def test_backend_has_event_driven_methods(self): + """Verify SQLiteS3Backend has BIP-0042 event-driven methods.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + assert hasattr(SQLiteS3Backend, "_handle_s3_event") + assert hasattr(SQLiteS3Backend, "start_sqs_consumer") + assert hasattr(SQLiteS3Backend, "is_event_driven") + assert callable(getattr(SQLiteS3Backend, "_handle_s3_event")) + assert callable(getattr(SQLiteS3Backend, "start_sqs_consumer")) + assert callable(getattr(SQLiteS3Backend, "is_event_driven")) + + +class TestEventDrivenBackendMixin: + """Test EventDrivenBackendMixin exists and has correct interface.""" + + def test_mixin_exists(self): + """Verify EventDrivenBackendMixin exists in backend.py.""" + from burr.tracking.server.backend import EventDrivenBackendMixin + + assert EventDrivenBackendMixin is not None + + def test_mixin_has_abstract_methods(self): + """Verify mixin has abstract start_sqs_consumer and is_event_driven.""" + import abc + + from burr.tracking.server.backend import EventDrivenBackendMixin + + assert issubclass(EventDrivenBackendMixin, abc.ABC) + assert hasattr(EventDrivenBackendMixin, "start_sqs_consumer") + assert hasattr(EventDrivenBackendMixin, "is_event_driven") + + def test_sqlite_s3_backend_inherits_mixin(self): + """Verify SQLiteS3Backend inherits from EventDrivenBackendMixin.""" + from burr.tracking.server.backend import EventDrivenBackendMixin + from burr.tracking.server.s3.backend import SQLiteS3Backend + + assert issubclass(SQLiteS3Backend, EventDrivenBackendMixin) + + +class TestQueryS3FileBuffering: + """Test _query_s3_file function signature includes buffer_size_mb.""" + + def test_query_s3_file_has_buffer_param(self): + """Verify _query_s3_file accepts buffer_size_mb parameter.""" + from burr.tracking.server.s3.backend import _query_s3_file + + sig = inspect.signature(_query_s3_file) + params = list(sig.parameters.keys()) + + assert "buffer_size_mb" in params + assert sig.parameters["buffer_size_mb"].default == 10 + + +class TestHandleS3Event: + """Test _handle_s3_event creates project if it doesn't exist.""" + + def test_handle_s3_event_method_exists(self): + """Verify _handle_s3_event method exists and is async.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + assert hasattr(SQLiteS3Backend, "_handle_s3_event") + method = getattr(SQLiteS3Backend, "_handle_s3_event") + assert inspect.iscoroutinefunction(method) + + def test_handle_s3_event_signature(self): + """Verify _handle_s3_event accepts s3_key and event_time parameters.""" + from burr.tracking.server.s3.backend import SQLiteS3Backend + + sig = inspect.signature(SQLiteS3Backend._handle_s3_event) + params = list(sig.parameters.keys()) + + assert "self" in params + assert "s3_key" in params + assert "event_time" in params From 7687e0eecbb1780a5be2937f757b729a10b31e3f Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Sun, 1 Mar 2026 01:45:16 -0600 Subject: [PATCH 2/8] feat: BIP-0042 Cloud-Native Architecture for Apache Burr on AWS (#664) --- terraform/modules/iam/main.tf | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/terraform/modules/iam/main.tf b/terraform/modules/iam/main.tf index 40d7943c9..bb7e6a955 100644 --- a/terraform/modules/iam/main.tf +++ b/terraform/modules/iam/main.tf @@ -66,7 +66,7 @@ resource "aws_iam_role_policy" "s3" { } data "aws_iam_policy_document" "sqs_least_privilege" { - count = var.enable_sqs && var.sqs_queue_arn != "" ? 1 : 0 + count = var.enable_sqs ? 1 : 0 statement { sid = "SQSConsume" @@ -81,7 +81,7 @@ data "aws_iam_policy_document" "sqs_least_privilege" { } resource "aws_iam_role_policy" "sqs" { - count = var.enable_sqs && var.sqs_queue_arn != "" ? 1 : 0 + count = var.enable_sqs ? 1 : 0 name = "${var.role_name}-sqs" role = aws_iam_role.burr_server.id policy = data.aws_iam_policy_document.sqs_least_privilege[0].json From 2fcaae023eb400b6a2686af805c7e643d1ce11e9 Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Sun, 1 Mar 2026 02:02:18 -0600 Subject: [PATCH 3/8] feat: BIP-0042 Cloud-Native Architecture for Apache Burr on AWS (#664) --- terraform/.terraform.lock.hcl | 20 +++++++++ terraform/dev.tfvars | 6 ++- terraform/main.tf | 56 ++++++++++++++++++++++++- terraform/modules/sqs/outputs.tf | 5 +++ terraform/outputs.tf | 10 +++++ terraform/prod.tfvars | 9 +++- terraform/tutorial.md | 70 ++++++++++++++++++++------------ terraform/variables.tf | 15 ++++++- 8 files changed, 160 insertions(+), 31 deletions(-) diff --git a/terraform/.terraform.lock.hcl b/terraform/.terraform.lock.hcl index abbb80149..ddfac895c 100644 --- a/terraform/.terraform.lock.hcl +++ b/terraform/.terraform.lock.hcl @@ -23,3 +23,23 @@ provider "registry.terraform.io/hashicorp/aws" { "zh:fcd3f30b91dbcc7db67d5d39268741ffa46696a230a1f2aef32d245ace54bf65", ] } + +provider "registry.terraform.io/hashicorp/random" { + version = "3.8.1" + constraints = ">= 3.0.0" + hashes = [ + "h1:osH3aBqEARwOz3VBJKdpFKJJCNIdgRC6k8vPojkLmlY=", + "zh:08dd03b918c7b55713026037c5400c48af5b9f468f483463321bd18e17b907b4", + "zh:0eee654a5542dc1d41920bbf2419032d6f0d5625b03bd81339e5b33394a3e0ae", + "zh:229665ddf060aa0ed315597908483eee5b818a17d09b6417a0f52fd9405c4f57", + "zh:2469d2e48f28076254a2a3fc327f184914566d9e40c5780b8d96ebf7205f8bc0", + "zh:37d7eb334d9561f335e748280f5535a384a88675af9a9eac439d4cfd663bcb66", + "zh:741101426a2f2c52dee37122f0f4a2f2d6af6d852cb1db634480a86398fa3511", + "zh:78d5eefdd9e494defcb3c68d282b8f96630502cac21d1ea161f53cfe9bb483b3", + "zh:a902473f08ef8df62cfe6116bd6c157070a93f66622384300de235a533e9d4a9", + "zh:b85c511a23e57a2147355932b3b6dce2a11e856b941165793a0c3d7578d94d05", + "zh:c5172226d18eaac95b1daac80172287b69d4ce32750c82ad77fa0768be4ea4b8", + "zh:dab4434dba34aad569b0bc243c2d3f3ff86dd7740def373f2a49816bd2ff819b", + "zh:f49fd62aa8c5525a5c17abd51e27ca5e213881d58882fd42fec4a545b53c9699", + ] +} diff --git a/terraform/dev.tfvars b/terraform/dev.tfvars index b35fd2a92..40c0bad98 100644 --- a/terraform/dev.tfvars +++ b/terraform/dev.tfvars @@ -16,12 +16,14 @@ # under the License. # Development environment configuration -# Replace ACCOUNT_ID with your AWS account ID for s3_bucket_name +# Bucket name is auto-generated: burr-tracking-{env}-{region}-{account_id}-{random} +# account_id: leave empty to auto-fetch from AWS credentials, or set explicitly aws_region = "us-east-1" environment = "dev" -s3_bucket_name = "burr-tracking-logs-dev-ACCOUNT_ID" +# account_id = "" # Optional. Empty = auto-fetch. Or set: account_id = "123456789012" + sqs_queue_name = "burr-s3-events-dev" # S3 only (polling mode) - simpler for dev; set to true for event-driven diff --git a/terraform/main.tf b/terraform/main.tf index 9ca90dba8..68cd3a852 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -23,6 +23,10 @@ terraform { source = "hashicorp/aws" version = ">= 5.0" } + random = { + source = "hashicorp/random" + version = ">= 3.0" + } } } @@ -33,10 +37,21 @@ provider "aws" { data "aws_caller_identity" "current" {} data "aws_region" "current" {} +resource "random_id" "bucket_suffix" { + byte_length = 4 +} + +locals { + region_short = replace(data.aws_region.current.name, "-", "") + account_id = var.account_id != "" ? var.account_id : data.aws_caller_identity.current.account_id + auto_bucket = "burr-tracking-${var.environment}-${local.region_short}-${local.account_id}-${random_id.bucket_suffix.hex}" + bucket_name = var.s3_bucket_name != "" ? var.s3_bucket_name : local.auto_bucket +} + module "s3" { source = "./modules/s3" - bucket_name = var.s3_bucket_name + bucket_name = local.bucket_name tags = local.common_tags lifecycle_rules = [ @@ -110,6 +125,45 @@ resource "aws_s3_bucket_notification" "burr_logs" { depends_on = [aws_sqs_queue_policy.s3_notifications] } +resource "aws_sns_topic" "dlq_alarm" { + count = var.enable_sqs ? 1 : 0 + + name = "${var.environment}-burr-dlq-alarm" + display_name = "Burr DLQ Alarm - ${var.environment}" + tags = local.common_tags +} + +resource "aws_sns_topic_subscription" "dlq_alarm_email" { + count = var.enable_sqs && length(var.dlq_alarm_notification_emails) > 0 ? length(var.dlq_alarm_notification_emails) : 0 + + topic_arn = aws_sns_topic.dlq_alarm[0].arn + protocol = "email" + endpoint = var.dlq_alarm_notification_emails[count.index] +} + +resource "aws_cloudwatch_metric_alarm" "dlq_messages" { + count = var.enable_sqs ? 1 : 0 + + alarm_name = "${var.environment}-burr-dlq-messages" + alarm_description = "Alarm when messages appear in Burr SQS dead letter queue" + comparison_operator = "GreaterThanThreshold" + evaluation_periods = 1 + metric_name = "ApproximateNumberOfMessagesVisible" + namespace = "AWS/SQS" + period = 60 + statistic = "Sum" + threshold = 0 + + alarm_actions = [aws_sns_topic.dlq_alarm[0].arn] + ok_actions = [aws_sns_topic.dlq_alarm[0].arn] + + dimensions = { + QueueName = module.sqs[0].dlq_name + } + + tags = local.common_tags +} + module "iam" { source = "./modules/iam" diff --git a/terraform/modules/sqs/outputs.tf b/terraform/modules/sqs/outputs.tf index ef5c86e88..5b7ccd098 100644 --- a/terraform/modules/sqs/outputs.tf +++ b/terraform/modules/sqs/outputs.tf @@ -39,3 +39,8 @@ output "dlq_arn" { description = "ARN of the dead letter queue" value = aws_sqs_queue.dlq.arn } + +output "dlq_name" { + description = "Name of the dead letter queue (for CloudWatch dimensions)" + value = aws_sqs_queue.dlq.name +} diff --git a/terraform/outputs.tf b/terraform/outputs.tf index a14b0ef99..8dba499a7 100644 --- a/terraform/outputs.tf +++ b/terraform/outputs.tf @@ -40,6 +40,16 @@ output "sqs_dlq_url" { value = var.enable_sqs ? module.sqs[0].dlq_url : null } +output "dlq_alarm_arn" { + description = "ARN of the CloudWatch alarm for DLQ messages" + value = var.enable_sqs ? aws_cloudwatch_metric_alarm.dlq_messages[0].arn : null +} + +output "dlq_alarm_sns_topic_arn" { + description = "ARN of the SNS topic for DLQ alarm notifications" + value = var.enable_sqs ? aws_sns_topic.dlq_alarm[0].arn : null +} + output "iam_role_arn" { description = "ARN of the IAM role for Burr server" value = module.iam.role_arn diff --git a/terraform/prod.tfvars b/terraform/prod.tfvars index 9dc6410a5..2e1e30a80 100644 --- a/terraform/prod.tfvars +++ b/terraform/prod.tfvars @@ -16,12 +16,14 @@ # under the License. # Production environment configuration -# Replace ACCOUNT_ID with your AWS account ID for s3_bucket_name +# Bucket name is auto-generated: burr-tracking-{env}-{region}-{account_id}-{random} +# account_id: leave empty to auto-fetch from AWS credentials, or set explicitly aws_region = "us-east-1" environment = "prod" -s3_bucket_name = "burr-tracking-logs-prod-ACCOUNT_ID" +# account_id = "" # Optional. Empty = auto-fetch. Or set: account_id = "123456789012" + sqs_queue_name = "burr-s3-events-prod" enable_sqs = true @@ -35,3 +37,6 @@ sqs_receive_wait_time_seconds = 20 sqs_max_receive_count = 3 enable_bedrock = false + +# Optional: receive email when messages land in DLQ +# dlq_alarm_notification_emails = ["ops@example.com"] diff --git a/terraform/tutorial.md b/terraform/tutorial.md index 90e389937..28132a4c5 100644 --- a/terraform/tutorial.md +++ b/terraform/tutorial.md @@ -2,12 +2,25 @@ This tutorial explains how to deploy Apache Burr tracking infrastructure on AWS using Terraform. All Terraform code lives in the `terraform/` folder. It covers deployment with S3 only (polling mode), with S3 and SQS (event-driven mode), and local development without AWS. +## Quick Start + +```bash +cd terraform +terraform init +terraform apply -var-file=dev.tfvars # S3 only, polling mode +# or +terraform apply -var-file=prod.tfvars # S3 + SQS, event-driven + DLQ alarm +``` + +Bucket names are auto-generated. After apply, run `terraform output burr_environment_variables` and set those on your Burr server. + ## Overview The Terraform configuration provisions: -- **S3 bucket**: Stores Burr application logs and database snapshots (always created for AWS deployment) +- **S3 bucket**: Stores Burr application logs and database snapshots. Name is auto-generated (`burr-tracking-{env}-{region}-{account_id}-{random}`) when not specified. - **SQS queue** (optional): Receives S3 event notifications for real-time tracking; controlled by `enable_sqs` +- **CloudWatch alarm + SNS**: Alerts when messages land in the dead letter queue; optional email subscriptions - **IAM role**: Least-privilege permissions for the Burr server ## Directory Structure @@ -16,11 +29,11 @@ All code is in `terraform/`: ``` terraform/ -├── main.tf # Root module wiring S3, SQS, IAM +├── main.tf # Root module: S3, SQS, CloudWatch alarm, SNS, IAM ├── variables.tf # Input variables ├── outputs.tf # Output values ├── dev.tfvars # Development: S3 only (enable_sqs = false) -├── prod.tfvars # Production: S3 + SQS (enable_sqs = true) +├── prod.tfvars # Production: S3 + SQS + DLQ alarm (enable_sqs = true) ├── tutorial.md # This file └── modules/ ├── s3/ # S3 bucket with versioning, encryption, lifecycle @@ -32,28 +45,19 @@ terraform/ - Terraform >= 1.0 - AWS CLI configured with credentials -- AWS account ID (for unique S3 bucket names) - -Get your AWS account ID: -```bash -aws sts get-caller-identity --query Account --output text -``` +No manual bucket naming required; names are auto-generated. `account_id` is fetched from AWS credentials when not set. For a custom bucket name, set `s3_bucket_name` in your tfvars. ## Using tfvars Files -| File | Mode | enable_sqs | Resources created | -|-------------|-------------------|------------|--------------------------| -| dev.tfvars | S3 only (polling) | false | S3 bucket, IAM role | -| prod.tfvars | S3 + SQS (event) | true | S3 bucket, SQS queue, IAM | +| File | Mode | enable_sqs | Resources created | +|-------------|-------------------|------------|--------------------------------------------------------| +| dev.tfvars | S3 only (polling) | false | S3 bucket, IAM role | +| prod.tfvars | S3 + SQS (event) | true | S3 bucket, SQS queue, DLQ, CloudWatch alarm, SNS, IAM | ### Development (dev.tfvars) - S3 Only -Uses S3 polling mode (no SQS). Edit and replace `ACCOUNT_ID` in `s3_bucket_name`: - -``` -s3_bucket_name = "burr-tracking-logs-dev-123456789012" -``` +Uses S3 polling mode (no SQS). Bucket name is auto-generated (`burr-tracking-{env}-{region}-{account_id}-{random}`). Override with `s3_bucket_name = "my-bucket"` in tfvars if needed. Deploy: @@ -66,11 +70,7 @@ terraform apply -var-file=dev.tfvars ### Production (prod.tfvars) - S3 + SQS -Uses event-driven mode with SQS. Edit and replace `ACCOUNT_ID`: - -``` -s3_bucket_name = "burr-tracking-logs-prod-123456789012" -``` +Uses event-driven mode with SQS. Bucket name is auto-generated (`burr-tracking-{env}-{region}-{account_id}-{random}`). A CloudWatch alarm fires when messages land in the DLQ. Deploy: @@ -147,11 +147,24 @@ burr --no-open |----------|-------------|---------| | aws_region | AWS region | us-east-1 | | environment | Environment name (dev, prod) | dev | -| s3_bucket_name | S3 bucket name (must be globally unique) | (required) | +| account_id | AWS account ID. Empty = auto-fetch from credentials | "" | +| s3_bucket_name | S3 bucket name. Empty = auto-generated (env, region, account_id, random) | "" | | enable_sqs | Create SQS for event-driven tracking | true | +| sqs_queue_name | Name of the SQS queue | burr-s3-events | | log_retention_days | Days to retain logs in S3 | 90 | | snapshot_retention_days | Days to retain DB snapshots | 30 | | enable_bedrock | Add Bedrock IAM permissions | false | +| dlq_alarm_notification_emails | Emails to notify when DLQ has messages (confirm via AWS email) | [] | + +## CloudWatch DLQ Alarm and SNS Notifications + +When SQS is enabled, a CloudWatch alarm fires when messages appear in the dead letter queue. An SNS topic is created for notifications. To receive email alerts, add your addresses to `dlq_alarm_notification_emails` in your tfvars: + +``` +dlq_alarm_notification_emails = ["ops@example.com", "oncall@example.com"] +``` + +Each email will receive a confirmation request from AWS; you must confirm the subscription before alerts are delivered. To use Slack or other endpoints, subscribe them to the SNS topic ARN (see `terraform output dlq_alarm_sns_topic_arn`) after apply. ## Outputs @@ -160,6 +173,9 @@ After apply, useful outputs: ```bash terraform output s3_bucket_name terraform output sqs_queue_url +terraform output sqs_dlq_url +terraform output dlq_alarm_arn +terraform output dlq_alarm_sns_topic_arn terraform output burr_environment_variables ``` @@ -187,8 +203,12 @@ aws s3api list-object-versions --bucket BUCKET_NAME --output json | jq -r '.Vers ## Troubleshooting -**S3 bucket name already exists**: S3 bucket names are globally unique. Use your account ID or a random suffix. +**S3 bucket name already exists**: S3 bucket names are globally unique. With auto-generation, each apply gets a new random suffix. For a fixed name, set `s3_bucket_name` explicitly. **SQS policy errors**: Ensure the S3 bucket notification depends on the queue policy. The Terraform handles this with `depends_on`. **Burr server not receiving events**: Verify BURR_SQS_QUEUE_URL is set and the IAM role has sqs:ReceiveMessage. Check CloudWatch for the SQS consumer. + +**DLQ alarm firing**: Messages in the DLQ mean the Burr server failed to process S3 events (e.g. crashed, timeout). Check the DLQ in the AWS Console, inspect failed messages, and fix the root cause. Confirm SNS email subscriptions via the link AWS sends. + +**No email from DLQ alarm**: Check your spam folder for the SNS confirmation email. Subscriptions are pending until confirmed. diff --git a/terraform/variables.tf b/terraform/variables.tf index 29fd36aef..5940be786 100644 --- a/terraform/variables.tf +++ b/terraform/variables.tf @@ -27,9 +27,16 @@ variable "environment" { default = "dev" } +variable "account_id" { + description = "AWS account ID for bucket name. Leave empty to auto-fetch from AWS credentials." + type = string + default = "" +} + variable "s3_bucket_name" { - description = "Name of the S3 bucket for Burr logs" + description = "Name of the S3 bucket for Burr logs. If empty, auto-generated from environment, region, and random suffix." type = string + default = "" } variable "enable_sqs" { @@ -91,3 +98,9 @@ variable "bedrock_model_arns" { type = list(string) default = [] } + +variable "dlq_alarm_notification_emails" { + description = "Email addresses to notify when messages land in the DLQ. Empty = no email subscriptions." + type = list(string) + default = [] +} From 57eda919c7b665fec314f89a0a5b8bc70459d851 Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Sun, 15 Mar 2026 17:21:00 -0500 Subject: [PATCH 4/8] BIP-0042 Cloud-Native Architecture for Apache Burr on AWS (#664) fixed build error --- burr/tracking/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/burr/tracking/__init__.py b/burr/tracking/__init__.py index a62e68fb1..36578873e 100644 --- a/burr/tracking/__init__.py +++ b/burr/tracking/__init__.py @@ -16,6 +16,14 @@ # under the License. from .client import LocalTrackingClient -from .s3client import S3TrackingClient + + +def __getattr__(name: str): + """Lazy load S3TrackingClient to avoid requiring boto3 unless used.""" + if name == "S3TrackingClient": + from burr.tracking.s3client import S3TrackingClient + return S3TrackingClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = ["LocalTrackingClient", "S3TrackingClient"] From 931299fbcb80d311c9359edcc030ef32c71a0d00 Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Mon, 16 Mar 2026 00:13:19 -0500 Subject: [PATCH 5/8] chore: move Bedrock integration to separate PR and review comments fixed --- .gitignore | 13 +- bedrock-changes.patch | Bin 0 -> 35472 bytes burr/integrations/__init__.py | 13 +- burr/integrations/bedrock.py | 227 ------------------ burr/tracking/__init__.py | 11 +- burr/tracking/server/backend.py | 16 +- burr/tracking/server/run.py | 17 +- burr/tracking/server/s3/README.md | 3 +- burr/tracking/server/s3/backend.py | 141 ++++++----- burr/version.py | 8 +- .../deployment/aws/terraform}/.gitignore | 1 + .../deployment/aws/terraform}/dev.tfvars | 2 - .../deployment/aws/terraform}/main.tf | 13 +- .../aws/terraform}/modules/iam/main.tf | 20 -- .../aws/terraform}/modules/iam/outputs.tf | 0 .../aws/terraform}/modules/iam/variables.tf | 17 -- .../aws/terraform}/modules/s3/main.tf | 0 .../aws/terraform}/modules/s3/outputs.tf | 0 .../aws/terraform}/modules/s3/variables.tf | 0 .../aws/terraform}/modules/sqs/main.tf | 0 .../aws/terraform}/modules/sqs/outputs.tf | 0 .../aws/terraform}/modules/sqs/variables.tf | 0 .../deployment/aws/terraform}/outputs.tf | 2 +- .../deployment/aws/terraform}/prod.tfvars | 2 - .../deployment/aws/terraform}/tutorial.md | 14 +- .../deployment/aws/terraform}/variables.tf | 12 - terraform/.terraform.lock.hcl | 45 ---- tests/integrations/test_bip0042_bedrock.py | 168 ------------- tests/tracking/test_bip0042_s3_buffering.py | 21 +- 29 files changed, 146 insertions(+), 620 deletions(-) create mode 100644 bedrock-changes.patch delete mode 100644 burr/integrations/bedrock.py rename {terraform => examples/deployment/aws/terraform}/.gitignore (87%) rename {terraform => examples/deployment/aws/terraform}/dev.tfvars (98%) rename {terraform => examples/deployment/aws/terraform}/main.tf (90%) rename {terraform => examples/deployment/aws/terraform}/modules/iam/main.tf (79%) rename {terraform => examples/deployment/aws/terraform}/modules/iam/outputs.tf (100%) rename {terraform => examples/deployment/aws/terraform}/modules/iam/variables.tf (76%) rename {terraform => examples/deployment/aws/terraform}/modules/s3/main.tf (100%) rename {terraform => examples/deployment/aws/terraform}/modules/s3/outputs.tf (100%) rename {terraform => examples/deployment/aws/terraform}/modules/s3/variables.tf (100%) rename {terraform => examples/deployment/aws/terraform}/modules/sqs/main.tf (100%) rename {terraform => examples/deployment/aws/terraform}/modules/sqs/outputs.tf (100%) rename {terraform => examples/deployment/aws/terraform}/modules/sqs/variables.tf (100%) rename {terraform => examples/deployment/aws/terraform}/outputs.tf (98%) rename {terraform => examples/deployment/aws/terraform}/prod.tfvars (98%) rename {terraform => examples/deployment/aws/terraform}/tutorial.md (94%) rename {terraform => examples/deployment/aws/terraform}/variables.tf (89%) delete mode 100644 terraform/.terraform.lock.hcl delete mode 100644 tests/integrations/test_bip0042_bedrock.py diff --git a/.gitignore b/.gitignore index c56ee23ed..1982de051 100644 --- a/.gitignore +++ b/.gitignore @@ -194,9 +194,10 @@ examples/*/statemachine examples/*/*/statemachine .vscode -# Terraform (see also terraform/.gitignore) -terraform/.terraform/ -terraform/*.tfstate -terraform/*.tfstate.* -terraform/.terraform.tfstate.lock.info -terraform/*.tfplan +# Terraform (see also examples/deployment/aws/terraform/.gitignore) +**/.terraform.lock.hcl +examples/deployment/aws/terraform/.terraform/ +examples/deployment/aws/terraform/*.tfstate +examples/deployment/aws/terraform/*.tfstate.* +examples/deployment/aws/terraform/.terraform.tfstate.lock.info +examples/deployment/aws/terraform/*.tfplan diff --git a/bedrock-changes.patch b/bedrock-changes.patch new file mode 100644 index 0000000000000000000000000000000000000000..fac3e09828f861cfe0890e9d31ef71d016f2d833 GIT binary patch literal 35472 zcmeI5Yi}IKm4@eYfc*~>`U}0bMpoo?Vh387lI2|?wquBjgFp}nN}?`n>arwdNfGqd zvwNO;=T%j8cXiK*6m!8~K=$;dPMtdUbE>-L|NeG=_IUOnKJ(dy_&kiWllZ^u%Xz{9^Wbd}z&BE}VKCK>8$o&SFGg&%O?f zF2ra)pM4r5nFnNj=l?wJew}ci&pwK?Kg@2-ewcj|PtJo1(DzXskJ8zzpz=|i9mi)L zyx{8N*^{_>7{`Oz{h-i1KD+V#AdYG~pM4j1@5L2($Mpxxwz|(~)%S5gI*2Ec3|hNv2u+Z#V}hRk6n z+Wpt~yKiW6lI+UUHCnGW`rUvGe7G%o|Gnukj?wxQS_dAM8a+f;cY-p|=5ey=Sn}&} zhD;xYemn>{dJ^C72VaoK`Rq;79P15Q^y_%?ZQOYjP_a2r<9VdL3zxQTr#4%mZH)AE zex{bfb-m0`YjAoRqkS3Uz{+Xlq60Lz6#k_2C`gq35qI(je{?B)&*j-~!W-R>KVaJf zT}-Wh34i!B(f4p^dGIT%dzID*>{C7qKY<^_qkI~8e38!VQ9el?rMCW9DA_Dq+MUqr z{jebT8GOgHpvy_nXFqt)m2cyUEOfTs^N^pNc=p@ivMd@j!XFVK94yEBQ99=;^cA#w z>4{P*<;U<5rH{hj(BdH3aN-cG685C)-F}$;ads=b=CzRS%u{I@7$AYN|5*>9!ppc0 z4df?zhHGGhy|<6V8vw!fbw4CV-at=02|D>+?i0)(CJNG4K7f%EC83WGqhIXkv*;0e z9Yybt;`;OD7;i1PAsWkrX6k?`mkK>qC>}bMHRN$HC8&!1Jrn_=~|+K8xrh z>((N$I10{R3<)Y_op}1YrTqR7(tSN_i0JJee4Nkz6cos@15)NlgJ=c2028oB+np!RXJt!+a~w}xb{<_tQ79Je4~M;OV|qW`UA zbLK&--SAyRA?OEZ$n2lv+pXDuhc)<9968^a-Hgvp#M9V^>+#$??%s_5*WyT&O^X}p zyY77(_jvvg0i`BtKD&OknE9V)*3pYpb&bCBja3%AHO4u4a z*mhv89PD5q(=r@vR?Y<8Ea~98-6IjIG9}UDIP@L~xE=B4R^m1k)3vRoZ#jx{?7>#@ z9jajPUr%KmZ^spRBQi!f_dL#d?oPzKcivgPE>1q*it)ZqVR5)j-QaoFD!6e!plmHE@hEZn zd0f30RFM|HNIHHYqVEfV8<-MDoh0A+S=fos!(RL*K3~R_Um`mGD$bFMTB=R8%>7!X zuLOs%DMRw!f+dRH53H_)oVI1G)!zOWxfcW=;t3cfwhIaEwpO;~3W>*6O9P|)pJ)%S z#P6;=*dlTtj$=a=FSPtLk`6>PN6}(mI>4gt2?@zH5%JV=xgF!H^PUUA13VHvU=O6< zP=&fBRN0F^Ea*X!8npUCd`G8{h)dDJ^eHK`p45`>Es|v^$8x0_nR_W~Ig&Wl>7XTg zM4jnj!k!pQPh{JE655BfD{rVz*f>GBpZMq2O&wyl2sA3G@9pi zpVxw(mS|Z-ajqOq=J*g*;hZdU7c1zcJ8kYj6R6#--BZ~^FQ8vNRydHeiM=ERFP1V| zVr|(~zKyJQ#T?Q<$l3pK*%KQg@DwT0b2&cCkmFWlj%@faK71!yB?5bpXqI)2=rZdU)vXJ&|BU~(wzrc- zm3@^)-|2D5eamB%SoLXNkFvCk+kuDZ)Uur|`j=2t%YaJAzw;ct`#M>UauRcm*b@68 z4bp6oEXqzu&yDa1eg9{gl>7i5k*fG1+W0viGna<%ccaEU+tZDZ43SDRGd&G|7+ zXhD__Bk}{5y{{#vAF~;wGU8jWr7db6wR}z21}MK#j;r`-n3{~hF%>YoQ9~Y!`uQQ* zQ0&eyJk1;)MXauQ33QLx{zJjGEXp#UMK4IpgP;nOQS}i!lD*1sA7j?yPS6#vIYm2c z+pC}#J*!@h*3ut1q0d0M+@yUO^9*BI+n&;EP2t$LJeK9-6hF%6e;k$o-LCcAp5HBe zw5HRhtJNB7_v*0Bm2g_RE_kiG({Qt<3sOFxZd+=*T9;R*+gs<=)@Rf0z|p#$vG$ee z_AOl&tk&K4gInuyyNB*`#92f@-=wo%fo7k0rmTR}k&5y0QEf|}lw-#Z z6GeGu%l5oAN<&2ta|N^oz2<`3mYzraE3=qZWzDVCGxLvlk}Ii(YDpqOBwtkAZ+R`T z^OH3cWgNE5uEa8HO6EY0jE+*t93Igs-k_#BUN7f@-P+M`FQz~@E?Td$R8j*f*mg<| zb4)-@D4z{u)Ef6X(osYHdYK`q8c)rlzHd#Y2l#Vbup?PR>#s&FoWw_k? z_xPy7x)=URD?c7ai!2qnDtJ5k$#ut^DYvuWI&VjNOFe4EU|{bmmvyC z^{L*+a(GhRRtpX?{%`BC&2nOV%p?ZJe>CvXP~wu~l+^KGqf9YelrYN`==@ zd~|Ic?k*DD%k`Vx=vehPD#u>ui9EcC^S8mL7RK%?uN@uJ8|JcE>BN6Bb*Sw13ZikB z^5}2biuudFt>fyjT@g3%Y1BRTBKm_W%)$_V<55cJU7S20!MtJ(N%M%equ&-~=#!az z^-{O4@wgqCokrz0j0fAtda|*WaEp~BR8(74?93sk&5D7Or0Hkj>AKvKK6?~XnyUG2 zt`9bH-S z-lPnlxht*wH+OC?uOkDVr=%9Z`aHaEZKcaN)G|{)tx<~e#?tihks08ZL8BZw+2%Qt zT7>N6G03WwSBUw>bRR5o?poqR9NF)`gr)D~a-Mgpt=Dcq1|swLQ0QSbPt0t{*!_M* zQdNb1t9fLkc+$u5Uv*Ti`K+<2=TMMa;WRYsQTD;C%4QDaA9D`mxR9XsS|DW5nLS#) znFE>17KW)b=9?VR6)80b8|%%^b*sG+W%~Ma@a)QeM-A?YNe* z0bSY5flN_gGY6tfZ8HZVjq7{f+WM_JUt=$v=N+;=ROH>VeR$=ZRoa?TWjD+*t6WC) z)CZUm$vB*T?xVl9x}F2cF>GGbvzZO4WvFYzhvqqm>bv#)y1K}$^kz0>GaE8AC$X6g z8P_;$>{-o*sKTgx@5w*S+(_vs$a+~TbRP31dc%*EufsgZHUF-aEHTd6YRSd2wq8Cx zJ+tAO^g&vmqgo17(MkdHw&k&nsb_ra`HidZ=4ZYMsbj8UKP02{1n$H(DcOu)+iZV z_AobFJ@=m$m62c5(mi&Mc;E z^qc3e{CpYFm)XV?hTv7^1?J<>ELL|kE&Xl5MBjROAr0g zyQBQdX8$AJ;l|#fTSTiwExe85HyK#3ocHj;H{$nO+R6GB=)Kwh=zMo1J(&iJv*c*sUK0)@1`gsuHQ2!5+NJZjvjhq=s9*avnMW}(syQDydvjTz`|G57?y9T z^9G{YIiA$am0OnY8`n_u?v8q*7tWDy-sb4;Mj$O>CnxsG(HlLioMPP{b$dqsFybin zU5?3EIcM;=_ufHa{cC>fz8jmdmOk%hk+Y0;L@ylVF^C_*vmSZ9mS}3!<$ea^tT!v*X!wD)uT&F zvR1RV*--pl%EXO`e2Hm~BGTMX`Q%hQUTaGUtBlCuNEb3YrMXI^H*Cr`uZeTo@i(hR za@Mg(bx7W%A)@6^;V20qqtv))r2 zwh}r1p^89l-FPk1__h074e$HFfmN*cW*+NUw9%)t=GFR8h+c=W%rMFod#8b8p4V^m zFOs}?7GN*EyKxv&M8?Q`#V{W0jh^`%r1C7VH!QzHs8xxJYK>y6pD9Y4Lb;F9GOAXI z99|1~^xlJI&I7%!M^r%-D(?n4MXSEPEs{SWu&kqn*7Z`iEkm{Zl$K{mitSfy zKgVl5d(k%@3aidr4XkpJCm;{dniuq@l{4448RN^7-f8v9QH>vS*%?3LuI}Ed>v*Kn zpI|-iF75s>pUS84s?&9Flz z%hq9$VlFJ7>hm=nruZ>T%c1@pi*@Iir*g1%tT#r@Iv*h;Msn)(BogHt{k0)zibUJ;V%CGtw+{Lzx&o}z(6%*&Bzhm4uZ|xnI zrL)!N@xG+tW3+i+{U(F8HQkF$4ojqUl6fa;_BGPFaPMrT{mVYACYO4cYjUzhL^k&Bw~Z{<{?1VfT}_lf7lR*Dq4Qf&6PJQ@ECTqW0x}Eo1L= z4BHjcppIKgTcT)CZ+O8oGGEpT*LV+8cnoSek`iurJnKWww!^>i*^Q0CJ@4Q=WdEn* zsgg#Dq+AQY?s|80Dr@Wogv217tb{sGYYUbcFLDKHbj&u_zXgSKYleh*cf0~?Z-%hn z*xr9m+e-${dm893@^5~#=fCcuH?As=WhuN(Snot_+hS|tW0oIg<>;9VlS~9103X)x z_BCX&tV?SZw4twJ39!wzU&yNia!>xPG2`iDZLdljN7S76dpHfuZ2wOKrL>->fn!^F z8W@bo`#zioPHWGHvxZsk@^BjH=+W3d)~A6zykE;{AeJ@0F*voJt?LQ(J8<|ws8j9` z-P-Gcct6iQe28kFagyaLanih69n=7ZCA={0=(vTp7c9ensnX@UPwZ zJlhu^>snWNwy78LPSg9X1BIzs2Q7zLB3Sp?EIBR*7E$!=k!& z-=C+O+UayFq%`j@(tR^7G8`>#c`*60TPQRF|Rk~$wWYj~npnovf3crNz4Z~gCvXl`n{#f}S+1_5# zGjhdMRw~#B)*f(?^p#a6jF3G?SbI|Ulc8%Y`Va4fTVATMcUe7ai+0hIJNPbEO|KDVAWBjBwOki`OgWdcf4`n|glp{YOI2yf1T* za(&)!ydPetT!Ca=^O1`6*{*+|ntxkubfpX#*Y2^D`PKO^8lBxb`PRBBV_UEBI*tsO a_g?i2u3o{Fe<4K~Jo-Vr3IF__{r>}22_q!{ literal 0 HcmV?d00001 diff --git a/burr/integrations/__init__.py b/burr/integrations/__init__.py index 956579056..9f5af4ccb 100644 --- a/burr/integrations/__init__.py +++ b/burr/integrations/__init__.py @@ -13,15 +13,4 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. - - -def __getattr__(name: str): - """Lazy load Bedrock integration to avoid requiring boto3 unless used.""" - if name == "BedrockAction": - from burr.integrations.bedrock import BedrockAction - return BedrockAction - if name == "BedrockStreamingAction": - from burr.integrations.bedrock import BedrockStreamingAction - return BedrockStreamingAction - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +# in the License. diff --git a/burr/integrations/bedrock.py b/burr/integrations/bedrock.py deleted file mode 100644 index 9a6353869..000000000 --- a/burr/integrations/bedrock.py +++ /dev/null @@ -1,227 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Amazon Bedrock integration for Burr. - -BIP-0042: This module provides Action classes for invoking Amazon Bedrock models -within Burr applications. - -Example usage: - from burr.integrations.bedrock import BedrockAction - - def prompt_mapper(state): - return { - "messages": [{"role": "user", "content": state["user_input"]}], - "system": [{"text": "You are a helpful assistant."}], - } - - action = BedrockAction( - model_id="anthropic.claude-3-sonnet-20240229-v1:0", - input_mapper=prompt_mapper, - reads=["user_input"], - writes=["response"], - ) -""" - -import logging -from typing import Any, Dict, Generator, List, Optional, Protocol, Tuple - -from burr.core.action import SingleStepAction, StreamingAction -from burr.core.state import State -from burr.integrations.base import require_plugin - -logger = logging.getLogger(__name__) - -try: - import boto3 - from botocore.config import Config - from botocore.exceptions import ClientError -except ImportError as e: - require_plugin(e, "bedrock") - - -class StateToPromptMapper(Protocol): - """Protocol for mapping Burr state to Bedrock prompt format.""" - - def __call__(self, state: State) -> Dict[str, Any]: - ... - - -class BedrockAction(SingleStepAction): - """Action that invokes Amazon Bedrock models using the Converse API.""" - - def __init__( - self, - model_id: str, - input_mapper: StateToPromptMapper, - reads: List[str], - writes: List[str], - name: str = "bedrock_invoke", - region: Optional[str] = None, - guardrail_id: Optional[str] = None, - guardrail_version: Optional[str] = None, - inference_config: Optional[Dict[str, Any]] = None, - max_retries: int = 3, - ): - super().__init__() - self._model_id = model_id - self._input_mapper = input_mapper - self._reads = reads - self._writes = writes - self._name = name - self._region = region - self._guardrail_id = guardrail_id - self._guardrail_version = guardrail_version or "DRAFT" - self._inference_config = inference_config or {"maxTokens": 4096} - - config = Config(retries={"max_attempts": max_retries, "mode": "adaptive"}) - self._client = boto3.client("bedrock-runtime", region_name=region, config=config) - - @property - def reads(self) -> List[str]: - return self._reads - - @property - def writes(self) -> List[str]: - return self._writes - - @property - def name(self) -> str: - return self._name - - def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: - prompt = self._input_mapper(state) - - request: Dict[str, Any] = { - "modelId": self._model_id, - "messages": prompt["messages"], - "inferenceConfig": self._inference_config, - } - - if "system" in prompt: - request["system"] = prompt["system"] - - if self._guardrail_id: - request["guardrailConfig"] = { - "guardrailIdentifier": self._guardrail_id, - "guardrailVersion": self._guardrail_version, - } - - try: - response = self._client.converse(**request) - except ClientError as e: - logger.error(f"Bedrock API error: {e}") - raise - - output_message = response["output"]["message"] - content_blocks = output_message.get("content", []) - text = content_blocks[0]["text"] if content_blocks else "" - - result: Dict[str, Any] = { - "response": text, - "usage": response.get("usage", {}), - "stop_reason": response.get("stopReason"), - } - - updates = {key: result[key] for key in self._writes if key in result} - new_state = state.update(**updates) - - return result, new_state - - -class BedrockStreamingAction(StreamingAction): - """Streaming variant of BedrockAction using Converse Stream API.""" - - def __init__( - self, - model_id: str, - input_mapper: StateToPromptMapper, - reads: List[str], - writes: List[str], - name: str = "bedrock_stream", - region: Optional[str] = None, - guardrail_id: Optional[str] = None, - guardrail_version: Optional[str] = None, - inference_config: Optional[Dict[str, Any]] = None, - max_retries: int = 3, - ): - super().__init__() - self._model_id = model_id - self._input_mapper = input_mapper - self._reads = reads - self._writes = writes - self._name = name - self._region = region - self._guardrail_id = guardrail_id - self._guardrail_version = guardrail_version or "DRAFT" - self._inference_config = inference_config or {"maxTokens": 4096} - - config = Config(retries={"max_attempts": max_retries, "mode": "adaptive"}) - self._client = boto3.client("bedrock-runtime", region_name=region, config=config) - - @property - def reads(self) -> List[str]: - return self._reads - - @property - def writes(self) -> List[str]: - return self._writes - - @property - def name(self) -> str: - return self._name - - def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: - prompt = self._input_mapper(state) - - request: Dict[str, Any] = { - "modelId": self._model_id, - "messages": prompt["messages"], - "inferenceConfig": self._inference_config, - } - - if "system" in prompt: - request["system"] = prompt["system"] - - if self._guardrail_id: - request["guardrailConfig"] = { - "guardrailIdentifier": self._guardrail_id, - "guardrailVersion": self._guardrail_version, - } - - try: - response = self._client.converse_stream(**request) - except ClientError as e: - logger.error(f"Bedrock streaming API error: {e}") - raise - - full_response = "" - stream = response.get("stream", []) - for event in stream: - if "contentBlockDelta" in event: - chunk = event["contentBlockDelta"]["delta"].get("text", "") - full_response += chunk - yield {"chunk": chunk, "response": full_response} - - yield {"chunk": "", "response": full_response, "complete": True} - - def update(self, result: dict, state: State) -> State: - if result.get("complete"): - updates = {"response": result.get("response", "")} - filtered = {k: v for k, v in updates.items() if k in self._writes} - return state.update(**filtered) - return state diff --git a/burr/tracking/__init__.py b/burr/tracking/__init__.py index 36578873e..bc2581ee6 100644 --- a/burr/tracking/__init__.py +++ b/burr/tracking/__init__.py @@ -17,13 +17,4 @@ from .client import LocalTrackingClient - -def __getattr__(name: str): - """Lazy load S3TrackingClient to avoid requiring boto3 unless used.""" - if name == "S3TrackingClient": - from burr.tracking.s3client import S3TrackingClient - return S3TrackingClient - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = ["LocalTrackingClient", "S3TrackingClient"] +__all__ = ["LocalTrackingClient"] diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index bf422c369..904fe0019 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -163,18 +163,18 @@ def snapshot_interval_milliseconds(self) -> Optional[int]: class EventDrivenBackendMixin(abc.ABC): - """Mixin for backends that support event-driven updates via SQS. + """Mixin for backends that support event-driven updates. - BIP-0042: This mixin enables backends to receive real-time notifications - from SQS instead of polling S3 for new files. + Enables backends to receive real-time notifications instead of polling + for new files. """ @abc.abstractmethod - async def start_sqs_consumer(self): - """Start the SQS consumer for event-driven tracking. + async def start_event_consumer(self): + """Start the event consumer for event-driven tracking. - This method should run indefinitely, processing S3 event notifications - from the configured SQS queue. + This method should run indefinitely, processing event notifications + from the configured message queue. """ pass @@ -182,7 +182,7 @@ async def start_sqs_consumer(self): def is_event_driven(self) -> bool: """Check if this backend is configured for event-driven updates. - :return: True if SQS mode is enabled and configured, False otherwise + :return: True if event-driven mode is enabled and configured, False otherwise """ pass diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index cc9627006..95ef427ff 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import asyncio import importlib import logging import os @@ -129,8 +130,6 @@ async def save_snapshot(): @asynccontextmanager async def lifespan(app: FastAPI): - import asyncio - # Download if it does it # For now we do this before the lifespan await download_snapshot() @@ -138,12 +137,20 @@ async def lifespan(app: FastAPI): await backend.lifespan(app).__anext__() await sync_index() # this will trigger the repeat every N seconds await save_snapshot() # this will trigger the repeat every N seconds - # BIP-0042: Start SQS consumer for event-driven tracking when configured + # Start event consumer for event-driven tracking when configured + event_consumer_task = None if isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven(): - asyncio.create_task(backend.start_sqs_consumer()) + event_consumer_task = asyncio.create_task(backend.start_event_consumer()) global initialized initialized = True yield + # Graceful shutdown: cancel event consumer task + if event_consumer_task is not None: + event_consumer_task.cancel() + try: + await event_consumer_task + except asyncio.CancelledError: + pass await backend.lifespan(app).__anext__() @@ -178,7 +185,7 @@ def get_app_spec(): logger = logging.getLogger(__name__) if app_spec.indexing: - # BIP-0042: Only use polling when not in event-driven (SQS) mode + # Only use polling when not in event-driven mode if not ( isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven() ): diff --git a/burr/tracking/server/s3/README.md b/burr/tracking/server/s3/README.md index 0dbd7608f..62a6035b2 100644 --- a/burr/tracking/server/s3/README.md +++ b/burr/tracking/server/s3/README.md @@ -59,8 +59,9 @@ This will immediately start indexing your s3 bucket (or pick up from the last sn To track your data, you use the S3TrackingClient. You pass the tracker to the `ApplicationBuilder`: - ```python +from burr.tracking.s3client import S3TrackingClient + app = ( ApplicationBuilder() .with_graph(graph) diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index d72ecf0f3..681f3b953 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -18,6 +18,7 @@ import asyncio import dataclasses import datetime +import enum import functools import itertools import json @@ -32,6 +33,7 @@ import fastapi import pydantic from aiobotocore import session +from pydantic import field_validator from fastapi import FastAPI from pydantic_settings import BaseSettings from tortoise import functions, transactions @@ -159,6 +161,13 @@ def from_path(cls, path: str, created_date: datetime.datetime) -> "DataFile": ) +class TrackingMode(str, enum.Enum): + """Tracking mode for S3 backend: polling or event-driven.""" + + POLLING = "POLLING" + EVENT_DRIVEN = "EVENT_DRIVEN" + + class S3Settings(BurrSettings): s3_bucket: str update_interval_milliseconds: int = 120_000 @@ -167,12 +176,20 @@ class S3Settings(BurrSettings): load_snapshot_on_start: bool = True prior_snapshots_to_keep: int = 5 # BIP-0042: Event-driven tracking settings - tracking_mode: str = "POLLING" # "POLLING" or "SQS" - POLLING is default for backward compatibility + tracking_mode: TrackingMode = TrackingMode.POLLING sqs_queue_url: Optional[str] = None sqs_region: Optional[str] = None sqs_wait_time_seconds: int = 20 # SQS long polling timeout s3_buffer_size_mb: int = 10 # RAM buffer before spilling to disk + @field_validator("tracking_mode", mode="before") + @classmethod + def coerce_tracking_mode(cls, v: object) -> object: + """Coerce legacy 'SQS' string to EVENT_DRIVEN for backward compatibility.""" + if v == "SQS": + return TrackingMode.EVENT_DRIVEN + return v + def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: # Get the inverse of the timestamp @@ -198,7 +215,7 @@ def __init__( load_snapshot_on_start: bool, prior_snapshots_to_keep: int, # BIP-0042: New parameters for event-driven tracking - tracking_mode: str = "POLLING", + tracking_mode: Union[TrackingMode, str] = TrackingMode.POLLING, sqs_queue_url: Optional[str] = None, sqs_region: Optional[str] = None, sqs_wait_time_seconds: int = 20, @@ -215,8 +232,13 @@ def __init__( self._load_snapshot_on_start = load_snapshot_on_start self._snapshot_key_history = [] self._prior_snapshots_to_keep = prior_snapshots_to_keep - # BIP-0042: Store event-driven tracking settings - self._tracking_mode = tracking_mode + # BIP-0042: Store event-driven tracking settings (normalize str to enum) + if isinstance(tracking_mode, TrackingMode): + self._tracking_mode = tracking_mode + elif tracking_mode == "SQS": + self._tracking_mode = TrackingMode.EVENT_DRIVEN + else: + self._tracking_mode = TrackingMode(tracking_mode) self._sqs_queue_url = sqs_queue_url self._sqs_region = sqs_region self._sqs_wait_time_seconds = sqs_wait_time_seconds @@ -739,73 +761,74 @@ async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> logger.info(f"Indexed S3 event: {s3_key}") except Exception as e: logger.error(f"Failed to handle S3 event {s3_key}: {e}") + raise # Re-raise so message stays in queue for retry / DLQ - async def start_sqs_consumer(self) -> None: - """Start the SQS consumer for event-driven tracking. + async def start_event_consumer(self) -> None: + """Start the event consumer for event-driven tracking. - BIP-0042: This method runs indefinitely, processing S3 event notifications - from the configured SQS queue. It handles both EventBridge and direct S3 - notification formats. + Runs indefinitely, processing S3 event notifications from the configured + message queue. Handles both EventBridge and direct S3 notification formats. """ - if self._tracking_mode != "SQS" or not self._sqs_queue_url: - logger.info("SQS consumer not configured, skipping") + if self._tracking_mode != TrackingMode.EVENT_DRIVEN or not self._sqs_queue_url: + logger.info("Event consumer not configured, skipping") return - logger.info(f"Starting SQS consumer for queue: {self._sqs_queue_url}") + logger.info(f"Starting event consumer for queue: {self._sqs_queue_url}") async with self._session.create_client("sqs", region_name=self._sqs_region) as sqs_client: - while True: - try: - response = await sqs_client.receive_message( - QueueUrl=self._sqs_queue_url, - MaxNumberOfMessages=10, - WaitTimeSeconds=self._sqs_wait_time_seconds, - VisibilityTimeout=300, - ) + try: + while True: + try: + response = await sqs_client.receive_message( + QueueUrl=self._sqs_queue_url, + MaxNumberOfMessages=10, + WaitTimeSeconds=self._sqs_wait_time_seconds, + VisibilityTimeout=300, + ) - messages = response.get("Messages", []) - for message in messages: - try: - body = json.loads(message["Body"]) - s3_key = None - event_time = None - - # Handle EventBridge wrapped S3 events - if "detail" in body: - s3_key = body["detail"]["object"]["key"] - event_time = datetime.datetime.fromisoformat( - body["time"].replace("Z", "+00:00") - ) - elif "Records" in body: - record = body["Records"][0] - s3_key = record["s3"]["object"]["key"] - event_time = datetime.datetime.fromisoformat( - record["eventTime"].replace("Z", "+00:00") + messages = response.get("Messages", []) + for message in messages: + try: + body = json.loads(message["Body"]) + s3_key = None + event_time = None + + # Handle EventBridge wrapped S3 events + if "detail" in body: + s3_key = body["detail"]["object"]["key"] + event_time = datetime.datetime.fromisoformat( + body["time"].replace("Z", "+00:00") + ) + elif "Records" in body: + record = body["Records"][0] + s3_key = record["s3"]["object"]["key"] + event_time = datetime.datetime.fromisoformat( + record["eventTime"].replace("Z", "+00:00") + ) + else: + logger.warning(f"Unknown message format: {body}") + continue + + if s3_key and s3_key.endswith(".jsonl"): + await self._handle_s3_event(s3_key, event_time) + + await sqs_client.delete_message( + QueueUrl=self._sqs_queue_url, + ReceiptHandle=message["ReceiptHandle"], ) - else: - logger.warning(f"Unknown message format: {body}") - continue - - if s3_key and s3_key.endswith(".jsonl"): - await self._handle_s3_event(s3_key, event_time) + except Exception as e: + logger.error(f"Failed to process SQS message: {e}") - await sqs_client.delete_message( - QueueUrl=self._sqs_queue_url, - ReceiptHandle=message["ReceiptHandle"], - ) - except Exception as e: - logger.error(f"Failed to process SQS message: {e}") - - except Exception as e: - logger.error(f"SQS consumer error: {e}") - await asyncio.sleep(5) + except Exception as e: + logger.error(f"Event consumer error: {e}") + await asyncio.sleep(5) + except asyncio.CancelledError: + logger.info("Event consumer shutting down") + raise def is_event_driven(self) -> bool: - """Check if this backend is configured for event-driven updates. - - BIP-0042: Returns True if tracking_mode is SQS and queue URL is configured. - """ - return self._tracking_mode == "SQS" and self._sqs_queue_url is not None + """Check if this backend is configured for event-driven updates.""" + return self._tracking_mode == TrackingMode.EVENT_DRIVEN and self._sqs_queue_url is not None async def indexing_jobs( self, offset: int = 0, limit: int = 100, filter_empty: bool = True diff --git a/burr/version.py b/burr/version.py index 8555b4cce..bc7c5aa75 100644 --- a/burr/version.py +++ b/burr/version.py @@ -20,5 +20,9 @@ try: __version__ = importlib.metadata.version("apache-burr") except importlib.metadata.PackageNotFoundError: - # Fallback for older installations or development - __version__ = importlib.metadata.version("burr") + try: + # Fallback for older installations + __version__ = importlib.metadata.version("burr") + except importlib.metadata.PackageNotFoundError: + # Development / source tree: no package metadata + __version__ = "0.0.0.dev" diff --git a/terraform/.gitignore b/examples/deployment/aws/terraform/.gitignore similarity index 87% rename from terraform/.gitignore rename to examples/deployment/aws/terraform/.gitignore index 5c111d166..00a20986e 100644 --- a/terraform/.gitignore +++ b/examples/deployment/aws/terraform/.gitignore @@ -1,5 +1,6 @@ # Terraform .terraform/ +.terraform.lock.hcl *.tfstate *.tfstate.* .terraform.tfstate.lock.info diff --git a/terraform/dev.tfvars b/examples/deployment/aws/terraform/dev.tfvars similarity index 98% rename from terraform/dev.tfvars rename to examples/deployment/aws/terraform/dev.tfvars index 40c0bad98..86378ba96 100644 --- a/terraform/dev.tfvars +++ b/examples/deployment/aws/terraform/dev.tfvars @@ -36,5 +36,3 @@ sqs_message_retention_seconds = 86400 sqs_visibility_timeout_seconds = 120 sqs_receive_wait_time_seconds = 20 sqs_max_receive_count = 3 - -enable_bedrock = false diff --git a/terraform/main.tf b/examples/deployment/aws/terraform/main.tf similarity index 90% rename from terraform/main.tf rename to examples/deployment/aws/terraform/main.tf index 68cd3a852..7c6b5bbc2 100644 --- a/terraform/main.tf +++ b/examples/deployment/aws/terraform/main.tf @@ -167,14 +167,11 @@ resource "aws_cloudwatch_metric_alarm" "dlq_messages" { module "iam" { source = "./modules/iam" - role_name = "${var.environment}-burr-server-role" - s3_bucket_arn = module.s3.bucket_arn - sqs_queue_arn = var.enable_sqs ? module.sqs[0].queue_arn : "" - enable_sqs = var.enable_sqs - enable_bedrock = var.enable_bedrock - bedrock_model_arns = var.bedrock_model_arns - bedrock_foundation_model_arn = "arn:aws:bedrock:${data.aws_region.current.name}::foundation-model/*" - tags = local.common_tags + role_name = "${var.environment}-burr-server-role" + s3_bucket_arn = module.s3.bucket_arn + sqs_queue_arn = var.enable_sqs ? module.sqs[0].queue_arn : "" + enable_sqs = var.enable_sqs + tags = local.common_tags } locals { diff --git a/terraform/modules/iam/main.tf b/examples/deployment/aws/terraform/modules/iam/main.tf similarity index 79% rename from terraform/modules/iam/main.tf rename to examples/deployment/aws/terraform/modules/iam/main.tf index bb7e6a955..b63284f19 100644 --- a/terraform/modules/iam/main.tf +++ b/examples/deployment/aws/terraform/modules/iam/main.tf @@ -87,23 +87,3 @@ resource "aws_iam_role_policy" "sqs" { policy = data.aws_iam_policy_document.sqs_least_privilege[0].json } -data "aws_iam_policy_document" "bedrock_least_privilege" { - count = var.enable_bedrock ? 1 : 0 - - statement { - sid = "BedrockInvokeModels" - effect = "Allow" - actions = [ - "bedrock:InvokeModel", - "bedrock:InvokeModelWithResponseStream" - ] - resources = length(var.bedrock_model_arns) > 0 ? var.bedrock_model_arns : [var.bedrock_foundation_model_arn] - } -} - -resource "aws_iam_role_policy" "bedrock" { - count = var.enable_bedrock ? 1 : 0 - name = "${var.role_name}-bedrock" - role = aws_iam_role.burr_server.id - policy = data.aws_iam_policy_document.bedrock_least_privilege[0].json -} diff --git a/terraform/modules/iam/outputs.tf b/examples/deployment/aws/terraform/modules/iam/outputs.tf similarity index 100% rename from terraform/modules/iam/outputs.tf rename to examples/deployment/aws/terraform/modules/iam/outputs.tf diff --git a/terraform/modules/iam/variables.tf b/examples/deployment/aws/terraform/modules/iam/variables.tf similarity index 76% rename from terraform/modules/iam/variables.tf rename to examples/deployment/aws/terraform/modules/iam/variables.tf index 304680676..9a2e83cc9 100644 --- a/terraform/modules/iam/variables.tf +++ b/examples/deployment/aws/terraform/modules/iam/variables.tf @@ -43,23 +43,6 @@ variable "sqs_queue_arn" { default = "" } -variable "enable_bedrock" { - description = "Enable Bedrock IAM permissions" - type = bool - default = false -} - -variable "bedrock_model_arns" { - description = "List of specific Bedrock model ARNs for least privilege. If empty, uses foundation model wildcard." - type = list(string) - default = [] -} - -variable "bedrock_foundation_model_arn" { - description = "Bedrock foundation model ARN wildcard when bedrock_model_arns is empty" - type = string -} - variable "tags" { description = "Tags to apply to resources" type = map(string) diff --git a/terraform/modules/s3/main.tf b/examples/deployment/aws/terraform/modules/s3/main.tf similarity index 100% rename from terraform/modules/s3/main.tf rename to examples/deployment/aws/terraform/modules/s3/main.tf diff --git a/terraform/modules/s3/outputs.tf b/examples/deployment/aws/terraform/modules/s3/outputs.tf similarity index 100% rename from terraform/modules/s3/outputs.tf rename to examples/deployment/aws/terraform/modules/s3/outputs.tf diff --git a/terraform/modules/s3/variables.tf b/examples/deployment/aws/terraform/modules/s3/variables.tf similarity index 100% rename from terraform/modules/s3/variables.tf rename to examples/deployment/aws/terraform/modules/s3/variables.tf diff --git a/terraform/modules/sqs/main.tf b/examples/deployment/aws/terraform/modules/sqs/main.tf similarity index 100% rename from terraform/modules/sqs/main.tf rename to examples/deployment/aws/terraform/modules/sqs/main.tf diff --git a/terraform/modules/sqs/outputs.tf b/examples/deployment/aws/terraform/modules/sqs/outputs.tf similarity index 100% rename from terraform/modules/sqs/outputs.tf rename to examples/deployment/aws/terraform/modules/sqs/outputs.tf diff --git a/terraform/modules/sqs/variables.tf b/examples/deployment/aws/terraform/modules/sqs/variables.tf similarity index 100% rename from terraform/modules/sqs/variables.tf rename to examples/deployment/aws/terraform/modules/sqs/variables.tf diff --git a/terraform/outputs.tf b/examples/deployment/aws/terraform/outputs.tf similarity index 98% rename from terraform/outputs.tf rename to examples/deployment/aws/terraform/outputs.tf index 8dba499a7..627a98bc0 100644 --- a/terraform/outputs.tf +++ b/examples/deployment/aws/terraform/outputs.tf @@ -64,7 +64,7 @@ output "burr_environment_variables" { description = "Environment variables to configure Burr server" value = var.enable_sqs ? { BURR_S3_BUCKET = module.s3.bucket_id - BURR_TRACKING_MODE = "SQS" + BURR_TRACKING_MODE = "EVENT_DRIVEN" BURR_SQS_QUEUE_URL = module.sqs[0].queue_url BURR_SQS_REGION = data.aws_region.current.name BURR_SQS_WAIT_TIME_SECONDS = "20" diff --git a/terraform/prod.tfvars b/examples/deployment/aws/terraform/prod.tfvars similarity index 98% rename from terraform/prod.tfvars rename to examples/deployment/aws/terraform/prod.tfvars index 2e1e30a80..43b9e8a95 100644 --- a/terraform/prod.tfvars +++ b/examples/deployment/aws/terraform/prod.tfvars @@ -36,7 +36,5 @@ sqs_visibility_timeout_seconds = 300 sqs_receive_wait_time_seconds = 20 sqs_max_receive_count = 3 -enable_bedrock = false - # Optional: receive email when messages land in DLQ # dlq_alarm_notification_emails = ["ops@example.com"] diff --git a/terraform/tutorial.md b/examples/deployment/aws/terraform/tutorial.md similarity index 94% rename from terraform/tutorial.md rename to examples/deployment/aws/terraform/tutorial.md index 28132a4c5..93883f90d 100644 --- a/terraform/tutorial.md +++ b/examples/deployment/aws/terraform/tutorial.md @@ -1,11 +1,11 @@ # Apache Burr AWS Tracking Infrastructure Tutorial -This tutorial explains how to deploy Apache Burr tracking infrastructure on AWS using Terraform. All Terraform code lives in the `terraform/` folder. It covers deployment with S3 only (polling mode), with S3 and SQS (event-driven mode), and local development without AWS. +This tutorial explains how to deploy Apache Burr tracking infrastructure on AWS using Terraform. All Terraform code lives in `examples/deployment/aws/terraform/`. It covers deployment with S3 only (polling mode), with S3 and SQS (event-driven mode), and local development without AWS. ## Quick Start ```bash -cd terraform +cd examples/deployment/aws/terraform terraform init terraform apply -var-file=dev.tfvars # S3 only, polling mode # or @@ -25,10 +25,10 @@ The Terraform configuration provisions: ## Directory Structure -All code is in `terraform/`: +All code is in `examples/deployment/aws/terraform/`: ``` -terraform/ +examples/deployment/aws/terraform/ ├── main.tf # Root module: S3, SQS, CloudWatch alarm, SNS, IAM ├── variables.tf # Input variables ├── outputs.tf # Output values @@ -62,7 +62,7 @@ Uses S3 polling mode (no SQS). Bucket name is auto-generated (`burr-tracking-{en Deploy: ```bash -cd terraform +cd examples/deployment/aws/terraform terraform init terraform plan -var-file=dev.tfvars terraform apply -var-file=dev.tfvars @@ -100,7 +100,7 @@ terraform output burr_environment_variables 4. Set these on your Burr server (ECS task, EC2, etc.): - BURR_S3_BUCKET -- BURR_TRACKING_MODE=SQS +- BURR_TRACKING_MODE=EVENT_DRIVEN - BURR_SQS_QUEUE_URL - BURR_SQS_REGION - BURR_SQS_WAIT_TIME_SECONDS @@ -153,7 +153,6 @@ burr --no-open | sqs_queue_name | Name of the SQS queue | burr-s3-events | | log_retention_days | Days to retain logs in S3 | 90 | | snapshot_retention_days | Days to retain DB snapshots | 30 | -| enable_bedrock | Add Bedrock IAM permissions | false | | dlq_alarm_notification_emails | Emails to notify when DLQ has messages (confirm via AWS email) | [] | ## CloudWatch DLQ Alarm and SNS Notifications @@ -185,7 +184,6 @@ The IAM role grants only: - **S3**: ListBucket, GetBucketLocation, GetObject, PutObject, DeleteObject, HeadObject on the specific bucket - **SQS** (when enabled): ReceiveMessage, DeleteMessage, GetQueueAttributes on the specific queue -- **Bedrock** (when enabled): InvokeModel, InvokeModelWithResponseStream on specified model ARNs ## Cleanup diff --git a/terraform/variables.tf b/examples/deployment/aws/terraform/variables.tf similarity index 89% rename from terraform/variables.tf rename to examples/deployment/aws/terraform/variables.tf index 5940be786..0af4960ab 100644 --- a/terraform/variables.tf +++ b/examples/deployment/aws/terraform/variables.tf @@ -87,18 +87,6 @@ variable "sqs_max_receive_count" { default = 3 } -variable "enable_bedrock" { - description = "Enable Bedrock IAM permissions for BedrockAction" - type = bool - default = false -} - -variable "bedrock_model_arns" { - description = "List of specific Bedrock model ARNs for least privilege. Empty uses foundation-model/*" - type = list(string) - default = [] -} - variable "dlq_alarm_notification_emails" { description = "Email addresses to notify when messages land in the DLQ. Empty = no email subscriptions." type = list(string) diff --git a/terraform/.terraform.lock.hcl b/terraform/.terraform.lock.hcl deleted file mode 100644 index ddfac895c..000000000 --- a/terraform/.terraform.lock.hcl +++ /dev/null @@ -1,45 +0,0 @@ -# This file is maintained automatically by "terraform init". -# Manual edits may be lost in future updates. - -provider "registry.terraform.io/hashicorp/aws" { - version = "6.34.0" - constraints = ">= 5.0.0" - hashes = [ - "h1:Qzr5C24XLiHmkJVuao/Kb+jFLPaxGE/D5GUgko5VdWg=", - "zh:1e49dc96bf50633583e3cbe23bb357642e7e9afe135f54e061e26af6310e50d2", - "zh:45651bb4dad681f17782d99d9324de182a7bb9fbe9dd22f120fdb7fe42969cc9", - "zh:5880c306a427128124585b460c53bbcab9fb3767f26f796eae204f65f111a927", - "zh:71fa9170989b3a1a6913c369bd4a792f4a3e2aab4024c2aff0911e704020b058", - "zh:8d48628fb30f11b04215e06f4dd8a3b32f5f9ea2ed116d0c81c686bf678f9185", - "zh:9b12af85486a96aedd8d7984b0ff811a4b42e3d88dad1a3fb4c0b580d04fa425", - "zh:a6885766588fcad887bdac8c3665e048480eda028e492759a1ea29d22b98d509", - "zh:a6ce9f5e7edc2258733e978bba147600b42a979e18575ce2c7d7dcb6d0b9911f", - "zh:c88d8b7d344e745b191509c29ca773d696da8ca3443f62b20f97982d2d33ea00", - "zh:cae90d6641728ad0219b6a84746bf86dd1dda3e31560d6495a202213ef0258b6", - "zh:cc35927d9d41878049c4221beb1d580a3dbadaca7ba39fb267e001ef9c59ccb3", - "zh:d9e1cb00dc33998e1242fb844e4e3e6cf95e57c664dc1eb55bb7d24f8324bad3", - "zh:f3dbf4a1b7020722145312eb4425f3ea356276d741e3f60fb703fc59a1e2d9fd", - "zh:faba832cc9d99a83f42aaf5a27a4c7309401200169ef04643104cfc8f522d007", - "zh:fcd3f30b91dbcc7db67d5d39268741ffa46696a230a1f2aef32d245ace54bf65", - ] -} - -provider "registry.terraform.io/hashicorp/random" { - version = "3.8.1" - constraints = ">= 3.0.0" - hashes = [ - "h1:osH3aBqEARwOz3VBJKdpFKJJCNIdgRC6k8vPojkLmlY=", - "zh:08dd03b918c7b55713026037c5400c48af5b9f468f483463321bd18e17b907b4", - "zh:0eee654a5542dc1d41920bbf2419032d6f0d5625b03bd81339e5b33394a3e0ae", - "zh:229665ddf060aa0ed315597908483eee5b818a17d09b6417a0f52fd9405c4f57", - "zh:2469d2e48f28076254a2a3fc327f184914566d9e40c5780b8d96ebf7205f8bc0", - "zh:37d7eb334d9561f335e748280f5535a384a88675af9a9eac439d4cfd663bcb66", - "zh:741101426a2f2c52dee37122f0f4a2f2d6af6d852cb1db634480a86398fa3511", - "zh:78d5eefdd9e494defcb3c68d282b8f96630502cac21d1ea161f53cfe9bb483b3", - "zh:a902473f08ef8df62cfe6116bd6c157070a93f66622384300de235a533e9d4a9", - "zh:b85c511a23e57a2147355932b3b6dce2a11e856b941165793a0c3d7578d94d05", - "zh:c5172226d18eaac95b1daac80172287b69d4ce32750c82ad77fa0768be4ea4b8", - "zh:dab4434dba34aad569b0bc243c2d3f3ff86dd7740def373f2a49816bd2ff819b", - "zh:f49fd62aa8c5525a5c17abd51e27ca5e213881d58882fd42fec4a545b53c9699", - ] -} diff --git a/tests/integrations/test_bip0042_bedrock.py b/tests/integrations/test_bip0042_bedrock.py deleted file mode 100644 index 45087e013..000000000 --- a/tests/integrations/test_bip0042_bedrock.py +++ /dev/null @@ -1,168 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""BIP-0042: Tests for Bedrock integration.""" - -import inspect - -import pytest - - -class TestBedrockImports: - """Test that Bedrock classes can be imported via lazy loading.""" - - def test_lazy_import_bedrock_action(self): - """Verify BedrockAction can be imported from burr.integrations.""" - try: - from burr.integrations import BedrockAction - - assert BedrockAction is not None - except ImportError as e: - assert "bedrock" in str(e).lower() or "boto3" in str(e).lower() - - def test_lazy_import_bedrock_streaming_action(self): - """Verify BedrockStreamingAction can be imported from burr.integrations.""" - try: - from burr.integrations import BedrockStreamingAction - - assert BedrockStreamingAction is not None - except ImportError as e: - assert "bedrock" in str(e).lower() or "boto3" in str(e).lower() - - def test_direct_import_bedrock_module(self): - """Verify bedrock.py module exists and has expected classes.""" - try: - from burr.integrations.bedrock import ( - BedrockAction, - BedrockStreamingAction, - StateToPromptMapper, - ) - - assert BedrockAction is not None - assert BedrockStreamingAction is not None - assert StateToPromptMapper is not None - except ImportError as e: - assert "bedrock" in str(e).lower() or "boto3" in str(e).lower() - - -class TestBedrockActionInterface: - """Test BedrockAction class interface (without boto3).""" - - @pytest.fixture - def mock_boto3(self, monkeypatch): - """Mock boto3 to allow testing without AWS credentials.""" - import sys - from unittest.mock import MagicMock - - mock_boto = MagicMock() - mock_client = MagicMock() - mock_boto.client.return_value = mock_client - - mock_botocore = MagicMock() - mock_botocore.config.Config = MagicMock - mock_botocore.exceptions.ClientError = Exception - - monkeypatch.setitem(sys.modules, "boto3", mock_boto) - monkeypatch.setitem(sys.modules, "botocore", mock_botocore) - monkeypatch.setitem(sys.modules, "botocore.config", mock_botocore.config) - monkeypatch.setitem(sys.modules, "botocore.exceptions", mock_botocore.exceptions) - - return mock_boto, mock_client - - def test_bedrock_action_extends_single_step_action(self, mock_boto3): - """Verify BedrockAction extends SingleStepAction.""" - import importlib - - import burr.integrations.bedrock as bedrock_module - - importlib.reload(bedrock_module) - - from burr.core.action import SingleStepAction - from burr.integrations.bedrock import BedrockAction - - assert issubclass(BedrockAction, SingleStepAction) - - def test_bedrock_streaming_action_extends_streaming_action(self, mock_boto3): - """Verify BedrockStreamingAction extends StreamingAction.""" - import importlib - - import burr.integrations.bedrock as bedrock_module - - importlib.reload(bedrock_module) - - from burr.core.action import StreamingAction - from burr.integrations.bedrock import BedrockStreamingAction - - assert issubclass(BedrockStreamingAction, StreamingAction) - - def test_bedrock_action_has_required_properties(self, mock_boto3): - """Verify BedrockAction has reads, writes, name properties.""" - import importlib - - import burr.integrations.bedrock as bedrock_module - - importlib.reload(bedrock_module) - - from burr.integrations.bedrock import BedrockAction - - action = BedrockAction( - model_id="test-model", - input_mapper=lambda s: {"messages": []}, - reads=["input"], - writes=["output"], - ) - - assert action.reads == ["input"] - assert action.writes == ["output"] - assert action.name == "bedrock_invoke" - - def test_bedrock_action_accepts_all_parameters(self, mock_boto3): - """Verify BedrockAction accepts all BIP-0042 specified parameters.""" - import importlib - - import burr.integrations.bedrock as bedrock_module - - importlib.reload(bedrock_module) - - from burr.integrations.bedrock import BedrockAction - - sig = inspect.signature(BedrockAction.__init__) - params = list(sig.parameters.keys()) - - assert "model_id" in params - assert "input_mapper" in params - assert "reads" in params - assert "writes" in params - assert "name" in params - assert "region" in params - assert "guardrail_id" in params - assert "guardrail_version" in params - assert "inference_config" in params - assert "max_retries" in params - - -class TestStateToPromptMapperProtocol: - """Test StateToPromptMapper Protocol exists.""" - - def test_protocol_exists(self): - """Verify StateToPromptMapper Protocol is defined.""" - try: - from burr.integrations.bedrock import StateToPromptMapper - - assert StateToPromptMapper is not None - except ImportError: - pytest.skip("boto3 not installed") diff --git a/tests/tracking/test_bip0042_s3_buffering.py b/tests/tracking/test_bip0042_s3_buffering.py index f445b0233..bbe7eadd6 100644 --- a/tests/tracking/test_bip0042_s3_buffering.py +++ b/tests/tracking/test_bip0042_s3_buffering.py @@ -27,10 +27,10 @@ class TestS3Settings: def test_s3_settings_has_tracking_mode(self): """Verify tracking_mode field exists with POLLING default.""" - from burr.tracking.server.s3.backend import S3Settings + from burr.tracking.server.s3.backend import S3Settings, TrackingMode assert "tracking_mode" in S3Settings.model_fields - assert S3Settings.model_fields["tracking_mode"].default == "POLLING" + assert S3Settings.model_fields["tracking_mode"].default == TrackingMode.POLLING def test_s3_settings_has_sqs_queue_url(self): """Verify sqs_queue_url field exists with None default.""" @@ -60,6 +60,13 @@ def test_s3_settings_has_s3_buffer_size_mb(self): assert "s3_buffer_size_mb" in S3Settings.model_fields assert S3Settings.model_fields["s3_buffer_size_mb"].default == 10 + def test_s3_settings_coerces_sqs_to_event_driven(self): + """Verify legacy 'SQS' string coerces to EVENT_DRIVEN for backward compatibility.""" + from burr.tracking.server.s3.backend import S3Settings, TrackingMode + + settings = S3Settings(s3_bucket="test", tracking_mode="SQS") + assert settings.tracking_mode == TrackingMode.EVENT_DRIVEN + class TestSQLiteS3BackendInit: """Test SQLiteS3Backend accepts and stores BIP-0042 parameters.""" @@ -78,14 +85,14 @@ def test_backend_accepts_new_parameters(self): assert "s3_buffer_size_mb" in params def test_backend_has_event_driven_methods(self): - """Verify SQLiteS3Backend has BIP-0042 event-driven methods.""" + """Verify SQLiteS3Backend has event-driven methods.""" from burr.tracking.server.s3.backend import SQLiteS3Backend assert hasattr(SQLiteS3Backend, "_handle_s3_event") - assert hasattr(SQLiteS3Backend, "start_sqs_consumer") + assert hasattr(SQLiteS3Backend, "start_event_consumer") assert hasattr(SQLiteS3Backend, "is_event_driven") assert callable(getattr(SQLiteS3Backend, "_handle_s3_event")) - assert callable(getattr(SQLiteS3Backend, "start_sqs_consumer")) + assert callable(getattr(SQLiteS3Backend, "start_event_consumer")) assert callable(getattr(SQLiteS3Backend, "is_event_driven")) @@ -99,13 +106,13 @@ def test_mixin_exists(self): assert EventDrivenBackendMixin is not None def test_mixin_has_abstract_methods(self): - """Verify mixin has abstract start_sqs_consumer and is_event_driven.""" + """Verify mixin has abstract start_event_consumer and is_event_driven.""" import abc from burr.tracking.server.backend import EventDrivenBackendMixin assert issubclass(EventDrivenBackendMixin, abc.ABC) - assert hasattr(EventDrivenBackendMixin, "start_sqs_consumer") + assert hasattr(EventDrivenBackendMixin, "start_event_consumer") assert hasattr(EventDrivenBackendMixin, "is_event_driven") def test_sqlite_s3_backend_inherits_mixin(self): From 0b3901a8efd396321f90b263c39feac1addc049c Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Mon, 16 Mar 2026 00:50:00 -0500 Subject: [PATCH 6/8] feat: add Bedrock integration (BIP-0042) as separate PR --- burr/integrations/__init__.py | 11 + burr/integrations/bedrock.py | 261 +++++++++++++++++++++ burr/tracking/server/s3/backend.py | 94 ++++---- pyproject.toml | 7 +- tests/integrations/test_bip0042_bedrock.py | 168 +++++++++++++ 5 files changed, 500 insertions(+), 41 deletions(-) create mode 100644 burr/integrations/bedrock.py create mode 100644 tests/integrations/test_bip0042_bedrock.py diff --git a/burr/integrations/__init__.py b/burr/integrations/__init__.py index 9f5af4ccb..051fa1e10 100644 --- a/burr/integrations/__init__.py +++ b/burr/integrations/__init__.py @@ -14,3 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # in the License. + + +def __getattr__(name: str): + """Lazy load Bedrock integration to avoid requiring boto3 unless used.""" + if name == "BedrockAction": + from burr.integrations.bedrock import BedrockAction + return BedrockAction + if name == "BedrockStreamingAction": + from burr.integrations.bedrock import BedrockStreamingAction + return BedrockStreamingAction + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/burr/integrations/bedrock.py b/burr/integrations/bedrock.py new file mode 100644 index 000000000..36969feaf --- /dev/null +++ b/burr/integrations/bedrock.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Amazon Bedrock integration for Burr. + +This module provides Action classes for invoking Amazon Bedrock models +within Burr applications. + +Example usage: + from burr.integrations.bedrock import BedrockAction + + def prompt_mapper(state): + return { + "messages": [{"role": "user", "content": state["user_input"]}], + "system": [{"text": "You are a helpful assistant."}], + } + + # With default client (created lazily on first use): + action = BedrockAction( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + input_mapper=prompt_mapper, + reads=["user_input"], + writes=["response"], + ) + + # With injected client (for tests or distributed execution): + # client = boto3.client("bedrock-runtime", region_name="us-east-1") + # action = BedrockAction(..., client=client) +""" + +import logging +from typing import Any, Generator, Optional, Protocol + +from burr.core.action import SingleStepAction, StreamingAction +from burr.core.state import State +from burr.integrations.base import require_plugin + +logger = logging.getLogger(__name__) + +# Type for injected Bedrock client (avoids boto3 import at type-check time) +BedrockClient = Any + +try: + import boto3 + from botocore.config import Config + from botocore.exceptions import ClientError +except ImportError as e: + require_plugin(e, "bedrock") + + +class StateToPromptMapper(Protocol): + """Protocol for mapping Burr state to Bedrock prompt format.""" + + def __call__(self, state: State) -> dict[str, Any]: + ... + + +class BedrockAction(SingleStepAction): + """Action that invokes Amazon Bedrock models using the Converse API.""" + + def __init__( + self, + model_id: str, + input_mapper: StateToPromptMapper, + reads: list[str], + writes: list[str], + name: str = "bedrock_invoke", + region: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + inference_config: Optional[dict[str, Any]] = None, + max_retries: int = 3, + client: Optional[BedrockClient] = None, + ): + super().__init__() + self._model_id = model_id + self._input_mapper = input_mapper + self._reads = reads + self._writes = writes + self._name = name + self._region = region + self._guardrail_id = guardrail_id + self._guardrail_version = guardrail_version or "DRAFT" + self._inference_config = inference_config or {"maxTokens": 4096} + self._max_retries = max_retries + self._client = client + + def _get_client(self) -> BedrockClient: + """Return the Bedrock client, creating lazily if not injected.""" + if self._client is not None: + return self._client + config = Config( + retries={"max_attempts": self._max_retries, "mode": "adaptive"} + ) + self._client = boto3.client( + "bedrock-runtime", region_name=self._region, config=config + ) + return self._client + + @property + def reads(self) -> list[str]: + return self._reads + + @property + def writes(self) -> list[str]: + return self._writes + + @property + def name(self) -> str: + return self._name + + def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: + prompt = self._input_mapper(state) + + request: dict[str, Any] = { + "modelId": self._model_id, + "messages": prompt["messages"], + "inferenceConfig": self._inference_config, + } + + if "system" in prompt: + request["system"] = prompt["system"] + + if self._guardrail_id: + request["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + } + + try: + response = self._get_client().converse(**request) + except ClientError as e: + logger.error("Bedrock API error: %s", e) + raise + + output_message = response["output"]["message"] + content_blocks = output_message.get("content", []) + text = content_blocks[0]["text"] if content_blocks else "" + + result: dict[str, Any] = { + "response": text, + "usage": response.get("usage", {}), + "stop_reason": response.get("stopReason"), + } + + updates = {key: result[key] for key in self._writes if key in result} + new_state = state.update(**updates) + + return result, new_state + + +class BedrockStreamingAction(StreamingAction): + """Streaming variant of BedrockAction using Converse Stream API.""" + + def __init__( + self, + model_id: str, + input_mapper: StateToPromptMapper, + reads: list[str], + writes: list[str], + name: str = "bedrock_stream", + region: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + inference_config: Optional[dict[str, Any]] = None, + max_retries: int = 3, + client: Optional[BedrockClient] = None, + ): + super().__init__() + self._model_id = model_id + self._input_mapper = input_mapper + self._reads = reads + self._writes = writes + self._name = name + self._region = region + self._guardrail_id = guardrail_id + self._guardrail_version = guardrail_version or "DRAFT" + self._inference_config = inference_config or {"maxTokens": 4096} + self._max_retries = max_retries + self._client = client + + def _get_client(self) -> BedrockClient: + """Return the Bedrock client, creating lazily if not injected.""" + if self._client is not None: + return self._client + config = Config( + retries={"max_attempts": self._max_retries, "mode": "adaptive"} + ) + self._client = boto3.client( + "bedrock-runtime", region_name=self._region, config=config + ) + return self._client + + @property + def reads(self) -> list[str]: + return self._reads + + @property + def writes(self) -> list[str]: + return self._writes + + @property + def name(self) -> str: + return self._name + + def stream_run( + self, state: State, **run_kwargs + ) -> Generator[dict, None, None]: + prompt = self._input_mapper(state) + + request: dict[str, Any] = { + "modelId": self._model_id, + "messages": prompt["messages"], + "inferenceConfig": self._inference_config, + } + + if "system" in prompt: + request["system"] = prompt["system"] + + if self._guardrail_id: + request["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + } + + try: + response = self._get_client().converse_stream(**request) + except ClientError as e: + logger.error("Bedrock streaming API error: %s", e) + raise + + full_response = "" + stream = response.get("stream", []) + for event in stream: + if "contentBlockDelta" in event: + chunk = event["contentBlockDelta"]["delta"].get("text", "") + full_response += chunk + yield {"chunk": chunk, "response": full_response} + + yield {"chunk": "", "response": full_response, "complete": True} + + def update(self, result: dict, state: State) -> State: + if result.get("complete"): + updates = {"response": result.get("response", "")} + filtered = {k: v for k, v in updates.items() if k in self._writes} + return state.update(**filtered) + return state diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index 681f3b953..04c1051cd 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -28,7 +28,7 @@ import tempfile import uuid from collections import Counter -from typing import List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Literal, Optional, Sequence, Type, TypeVar, Union import fastapi import pydantic @@ -281,7 +281,7 @@ async def snapshot(self): s3_key = f"{self._snapshot_prefix}/{timestamp}/{self._backend_id}/snapshot.db" # TODO -- copy the path at snapshot_path to s3 using aiobotocore session = self._session - logger.info(f"Saving db snapshot at: {s3_key}") + logger.info("Saving db snapshot at: %s", s3_key) async with session.create_client("s3") as s3_client: with open(path, "rb") as file_data: await s3_client.put_object(Bucket=self._bucket, Key=s3_key, Body=file_data) @@ -289,7 +289,7 @@ async def snapshot(self): self._snapshot_key_history.append(s3_key) if len(self._snapshot_key_history) > 5: old_snapshot_to_remove = self._snapshot_key_history.pop(0) - logger.info(f"Removing old snapshot: {old_snapshot_to_remove}") + logger.info("Removing old snapshot: %s", old_snapshot_to_remove) async with session.create_client("s3") as s3_client: await s3_client.delete_object(Bucket=self._bucket, Key=old_snapshot_to_remove) @@ -313,7 +313,7 @@ async def _s3_get_first_write_date(self, project_id: str): async def _update_projects(self): current_projects = await Project.all() project_names = {project.name for project in current_projects} - logger.info(f"Current projects: {project_names}") + logger.info("Current projects: %s", project_names) async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( @@ -323,7 +323,7 @@ async def _update_projects(self): project_name = prefix.get("Prefix").split("/")[-2] if project_name not in project_names: now = system.now() - logger.info(f"Creating project: {project_name}") + logger.info("Creating project: %s", project_name) await Project.create( name=project_name, uri=None, @@ -349,7 +349,7 @@ async def query_applications_by_key( async def _gather_metadata_files( self, - metadata_files: List[DataFile], + metadata_files: list[DataFile], ) -> Sequence[dict]: """Gives a list of metadata files so we can update the application""" @@ -378,7 +378,7 @@ async def _query_metadata_file(metadata_file: DataFile) -> dict: ) return out - async def _gather_log_file_data(self, log_files: List[DataFile]) -> Sequence[dict]: + async def _gather_log_file_data(self, log_files: list[DataFile]) -> Sequence[dict]: """Gives a list of log files so we can update the application""" async def _query_log_file(log_file: DataFile) -> dict: @@ -410,9 +410,9 @@ async def _gather_paths_to_update( :return: list of paths to update """ - logger.info(f"Scanning db with highwatermark: {high_watermark_s3_path}") + logger.info("Scanning db with highwatermark: %s", high_watermark_s3_path) paths_to_update = [] - logger.info(f"Scanning log data for project: {project.name}") + logger.info("Scanning log data for project: %s", project.name) async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( @@ -424,11 +424,11 @@ async def _gather_paths_to_update( key = content["Key"] last_modified = content["LastModified"] # Created == last_modified as we have an immutable data model - logger.info(f"Found new file: {key}") + logger.info("Found new file: %s", key) paths_to_update.append(DataFile.from_path(key, created_date=last_modified)) if len(paths_to_update) >= max_paths: break - logger.info(f"Found {len(paths_to_update)} new files to index") + logger.info("Found %s new files to index", len(paths_to_update)) return paths_to_update async def _ensure_applications_exist( @@ -445,10 +445,13 @@ async def _ensure_applications_exist( ) counter = Counter([path.file_type for path in paths_to_update]) logger.info( - f"Found {len(all_application_keys)} applications in the scan, " - f"including: {counter['log']} log files, " - f"{counter['metadata']} metadata files, and {counter['graph']} graph files, " - f"and {len(paths_to_update) - len(all_application_keys)} other files." + "Found %s applications in the scan, including: %s log files, " + "%s metadata files, and %s graph files, and %s other files.", + len(all_application_keys), + counter["log"], + counter["metadata"], + counter["graph"], + len(paths_to_update) - len(all_application_keys), ) # First, let's create all applications, ignoring them if they exist @@ -472,7 +475,9 @@ async def _ensure_applications_exist( ] logger.info( - f"Creating {len(apps_to_create)} new applications, with keys: {[(app.name, app.partition_key) for app in apps_to_create]}" + "Creating %s new applications, with keys: %s", + len(apps_to_create), + [(app.name, app.partition_key) for app in apps_to_create], ) await Application.bulk_create(apps_to_create) all_applications = await self.query_applications_by_key(all_application_keys) @@ -487,7 +492,7 @@ async def _update_all_applications( :param paths_to_update: All paths to update :return: """ - logger.info(f"found: {len(all_applications)} applications to update in the db") + logger.info("found: %s applications to update in the db", len(all_applications)) metadata_data = [path for path in paths_to_update if path.file_type == "metadata"] graph_data = [path for path in paths_to_update if path.file_type == "graph"] metadata_objects = await self._gather_metadata_files(metadata_data) @@ -550,7 +555,7 @@ async def _update_high_watermark( async def _scan_and_update_db_for_project( self, project: Project, indexing_job: IndexingJob - ) -> Tuple[IndexStatus, int]: + ) -> tuple[IndexStatus, int]: """Scans and updates the database for a project. TODO -- break this up into functions @@ -565,7 +570,7 @@ async def _scan_and_update_db_for_project( ) # This way we can sort by the latest captured time high_watermark = current_status.s3_highwatermark if current_status is not None else "" - logger.info(f"Scanning db with highwatermark: {high_watermark}") + logger.info("Scanning db with highwatermark: %s", high_watermark) paths_to_update = await self._gather_paths_to_update( project=project, high_watermark_s3_path=high_watermark ) @@ -590,7 +595,7 @@ async def _scan_and_update_db(self): # TODO -- add error catching status, num_files = await self._scan_and_update_db_for_project(project, indexing_job) - logger.info(f"Scanned: {num_files} files with status stored at ID={status.id}") + logger.info("Scanned: %s files with status stored at ID=%s", num_files, status.id) indexing_job.records_processed = num_files indexing_job.end_time = system.now() @@ -637,7 +642,7 @@ async def list_apps( partition_key: Optional[str], limit: int = 100, offset: int = 0, - ) -> Tuple[Sequence[schema.ApplicationSummary], int]: + ) -> tuple[Sequence[schema.ApplicationSummary], int]: # TODO -- distinctify between project name and project ID # Currently they're the same in the UI but we'll want to have them decoupled app_query = ( @@ -745,7 +750,7 @@ async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> project = await Project.filter(name=project_name).first() if project is None: - logger.info(f"Creating project {project_name} from S3 event") + logger.info("Creating project %s from S3 event", project_name) project = await Project.create( name=project_name, uri=None, @@ -758,9 +763,9 @@ async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> await self._update_all_applications(all_applications, [data_file]) await self.update_log_files([data_file], all_applications) - logger.info(f"Indexed S3 event: {s3_key}") + logger.info("Indexed S3 event: %s", s3_key) except Exception as e: - logger.error(f"Failed to handle S3 event {s3_key}: {e}") + logger.error("Failed to handle S3 event %s: %s", s3_key, e) raise # Re-raise so message stays in queue for retry / DLQ async def start_event_consumer(self) -> None: @@ -773,7 +778,7 @@ async def start_event_consumer(self) -> None: logger.info("Event consumer not configured, skipping") return - logger.info(f"Starting event consumer for queue: {self._sqs_queue_url}") + logger.info("Starting event consumer for queue: %s", self._sqs_queue_url) async with self._session.create_client("sqs", region_name=self._sqs_region) as sqs_client: try: @@ -793,34 +798,43 @@ async def start_event_consumer(self) -> None: s3_key = None event_time = None - # Handle EventBridge wrapped S3 events + # Handle EventBridge wrapped S3 events (one event per message) if "detail" in body: - s3_key = body["detail"]["object"]["key"] - event_time = datetime.datetime.fromisoformat( - body["time"].replace("Z", "+00:00") - ) + s3_keys_with_times = [ + ( + body["detail"]["object"]["key"], + datetime.datetime.fromisoformat( + body["time"].replace("Z", "+00:00") + ), + ) + ] elif "Records" in body: - record = body["Records"][0] - s3_key = record["s3"]["object"]["key"] - event_time = datetime.datetime.fromisoformat( - record["eventTime"].replace("Z", "+00:00") - ) + s3_keys_with_times = [ + ( + record["s3"]["object"]["key"], + datetime.datetime.fromisoformat( + record["eventTime"].replace("Z", "+00:00") + ), + ) + for record in body["Records"] + ] else: - logger.warning(f"Unknown message format: {body}") + logger.warning("Unknown message format: %s", body) continue - if s3_key and s3_key.endswith(".jsonl"): - await self._handle_s3_event(s3_key, event_time) + for s3_key, event_time in s3_keys_with_times: + if s3_key and s3_key.endswith(".jsonl"): + await self._handle_s3_event(s3_key, event_time) await sqs_client.delete_message( QueueUrl=self._sqs_queue_url, ReceiptHandle=message["ReceiptHandle"], ) except Exception as e: - logger.error(f"Failed to process SQS message: {e}") + logger.error("Failed to process SQS message: %s", e) except Exception as e: - logger.error(f"Event consumer error: {e}") + logger.error("Event consumer error: %s", e) await asyncio.sleep(5) except asyncio.CancelledError: logger.info("Event consumer shutting down") diff --git a/pyproject.toml b/pyproject.toml index 3cb698ede..4704cfb2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,8 @@ tests = [ "apache-burr[redis]", "apache-burr[opentelemetry]", "apache-burr[haystack]", - "apache-burr[ray]" + "apache-burr[ray]", + "apache-burr[bedrock]" ] documentation = [ @@ -128,6 +129,10 @@ tracking-client-s3 = [ "boto3" ] +bedrock = [ + "boto3" +] + tracking-server-s3 = [ "aerich", "aiobotocore", diff --git a/tests/integrations/test_bip0042_bedrock.py b/tests/integrations/test_bip0042_bedrock.py new file mode 100644 index 000000000..c0c0522d0 --- /dev/null +++ b/tests/integrations/test_bip0042_bedrock.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for Bedrock integration.""" + +import inspect +from unittest.mock import MagicMock + +import pytest + +boto3 = pytest.importorskip("boto3", reason="boto3 required for Bedrock tests") + + +class TestBedrockImports: + """Test that Bedrock classes can be imported via lazy loading.""" + + def test_lazy_import_bedrock_action(self): + """Verify BedrockAction can be imported from burr.integrations.""" + from burr.integrations import BedrockAction + + assert BedrockAction is not None + + def test_lazy_import_bedrock_streaming_action(self): + """Verify BedrockStreamingAction can be imported from burr.integrations.""" + from burr.integrations import BedrockStreamingAction + + assert BedrockStreamingAction is not None + + def test_direct_import_bedrock_module(self): + """Verify bedrock.py module exists and has expected classes.""" + from burr.integrations.bedrock import ( + BedrockAction, + BedrockStreamingAction, + StateToPromptMapper, + ) + + assert BedrockAction is not None + assert BedrockStreamingAction is not None + assert StateToPromptMapper is not None + + +class TestBedrockActionInterface: + """Test BedrockAction class interface with mocked boto3.""" + + def test_bedrock_action_extends_single_step_action(self): + """Verify BedrockAction extends SingleStepAction.""" + from burr.core.action import SingleStepAction + from burr.integrations.bedrock import BedrockAction + + assert issubclass(BedrockAction, SingleStepAction) + + def test_bedrock_streaming_action_extends_streaming_action(self): + """Verify BedrockStreamingAction extends StreamingAction.""" + from burr.core.action import StreamingAction + from burr.integrations.bedrock import BedrockStreamingAction + + assert issubclass(BedrockStreamingAction, StreamingAction) + + def test_bedrock_action_has_required_properties(self): + """Verify BedrockAction has reads, writes, name properties.""" + from burr.integrations.bedrock import BedrockAction + + action = BedrockAction( + model_id="test-model", + input_mapper=lambda s: {"messages": []}, + reads=["input"], + writes=["output"], + ) + assert action.reads == ["input"] + assert action.writes == ["output"] + assert action.name == "bedrock_invoke" + + def test_bedrock_action_accepts_all_parameters(self): + """Verify BedrockAction accepts all specified parameters.""" + from burr.integrations.bedrock import BedrockAction + + sig = inspect.signature(BedrockAction.__init__) + params = list(sig.parameters.keys()) + assert "model_id" in params + assert "input_mapper" in params + assert "reads" in params + assert "writes" in params + assert "name" in params + assert "region" in params + assert "guardrail_id" in params + assert "guardrail_version" in params + assert "inference_config" in params + assert "max_retries" in params + assert "client" in params + + def test_bedrock_action_uses_injected_client(self): + """Verify BedrockAction uses injected client when provided.""" + from burr.integrations.bedrock import BedrockAction + + mock_client = MagicMock() + mock_client.converse.return_value = { + "output": {"message": {"content": [{"text": "hi"}]}}, + "usage": {}, + "stopReason": "end_turn", + } + + action = BedrockAction( + model_id="test-model", + input_mapper=lambda s: {"messages": [{"role": "user", "content": "hi"}]}, + reads=[], + writes=["response"], + client=mock_client, + ) + + result, _ = action.run_and_update({}) + assert result["response"] == "hi" + mock_client.converse.assert_called_once() + + +class TestBedrockStreamingActionInterface: + """Test BedrockStreamingAction class interface with mocked boto3.""" + + def test_bedrock_streaming_action_uses_injected_client(self): + """Verify BedrockStreamingAction uses injected client when provided.""" + from burr.integrations.bedrock import BedrockStreamingAction + + mock_client = MagicMock() + mock_client.converse_stream.return_value = { + "stream": [ + {"contentBlockDelta": {"delta": {"text": "hello "}}}, + {"contentBlockDelta": {"delta": {"text": "world"}}}, + ] + } + + action = BedrockStreamingAction( + model_id="test-model", + input_mapper=lambda s: {"messages": [{"role": "user", "content": "hi"}]}, + reads=[], + writes=["response"], + client=mock_client, + ) + + chunks = list(action.stream_run({})) + assert len(chunks) == 3 # 2 content chunks + 1 complete + assert chunks[0]["chunk"] == "hello " + assert chunks[1]["chunk"] == "world" + assert chunks[2]["complete"] is True + assert chunks[2]["response"] == "hello world" + mock_client.converse_stream.assert_called_once() + + +class TestStateToPromptMapperProtocol: + """Test StateToPromptMapper Protocol exists.""" + + def test_protocol_exists(self): + """Verify StateToPromptMapper Protocol is defined.""" + from burr.integrations.bedrock import StateToPromptMapper + + assert StateToPromptMapper is not None From b4bba67bd815eb311115fb3bfe293d455956fdd4 Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Mon, 16 Mar 2026 02:41:32 -0500 Subject: [PATCH 7/8] Revert "feat: add Bedrock integration (BIP-0042) as separate PR" This reverts commit 61673cdfbc79f0c6bcc444594173042de186fab1. --- burr/integrations/__init__.py | 11 - burr/integrations/bedrock.py | 261 --------------------- burr/tracking/server/s3/backend.py | 94 ++++---- pyproject.toml | 7 +- tests/integrations/test_bip0042_bedrock.py | 168 ------------- 5 files changed, 41 insertions(+), 500 deletions(-) delete mode 100644 burr/integrations/bedrock.py delete mode 100644 tests/integrations/test_bip0042_bedrock.py diff --git a/burr/integrations/__init__.py b/burr/integrations/__init__.py index 051fa1e10..9f5af4ccb 100644 --- a/burr/integrations/__init__.py +++ b/burr/integrations/__init__.py @@ -14,14 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # in the License. - - -def __getattr__(name: str): - """Lazy load Bedrock integration to avoid requiring boto3 unless used.""" - if name == "BedrockAction": - from burr.integrations.bedrock import BedrockAction - return BedrockAction - if name == "BedrockStreamingAction": - from burr.integrations.bedrock import BedrockStreamingAction - return BedrockStreamingAction - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/burr/integrations/bedrock.py b/burr/integrations/bedrock.py deleted file mode 100644 index 36969feaf..000000000 --- a/burr/integrations/bedrock.py +++ /dev/null @@ -1,261 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Amazon Bedrock integration for Burr. - -This module provides Action classes for invoking Amazon Bedrock models -within Burr applications. - -Example usage: - from burr.integrations.bedrock import BedrockAction - - def prompt_mapper(state): - return { - "messages": [{"role": "user", "content": state["user_input"]}], - "system": [{"text": "You are a helpful assistant."}], - } - - # With default client (created lazily on first use): - action = BedrockAction( - model_id="anthropic.claude-3-sonnet-20240229-v1:0", - input_mapper=prompt_mapper, - reads=["user_input"], - writes=["response"], - ) - - # With injected client (for tests or distributed execution): - # client = boto3.client("bedrock-runtime", region_name="us-east-1") - # action = BedrockAction(..., client=client) -""" - -import logging -from typing import Any, Generator, Optional, Protocol - -from burr.core.action import SingleStepAction, StreamingAction -from burr.core.state import State -from burr.integrations.base import require_plugin - -logger = logging.getLogger(__name__) - -# Type for injected Bedrock client (avoids boto3 import at type-check time) -BedrockClient = Any - -try: - import boto3 - from botocore.config import Config - from botocore.exceptions import ClientError -except ImportError as e: - require_plugin(e, "bedrock") - - -class StateToPromptMapper(Protocol): - """Protocol for mapping Burr state to Bedrock prompt format.""" - - def __call__(self, state: State) -> dict[str, Any]: - ... - - -class BedrockAction(SingleStepAction): - """Action that invokes Amazon Bedrock models using the Converse API.""" - - def __init__( - self, - model_id: str, - input_mapper: StateToPromptMapper, - reads: list[str], - writes: list[str], - name: str = "bedrock_invoke", - region: Optional[str] = None, - guardrail_id: Optional[str] = None, - guardrail_version: Optional[str] = None, - inference_config: Optional[dict[str, Any]] = None, - max_retries: int = 3, - client: Optional[BedrockClient] = None, - ): - super().__init__() - self._model_id = model_id - self._input_mapper = input_mapper - self._reads = reads - self._writes = writes - self._name = name - self._region = region - self._guardrail_id = guardrail_id - self._guardrail_version = guardrail_version or "DRAFT" - self._inference_config = inference_config or {"maxTokens": 4096} - self._max_retries = max_retries - self._client = client - - def _get_client(self) -> BedrockClient: - """Return the Bedrock client, creating lazily if not injected.""" - if self._client is not None: - return self._client - config = Config( - retries={"max_attempts": self._max_retries, "mode": "adaptive"} - ) - self._client = boto3.client( - "bedrock-runtime", region_name=self._region, config=config - ) - return self._client - - @property - def reads(self) -> list[str]: - return self._reads - - @property - def writes(self) -> list[str]: - return self._writes - - @property - def name(self) -> str: - return self._name - - def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: - prompt = self._input_mapper(state) - - request: dict[str, Any] = { - "modelId": self._model_id, - "messages": prompt["messages"], - "inferenceConfig": self._inference_config, - } - - if "system" in prompt: - request["system"] = prompt["system"] - - if self._guardrail_id: - request["guardrailConfig"] = { - "guardrailIdentifier": self._guardrail_id, - "guardrailVersion": self._guardrail_version, - } - - try: - response = self._get_client().converse(**request) - except ClientError as e: - logger.error("Bedrock API error: %s", e) - raise - - output_message = response["output"]["message"] - content_blocks = output_message.get("content", []) - text = content_blocks[0]["text"] if content_blocks else "" - - result: dict[str, Any] = { - "response": text, - "usage": response.get("usage", {}), - "stop_reason": response.get("stopReason"), - } - - updates = {key: result[key] for key in self._writes if key in result} - new_state = state.update(**updates) - - return result, new_state - - -class BedrockStreamingAction(StreamingAction): - """Streaming variant of BedrockAction using Converse Stream API.""" - - def __init__( - self, - model_id: str, - input_mapper: StateToPromptMapper, - reads: list[str], - writes: list[str], - name: str = "bedrock_stream", - region: Optional[str] = None, - guardrail_id: Optional[str] = None, - guardrail_version: Optional[str] = None, - inference_config: Optional[dict[str, Any]] = None, - max_retries: int = 3, - client: Optional[BedrockClient] = None, - ): - super().__init__() - self._model_id = model_id - self._input_mapper = input_mapper - self._reads = reads - self._writes = writes - self._name = name - self._region = region - self._guardrail_id = guardrail_id - self._guardrail_version = guardrail_version or "DRAFT" - self._inference_config = inference_config or {"maxTokens": 4096} - self._max_retries = max_retries - self._client = client - - def _get_client(self) -> BedrockClient: - """Return the Bedrock client, creating lazily if not injected.""" - if self._client is not None: - return self._client - config = Config( - retries={"max_attempts": self._max_retries, "mode": "adaptive"} - ) - self._client = boto3.client( - "bedrock-runtime", region_name=self._region, config=config - ) - return self._client - - @property - def reads(self) -> list[str]: - return self._reads - - @property - def writes(self) -> list[str]: - return self._writes - - @property - def name(self) -> str: - return self._name - - def stream_run( - self, state: State, **run_kwargs - ) -> Generator[dict, None, None]: - prompt = self._input_mapper(state) - - request: dict[str, Any] = { - "modelId": self._model_id, - "messages": prompt["messages"], - "inferenceConfig": self._inference_config, - } - - if "system" in prompt: - request["system"] = prompt["system"] - - if self._guardrail_id: - request["guardrailConfig"] = { - "guardrailIdentifier": self._guardrail_id, - "guardrailVersion": self._guardrail_version, - } - - try: - response = self._get_client().converse_stream(**request) - except ClientError as e: - logger.error("Bedrock streaming API error: %s", e) - raise - - full_response = "" - stream = response.get("stream", []) - for event in stream: - if "contentBlockDelta" in event: - chunk = event["contentBlockDelta"]["delta"].get("text", "") - full_response += chunk - yield {"chunk": chunk, "response": full_response} - - yield {"chunk": "", "response": full_response, "complete": True} - - def update(self, result: dict, state: State) -> State: - if result.get("complete"): - updates = {"response": result.get("response", "")} - filtered = {k: v for k, v in updates.items() if k in self._writes} - return state.update(**filtered) - return state diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index 04c1051cd..681f3b953 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -28,7 +28,7 @@ import tempfile import uuid from collections import Counter -from typing import Literal, Optional, Sequence, Type, TypeVar, Union +from typing import List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union import fastapi import pydantic @@ -281,7 +281,7 @@ async def snapshot(self): s3_key = f"{self._snapshot_prefix}/{timestamp}/{self._backend_id}/snapshot.db" # TODO -- copy the path at snapshot_path to s3 using aiobotocore session = self._session - logger.info("Saving db snapshot at: %s", s3_key) + logger.info(f"Saving db snapshot at: {s3_key}") async with session.create_client("s3") as s3_client: with open(path, "rb") as file_data: await s3_client.put_object(Bucket=self._bucket, Key=s3_key, Body=file_data) @@ -289,7 +289,7 @@ async def snapshot(self): self._snapshot_key_history.append(s3_key) if len(self._snapshot_key_history) > 5: old_snapshot_to_remove = self._snapshot_key_history.pop(0) - logger.info("Removing old snapshot: %s", old_snapshot_to_remove) + logger.info(f"Removing old snapshot: {old_snapshot_to_remove}") async with session.create_client("s3") as s3_client: await s3_client.delete_object(Bucket=self._bucket, Key=old_snapshot_to_remove) @@ -313,7 +313,7 @@ async def _s3_get_first_write_date(self, project_id: str): async def _update_projects(self): current_projects = await Project.all() project_names = {project.name for project in current_projects} - logger.info("Current projects: %s", project_names) + logger.info(f"Current projects: {project_names}") async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( @@ -323,7 +323,7 @@ async def _update_projects(self): project_name = prefix.get("Prefix").split("/")[-2] if project_name not in project_names: now = system.now() - logger.info("Creating project: %s", project_name) + logger.info(f"Creating project: {project_name}") await Project.create( name=project_name, uri=None, @@ -349,7 +349,7 @@ async def query_applications_by_key( async def _gather_metadata_files( self, - metadata_files: list[DataFile], + metadata_files: List[DataFile], ) -> Sequence[dict]: """Gives a list of metadata files so we can update the application""" @@ -378,7 +378,7 @@ async def _query_metadata_file(metadata_file: DataFile) -> dict: ) return out - async def _gather_log_file_data(self, log_files: list[DataFile]) -> Sequence[dict]: + async def _gather_log_file_data(self, log_files: List[DataFile]) -> Sequence[dict]: """Gives a list of log files so we can update the application""" async def _query_log_file(log_file: DataFile) -> dict: @@ -410,9 +410,9 @@ async def _gather_paths_to_update( :return: list of paths to update """ - logger.info("Scanning db with highwatermark: %s", high_watermark_s3_path) + logger.info(f"Scanning db with highwatermark: {high_watermark_s3_path}") paths_to_update = [] - logger.info("Scanning log data for project: %s", project.name) + logger.info(f"Scanning log data for project: {project.name}") async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( @@ -424,11 +424,11 @@ async def _gather_paths_to_update( key = content["Key"] last_modified = content["LastModified"] # Created == last_modified as we have an immutable data model - logger.info("Found new file: %s", key) + logger.info(f"Found new file: {key}") paths_to_update.append(DataFile.from_path(key, created_date=last_modified)) if len(paths_to_update) >= max_paths: break - logger.info("Found %s new files to index", len(paths_to_update)) + logger.info(f"Found {len(paths_to_update)} new files to index") return paths_to_update async def _ensure_applications_exist( @@ -445,13 +445,10 @@ async def _ensure_applications_exist( ) counter = Counter([path.file_type for path in paths_to_update]) logger.info( - "Found %s applications in the scan, including: %s log files, " - "%s metadata files, and %s graph files, and %s other files.", - len(all_application_keys), - counter["log"], - counter["metadata"], - counter["graph"], - len(paths_to_update) - len(all_application_keys), + f"Found {len(all_application_keys)} applications in the scan, " + f"including: {counter['log']} log files, " + f"{counter['metadata']} metadata files, and {counter['graph']} graph files, " + f"and {len(paths_to_update) - len(all_application_keys)} other files." ) # First, let's create all applications, ignoring them if they exist @@ -475,9 +472,7 @@ async def _ensure_applications_exist( ] logger.info( - "Creating %s new applications, with keys: %s", - len(apps_to_create), - [(app.name, app.partition_key) for app in apps_to_create], + f"Creating {len(apps_to_create)} new applications, with keys: {[(app.name, app.partition_key) for app in apps_to_create]}" ) await Application.bulk_create(apps_to_create) all_applications = await self.query_applications_by_key(all_application_keys) @@ -492,7 +487,7 @@ async def _update_all_applications( :param paths_to_update: All paths to update :return: """ - logger.info("found: %s applications to update in the db", len(all_applications)) + logger.info(f"found: {len(all_applications)} applications to update in the db") metadata_data = [path for path in paths_to_update if path.file_type == "metadata"] graph_data = [path for path in paths_to_update if path.file_type == "graph"] metadata_objects = await self._gather_metadata_files(metadata_data) @@ -555,7 +550,7 @@ async def _update_high_watermark( async def _scan_and_update_db_for_project( self, project: Project, indexing_job: IndexingJob - ) -> tuple[IndexStatus, int]: + ) -> Tuple[IndexStatus, int]: """Scans and updates the database for a project. TODO -- break this up into functions @@ -570,7 +565,7 @@ async def _scan_and_update_db_for_project( ) # This way we can sort by the latest captured time high_watermark = current_status.s3_highwatermark if current_status is not None else "" - logger.info("Scanning db with highwatermark: %s", high_watermark) + logger.info(f"Scanning db with highwatermark: {high_watermark}") paths_to_update = await self._gather_paths_to_update( project=project, high_watermark_s3_path=high_watermark ) @@ -595,7 +590,7 @@ async def _scan_and_update_db(self): # TODO -- add error catching status, num_files = await self._scan_and_update_db_for_project(project, indexing_job) - logger.info("Scanned: %s files with status stored at ID=%s", num_files, status.id) + logger.info(f"Scanned: {num_files} files with status stored at ID={status.id}") indexing_job.records_processed = num_files indexing_job.end_time = system.now() @@ -642,7 +637,7 @@ async def list_apps( partition_key: Optional[str], limit: int = 100, offset: int = 0, - ) -> tuple[Sequence[schema.ApplicationSummary], int]: + ) -> Tuple[Sequence[schema.ApplicationSummary], int]: # TODO -- distinctify between project name and project ID # Currently they're the same in the UI but we'll want to have them decoupled app_query = ( @@ -750,7 +745,7 @@ async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> project = await Project.filter(name=project_name).first() if project is None: - logger.info("Creating project %s from S3 event", project_name) + logger.info(f"Creating project {project_name} from S3 event") project = await Project.create( name=project_name, uri=None, @@ -763,9 +758,9 @@ async def _handle_s3_event(self, s3_key: str, event_time: datetime.datetime) -> await self._update_all_applications(all_applications, [data_file]) await self.update_log_files([data_file], all_applications) - logger.info("Indexed S3 event: %s", s3_key) + logger.info(f"Indexed S3 event: {s3_key}") except Exception as e: - logger.error("Failed to handle S3 event %s: %s", s3_key, e) + logger.error(f"Failed to handle S3 event {s3_key}: {e}") raise # Re-raise so message stays in queue for retry / DLQ async def start_event_consumer(self) -> None: @@ -778,7 +773,7 @@ async def start_event_consumer(self) -> None: logger.info("Event consumer not configured, skipping") return - logger.info("Starting event consumer for queue: %s", self._sqs_queue_url) + logger.info(f"Starting event consumer for queue: {self._sqs_queue_url}") async with self._session.create_client("sqs", region_name=self._sqs_region) as sqs_client: try: @@ -798,43 +793,34 @@ async def start_event_consumer(self) -> None: s3_key = None event_time = None - # Handle EventBridge wrapped S3 events (one event per message) + # Handle EventBridge wrapped S3 events if "detail" in body: - s3_keys_with_times = [ - ( - body["detail"]["object"]["key"], - datetime.datetime.fromisoformat( - body["time"].replace("Z", "+00:00") - ), - ) - ] + s3_key = body["detail"]["object"]["key"] + event_time = datetime.datetime.fromisoformat( + body["time"].replace("Z", "+00:00") + ) elif "Records" in body: - s3_keys_with_times = [ - ( - record["s3"]["object"]["key"], - datetime.datetime.fromisoformat( - record["eventTime"].replace("Z", "+00:00") - ), - ) - for record in body["Records"] - ] + record = body["Records"][0] + s3_key = record["s3"]["object"]["key"] + event_time = datetime.datetime.fromisoformat( + record["eventTime"].replace("Z", "+00:00") + ) else: - logger.warning("Unknown message format: %s", body) + logger.warning(f"Unknown message format: {body}") continue - for s3_key, event_time in s3_keys_with_times: - if s3_key and s3_key.endswith(".jsonl"): - await self._handle_s3_event(s3_key, event_time) + if s3_key and s3_key.endswith(".jsonl"): + await self._handle_s3_event(s3_key, event_time) await sqs_client.delete_message( QueueUrl=self._sqs_queue_url, ReceiptHandle=message["ReceiptHandle"], ) except Exception as e: - logger.error("Failed to process SQS message: %s", e) + logger.error(f"Failed to process SQS message: {e}") except Exception as e: - logger.error("Event consumer error: %s", e) + logger.error(f"Event consumer error: {e}") await asyncio.sleep(5) except asyncio.CancelledError: logger.info("Event consumer shutting down") diff --git a/pyproject.toml b/pyproject.toml index 4704cfb2d..3cb698ede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,8 +98,7 @@ tests = [ "apache-burr[redis]", "apache-burr[opentelemetry]", "apache-burr[haystack]", - "apache-burr[ray]", - "apache-burr[bedrock]" + "apache-burr[ray]" ] documentation = [ @@ -129,10 +128,6 @@ tracking-client-s3 = [ "boto3" ] -bedrock = [ - "boto3" -] - tracking-server-s3 = [ "aerich", "aiobotocore", diff --git a/tests/integrations/test_bip0042_bedrock.py b/tests/integrations/test_bip0042_bedrock.py deleted file mode 100644 index c0c0522d0..000000000 --- a/tests/integrations/test_bip0042_bedrock.py +++ /dev/null @@ -1,168 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Tests for Bedrock integration.""" - -import inspect -from unittest.mock import MagicMock - -import pytest - -boto3 = pytest.importorskip("boto3", reason="boto3 required for Bedrock tests") - - -class TestBedrockImports: - """Test that Bedrock classes can be imported via lazy loading.""" - - def test_lazy_import_bedrock_action(self): - """Verify BedrockAction can be imported from burr.integrations.""" - from burr.integrations import BedrockAction - - assert BedrockAction is not None - - def test_lazy_import_bedrock_streaming_action(self): - """Verify BedrockStreamingAction can be imported from burr.integrations.""" - from burr.integrations import BedrockStreamingAction - - assert BedrockStreamingAction is not None - - def test_direct_import_bedrock_module(self): - """Verify bedrock.py module exists and has expected classes.""" - from burr.integrations.bedrock import ( - BedrockAction, - BedrockStreamingAction, - StateToPromptMapper, - ) - - assert BedrockAction is not None - assert BedrockStreamingAction is not None - assert StateToPromptMapper is not None - - -class TestBedrockActionInterface: - """Test BedrockAction class interface with mocked boto3.""" - - def test_bedrock_action_extends_single_step_action(self): - """Verify BedrockAction extends SingleStepAction.""" - from burr.core.action import SingleStepAction - from burr.integrations.bedrock import BedrockAction - - assert issubclass(BedrockAction, SingleStepAction) - - def test_bedrock_streaming_action_extends_streaming_action(self): - """Verify BedrockStreamingAction extends StreamingAction.""" - from burr.core.action import StreamingAction - from burr.integrations.bedrock import BedrockStreamingAction - - assert issubclass(BedrockStreamingAction, StreamingAction) - - def test_bedrock_action_has_required_properties(self): - """Verify BedrockAction has reads, writes, name properties.""" - from burr.integrations.bedrock import BedrockAction - - action = BedrockAction( - model_id="test-model", - input_mapper=lambda s: {"messages": []}, - reads=["input"], - writes=["output"], - ) - assert action.reads == ["input"] - assert action.writes == ["output"] - assert action.name == "bedrock_invoke" - - def test_bedrock_action_accepts_all_parameters(self): - """Verify BedrockAction accepts all specified parameters.""" - from burr.integrations.bedrock import BedrockAction - - sig = inspect.signature(BedrockAction.__init__) - params = list(sig.parameters.keys()) - assert "model_id" in params - assert "input_mapper" in params - assert "reads" in params - assert "writes" in params - assert "name" in params - assert "region" in params - assert "guardrail_id" in params - assert "guardrail_version" in params - assert "inference_config" in params - assert "max_retries" in params - assert "client" in params - - def test_bedrock_action_uses_injected_client(self): - """Verify BedrockAction uses injected client when provided.""" - from burr.integrations.bedrock import BedrockAction - - mock_client = MagicMock() - mock_client.converse.return_value = { - "output": {"message": {"content": [{"text": "hi"}]}}, - "usage": {}, - "stopReason": "end_turn", - } - - action = BedrockAction( - model_id="test-model", - input_mapper=lambda s: {"messages": [{"role": "user", "content": "hi"}]}, - reads=[], - writes=["response"], - client=mock_client, - ) - - result, _ = action.run_and_update({}) - assert result["response"] == "hi" - mock_client.converse.assert_called_once() - - -class TestBedrockStreamingActionInterface: - """Test BedrockStreamingAction class interface with mocked boto3.""" - - def test_bedrock_streaming_action_uses_injected_client(self): - """Verify BedrockStreamingAction uses injected client when provided.""" - from burr.integrations.bedrock import BedrockStreamingAction - - mock_client = MagicMock() - mock_client.converse_stream.return_value = { - "stream": [ - {"contentBlockDelta": {"delta": {"text": "hello "}}}, - {"contentBlockDelta": {"delta": {"text": "world"}}}, - ] - } - - action = BedrockStreamingAction( - model_id="test-model", - input_mapper=lambda s: {"messages": [{"role": "user", "content": "hi"}]}, - reads=[], - writes=["response"], - client=mock_client, - ) - - chunks = list(action.stream_run({})) - assert len(chunks) == 3 # 2 content chunks + 1 complete - assert chunks[0]["chunk"] == "hello " - assert chunks[1]["chunk"] == "world" - assert chunks[2]["complete"] is True - assert chunks[2]["response"] == "hello world" - mock_client.converse_stream.assert_called_once() - - -class TestStateToPromptMapperProtocol: - """Test StateToPromptMapper Protocol exists.""" - - def test_protocol_exists(self): - """Verify StateToPromptMapper Protocol is defined.""" - from burr.integrations.bedrock import StateToPromptMapper - - assert StateToPromptMapper is not None From 5ca97398ae29106e6fe4ca6be87e802fc71cc19c Mon Sep 17 00:00:00 2001 From: vaquarkhan Date: Mon, 16 Mar 2026 02:43:54 -0500 Subject: [PATCH 8/8] fix: add tracking-server-s3 to CI deps, remove Bedrock (in separate PR), fix typo --- .github/workflows/python-package.yml | 2 +- bedrock-changes.patch | Bin 35472 -> 0 bytes burr/integrations/__init__.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 bedrock-changes.patch diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index e1f18de74..bddd4e325 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -69,7 +69,7 @@ jobs: - name: Install dependencies run: | - python -m pip install -e ".[tests,tracking-client,graphviz]" + python -m pip install -e ".[tests,tracking-client,graphviz,tracking-server-s3]" - name: Run tests run: | diff --git a/bedrock-changes.patch b/bedrock-changes.patch deleted file mode 100644 index fac3e09828f861cfe0890e9d31ef71d016f2d833..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 35472 zcmeI5Yi}IKm4@eYfc*~>`U}0bMpoo?Vh387lI2|?wquBjgFp}nN}?`n>arwdNfGqd zvwNO;=T%j8cXiK*6m!8~K=$;dPMtdUbE>-L|NeG=_IUOnKJ(dy_&kiWllZ^u%Xz{9^Wbd}z&BE}VKCK>8$o&SFGg&%O?f zF2ra)pM4r5nFnNj=l?wJew}ci&pwK?Kg@2-ewcj|PtJo1(DzXskJ8zzpz=|i9mi)L zyx{8N*^{_>7{`Oz{h-i1KD+V#AdYG~pM4j1@5L2($Mpxxwz|(~)%S5gI*2Ec3|hNv2u+Z#V}hRk6n z+Wpt~yKiW6lI+UUHCnGW`rUvGe7G%o|Gnukj?wxQS_dAM8a+f;cY-p|=5ey=Sn}&} zhD;xYemn>{dJ^C72VaoK`Rq;79P15Q^y_%?ZQOYjP_a2r<9VdL3zxQTr#4%mZH)AE zex{bfb-m0`YjAoRqkS3Uz{+Xlq60Lz6#k_2C`gq35qI(je{?B)&*j-~!W-R>KVaJf zT}-Wh34i!B(f4p^dGIT%dzID*>{C7qKY<^_qkI~8e38!VQ9el?rMCW9DA_Dq+MUqr z{jebT8GOgHpvy_nXFqt)m2cyUEOfTs^N^pNc=p@ivMd@j!XFVK94yEBQ99=;^cA#w z>4{P*<;U<5rH{hj(BdH3aN-cG685C)-F}$;ads=b=CzRS%u{I@7$AYN|5*>9!ppc0 z4df?zhHGGhy|<6V8vw!fbw4CV-at=02|D>+?i0)(CJNG4K7f%EC83WGqhIXkv*;0e z9Yybt;`;OD7;i1PAsWkrX6k?`mkK>qC>}bMHRN$HC8&!1Jrn_=~|+K8xrh z>((N$I10{R3<)Y_op}1YrTqR7(tSN_i0JJee4Nkz6cos@15)NlgJ=c2028oB+np!RXJt!+a~w}xb{<_tQ79Je4~M;OV|qW`UA zbLK&--SAyRA?OEZ$n2lv+pXDuhc)<9968^a-Hgvp#M9V^>+#$??%s_5*WyT&O^X}p zyY77(_jvvg0i`BtKD&OknE9V)*3pYpb&bCBja3%AHO4u4a z*mhv89PD5q(=r@vR?Y<8Ea~98-6IjIG9}UDIP@L~xE=B4R^m1k)3vRoZ#jx{?7>#@ z9jajPUr%KmZ^spRBQi!f_dL#d?oPzKcivgPE>1q*it)ZqVR5)j-QaoFD!6e!plmHE@hEZn zd0f30RFM|HNIHHYqVEfV8<-MDoh0A+S=fos!(RL*K3~R_Um`mGD$bFMTB=R8%>7!X zuLOs%DMRw!f+dRH53H_)oVI1G)!zOWxfcW=;t3cfwhIaEwpO;~3W>*6O9P|)pJ)%S z#P6;=*dlTtj$=a=FSPtLk`6>PN6}(mI>4gt2?@zH5%JV=xgF!H^PUUA13VHvU=O6< zP=&fBRN0F^Ea*X!8npUCd`G8{h)dDJ^eHK`p45`>Es|v^$8x0_nR_W~Ig&Wl>7XTg zM4jnj!k!pQPh{JE655BfD{rVz*f>GBpZMq2O&wyl2sA3G@9pi zpVxw(mS|Z-ajqOq=J*g*;hZdU7c1zcJ8kYj6R6#--BZ~^FQ8vNRydHeiM=ERFP1V| zVr|(~zKyJQ#T?Q<$l3pK*%KQg@DwT0b2&cCkmFWlj%@faK71!yB?5bpXqI)2=rZdU)vXJ&|BU~(wzrc- zm3@^)-|2D5eamB%SoLXNkFvCk+kuDZ)Uur|`j=2t%YaJAzw;ct`#M>UauRcm*b@68 z4bp6oEXqzu&yDa1eg9{gl>7i5k*fG1+W0viGna<%ccaEU+tZDZ43SDRGd&G|7+ zXhD__Bk}{5y{{#vAF~;wGU8jWr7db6wR}z21}MK#j;r`-n3{~hF%>YoQ9~Y!`uQQ* zQ0&eyJk1;)MXauQ33QLx{zJjGEXp#UMK4IpgP;nOQS}i!lD*1sA7j?yPS6#vIYm2c z+pC}#J*!@h*3ut1q0d0M+@yUO^9*BI+n&;EP2t$LJeK9-6hF%6e;k$o-LCcAp5HBe zw5HRhtJNB7_v*0Bm2g_RE_kiG({Qt<3sOFxZd+=*T9;R*+gs<=)@Rf0z|p#$vG$ee z_AOl&tk&K4gInuyyNB*`#92f@-=wo%fo7k0rmTR}k&5y0QEf|}lw-#Z z6GeGu%l5oAN<&2ta|N^oz2<`3mYzraE3=qZWzDVCGxLvlk}Ii(YDpqOBwtkAZ+R`T z^OH3cWgNE5uEa8HO6EY0jE+*t93Igs-k_#BUN7f@-P+M`FQz~@E?Td$R8j*f*mg<| zb4)-@D4z{u)Ef6X(osYHdYK`q8c)rlzHd#Y2l#Vbup?PR>#s&FoWw_k? z_xPy7x)=URD?c7ai!2qnDtJ5k$#ut^DYvuWI&VjNOFe4EU|{bmmvyC z^{L*+a(GhRRtpX?{%`BC&2nOV%p?ZJe>CvXP~wu~l+^KGqf9YelrYN`==@ zd~|Ic?k*DD%k`Vx=vehPD#u>ui9EcC^S8mL7RK%?uN@uJ8|JcE>BN6Bb*Sw13ZikB z^5}2biuudFt>fyjT@g3%Y1BRTBKm_W%)$_V<55cJU7S20!MtJ(N%M%equ&-~=#!az z^-{O4@wgqCokrz0j0fAtda|*WaEp~BR8(74?93sk&5D7Or0Hkj>AKvKK6?~XnyUG2 zt`9bH-S z-lPnlxht*wH+OC?uOkDVr=%9Z`aHaEZKcaN)G|{)tx<~e#?tihks08ZL8BZw+2%Qt zT7>N6G03WwSBUw>bRR5o?poqR9NF)`gr)D~a-Mgpt=Dcq1|swLQ0QSbPt0t{*!_M* zQdNb1t9fLkc+$u5Uv*Ti`K+<2=TMMa;WRYsQTD;C%4QDaA9D`mxR9XsS|DW5nLS#) znFE>17KW)b=9?VR6)80b8|%%^b*sG+W%~Ma@a)QeM-A?YNe* z0bSY5flN_gGY6tfZ8HZVjq7{f+WM_JUt=$v=N+;=ROH>VeR$=ZRoa?TWjD+*t6WC) z)CZUm$vB*T?xVl9x}F2cF>GGbvzZO4WvFYzhvqqm>bv#)y1K}$^kz0>GaE8AC$X6g z8P_;$>{-o*sKTgx@5w*S+(_vs$a+~TbRP31dc%*EufsgZHUF-aEHTd6YRSd2wq8Cx zJ+tAO^g&vmqgo17(MkdHw&k&nsb_ra`HidZ=4ZYMsbj8UKP02{1n$H(DcOu)+iZV z_AobFJ@=m$m62c5(mi&Mc;E z^qc3e{CpYFm)XV?hTv7^1?J<>ELL|kE&Xl5MBjROAr0g zyQBQdX8$AJ;l|#fTSTiwExe85HyK#3ocHj;H{$nO+R6GB=)Kwh=zMo1J(&iJv*c*sUK0)@1`gsuHQ2!5+NJZjvjhq=s9*avnMW}(syQDydvjTz`|G57?y9T z^9G{YIiA$am0OnY8`n_u?v8q*7tWDy-sb4;Mj$O>CnxsG(HlLioMPP{b$dqsFybin zU5?3EIcM;=_ufHa{cC>fz8jmdmOk%hk+Y0;L@ylVF^C_*vmSZ9mS}3!<$ea^tT!v*X!wD)uT&F zvR1RV*--pl%EXO`e2Hm~BGTMX`Q%hQUTaGUtBlCuNEb3YrMXI^H*Cr`uZeTo@i(hR za@Mg(bx7W%A)@6^;V20qqtv))r2 zwh}r1p^89l-FPk1__h074e$HFfmN*cW*+NUw9%)t=GFR8h+c=W%rMFod#8b8p4V^m zFOs}?7GN*EyKxv&M8?Q`#V{W0jh^`%r1C7VH!QzHs8xxJYK>y6pD9Y4Lb;F9GOAXI z99|1~^xlJI&I7%!M^r%-D(?n4MXSEPEs{SWu&kqn*7Z`iEkm{Zl$K{mitSfy zKgVl5d(k%@3aidr4XkpJCm;{dniuq@l{4448RN^7-f8v9QH>vS*%?3LuI}Ed>v*Kn zpI|-iF75s>pUS84s?&9Flz z%hq9$VlFJ7>hm=nruZ>T%c1@pi*@Iir*g1%tT#r@Iv*h;Msn)(BogHt{k0)zibUJ;V%CGtw+{Lzx&o}z(6%*&Bzhm4uZ|xnI zrL)!N@xG+tW3+i+{U(F8HQkF$4ojqUl6fa;_BGPFaPMrT{mVYACYO4cYjUzhL^k&Bw~Z{<{?1VfT}_lf7lR*Dq4Qf&6PJQ@ECTqW0x}Eo1L= z4BHjcppIKgTcT)CZ+O8oGGEpT*LV+8cnoSek`iurJnKWww!^>i*^Q0CJ@4Q=WdEn* zsgg#Dq+AQY?s|80Dr@Wogv217tb{sGYYUbcFLDKHbj&u_zXgSKYleh*cf0~?Z-%hn z*xr9m+e-${dm893@^5~#=fCcuH?As=WhuN(Snot_+hS|tW0oIg<>;9VlS~9103X)x z_BCX&tV?SZw4twJ39!wzU&yNia!>xPG2`iDZLdljN7S76dpHfuZ2wOKrL>->fn!^F z8W@bo`#zioPHWGHvxZsk@^BjH=+W3d)~A6zykE;{AeJ@0F*voJt?LQ(J8<|ws8j9` z-P-Gcct6iQe28kFagyaLanih69n=7ZCA={0=(vTp7c9ensnX@UPwZ zJlhu^>snWNwy78LPSg9X1BIzs2Q7zLB3Sp?EIBR*7E$!=k!& z-=C+O+UayFq%`j@(tR^7G8`>#c`*60TPQRF|Rk~$wWYj~npnovf3crNz4Z~gCvXl`n{#f}S+1_5# zGjhdMRw~#B)*f(?^p#a6jF3G?SbI|Ulc8%Y`Va4fTVATMcUe7ai+0hIJNPbEO|KDVAWBjBwOki`OgWdcf4`n|glp{YOI2yf1T* za(&)!ydPetT!Ca=^O1`6*{*+|ntxkubfpX#*Y2^D`PKO^8lBxb`PRBBV_UEBI*tsO a_g?i2u3o{Fe<4K~Jo-Vr3IF__{r>}22_q!{ diff --git a/burr/integrations/__init__.py b/burr/integrations/__init__.py index 9f5af4ccb..13a83393a 100644 --- a/burr/integrations/__init__.py +++ b/burr/integrations/__init__.py @@ -13,4 +13,4 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# in the License. +# under the License.