Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install -e ".[tests,tracking-client,graphviz]"
python -m pip install -e ".[tests,tracking-client,graphviz,tracking-server-s3]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tracking-server-s3 in the main CI install and apache-burr[bedrock] in [tests] (pyproject.toml:101) means every contributor now pulls boto3, aiobotocore, tortoise-orm, aerich even for unrelated PRs. Keep AWS deps in a separate CI job and test group, like the existing test-persister-dbs pattern.


- name: Run tests
run: |
Expand Down
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,11 @@ burr/tracking/server/build
examples/*/statemachine
examples/*/*/statemachine
.vscode

# Terraform (see also examples/deployment/aws/terraform/.gitignore)
**/.terraform.lock.hcl
examples/deployment/aws/terraform/.terraform/
examples/deployment/aws/terraform/*.tfstate
examples/deployment/aws/terraform/*.tfstate.*
examples/deployment/aws/terraform/.terraform.tfstate.lock.info
examples/deployment/aws/terraform/*.tfplan
11 changes: 11 additions & 0 deletions burr/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


def __getattr__(name: str):
"""Lazy load Bedrock integration to avoid requiring boto3 unless used."""
if name == "BedrockAction":
from burr.integrations.bedrock import BedrockAction
return BedrockAction
if name == "BedrockStreamingAction":
from burr.integrations.bedrock import BedrockStreamingAction
return BedrockStreamingAction
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
261 changes: 261 additions & 0 deletions burr/integrations/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Amazon Bedrock integration for Burr.

This module provides Action classes for invoking Amazon Bedrock models
within Burr applications.

Example usage:
from burr.integrations.bedrock import BedrockAction

def prompt_mapper(state):
return {
"messages": [{"role": "user", "content": state["user_input"]}],
"system": [{"text": "You are a helpful assistant."}],
}

# With default client (created lazily on first use):
action = BedrockAction(
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
input_mapper=prompt_mapper,
reads=["user_input"],
writes=["response"],
)

# With injected client (for tests or distributed execution):
# client = boto3.client("bedrock-runtime", region_name="us-east-1")
# action = BedrockAction(..., client=client)
"""

import logging
from typing import Any, Generator, Optional, Protocol

from burr.core.action import SingleStepAction, StreamingAction
from burr.core.state import State
from burr.integrations.base import require_plugin

logger = logging.getLogger(__name__)

# Type for injected Bedrock client (avoids boto3 import at type-check time)
BedrockClient = Any

try:
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
except ImportError as e:
require_plugin(e, "bedrock")


class StateToPromptMapper(Protocol):
"""Protocol for mapping Burr state to Bedrock prompt format."""

def __call__(self, state: State) -> dict[str, Any]:
...


class BedrockAction(SingleStepAction):
"""Action that invokes Amazon Bedrock models using the Converse API."""

def __init__(
self,
model_id: str,
input_mapper: StateToPromptMapper,
reads: list[str],
writes: list[str],
name: str = "bedrock_invoke",
region: Optional[str] = None,
guardrail_id: Optional[str] = None,
guardrail_version: Optional[str] = None,
inference_config: Optional[dict[str, Any]] = None,
max_retries: int = 3,
client: Optional[BedrockClient] = None,
):
super().__init__()
self._model_id = model_id
self._input_mapper = input_mapper
self._reads = reads
self._writes = writes
self._name = name
self._region = region
self._guardrail_id = guardrail_id
self._guardrail_version = guardrail_version or "DRAFT"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guardrail_version or "DRAFT" silently defaults to DRAFT when only guardrail_id is set. Is that intentional? Feels risky for prod, someone could set a guardrail ID and not realize they're running against an unpublished draft. Same at line 191.

self._inference_config = inference_config or {"maxTokens": 4096}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a small nit: inference_config or {"maxTokens": 4096} means inference_config={} gives you the default because empty dict is falsy. Use if inference_config is not None if you want to allow empty configs.

self._max_retries = max_retries
self._client = client

def _get_client(self) -> BedrockClient:
"""Return the Bedrock client, creating lazily if not injected."""
if self._client is not None:
return self._client
config = Config(
retries={"max_attempts": self._max_retries, "mode": "adaptive"}
)
self._client = boto3.client(
"bedrock-runtime", region_name=self._region, config=config
)
return self._client

@property
def reads(self) -> list[str]:
return self._reads

@property
def writes(self) -> list[str]:
return self._writes

@property
def name(self) -> str:
return self._name

def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
prompt = self._input_mapper(state)

request: dict[str, Any] = {
"modelId": self._model_id,
"messages": prompt["messages"],
"inferenceConfig": self._inference_config,
}

if "system" in prompt:
request["system"] = prompt["system"]

if self._guardrail_id:
request["guardrailConfig"] = {
"guardrailIdentifier": self._guardrail_id,
"guardrailVersion": self._guardrail_version,
}

try:
response = self._get_client().converse(**request)
except ClientError as e:
logger.error("Bedrock API error: %s", e)
raise

output_message = response["output"]["message"]
content_blocks = output_message.get("content", [])
text = content_blocks[0]["text"] if content_blocks else ""

result: dict[str, Any] = {
"response": text,
"usage": response.get("usage", {}),
"stop_reason": response.get("stopReason"),
}

updates = {key: result[key] for key in self._writes if key in result}
new_state = state.update(**updates)

return result, new_state


class BedrockStreamingAction(StreamingAction):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO BedrockStreamingAction is a copy-paste of BedrockAction. __init__, _get_client, reads/writes/name properties are all identical, ~70 lines duplicated. Pull them into a _BedrockBase and let each subclass just implement its execution method.

"""Streaming variant of BedrockAction using Converse Stream API."""

def __init__(
self,
model_id: str,
input_mapper: StateToPromptMapper,
reads: list[str],
writes: list[str],
name: str = "bedrock_stream",
region: Optional[str] = None,
guardrail_id: Optional[str] = None,
guardrail_version: Optional[str] = None,
inference_config: Optional[dict[str, Any]] = None,
max_retries: int = 3,
client: Optional[BedrockClient] = None,
):
super().__init__()
self._model_id = model_id
self._input_mapper = input_mapper
self._reads = reads
self._writes = writes
self._name = name
self._region = region
self._guardrail_id = guardrail_id
self._guardrail_version = guardrail_version or "DRAFT"
self._inference_config = inference_config or {"maxTokens": 4096}
self._max_retries = max_retries
self._client = client

def _get_client(self) -> BedrockClient:
"""Return the Bedrock client, creating lazily if not injected."""
if self._client is not None:
return self._client
config = Config(
retries={"max_attempts": self._max_retries, "mode": "adaptive"}
)
self._client = boto3.client(
"bedrock-runtime", region_name=self._region, config=config
)
return self._client

@property
def reads(self) -> list[str]:
return self._reads

@property
def writes(self) -> list[str]:
return self._writes

@property
def name(self) -> str:
return self._name

def stream_run(
self, state: State, **run_kwargs
) -> Generator[dict, None, None]:
prompt = self._input_mapper(state)

request: dict[str, Any] = {
"modelId": self._model_id,
"messages": prompt["messages"],
"inferenceConfig": self._inference_config,
}

if "system" in prompt:
request["system"] = prompt["system"]

if self._guardrail_id:
request["guardrailConfig"] = {
"guardrailIdentifier": self._guardrail_id,
"guardrailVersion": self._guardrail_version,
}

try:
response = self._get_client().converse_stream(**request)
except ClientError as e:
logger.error("Bedrock streaming API error: %s", e)
raise

full_response = ""
stream = response.get("stream", [])
for event in stream:
if "contentBlockDelta" in event:
chunk = event["contentBlockDelta"]["delta"].get("text", "")
full_response += chunk
yield {"chunk": chunk, "response": full_response}

yield {"chunk": "", "response": full_response, "complete": True}

def update(self, result: dict, state: State) -> State:
if result.get("complete"):
updates = {"response": result.get("response", "")}
filtered = {k: v for k, v in updates.items() if k in self._writes}
return state.update(**filtered)
return state
25 changes: 25 additions & 0 deletions burr/tracking/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ def snapshot_interval_milliseconds(self) -> Optional[int]:
pass


class EventDrivenBackendMixin(abc.ABC):
"""Mixin for backends that support event-driven updates.

Enables backends to receive real-time notifications instead of polling
for new files.
"""

@abc.abstractmethod
async def start_event_consumer(self):
"""Start the event consumer for event-driven tracking.

This method should run indefinitely, processing event notifications
from the configured message queue.
"""
pass

@abc.abstractmethod
def is_event_driven(self) -> bool:
"""Check if this backend is configured for event-driven updates.

:return: True if event-driven mode is enabled and configured, False otherwise
"""
pass


class BackendBase(abc.ABC):
async def lifespan(self, app: FastAPI):
"""Quick tool to allow plugin to the app's lifecycle.
Expand Down
31 changes: 25 additions & 6 deletions burr/tracking/server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import importlib
import logging
import os
Expand All @@ -29,6 +30,7 @@
from burr.tracking.server.backend import (
AnnotationsBackendMixin,
BackendBase,
EventDrivenBackendMixin,
IndexingBackendMixin,
SnapshottingBackendMixin,
)
Expand Down Expand Up @@ -135,9 +137,20 @@ async def lifespan(app: FastAPI):
await backend.lifespan(app).__anext__()
await sync_index() # this will trigger the repeat every N seconds
await save_snapshot() # this will trigger the repeat every N seconds
# Start event consumer for event-driven tracking when configured
event_consumer_task = None
if isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven():
event_consumer_task = asyncio.create_task(backend.start_event_consumer())
global initialized
initialized = True
yield
# Graceful shutdown: cancel event consumer task
if event_consumer_task is not None:
event_consumer_task.cancel()
try:
await event_consumer_task
except asyncio.CancelledError:
pass
await backend.lifespan(app).__anext__()


Expand Down Expand Up @@ -172,12 +185,18 @@ def get_app_spec():
logger = logging.getLogger(__name__)

if app_spec.indexing:
update_interval = backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None
sync_index = repeat_every(
seconds=backend.update_interval_milliseconds() / 1000,
wait_first=True,
logger=logger,
)(sync_index)
# Only use polling when not in event-driven mode
if not (
isinstance(backend, EventDrivenBackendMixin) and backend.is_event_driven()
):
update_interval = (
backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None
)
sync_index = repeat_every(
seconds=backend.update_interval_milliseconds() / 1000,
wait_first=True,
logger=logger,
)(sync_index)

if app_spec.snapshotting:
snapshot_interval = (
Expand Down
Loading