Skip to content
Open
10 changes: 7 additions & 3 deletions samples/hello_world_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from a2a.server.events.event_queue import EventQueue
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
from a2a.server.routes import (
add_a2a_routes_to_fastapi,
create_agent_card_routes,
create_jsonrpc_routes,
create_rest_routes,
Expand Down Expand Up @@ -220,9 +221,12 @@ async def serve(
agent_card=agent_card,
)
app = FastAPI()
app.routes.extend(jsonrpc_routes)
app.routes.extend(agent_card_routes)
app.routes.extend(rest_routes)
add_a2a_routes_to_fastapi(
app,
agent_card_routes=agent_card_routes,
jsonrpc_routes=jsonrpc_routes,
rest_routes=rest_routes,
)

grpc_server = grpc.aio.server()
grpc_server.add_insecure_port(f'{host}:{grpc_port}')
Expand Down
71 changes: 65 additions & 6 deletions src/a2a/server/routes/_proto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any

from google.api import field_behavior_pb2 as fb
from google.protobuf.descriptor import Descriptor, FieldDescriptor
from google.protobuf.message import Message

Expand Down Expand Up @@ -33,6 +34,12 @@
FieldDescriptor.TYPE_SINT64: {'type': 'string'},
}


def _is_required(field: FieldDescriptor) -> bool:
"""Returns True if the field carries google.api.field_behavior = REQUIRED."""
return fb.REQUIRED in field.GetOptions().Extensions[fb.field_behavior] # type: ignore[index] # ty: ignore[invalid-argument-type]


