Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,21 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
),
"JSON_VALUE(attributes, '$.model_version') AS model_version",
"JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata",
(
"CAST(JSON_VALUE(attributes,"
" '$.usage_metadata.cached_content_token_count')"
" AS INT64) AS usage_cached_tokens"
),
(
"SAFE_DIVIDE("
"CAST(JSON_VALUE(attributes,"
" '$.usage_metadata.cached_content_token_count')"
" AS INT64),"
"CAST(JSON_VALUE(content, '$.usage.prompt')"
" AS INT64)"
") AS context_cache_hit_rate"
),
"JSON_QUERY(attributes, '$.cache_metadata') AS cache_metadata",
],
"LLM_ERROR": [
"CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms",
Expand Down Expand Up @@ -1860,6 +1875,7 @@ class EventData:
model: Optional[str] = None
model_version: Optional[str] = None
usage_metadata: Any = None
cache_metadata: Any = None
status: str = "OK"
error_message: Optional[str] = None
extra_attributes: dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -2640,6 +2656,14 @@ def _enrich_attributes(
attrs["usage_metadata"] = usage_dict
else:
attrs["usage_metadata"] = event_data.usage_metadata
if event_data.cache_metadata:
cache_meta_dict, _ = _recursive_smart_truncate(
event_data.cache_metadata, self.config.max_content_length
)
if isinstance(cache_meta_dict, dict):
attrs["cache_metadata"] = cache_meta_dict
else:
attrs["cache_metadata"] = event_data.cache_metadata

if self.config.log_session_metadata:
try:
Expand Down Expand Up @@ -3172,6 +3196,7 @@ async def after_model_callback(
time_to_first_token_ms=tfft,
model_version=llm_response.model_version,
usage_metadata=llm_response.usage_metadata,
cache_metadata=llm_response.cache_metadata,
span_id_override=span_id if use_override else None,
parent_span_id_override=parent_span_id if use_override else None,
),
Expand Down
92 changes: 92 additions & 0 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6672,3 +6672,95 @@ def _fake_run_coroutine_threadsafe(coro, loop):
mock_rcts.assert_called()
call_args = mock_rcts.call_args
assert call_args[0][1] is other_loop


# ================================================================
# TEST CLASS: Context cache metrics (Issue #5210)
# ================================================================
class TestContextCacheMetrics:
"""Tests for context cache metric extraction and view columns."""

@pytest.mark.asyncio
async def test_cache_metadata_stored_in_attributes(
self,
bq_plugin_inst,
mock_write_client,
callback_context,
dummy_arrow_schema,
):
"""cache_metadata from LlmResponse is stored in attributes."""
from google.adk.models.cache_metadata import CacheMetadata

cache_meta = CacheMetadata(
cache_name="projects/p/locations/us/cachedContents/123",
fingerprint="abc123",
invocations_used=5,
contents_count=10,
expire_time=1700000000.0,
created_at=1699000000.0,
)

bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
contents=[types.Content(parts=[types.Part(text="test")])],
)
await bq_plugin_inst.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)

llm_response = llm_response_lib.LlmResponse(
content=types.Content(parts=[types.Part(text="Response")]),
cache_metadata=cache_meta,
)
await bq_plugin_inst.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
await asyncio.sleep(0.05)
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
log_entry = next(r for r in rows if r["event_type"] == "LLM_RESPONSE")
attributes = json.loads(log_entry["attributes"])
assert "cache_metadata" in attributes
assert attributes["cache_metadata"]["cache_name"] == (
"projects/p/locations/us/cachedContents/123"
)
assert attributes["cache_metadata"]["fingerprint"] == "abc123"
assert attributes["cache_metadata"]["invocations_used"] == 5

@pytest.mark.asyncio
async def test_no_cache_metadata_when_absent(
self,
bq_plugin_inst,
mock_write_client,
callback_context,
dummy_arrow_schema,
):
"""No cache_metadata in attributes when LlmResponse has none."""
bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
contents=[types.Content(parts=[types.Part(text="test")])],
)
await bq_plugin_inst.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)

llm_response = llm_response_lib.LlmResponse(
content=types.Content(parts=[types.Part(text="Response")]),
)
await bq_plugin_inst.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
await asyncio.sleep(0.05)
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
log_entry = next(r for r in rows if r["event_type"] == "LLM_RESPONSE")
attributes = json.loads(log_entry["attributes"])
assert "cache_metadata" not in attributes

def test_view_def_includes_cache_columns(self):
"""LLM_RESPONSE view definition includes cache metric columns."""
view_cols = bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS["LLM_RESPONSE"]
col_str = " ".join(view_cols)
assert "usage_cached_tokens" in col_str
assert "context_cache_hit_rate" in col_str
assert "cache_metadata" in col_str