_WELL_KNOWN_SCHEMAS: dict[str, dict[str, Any]] = {
'google.protobuf.Timestamp': {'type': 'string', 'format': 'date-time'},
'google.protobuf.Duration': {'type': 'string'},
Expand All @@ -57,16 +64,46 @@ def field_schema(

if field.type == FieldDescriptor.TYPE_MESSAGE:
item = message_schema(field.message_type, components)
# Well-known types return an inline schema (no $ref); don't wrap them as
# nullable — they're already inlined as their JSON-Schema equivalent.
# Repeated fields must not return early here — they fall through to the
# array-wrapping block below.
if not field.is_repeated and not _is_required(field) and '$ref' in item:
return {'oneOf': [item, {'type': 'null'}], 'example': None}
elif field.type == FieldDescriptor.TYPE_ENUM:
item = {
'type': 'string',
'enum': [v.name for v in field.enum_type.values],
}
values = [v.name for v in field.enum_type.values]
example = next(
(
v
for v in values
if 'UNSPECIFIED' not in v and 'UNKNOWN' not in v
),
values[0] if values else None,
)
item: dict[str, Any] = {'type': 'string', 'enum': values}
if example:
item['example'] = example
else:
item = dict(_PROTO_SCALAR_SCHEMAS.get(field.type, {'type': 'string'}))
if field.type == FieldDescriptor.TYPE_STRING:
# REQUIRED fields must be non-empty; use the field name as a
# recognisable placeholder. All other strings default to "".
item['example'] = field.name if _is_required(field) else ''
elif field.type == FieldDescriptor.TYPE_BOOL:
item['example'] = False

if field.is_repeated:
return {'type': 'array', 'items': item}
array_schema: dict[str, Any] = {'type': 'array', 'items': item}
# Propagate the item example to the array so Swagger pre-fills one entry
# instead of generating one entry per oneOf branch.
item_example = (
components.get(item['$ref'].split('/')[-1], {}).get('example')
if '$ref' in item
else item.get('example')
)
if item_example is not None:
array_schema['example'] = [item_example]
return array_schema
return item


Expand Down Expand Up @@ -114,5 +151,27 @@ def message_schema(
if base_properties:
parts.append({'type': 'object', 'properties': base_properties})
parts.extend(oneof_constraints)
components[name] = parts[0] if len(parts) == 1 else {'allOf': parts}
schema: dict[str, Any] = parts[0] if len(parts) == 1 else {'allOf': parts}
# Provide a single concrete example using the first oneof variant so Swagger
# doesn't expand every branch into separate array items.
first_oneof_field = real_oneofs[0].fields[0]
first_field_schema = field_schema(first_oneof_field, components)
if 'example' in first_field_schema:
first_example: Any = first_field_schema['example']
elif '$ref' in first_field_schema:
ref_name = first_field_schema['$ref'].split('/')[-1]
first_example = components.get(ref_name, {}).get('example')
else:
_type_defaults: dict[str, Any] = {
'integer': 0,
'number': 0.0,
'boolean': False,
'array': [],
'object': {},
}
first_example = _type_defaults.get(
first_field_schema.get('type', 'string'), ''
)
schema['example'] = {first_oneof_field.name: first_example}
components[name] = schema
return ref
145 changes: 103 additions & 42 deletions tests/server/routes/test_proto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,48 +59,6 @@ def test_message_schema_oneof_variants_have_required():
assert len(variant['required']) == 1


def test_message_schema_multiple_oneofs_use_allof_not_cartesian_product():
# Simulate a descriptor with two oneofs: verify allOf has one constraint
# per oneof rather than a flat list of cross-product variants.
from unittest.mock import MagicMock

def _make_field(name):
f = MagicMock()
f.name = name
f.message_type = None
f.type = 9 # TYPE_STRING
f.is_repeated = False
return f

def _make_oneof(fields):
o = MagicMock()
o.fields = fields
return o

f_a, f_b = _make_field('a'), _make_field('b')
f_x, f_y = _make_field('x'), _make_field('y')
oneof1 = _make_oneof([f_a, f_b])
oneof2 = _make_oneof([f_x, f_y])

descriptor = MagicMock()
descriptor.full_name = 'test.MultiOneof'
descriptor.name = 'MultiOneof'
descriptor.oneofs = [oneof1, oneof2]
descriptor.fields = [f_a, f_b, f_x, f_y]

components = {}
message_schema(descriptor, components)
schema = components['MultiOneof']

# Should be allOf with two oneOf constraints (one per oneof group),
# NOT a flat oneOf with 2*2=4 Cartesian-product variants.
assert 'allOf' in schema
one_of_constraints = [p for p in schema['allOf'] if 'oneOf' in p]
assert len(one_of_constraints) == 2
assert len(one_of_constraints[0]['oneOf']) == 2
assert len(one_of_constraints[1]['oneOf']) == 2


def test_field_schema_repeated_wraps_in_array():
components = {}
msg_descriptor = SendMessageRequest.DESCRIPTOR.fields_by_name[
Expand All @@ -120,6 +78,92 @@ def test_field_schema_enum():
assert 'ROLE_AGENT' in schema['enum']


def test_field_schema_enum_example_skips_unspecified():
role_field = Message.DESCRIPTOR.fields_by_name['role']
schema = field_schema(role_field, {})
assert schema['example'] == 'ROLE_USER'


def test_field_schema_string_example_is_empty():
context_id_field = Message.DESCRIPTOR.fields_by_name['context_id']
schema = field_schema(context_id_field, {})
assert schema['example'] == ''


def test_field_schema_string_required_uses_field_name():
# REQUIRED string fields must be non-empty; the field name is the placeholder.
message_id_field = Message.DESCRIPTOR.fields_by_name['message_id']
schema = field_schema(message_id_field, {})
assert schema['example'] == 'message_id'


def test_field_schema_bool_example_is_false():
from a2a.types.a2a_pb2 import SendMessageConfiguration

field = SendMessageConfiguration.DESCRIPTOR.fields_by_name[
'return_immediately'
]
schema = field_schema(field, {})
assert schema['example'] is False


def test_field_schema_optional_message_is_nullable():
# Non-REQUIRED message fields default to null so Swagger doesn't pre-fill them
# with empty sub-fields that trigger server-side required-field validation.
from a2a.types.a2a_pb2 import SendMessageConfiguration

field = SendMessageConfiguration.DESCRIPTOR.fields_by_name[
'task_push_notification_config'
]
schema = field_schema(field, {})
assert schema['example'] is None
assert any(v == {'type': 'null'} for v in schema['oneOf'])


def test_field_schema_required_message_is_not_nullable():
from a2a.types.a2a_pb2 import SendMessageRequest

field = SendMessageRequest.DESCRIPTOR.fields_by_name['message']
schema = field_schema(field, {})
assert '$ref' in schema
assert 'oneOf' not in schema


def test_field_schema_repeated_optional_message_is_array_not_nullable():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Task.history here instead of mocking? Based on the test description it should suit it: repeated and non-required.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in dc7f066 — the test now uses Task.history (a real repeated, non-required message field) instead of a mock.

# Repeated non-REQUIRED message fields must be wrapped as an array, not
# returned early as a nullable oneOf — the is_repeated check must come
# first. Task.history is a real repeated, non-required message field.
from a2a.types.a2a_pb2 import Task

field = Task.DESCRIPTOR.fields_by_name['history']
schema = field_schema(field, {})
assert schema['type'] == 'array'
assert 'oneOf' not in schema
assert '$ref' in schema['items']


def test_message_schema_oneof_example_uses_first_variant_only():
components = {}
message_schema(Part.DESCRIPTOR, components)
example = components['Part']['example']
assert example == {'text': ''}
# base properties (metadata, filename, media_type) must not appear in the
# example — they are objects/strings that would be wrong if sent as "".
assert 'metadata' not in example
assert 'filename' not in example


def test_field_schema_repeated_ref_example_propagated():
components = {}
msg_descriptor = SendMessageRequest.DESCRIPTOR.fields_by_name[
'message'
].message_type
parts_field = msg_descriptor.fields_by_name['parts']
schema = field_schema(parts_field, components)
assert schema['type'] == 'array'
assert schema['example'] == [{'text': ''}]


def test_field_schema_map_entry():
metadata_field = SendMessageRequest.DESCRIPTOR.fields_by_name['metadata']
schema = field_schema(metadata_field, {})
Expand All @@ -130,3 +174,20 @@ def test_rest_body_types_coverage():
assert ('/message:send', 'POST') in REST_BODY_TYPES
assert ('/message:stream', 'POST') in REST_BODY_TYPES
assert ('/tasks/{id}/pushNotificationConfigs', 'POST') in REST_BODY_TYPES


def test_full_schema_builds_for_all_rest_body_types():
# Safety net: build the complete schema for every registered REST body
# type into a shared components dict. Any proto field structure we don't
# support (or stop supporting after a proto change) fails right here
# rather than silently producing a broken Swagger document.
components: dict = {}
for msg in REST_BODY_TYPES.values():
ref = message_schema(msg.DESCRIPTOR, components)
assert ref['$ref'].startswith('#/components/schemas/')

# Every registered schema must be a non-empty object/composition (the
# cyclic-type placeholder is filled in before the build returns).
for name, schema in components.items():
assert schema, f'{name} resolved to an empty schema'
assert 'type' in schema or 'allOf' in schema or '$ref' in schema
Loading