From 1728c15b2ed689ff81bf6a874314d57c07b50354 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Mon, 2 Feb 2026 16:48:26 -0500 Subject: [PATCH 01/11] DGS-23437: Enable client omit Confluent --- examples/oauth_schema_registry.py | 27 ++++ .../_async/schema_registry_client.py | 27 ++-- .../_sync/schema_registry_client.py | 32 +++-- .../common/schema_registry_client.py | 16 ++- .../_async/test_bearer_field_provider.py | 133 +++++++++++++++++- .../_sync/test_bearer_field_provider.py | 4 +- 6 files changed, 205 insertions(+), 34 deletions(-) diff --git a/examples/oauth_schema_registry.py b/examples/oauth_schema_registry.py index fd1115b5c..10856ac48 100644 --- a/examples/oauth_schema_registry.py +++ b/examples/oauth_schema_registry.py @@ -71,6 +71,33 @@ def custom_oauth_function(config): custom_sr_client = SchemaRegistryClient(custom_sr_config) print(custom_sr_client.get_subjects()) + # Example: Using union-of-pools with comma-separated pool IDs + union_of_pools_config = { + 'url': 'https://psrc-123456.us-east-1.aws.confluent.cloud', + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': 'multi-pool-token', + 'bearer.auth.logical.cluster': 'lsrc-12345', + 'bearer.auth.identity.pool.id': 'pool-abc,pool-def,pool-ghi', + } + + union_sr_client = SchemaRegistryClient(union_of_pools_config) + print(union_sr_client.get_subjects()) + + # Example: Omitting identity pool for auto pool mapping + auto_pool_config = { + 'url': 'https://psrc-123456.us-east-1.aws.confluent.cloud', + 'bearer.auth.credentials.source': 'OAUTHBEARER', + 'bearer.auth.client.id': 'client-id', + 'bearer.auth.client.secret': 'client-secret', + 'bearer.auth.scope': 'schema_registry', + 'bearer.auth.issuer.endpoint.url': 'https://yourauthprovider.com/v1/token', + 'bearer.auth.logical.cluster': 'lsrc-12345', + # bearer.auth.identity.pool.id is omitted - SR will use auto pool mapping + } + + auto_pool_sr_client = SchemaRegistryClient(auto_pool_config) + print(auto_pool_sr_client.get_subjects()) + if __name__ == '__main__': main() diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index e0efe351e..0abde8f3d 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -97,10 +97,10 @@ def __init__( scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int, + identity_pool: Optional[str] = None, ): self.token = None self.logical_cluster = logical_cluster @@ -113,11 +113,13 @@ def __init__( self.token_expiry_threshold = 0.8 async def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': await self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def token_expired(self) -> bool: if self.token is None: @@ -283,19 +285,17 @@ def __init__(self, conf: dict): self.auth = None if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: + if 'bearer.auth.logical.cluster' not in conf_copy: raise ValueError( - "Missing required bearer configuration properties: {}".format(", ".join(missing_headers)) + "Missing required bearer configuration property: bearer.auth.logical.cluster" ) logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) + if identity_pool is not None and not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) if self.bearer_auth_credentials_source == 'OAUTHBEARER': @@ -335,10 +335,10 @@ def __init__(self, conf: dict): self.scope, self.token_endpoint, logical_cluster, - identity_pool, self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms, + identity_pool, ) else: # STATIC_TOKEN if 'bearer.auth.token' not in conf_copy: @@ -412,7 +412,8 @@ async def handle_bearer_auth(self, headers: dict) -> None: if self.bearer_field_provider is None: raise ValueError("Bearer field provider is not set") bearer_fields = await self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + # Note: bearer.auth.identity.pool.id is optional; only token and logical.cluster are required + required_fields = ['bearer.auth.token', 'bearer.auth.logical.cluster'] missing_fields = [] for field in required_fields: @@ -427,9 +428,11 @@ async def handle_bearer_auth(self, headers: dict) -> None: ) headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + if 'bearer.auth.identity.pool.id' in bearer_fields: + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + async def get(self, url: str, query: Optional[dict] = None) -> Any: return await self.send_request(url, method='GET', query=query) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 74435e108..5d30267aa 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -97,10 +97,10 @@ def __init__( scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int, + identity_pool: Optional[str] = None, ): self.token = None self.logical_cluster = logical_cluster @@ -113,11 +113,13 @@ def __init__( self.token_expiry_threshold = 0.8 def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def token_expired(self) -> bool: if self.token is None: @@ -283,19 +285,20 @@ def __init__(self, conf: dict): self.auth = None if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: + if 'bearer.auth.logical.cluster' not in conf_copy: raise ValueError( - "Missing required bearer configuration properties: {}".format(", ".join(missing_headers)) + "Missing required bearer configuration property: bearer.auth.logical.cluster" ) logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): + # Identity pool is optional and may be a single pool ID or comma-separated list of pool IDs. + # The value is passed as-is in the Confluent-Identity-Pool-Id header to Schema Registry. + # When union-of-pools is enabled on the server, multiple IDs are interpreted as a union. + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) + if identity_pool is not None and not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) if self.bearer_auth_credentials_source == 'OAUTHBEARER': @@ -335,10 +338,10 @@ def __init__(self, conf: dict): self.scope, self.token_endpoint, logical_cluster, - identity_pool, self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms, + identity_pool, ) else: # STATIC_TOKEN if 'bearer.auth.token' not in conf_copy: @@ -412,7 +415,8 @@ def handle_bearer_auth(self, headers: dict) -> None: if self.bearer_field_provider is None: raise ValueError("Bearer field provider is not set") bearer_fields = self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + # Note: bearer.auth.identity.pool.id is optional; only token and logical.cluster are required + required_fields = ['bearer.auth.token', 'bearer.auth.logical.cluster'] missing_fields = [] for field in required_fields: @@ -427,9 +431,13 @@ def handle_bearer_auth(self, headers: dict) -> None: ) headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + # Identity pool may be a single pool ID or comma-separated list of pool IDs. + # Pass the value as-is to the server; it handles union-of-pools semantics. + if 'bearer.auth.identity.pool.id' in bearer_fields: + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + def get(self, url: str, query: Optional[dict] = None) -> Any: return self.send_request(url, method='GET', query=query) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index afdefd08f..0464ab609 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -71,33 +71,37 @@ async def get_bearer_fields(self) -> dict: class _StaticFieldProvider(_BearerFieldProvider): """Synchronous static token bearer field provider.""" - def __init__(self, token: str, logical_cluster: str, identity_pool: str): + def __init__(self, token: str, logical_cluster: str, identity_pool: Optional[str] = None): self.token = token self.logical_cluster = logical_cluster self.identity_pool = identity_pool def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields class _AsyncStaticFieldProvider(_AsyncBearerFieldProvider): """Asynchronous static token bearer field provider.""" - def __init__(self, token: str, logical_cluster: str, identity_pool: str): + def __init__(self, token: str, logical_cluster: str, identity_pool: Optional[str] = None): self.token = token self.logical_cluster = logical_cluster self.identity_pool = identity_pool async def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def is_success(status_code: int) -> bool: diff --git a/tests/schema_registry/_async/test_bearer_field_provider.py b/tests/schema_registry/_async/test_bearer_field_provider.py index ffedc15b6..2301c768e 100644 --- a/tests/schema_registry/_async/test_bearer_field_provider.py +++ b/tests/schema_registry/_async/test_bearer_field_provider.py @@ -51,7 +51,7 @@ async def custom_oauth_function(config: dict) -> dict: def test_expiry(): - oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) # Use consistent test data: expires_at and expires_in should match # Token expires in 2 seconds, with 0.8 threshold, should refresh after 1.6 seconds (when 0.4s remaining) oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2} @@ -61,7 +61,7 @@ def test_expiry(): async def test_get_token(): - oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) def update_token1(): oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} @@ -158,3 +158,132 @@ async def test_bearer_field_headers_valid(): assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token']) assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id'] assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster'] + + +async def test_bearer_field_headers_optional_identity_pool(): + """Test that identity pool is optional and header is omitted when not provided.""" + + async def custom_oauth_no_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_no_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert headers['Authorization'] == "Bearer {}".format(TEST_TOKEN) + assert headers['target-sr-cluster'] == TEST_CLUSTER + # Confluent-Identity-Pool-Id should NOT be present when identity pool is omitted + assert 'Confluent-Identity-Pool-Id' not in headers + + +async def test_bearer_field_headers_comma_separated_pools(): + """Test that comma-separated pool IDs are passed through as-is in the header.""" + comma_separated_pools = 'pool-1,pool-2,pool-3' + + async def custom_oauth_multi_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_multi_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'Confluent-Identity-Pool-Id' in headers + assert 'target-sr-cluster' in headers + # Verify comma-separated value is passed through unchanged + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +async def test_static_token_optional_identity_pool(): + """Test that STATIC_TOKEN credential source works without identity pool.""" + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert 'Confluent-Identity-Pool-Id' not in headers + + +async def test_static_token_comma_separated_pools(): + """Test that STATIC_TOKEN credential source supports comma-separated pool IDs.""" + comma_separated_pools = 'pool-abc,pool-def,pool-ghi' + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Confluent-Identity-Pool-Id' in headers + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +def test_static_field_provider_optional_pool(): + """Test that _StaticFieldProvider works with optional identity pool.""" + from confluent_kafka.schema_registry.common.schema_registry_client import _AsyncStaticFieldProvider + import asyncio + + async def check_provider(): + static_field_provider = _AsyncStaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) + bearer_fields = await static_field_provider.get_bearer_fields() + + assert bearer_fields['bearer.auth.token'] == TEST_TOKEN + assert bearer_fields['bearer.auth.logical.cluster'] == TEST_CLUSTER + assert 'bearer.auth.identity.pool.id' not in bearer_fields + + asyncio.run(check_provider()) diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index ebeceaa7e..86dcb12e5 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -51,7 +51,7 @@ def custom_oauth_function(config: dict) -> dict: def test_expiry(): - oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) # Use consistent test data: expires_at and expires_in should match # Token expires in 2 seconds, with 0.8 threshold, should refresh after 1.6 seconds (when 0.4s remaining) oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2} @@ -61,7 +61,7 @@ def test_expiry(): def test_get_token(): - oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) def update_token1(): oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} From b0940f3a2ea0ddba1a6c4b92477097b180d31ef3 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Mon, 2 Feb 2026 17:14:24 -0500 Subject: [PATCH 02/11] Fix to sync version --- .../_sync/schema_registry_client.py | 2 - .../_sync/test_bearer_field_provider.py | 123 ++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 5d30267aa..9fd6832f7 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -433,8 +433,6 @@ def handle_bearer_auth(self, headers: dict) -> None: headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] - # Identity pool may be a single pool ID or comma-separated list of pool IDs. - # Pass the value as-is to the server; it handles union-of-pools semantics. if 'bearer.auth.identity.pool.id' in bearer_fields: headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index 86dcb12e5..a54845fd7 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -158,3 +158,126 @@ def test_bearer_field_headers_valid(): assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token']) assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id'] assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster'] + + +def test_bearer_field_headers_optional_identity_pool(): + """Test that identity pool is optional and header is omitted when not provided.""" + + def custom_oauth_no_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_no_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert headers['Authorization'] == "Bearer {}".format(TEST_TOKEN) + assert headers['target-sr-cluster'] == TEST_CLUSTER + # Confluent-Identity-Pool-Id should NOT be present when identity pool is omitted + assert 'Confluent-Identity-Pool-Id' not in headers + + +def test_bearer_field_headers_comma_separated_pools(): + """Test that comma-separated pool IDs are passed through as-is in the header.""" + comma_separated_pools = 'pool-1,pool-2,pool-3' + + def custom_oauth_multi_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_multi_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'Confluent-Identity-Pool-Id' in headers + assert 'target-sr-cluster' in headers + # Verify comma-separated value is passed through unchanged + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +def test_static_token_optional_identity_pool(): + """Test that STATIC_TOKEN credential source works without identity pool.""" + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert 'Confluent-Identity-Pool-Id' not in headers + + +def test_static_token_comma_separated_pools(): + """Test that STATIC_TOKEN credential source supports comma-separated pool IDs.""" + comma_separated_pools = 'pool-abc,pool-def,pool-ghi' + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Confluent-Identity-Pool-Id' in headers + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +def test_static_field_provider_optional_pool(): + """Test that _StaticFieldProvider works with optional identity pool.""" + static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) + bearer_fields = static_field_provider.get_bearer_fields() + + assert bearer_fields['bearer.auth.token'] == TEST_TOKEN + assert bearer_fields['bearer.auth.logical.cluster'] == TEST_CLUSTER + assert 'bearer.auth.identity.pool.id' not in bearer_fields From efd2db4207da8f0f200a77f17cf163f11c308649 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Mon, 2 Feb 2026 17:36:53 -0500 Subject: [PATCH 03/11] Fix tests --- tests/schema_registry/_async/test_bearer_field_provider.py | 2 +- tests/schema_registry/_async/test_config.py | 2 +- tests/schema_registry/_sync/test_bearer_field_provider.py | 2 +- tests/schema_registry/_sync/test_config.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/schema_registry/_async/test_bearer_field_provider.py b/tests/schema_registry/_async/test_bearer_field_provider.py index 2301c768e..89616cac9 100644 --- a/tests/schema_registry/_async/test_bearer_field_provider.py +++ b/tests/schema_registry/_async/test_bearer_field_provider.py @@ -86,7 +86,7 @@ def update_token2(): async def test_generate_token_retry_logic(): - oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000) + oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 5, 1000, 20000, TEST_POOL) with ( patch("confluent_kafka.schema_registry._async.schema_registry_client.asyncio.sleep") as mock_sleep, diff --git a/tests/schema_registry/_async/test_config.py b/tests/schema_registry/_async/test_config.py index 98dd79c66..e374b0fb9 100644 --- a/tests/schema_registry/_async/test_config.py +++ b/tests/schema_registry/_async/test_config.py @@ -124,7 +124,7 @@ def test_config_auth_userinfo_invalid(): def test_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': "OAUTHBEARER"} - with pytest.raises(ValueError, match=r"Missing required bearer configuration properties: (.*)"): + with pytest.raises(ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster"): AsyncSchemaRegistryClient(conf) diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index a54845fd7..5f72d03e3 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -86,7 +86,7 @@ def update_token2(): def test_generate_token_retry_logic(): - oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000) + oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 5, 1000, 20000, TEST_POOL) with ( patch("confluent_kafka.schema_registry._sync.schema_registry_client.time.sleep") as mock_sleep, diff --git a/tests/schema_registry/_sync/test_config.py b/tests/schema_registry/_sync/test_config.py index e5b74aeda..541d1b761 100644 --- a/tests/schema_registry/_sync/test_config.py +++ b/tests/schema_registry/_sync/test_config.py @@ -124,7 +124,7 @@ def test_config_auth_userinfo_invalid(): def test_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': "OAUTHBEARER"} - with pytest.raises(ValueError, match=r"Missing required bearer configuration properties: (.*)"): + with pytest.raises(ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster"): SchemaRegistryClient(conf) From da6ece1e102528bd6514315d7942fefa477fdad3 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Mon, 2 Feb 2026 17:58:18 -0500 Subject: [PATCH 04/11] Unasync run --- .../schema_registry/_sync/avro.py | 13 ++++++-- .../schema_registry/_sync/json_schema.py | 11 +++++-- .../schema_registry/_sync/protobuf.py | 11 +++++-- .../_sync/schema_registry_client.py | 31 ++++++++++++------- .../_sync/test_json_serializers.py | 2 +- .../schema_registry/_sync/test_api_client.py | 6 ++-- tests/schema_registry/_sync/test_avro.py | 8 +++-- .../schema_registry/_sync/test_avro_serdes.py | 2 +- .../_sync/test_bearer_field_provider.py | 16 +++++++--- tests/schema_registry/_sync/test_json.py | 2 +- .../schema_registry/_sync/test_json_serdes.py | 2 +- 11 files changed, 72 insertions(+), 32 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 289367265..23467f322 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -21,13 +21,14 @@ from fastavro import schemaless_reader, schemaless_writer from confluent_kafka.schema_registry import ( + SchemaRegistryClient, RuleMode, Schema, - SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) + from confluent_kafka.schema_registry.common.avro import ( AVRO_TYPE, AvroSchema, @@ -54,7 +55,9 @@ ] -def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistryClient) -> Dict[str, AvroSchema]: +def _resolve_named_schema( + schema: Schema, schema_registry_client: SchemaRegistryClient +) -> Dict[str, AvroSchema]: """ Resolves named schemas referenced by the provided schema recursively. :param schema: Schema to resolve named schemas for. @@ -79,6 +82,7 @@ def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistry return named_schemas + class AvroSerializer(BaseSerializer): """ Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. @@ -458,6 +462,7 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema + class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry @@ -607,7 +612,9 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __call__( + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None + ) -> Union[dict, object, None]: return self.__deserialize(data, ctx) def __deserialize( diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index bc6fd4d4a..2e3a6046f 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -14,9 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading as _locks import io import logging -import threading as _locks from typing import Any, Callable, Optional, Tuple, Union, cast import orjson @@ -27,13 +27,14 @@ from referencing import Registry, Resource from confluent_kafka.schema_registry import ( + SchemaRegistryClient, RuleMode, Schema, - SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) + from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JSON_TYPE, @@ -87,6 +88,7 @@ def _resolve_named_schema( return ref_registry + class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -457,6 +459,7 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re return validator + class JSONDeserializer(BaseDeserializer): """ Deserializer for JSON encoded data with Confluent Schema Registry @@ -618,7 +621,9 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__( + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None + ) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 5e1620276..24984cd99 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -33,6 +33,7 @@ reference_subject_name_strategy, topic_subject_name_strategy, ) + from confluent_kafka.schema_registry.common.protobuf import ( PROTOBUF_TYPE, _bytes, @@ -97,6 +98,7 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) + class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -359,7 +361,9 @@ def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): for value in ints: ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - def _resolve_dependencies(self, ctx: SerializationContext, file_desc: FileDescriptor) -> List[SchemaReference]: + def _resolve_dependencies( + self, ctx: SerializationContext, file_desc: FileDescriptor + ) -> List[SchemaReference]: """ Resolves and optionally registers schema references recursively. @@ -485,6 +489,7 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool + class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. @@ -593,7 +598,9 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__( + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None + ) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 9fd6832f7..96d1bfae5 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -16,11 +16,11 @@ # limitations under the License. # +import threading as _locks import json import logging import os import ssl -import threading as _locks import time import urllib from typing import Any, Callable, Dict, List, Literal, Optional, Union @@ -39,8 +39,8 @@ SchemaVersion, ServerConfig, _BearerFieldProvider, - _SchemaCache, _StaticFieldProvider, + _SchemaCache, full_jitter, is_retriable, is_success, @@ -294,9 +294,6 @@ def __init__(self, conf: dict): if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - # Identity pool is optional and may be a single pool ID or comma-separated list of pool IDs. - # The value is passed as-is in the Confluent-Identity-Pool-Id header to Schema Registry. - # When union-of-pools is enabled on the server, multiple IDs are interpreted as a union. identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) if identity_pool is not None and not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) @@ -448,7 +445,9 @@ def delete(self, url: str) -> Any: def put(self, url: str, body: Optional[dict] = None) -> Any: return self.send_request(url, method='PUT', body=body) - def send_request(self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None) -> Any: + def send_request( + self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None + ) -> Any: """ Sends HTTP request to the SchemaRegistry, trying each base URL in turn. @@ -953,7 +952,9 @@ def lookup_schema( query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) - response = self._rest_client.post('subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request) + response = self._rest_client.post( + 'subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request + ) result = RegisteredSchema.from_dict(response) @@ -1055,7 +1056,9 @@ def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'R return registered_schema query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -1138,7 +1141,9 @@ def get_version( return registered_schema query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -1225,7 +1230,9 @@ def delete_version(self, subject_name: str, version: int, permanent: bool = Fals 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) ) else: - response = self._rest_client.delete('subjects/{}/versions/{}'.format(_urlencode(subject_name), version)) + response = self._rest_client.delete( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) + ) # Clear cache for both soft and hard deletes to maintain consistency self._cache.remove_by_subject_version(subject_name, version) @@ -1353,7 +1360,9 @@ def test_compatibility_all_versions( ) return response['is_compatible'] - def set_config(self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None) -> 'ServerConfig': + def set_config( + self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None + ) -> 'ServerConfig': """ Update global or subject config. diff --git a/tests/integration/schema_registry/_sync/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py index db27ce5aa..95f9b8a44 100644 --- a/tests/integration/schema_registry/_sync/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition from confluent_kafka.error import ConsumeError, ValueSerializationError -from confluent_kafka.schema_registry import Schema, SchemaReference, SchemaRegistryClient +from confluent_kafka.schema_registry import SchemaRegistryClient, Schema, SchemaReference from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer diff --git a/tests/schema_registry/_sync/test_api_client.py b/tests/schema_registry/_sync/test_api_client.py index 821854c37..760f585aa 100644 --- a/tests/schema_registry/_sync/test_api_client.py +++ b/tests/schema_registry/_sync/test_api_client.py @@ -22,7 +22,7 @@ from confluent_kafka.schema_registry.common.schema_registry_client import SchemaVersion from confluent_kafka.schema_registry.error import SchemaRegistryError -from confluent_kafka.schema_registry.schema_registry_client import Schema, SchemaRegistryClient +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient, Schema from tests.schema_registry.conftest import COUNTER, SCHEMA, SCHEMA_ID, SUBJECTS, USERINFO, VERSION, VERSIONS """ @@ -459,7 +459,9 @@ def test_schema_equivilence(load_avsc): ('test-key', 1, True), ], ) -def test_test_compatibility_no_error(mock_schema_registry, load_avsc, subject_name, version, expected_compatibility): +def test_test_compatibility_no_error( + mock_schema_registry, load_avsc, subject_name, version, expected_compatibility +): conf = {'url': TEST_URL} sr = SchemaRegistryClient(conf) schema = Schema(load_avsc('basic_schema.avsc'), schema_type='AVRO') diff --git a/tests/schema_registry/_sync/test_avro.py b/tests/schema_registry/_sync/test_avro.py index 7a9872285..8fb415ff5 100644 --- a/tests/schema_registry/_sync/test_avro.py +++ b/tests/schema_registry/_sync/test_avro.py @@ -106,7 +106,9 @@ def test_avro_serializer_config_subject_name_strategy(): conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) + test_serializer = AvroSerializer( + test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} + ) assert test_serializer._subject_name_func is record_subject_name_strategy @@ -145,7 +147,9 @@ def test_avro_serializer_record_subject_name_strategy_primitive(load_avsc): """ conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) + test_serializer = AvroSerializer( + test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} + ) ctx = SerializationContext( 'test_subj', diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index 0c8ca7977..f027176ee 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -23,10 +23,10 @@ from fastavro._logical_readers import UUID from confluent_kafka.schema_registry import ( + SchemaRegistryClient, Metadata, MetadataProperties, Schema, - SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.avro import AvroDeserializer, AvroSerializer diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index 5f72d03e3..8dcf20bf0 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -275,9 +275,15 @@ def test_static_token_comma_separated_pools(): def test_static_field_provider_optional_pool(): """Test that _StaticFieldProvider works with optional identity pool.""" - static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) - bearer_fields = static_field_provider.get_bearer_fields() + from confluent_kafka.schema_registry.common.schema_registry_client import _StaticFieldProvider + + + def check_provider(): + static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) + bearer_fields = static_field_provider.get_bearer_fields() + + assert bearer_fields['bearer.auth.token'] == TEST_TOKEN + assert bearer_fields['bearer.auth.logical.cluster'] == TEST_CLUSTER + assert 'bearer.auth.identity.pool.id' not in bearer_fields - assert bearer_fields['bearer.auth.token'] == TEST_TOKEN - assert bearer_fields['bearer.auth.logical.cluster'] == TEST_CLUSTER - assert 'bearer.auth.identity.pool.id' not in bearer_fields + check_provider() diff --git a/tests/schema_registry/_sync/test_json.py b/tests/schema_registry/_sync/test_json.py index 7d3d912e1..b9aec592a 100644 --- a/tests/schema_registry/_sync/test_json.py +++ b/tests/schema_registry/_sync/test_json.py @@ -23,10 +23,10 @@ import pytest from confluent_kafka.schema_registry import ( + SchemaRegistryClient, RegisteredSchema, Schema, SchemaReference, - SchemaRegistryClient, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer from confluent_kafka.schema_registry.rule_registry import RuleRegistry diff --git a/tests/schema_registry/_sync/test_json_serdes.py b/tests/schema_registry/_sync/test_json_serdes.py index 83de26355..66540d286 100644 --- a/tests/schema_registry/_sync/test_json_serdes.py +++ b/tests/schema_registry/_sync/test_json_serdes.py @@ -21,10 +21,10 @@ import pytest from confluent_kafka.schema_registry import ( + SchemaRegistryClient, Metadata, MetadataProperties, Schema, - SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer From 2f45d4903d7cb839556f0f000b60f7934bbe0753 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 12:33:36 -0500 Subject: [PATCH 05/11] restoring --- .../schema_registry/_sync/avro.py | 13 +---- .../schema_registry/_sync/json_schema.py | 11 +--- .../schema_registry/_sync/protobuf.py | 11 +--- .../_sync/schema_registry_client.py | 55 +++++++------------ .../schema_registry/_sync/test_api_client.py | 6 +- .../schema_registry/_sync/test_avro_serdes.py | 2 +- tests/schema_registry/_sync/test_json.py | 2 +- .../schema_registry/_sync/test_json_serdes.py | 2 +- 8 files changed, 33 insertions(+), 69 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 23467f322..289367265 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -21,14 +21,13 @@ from fastavro import schemaless_reader, schemaless_writer from confluent_kafka.schema_registry import ( - SchemaRegistryClient, RuleMode, Schema, + SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) - from confluent_kafka.schema_registry.common.avro import ( AVRO_TYPE, AvroSchema, @@ -55,9 +54,7 @@ ] -def _resolve_named_schema( - schema: Schema, schema_registry_client: SchemaRegistryClient -) -> Dict[str, AvroSchema]: +def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistryClient) -> Dict[str, AvroSchema]: """ Resolves named schemas referenced by the provided schema recursively. :param schema: Schema to resolve named schemas for. @@ -82,7 +79,6 @@ def _resolve_named_schema( return named_schemas - class AvroSerializer(BaseSerializer): """ Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. @@ -462,7 +458,6 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema - class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry @@ -612,9 +607,7 @@ def __init_impl( __init__ = __init_impl - def __call__( - self, data: Optional[bytes], ctx: Optional[SerializationContext] = None - ) -> Union[dict, object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) def __deserialize( diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 2e3a6046f..bc6fd4d4a 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -14,9 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import threading as _locks import io import logging +import threading as _locks from typing import Any, Callable, Optional, Tuple, Union, cast import orjson @@ -27,14 +27,13 @@ from referencing import Registry, Resource from confluent_kafka.schema_registry import ( - SchemaRegistryClient, RuleMode, Schema, + SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) - from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JSON_TYPE, @@ -88,7 +87,6 @@ def _resolve_named_schema( return ref_registry - class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -459,7 +457,6 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re return validator - class JSONDeserializer(BaseDeserializer): """ Deserializer for JSON encoded data with Confluent Schema Registry @@ -621,9 +618,7 @@ def __init_impl( __init__ = __init_impl - def __call__( - self, data: Optional[bytes], ctx: Optional[SerializationContext] = None - ) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 24984cd99..5e1620276 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -33,7 +33,6 @@ reference_subject_name_strategy, topic_subject_name_strategy, ) - from confluent_kafka.schema_registry.common.protobuf import ( PROTOBUF_TYPE, _bytes, @@ -98,7 +97,6 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) - class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -361,9 +359,7 @@ def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): for value in ints: ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - def _resolve_dependencies( - self, ctx: SerializationContext, file_desc: FileDescriptor - ) -> List[SchemaReference]: + def _resolve_dependencies(self, ctx: SerializationContext, file_desc: FileDescriptor) -> List[SchemaReference]: """ Resolves and optionally registers schema references recursively. @@ -489,7 +485,6 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool - class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. @@ -598,9 +593,7 @@ def __init_impl( __init__ = __init_impl - def __call__( - self, data: Optional[bytes], ctx: Optional[SerializationContext] = None - ) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 96d1bfae5..74435e108 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -16,11 +16,11 @@ # limitations under the License. # -import threading as _locks import json import logging import os import ssl +import threading as _locks import time import urllib from typing import Any, Callable, Dict, List, Literal, Optional, Union @@ -39,8 +39,8 @@ SchemaVersion, ServerConfig, _BearerFieldProvider, - _StaticFieldProvider, _SchemaCache, + _StaticFieldProvider, full_jitter, is_retriable, is_success, @@ -97,10 +97,10 @@ def __init__( scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int, - identity_pool: Optional[str] = None, ): self.token = None self.logical_cluster = logical_cluster @@ -113,13 +113,11 @@ def __init__( self.token_expiry_threshold = 0.8 def get_bearer_fields(self) -> dict: - fields = { + return { 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool, } - if self.identity_pool is not None: - fields['bearer.auth.identity.pool.id'] = self.identity_pool - return fields def token_expired(self) -> bool: if self.token is None: @@ -285,17 +283,19 @@ def __init__(self, conf: dict): self.auth = None if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - if 'bearer.auth.logical.cluster' not in conf_copy: + headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] + missing_headers = [header for header in headers if header not in conf_copy] + if missing_headers: raise ValueError( - "Missing required bearer configuration property: bearer.auth.logical.cluster" + "Missing required bearer configuration properties: {}".format(", ".join(missing_headers)) ) logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) - if identity_pool is not None and not isinstance(identity_pool, str): + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') + if not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) if self.bearer_auth_credentials_source == 'OAUTHBEARER': @@ -335,10 +335,10 @@ def __init__(self, conf: dict): self.scope, self.token_endpoint, logical_cluster, + identity_pool, self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms, - identity_pool, ) else: # STATIC_TOKEN if 'bearer.auth.token' not in conf_copy: @@ -412,8 +412,7 @@ def handle_bearer_auth(self, headers: dict) -> None: if self.bearer_field_provider is None: raise ValueError("Bearer field provider is not set") bearer_fields = self.bearer_field_provider.get_bearer_fields() - # Note: bearer.auth.identity.pool.id is optional; only token and logical.cluster are required - required_fields = ['bearer.auth.token', 'bearer.auth.logical.cluster'] + required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] missing_fields = [] for field in required_fields: @@ -428,11 +427,9 @@ def handle_bearer_auth(self, headers: dict) -> None: ) headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] - if 'bearer.auth.identity.pool.id' in bearer_fields: - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] - def get(self, url: str, query: Optional[dict] = None) -> Any: return self.send_request(url, method='GET', query=query) @@ -445,9 +442,7 @@ def delete(self, url: str) -> Any: def put(self, url: str, body: Optional[dict] = None) -> Any: return self.send_request(url, method='PUT', body=body) - def send_request( - self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None - ) -> Any: + def send_request(self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None) -> Any: """ Sends HTTP request to the SchemaRegistry, trying each base URL in turn. @@ -952,9 +947,7 @@ def lookup_schema( query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) - response = self._rest_client.post( - 'subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request - ) + response = self._rest_client.post('subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request) result = RegisteredSchema.from_dict(response) @@ -1056,9 +1049,7 @@ def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'R return registered_schema query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get( - 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query - ) + response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query) registered_schema = RegisteredSchema.from_dict(response) @@ -1141,9 +1132,7 @@ def get_version( return registered_schema query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get( - 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query - ) + response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query) registered_schema = RegisteredSchema.from_dict(response) @@ -1230,9 +1219,7 @@ def delete_version(self, subject_name: str, version: int, permanent: bool = Fals 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) ) else: - response = self._rest_client.delete( - 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) - ) + response = self._rest_client.delete('subjects/{}/versions/{}'.format(_urlencode(subject_name), version)) # Clear cache for both soft and hard deletes to maintain consistency self._cache.remove_by_subject_version(subject_name, version) @@ -1360,9 +1347,7 @@ def test_compatibility_all_versions( ) return response['is_compatible'] - def set_config( - self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None - ) -> 'ServerConfig': + def set_config(self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None) -> 'ServerConfig': """ Update global or subject config. diff --git a/tests/schema_registry/_sync/test_api_client.py b/tests/schema_registry/_sync/test_api_client.py index 760f585aa..821854c37 100644 --- a/tests/schema_registry/_sync/test_api_client.py +++ b/tests/schema_registry/_sync/test_api_client.py @@ -22,7 +22,7 @@ from confluent_kafka.schema_registry.common.schema_registry_client import SchemaVersion from confluent_kafka.schema_registry.error import SchemaRegistryError -from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient, Schema +from confluent_kafka.schema_registry.schema_registry_client import Schema, SchemaRegistryClient from tests.schema_registry.conftest import COUNTER, SCHEMA, SCHEMA_ID, SUBJECTS, USERINFO, VERSION, VERSIONS """ @@ -459,9 +459,7 @@ def test_schema_equivilence(load_avsc): ('test-key', 1, True), ], ) -def test_test_compatibility_no_error( - mock_schema_registry, load_avsc, subject_name, version, expected_compatibility -): +def test_test_compatibility_no_error(mock_schema_registry, load_avsc, subject_name, version, expected_compatibility): conf = {'url': TEST_URL} sr = SchemaRegistryClient(conf) schema = Schema(load_avsc('basic_schema.avsc'), schema_type='AVRO') diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index f027176ee..0c8ca7977 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -23,10 +23,10 @@ from fastavro._logical_readers import UUID from confluent_kafka.schema_registry import ( - SchemaRegistryClient, Metadata, MetadataProperties, Schema, + SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.avro import AvroDeserializer, AvroSerializer diff --git a/tests/schema_registry/_sync/test_json.py b/tests/schema_registry/_sync/test_json.py index b9aec592a..7d3d912e1 100644 --- a/tests/schema_registry/_sync/test_json.py +++ b/tests/schema_registry/_sync/test_json.py @@ -23,10 +23,10 @@ import pytest from confluent_kafka.schema_registry import ( - SchemaRegistryClient, RegisteredSchema, Schema, SchemaReference, + SchemaRegistryClient, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer from confluent_kafka.schema_registry.rule_registry import RuleRegistry diff --git a/tests/schema_registry/_sync/test_json_serdes.py b/tests/schema_registry/_sync/test_json_serdes.py index 66540d286..83de26355 100644 --- a/tests/schema_registry/_sync/test_json_serdes.py +++ b/tests/schema_registry/_sync/test_json_serdes.py @@ -21,10 +21,10 @@ import pytest from confluent_kafka.schema_registry import ( - SchemaRegistryClient, Metadata, MetadataProperties, Schema, + SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer From 001f1fae81d7ec467e4693dbddfd30214553e821 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 12:38:37 -0500 Subject: [PATCH 06/11] unasync --- .../_sync/schema_registry_client.py | 55 ++++++++++++------- tests/schema_registry/_sync/test_avro.py | 8 +-- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 74435e108..96d1bfae5 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -16,11 +16,11 @@ # limitations under the License. # +import threading as _locks import json import logging import os import ssl -import threading as _locks import time import urllib from typing import Any, Callable, Dict, List, Literal, Optional, Union @@ -39,8 +39,8 @@ SchemaVersion, ServerConfig, _BearerFieldProvider, - _SchemaCache, _StaticFieldProvider, + _SchemaCache, full_jitter, is_retriable, is_success, @@ -97,10 +97,10 @@ def __init__( scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int, + identity_pool: Optional[str] = None, ): self.token = None self.logical_cluster = logical_cluster @@ -113,11 +113,13 @@ def __init__( self.token_expiry_threshold = 0.8 def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def token_expired(self) -> bool: if self.token is None: @@ -283,19 +285,17 @@ def __init__(self, conf: dict): self.auth = None if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: + if 'bearer.auth.logical.cluster' not in conf_copy: raise ValueError( - "Missing required bearer configuration properties: {}".format(", ".join(missing_headers)) + "Missing required bearer configuration property: bearer.auth.logical.cluster" ) logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) + if identity_pool is not None and not isinstance(identity_pool, str): raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) if self.bearer_auth_credentials_source == 'OAUTHBEARER': @@ -335,10 +335,10 @@ def __init__(self, conf: dict): self.scope, self.token_endpoint, logical_cluster, - identity_pool, self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms, + identity_pool, ) else: # STATIC_TOKEN if 'bearer.auth.token' not in conf_copy: @@ -412,7 +412,8 @@ def handle_bearer_auth(self, headers: dict) -> None: if self.bearer_field_provider is None: raise ValueError("Bearer field provider is not set") bearer_fields = self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + # Note: bearer.auth.identity.pool.id is optional; only token and logical.cluster are required + required_fields = ['bearer.auth.token', 'bearer.auth.logical.cluster'] missing_fields = [] for field in required_fields: @@ -427,9 +428,11 @@ def handle_bearer_auth(self, headers: dict) -> None: ) headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + if 'bearer.auth.identity.pool.id' in bearer_fields: + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + def get(self, url: str, query: Optional[dict] = None) -> Any: return self.send_request(url, method='GET', query=query) @@ -442,7 +445,9 @@ def delete(self, url: str) -> Any: def put(self, url: str, body: Optional[dict] = None) -> Any: return self.send_request(url, method='PUT', body=body) - def send_request(self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None) -> Any: + def send_request( + self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None + ) -> Any: """ Sends HTTP request to the SchemaRegistry, trying each base URL in turn. @@ -947,7 +952,9 @@ def lookup_schema( query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) - response = self._rest_client.post('subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request) + response = self._rest_client.post( + 'subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request + ) result = RegisteredSchema.from_dict(response) @@ -1049,7 +1056,9 @@ def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'R return registered_schema query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -1132,7 +1141,9 @@ def get_version( return registered_schema query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -1219,7 +1230,9 @@ def delete_version(self, subject_name: str, version: int, permanent: bool = Fals 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) ) else: - response = self._rest_client.delete('subjects/{}/versions/{}'.format(_urlencode(subject_name), version)) + response = self._rest_client.delete( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) + ) # Clear cache for both soft and hard deletes to maintain consistency self._cache.remove_by_subject_version(subject_name, version) @@ -1347,7 +1360,9 @@ def test_compatibility_all_versions( ) return response['is_compatible'] - def set_config(self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None) -> 'ServerConfig': + def set_config( + self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None + ) -> 'ServerConfig': """ Update global or subject config. diff --git a/tests/schema_registry/_sync/test_avro.py b/tests/schema_registry/_sync/test_avro.py index 8fb415ff5..7a9872285 100644 --- a/tests/schema_registry/_sync/test_avro.py +++ b/tests/schema_registry/_sync/test_avro.py @@ -106,9 +106,7 @@ def test_avro_serializer_config_subject_name_strategy(): conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer( - test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} - ) + test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) assert test_serializer._subject_name_func is record_subject_name_strategy @@ -147,9 +145,7 @@ def test_avro_serializer_record_subject_name_strategy_primitive(load_avsc): """ conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer( - test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} - ) + test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) ctx = SerializationContext( 'test_subj', From 1e86713d212d5ef27c4363d90b3da74a6729e326 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 14:41:18 -0500 Subject: [PATCH 07/11] undo unrelated unasync build --- .../integration/schema_registry/_sync/test_json_serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/schema_registry/_sync/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py index 95f9b8a44..db27ce5aa 100644 --- a/tests/integration/schema_registry/_sync/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition from confluent_kafka.error import ConsumeError, ValueSerializationError -from confluent_kafka.schema_registry import SchemaRegistryClient, Schema, SchemaReference +from confluent_kafka.schema_registry import Schema, SchemaReference, SchemaRegistryClient from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer From 2f5b75a969b32e06418acec9c475fc38985494c2 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 15:51:58 -0500 Subject: [PATCH 08/11] unasync run again --- src/confluent_kafka/schema_registry/_sync/avro.py | 13 ++++++++++--- .../schema_registry/_sync/json_schema.py | 11 ++++++++--- .../schema_registry/_sync/protobuf.py | 11 +++++++++-- .../schema_registry/_sync/test_json_serializers.py | 2 +- tests/schema_registry/_sync/test_api_client.py | 6 ++++-- tests/schema_registry/_sync/test_avro.py | 8 ++++++-- tests/schema_registry/_sync/test_avro_serdes.py | 2 +- tests/schema_registry/_sync/test_json.py | 2 +- tests/schema_registry/_sync/test_json_serdes.py | 2 +- 9 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 289367265..23467f322 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -21,13 +21,14 @@ from fastavro import schemaless_reader, schemaless_writer from confluent_kafka.schema_registry import ( + SchemaRegistryClient, RuleMode, Schema, - SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) + from confluent_kafka.schema_registry.common.avro import ( AVRO_TYPE, AvroSchema, @@ -54,7 +55,9 @@ ] -def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistryClient) -> Dict[str, AvroSchema]: +def _resolve_named_schema( + schema: Schema, schema_registry_client: SchemaRegistryClient +) -> Dict[str, AvroSchema]: """ Resolves named schemas referenced by the provided schema recursively. :param schema: Schema to resolve named schemas for. @@ -79,6 +82,7 @@ def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistry return named_schemas + class AvroSerializer(BaseSerializer): """ Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. @@ -458,6 +462,7 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema + class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry @@ -607,7 +612,9 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __call__( + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None + ) -> Union[dict, object, None]: return self.__deserialize(data, ctx) def __deserialize( diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index bc6fd4d4a..2e3a6046f 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -14,9 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading as _locks import io import logging -import threading as _locks from typing import Any, Callable, Optional, Tuple, Union, cast import orjson @@ -27,13 +27,14 @@ from referencing import Registry, Resource from confluent_kafka.schema_registry import ( + SchemaRegistryClient, RuleMode, Schema, - SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) + from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JSON_TYPE, @@ -87,6 +88,7 @@ def _resolve_named_schema( return ref_registry + class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -457,6 +459,7 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re return validator + class JSONDeserializer(BaseDeserializer): """ Deserializer for JSON encoded data with Confluent Schema Registry @@ -618,7 +621,9 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__( + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None + ) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 5e1620276..24984cd99 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -33,6 +33,7 @@ reference_subject_name_strategy, topic_subject_name_strategy, ) + from confluent_kafka.schema_registry.common.protobuf import ( PROTOBUF_TYPE, _bytes, @@ -97,6 +98,7 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) + class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -359,7 +361,9 @@ def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): for value in ints: ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - def _resolve_dependencies(self, ctx: SerializationContext, file_desc: FileDescriptor) -> List[SchemaReference]: + def _resolve_dependencies( + self, ctx: SerializationContext, file_desc: FileDescriptor + ) -> List[SchemaReference]: """ Resolves and optionally registers schema references recursively. @@ -485,6 +489,7 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool + class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. @@ -593,7 +598,9 @@ def __init_impl( __init__ = __init_impl - def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __call__( + self, data: Optional[bytes], ctx: Optional[SerializationContext] = None + ) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/tests/integration/schema_registry/_sync/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py index db27ce5aa..95f9b8a44 100644 --- a/tests/integration/schema_registry/_sync/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition from confluent_kafka.error import ConsumeError, ValueSerializationError -from confluent_kafka.schema_registry import Schema, SchemaReference, SchemaRegistryClient +from confluent_kafka.schema_registry import SchemaRegistryClient, Schema, SchemaReference from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer diff --git a/tests/schema_registry/_sync/test_api_client.py b/tests/schema_registry/_sync/test_api_client.py index 821854c37..760f585aa 100644 --- a/tests/schema_registry/_sync/test_api_client.py +++ b/tests/schema_registry/_sync/test_api_client.py @@ -22,7 +22,7 @@ from confluent_kafka.schema_registry.common.schema_registry_client import SchemaVersion from confluent_kafka.schema_registry.error import SchemaRegistryError -from confluent_kafka.schema_registry.schema_registry_client import Schema, SchemaRegistryClient +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient, Schema from tests.schema_registry.conftest import COUNTER, SCHEMA, SCHEMA_ID, SUBJECTS, USERINFO, VERSION, VERSIONS """ @@ -459,7 +459,9 @@ def test_schema_equivilence(load_avsc): ('test-key', 1, True), ], ) -def test_test_compatibility_no_error(mock_schema_registry, load_avsc, subject_name, version, expected_compatibility): +def test_test_compatibility_no_error( + mock_schema_registry, load_avsc, subject_name, version, expected_compatibility +): conf = {'url': TEST_URL} sr = SchemaRegistryClient(conf) schema = Schema(load_avsc('basic_schema.avsc'), schema_type='AVRO') diff --git a/tests/schema_registry/_sync/test_avro.py b/tests/schema_registry/_sync/test_avro.py index 7a9872285..8fb415ff5 100644 --- a/tests/schema_registry/_sync/test_avro.py +++ b/tests/schema_registry/_sync/test_avro.py @@ -106,7 +106,9 @@ def test_avro_serializer_config_subject_name_strategy(): conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) + test_serializer = AvroSerializer( + test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} + ) assert test_serializer._subject_name_func is record_subject_name_strategy @@ -145,7 +147,9 @@ def test_avro_serializer_record_subject_name_strategy_primitive(load_avsc): """ conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) + test_serializer = AvroSerializer( + test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} + ) ctx = SerializationContext( 'test_subj', diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index 0c8ca7977..f027176ee 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -23,10 +23,10 @@ from fastavro._logical_readers import UUID from confluent_kafka.schema_registry import ( + SchemaRegistryClient, Metadata, MetadataProperties, Schema, - SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.avro import AvroDeserializer, AvroSerializer diff --git a/tests/schema_registry/_sync/test_json.py b/tests/schema_registry/_sync/test_json.py index 7d3d912e1..b9aec592a 100644 --- a/tests/schema_registry/_sync/test_json.py +++ b/tests/schema_registry/_sync/test_json.py @@ -23,10 +23,10 @@ import pytest from confluent_kafka.schema_registry import ( + SchemaRegistryClient, RegisteredSchema, Schema, SchemaReference, - SchemaRegistryClient, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer from confluent_kafka.schema_registry.rule_registry import RuleRegistry diff --git a/tests/schema_registry/_sync/test_json_serdes.py b/tests/schema_registry/_sync/test_json_serdes.py index 83de26355..66540d286 100644 --- a/tests/schema_registry/_sync/test_json_serdes.py +++ b/tests/schema_registry/_sync/test_json_serdes.py @@ -21,10 +21,10 @@ import pytest from confluent_kafka.schema_registry import ( + SchemaRegistryClient, Metadata, MetadataProperties, Schema, - SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer From ad173f9ee1171888c0e1fc3ddb07548ce1e36889 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 16:21:12 -0500 Subject: [PATCH 09/11] unasync --- .../schema_registry/_sync/avro.py | 13 ++------ .../schema_registry/_sync/json_schema.py | 11 ++----- .../schema_registry/_sync/protobuf.py | 11 ++----- .../_sync/schema_registry_client.py | 32 ++++++------------- .../_sync/test_json_serializers.py | 2 +- .../schema_registry/_sync/test_api_client.py | 6 ++-- tests/schema_registry/_sync/test_avro.py | 8 ++--- .../schema_registry/_sync/test_avro_serdes.py | 2 +- .../_sync/test_bearer_field_provider.py | 1 - tests/schema_registry/_sync/test_config.py | 4 ++- tests/schema_registry/_sync/test_json.py | 2 +- .../schema_registry/_sync/test_json_serdes.py | 2 +- 12 files changed, 28 insertions(+), 66 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 23467f322..289367265 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -21,14 +21,13 @@ from fastavro import schemaless_reader, schemaless_writer from confluent_kafka.schema_registry import ( - SchemaRegistryClient, RuleMode, Schema, + SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) - from confluent_kafka.schema_registry.common.avro import ( AVRO_TYPE, AvroSchema, @@ -55,9 +54,7 @@ ] -def _resolve_named_schema( - schema: Schema, schema_registry_client: SchemaRegistryClient -) -> Dict[str, AvroSchema]: +def _resolve_named_schema(schema: Schema, schema_registry_client: SchemaRegistryClient) -> Dict[str, AvroSchema]: """ Resolves named schemas referenced by the provided schema recursively. :param schema: Schema to resolve named schemas for. @@ -82,7 +79,6 @@ def _resolve_named_schema( return named_schemas - class AvroSerializer(BaseSerializer): """ Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. @@ -462,7 +458,6 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema - class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry @@ -612,9 +607,7 @@ def __init_impl( __init__ = __init_impl - def __call__( - self, data: Optional[bytes], ctx: Optional[SerializationContext] = None - ) -> Union[dict, object, None]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) def __deserialize( diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 2e3a6046f..bc6fd4d4a 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -14,9 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import threading as _locks import io import logging +import threading as _locks from typing import Any, Callable, Optional, Tuple, Union, cast import orjson @@ -27,14 +27,13 @@ from referencing import Registry, Resource from confluent_kafka.schema_registry import ( - SchemaRegistryClient, RuleMode, Schema, + SchemaRegistryClient, dual_schema_id_deserializer, prefix_schema_id_serializer, topic_subject_name_strategy, ) - from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JSON_TYPE, @@ -88,7 +87,6 @@ def _resolve_named_schema( return ref_registry - class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -459,7 +457,6 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re return validator - class JSONDeserializer(BaseDeserializer): """ Deserializer for JSON encoded data with Confluent Schema Registry @@ -621,9 +618,7 @@ def __init_impl( __init__ = __init_impl - def __call__( - self, data: Optional[bytes], ctx: Optional[SerializationContext] = None - ) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 24984cd99..5e1620276 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -33,7 +33,6 @@ reference_subject_name_strategy, topic_subject_name_strategy, ) - from confluent_kafka.schema_registry.common.protobuf import ( PROTOBUF_TYPE, _bytes, @@ -98,7 +97,6 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) - class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -361,9 +359,7 @@ def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): for value in ints: ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - def _resolve_dependencies( - self, ctx: SerializationContext, file_desc: FileDescriptor - ) -> List[SchemaReference]: + def _resolve_dependencies(self, ctx: SerializationContext, file_desc: FileDescriptor) -> List[SchemaReference]: """ Resolves and optionally registers schema references recursively. @@ -489,7 +485,6 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool - class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. @@ -598,9 +593,7 @@ def __init_impl( __init__ = __init_impl - def __call__( - self, data: Optional[bytes], ctx: Optional[SerializationContext] = None - ) -> Optional[bytes]: + def __call__(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__deserialize(data, ctx) def __deserialize(self, data: Optional[bytes], ctx: Optional[SerializationContext] = None) -> Optional[bytes]: diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 96d1bfae5..358e2e684 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -16,11 +16,11 @@ # limitations under the License. # -import threading as _locks import json import logging import os import ssl +import threading as _locks import time import urllib from typing import Any, Callable, Dict, List, Literal, Optional, Union @@ -39,8 +39,8 @@ SchemaVersion, ServerConfig, _BearerFieldProvider, - _StaticFieldProvider, _SchemaCache, + _StaticFieldProvider, full_jitter, is_retriable, is_success, @@ -286,9 +286,7 @@ def __init__(self, conf: dict): if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: if 'bearer.auth.logical.cluster' not in conf_copy: - raise ValueError( - "Missing required bearer configuration property: bearer.auth.logical.cluster" - ) + raise ValueError("Missing required bearer configuration property: bearer.auth.logical.cluster") logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): @@ -445,9 +443,7 @@ def delete(self, url: str) -> Any: def put(self, url: str, body: Optional[dict] = None) -> Any: return self.send_request(url, method='PUT', body=body) - def send_request( - self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None - ) -> Any: + def send_request(self, url: str, method: str, body: Optional[dict] = None, query: Optional[dict] = None) -> Any: """ Sends HTTP request to the SchemaRegistry, trying each base URL in turn. @@ -952,9 +948,7 @@ def lookup_schema( query_string = '&'.join(f"{key}={value}" for key, value in query_params.items()) - response = self._rest_client.post( - 'subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request - ) + response = self._rest_client.post('subjects/{}?{}'.format(_urlencode(subject_name), query_string), body=request) result = RegisteredSchema.from_dict(response) @@ -1056,9 +1050,7 @@ def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'R return registered_schema query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get( - 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query - ) + response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query) registered_schema = RegisteredSchema.from_dict(response) @@ -1141,9 +1133,7 @@ def get_version( return registered_schema query: dict[str, Any] = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get( - 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query - ) + response = self._rest_client.get('subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query) registered_schema = RegisteredSchema.from_dict(response) @@ -1230,9 +1220,7 @@ def delete_version(self, subject_name: str, version: int, permanent: bool = Fals 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) ) else: - response = self._rest_client.delete( - 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) - ) + response = self._rest_client.delete('subjects/{}/versions/{}'.format(_urlencode(subject_name), version)) # Clear cache for both soft and hard deletes to maintain consistency self._cache.remove_by_subject_version(subject_name, version) @@ -1360,9 +1348,7 @@ def test_compatibility_all_versions( ) return response['is_compatible'] - def set_config( - self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None - ) -> 'ServerConfig': + def set_config(self, subject_name: Optional[str] = None, config: Optional['ServerConfig'] = None) -> 'ServerConfig': """ Update global or subject config. diff --git a/tests/integration/schema_registry/_sync/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py index 95f9b8a44..db27ce5aa 100644 --- a/tests/integration/schema_registry/_sync/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition from confluent_kafka.error import ConsumeError, ValueSerializationError -from confluent_kafka.schema_registry import SchemaRegistryClient, Schema, SchemaReference +from confluent_kafka.schema_registry import Schema, SchemaReference, SchemaRegistryClient from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer diff --git a/tests/schema_registry/_sync/test_api_client.py b/tests/schema_registry/_sync/test_api_client.py index 760f585aa..821854c37 100644 --- a/tests/schema_registry/_sync/test_api_client.py +++ b/tests/schema_registry/_sync/test_api_client.py @@ -22,7 +22,7 @@ from confluent_kafka.schema_registry.common.schema_registry_client import SchemaVersion from confluent_kafka.schema_registry.error import SchemaRegistryError -from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient, Schema +from confluent_kafka.schema_registry.schema_registry_client import Schema, SchemaRegistryClient from tests.schema_registry.conftest import COUNTER, SCHEMA, SCHEMA_ID, SUBJECTS, USERINFO, VERSION, VERSIONS """ @@ -459,9 +459,7 @@ def test_schema_equivilence(load_avsc): ('test-key', 1, True), ], ) -def test_test_compatibility_no_error( - mock_schema_registry, load_avsc, subject_name, version, expected_compatibility -): +def test_test_compatibility_no_error(mock_schema_registry, load_avsc, subject_name, version, expected_compatibility): conf = {'url': TEST_URL} sr = SchemaRegistryClient(conf) schema = Schema(load_avsc('basic_schema.avsc'), schema_type='AVRO') diff --git a/tests/schema_registry/_sync/test_avro.py b/tests/schema_registry/_sync/test_avro.py index 8fb415ff5..7a9872285 100644 --- a/tests/schema_registry/_sync/test_avro.py +++ b/tests/schema_registry/_sync/test_avro.py @@ -106,9 +106,7 @@ def test_avro_serializer_config_subject_name_strategy(): conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer( - test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} - ) + test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) assert test_serializer._subject_name_func is record_subject_name_strategy @@ -147,9 +145,7 @@ def test_avro_serializer_record_subject_name_strategy_primitive(load_avsc): """ conf = {'url': TEST_URL} test_client = SchemaRegistryClient(conf) - test_serializer = AvroSerializer( - test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy} - ) + test_serializer = AvroSerializer(test_client, '"int"', conf={'subject.name.strategy': record_subject_name_strategy}) ctx = SerializationContext( 'test_subj', diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index f027176ee..0c8ca7977 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -23,10 +23,10 @@ from fastavro._logical_readers import UUID from confluent_kafka.schema_registry import ( - SchemaRegistryClient, Metadata, MetadataProperties, Schema, + SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.avro import AvroDeserializer, AvroSerializer diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index 8dcf20bf0..0d4ce3e23 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -276,7 +276,6 @@ def test_static_token_comma_separated_pools(): def test_static_field_provider_optional_pool(): """Test that _StaticFieldProvider works with optional identity pool.""" from confluent_kafka.schema_registry.common.schema_registry_client import _StaticFieldProvider - def check_provider(): static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) diff --git a/tests/schema_registry/_sync/test_config.py b/tests/schema_registry/_sync/test_config.py index 541d1b761..a927a0bbf 100644 --- a/tests/schema_registry/_sync/test_config.py +++ b/tests/schema_registry/_sync/test_config.py @@ -124,7 +124,9 @@ def test_config_auth_userinfo_invalid(): def test_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': "OAUTHBEARER"} - with pytest.raises(ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster"): + with pytest.raises( + ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster" + ): SchemaRegistryClient(conf) diff --git a/tests/schema_registry/_sync/test_json.py b/tests/schema_registry/_sync/test_json.py index b9aec592a..7d3d912e1 100644 --- a/tests/schema_registry/_sync/test_json.py +++ b/tests/schema_registry/_sync/test_json.py @@ -23,10 +23,10 @@ import pytest from confluent_kafka.schema_registry import ( - SchemaRegistryClient, RegisteredSchema, Schema, SchemaReference, + SchemaRegistryClient, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer from confluent_kafka.schema_registry.rule_registry import RuleRegistry diff --git a/tests/schema_registry/_sync/test_json_serdes.py b/tests/schema_registry/_sync/test_json_serdes.py index 66540d286..83de26355 100644 --- a/tests/schema_registry/_sync/test_json_serdes.py +++ b/tests/schema_registry/_sync/test_json_serdes.py @@ -21,10 +21,10 @@ import pytest from confluent_kafka.schema_registry import ( - SchemaRegistryClient, Metadata, MetadataProperties, Schema, + SchemaRegistryClient, header_schema_id_serializer, ) from confluent_kafka.schema_registry.json_schema import JSONDeserializer, JSONSerializer From a8d41943fe498d6858f1d21a6929c4a79f892d76 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 17:43:17 -0500 Subject: [PATCH 10/11] make style-fix and helper method --- .../_async/schema_registry_client.py | 9 +-- .../_sync/schema_registry_client.py | 5 +- .../common/schema_registry_client.py | 30 ++++++++++ .../_async/test_bearer_field_provider.py | 3 +- tests/schema_registry/_async/test_config.py | 4 +- .../_sync/test_bearer_field_provider.py | 1 + tests/test_Admin.py | 4 +- tests/test_unasync.py | 60 ++++++++++++------- 8 files changed, 84 insertions(+), 32 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 0abde8f3d..0564c2e22 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -44,6 +44,7 @@ full_jitter, is_retriable, is_success, + normalize_identity_pool, ) from confluent_kafka.schema_registry.error import OAuthTokenError, SchemaRegistryError @@ -286,17 +287,13 @@ def __init__(self, conf: dict): if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: if 'bearer.auth.logical.cluster' not in conf_copy: - raise ValueError( - "Missing required bearer configuration property: bearer.auth.logical.cluster" - ) + raise ValueError("Missing required bearer configuration property: bearer.auth.logical.cluster") logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) - if identity_pool is not None and not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + identity_pool = normalize_identity_pool(conf_copy.pop('bearer.auth.identity.pool.id', None)) if self.bearer_auth_credentials_source == 'OAUTHBEARER': properties_list = [ diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 358e2e684..ee9847863 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -44,6 +44,7 @@ full_jitter, is_retriable, is_success, + normalize_identity_pool, ) from confluent_kafka.schema_registry.error import OAuthTokenError, SchemaRegistryError @@ -292,9 +293,7 @@ def __init__(self, conf: dict): if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id', None) - if identity_pool is not None and not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + identity_pool = normalize_identity_pool(conf_copy.pop('bearer.auth.identity.pool.id', None)) if self.bearer_auth_credentials_source == 'OAUTHBEARER': properties_list = [ diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index 0464ab609..82ae27d89 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -31,6 +31,7 @@ 'is_success', 'is_retriable', 'full_jitter', + 'normalize_identity_pool', '_StaticFieldProvider', '_AsyncStaticFieldProvider', '_SchemaCache', @@ -117,6 +118,35 @@ def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) - return random.random() * min(no_jitter_delay, max_delay_ms) +def normalize_identity_pool(identity_pool_raw: Any) -> Optional[str]: + """ + Normalize identity pool configuration to a comma-separated string. + + Identity pool can be provided as: + - None: Returns None (no identity pool configured) + - str: Returns as-is (single pool ID or already comma-separated) + - list[str]: Joins with commas (multiple pool IDs) + + Args: + identity_pool_raw: The raw identity pool configuration value. + + Returns: + A comma-separated string of identity pool IDs, or None. + + Raises: + TypeError: If identity_pool_raw is not None, str, or list of strings. + """ + if identity_pool_raw is None: + return None + if isinstance(identity_pool_raw, str): + return identity_pool_raw + if isinstance(identity_pool_raw, list): + if not all(isinstance(item, str) for item in identity_pool_raw): + raise TypeError("All items in identity pool list must be strings") + return ",".join(identity_pool_raw) + raise TypeError("identity pool id must be a str or list, not " + str(type(identity_pool_raw))) + + class _SchemaCache(object): """ Thread-safe cache for use with the Schema Registry Client. diff --git a/tests/schema_registry/_async/test_bearer_field_provider.py b/tests/schema_registry/_async/test_bearer_field_provider.py index 89616cac9..61162165d 100644 --- a/tests/schema_registry/_async/test_bearer_field_provider.py +++ b/tests/schema_registry/_async/test_bearer_field_provider.py @@ -275,9 +275,10 @@ async def test_static_token_comma_separated_pools(): def test_static_field_provider_optional_pool(): """Test that _StaticFieldProvider works with optional identity pool.""" - from confluent_kafka.schema_registry.common.schema_registry_client import _AsyncStaticFieldProvider import asyncio + from confluent_kafka.schema_registry.common.schema_registry_client import _AsyncStaticFieldProvider + async def check_provider(): static_field_provider = _AsyncStaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) bearer_fields = await static_field_provider.get_bearer_fields() diff --git a/tests/schema_registry/_async/test_config.py b/tests/schema_registry/_async/test_config.py index e374b0fb9..4424a27e1 100644 --- a/tests/schema_registry/_async/test_config.py +++ b/tests/schema_registry/_async/test_config.py @@ -124,7 +124,9 @@ def test_config_auth_userinfo_invalid(): def test_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': "OAUTHBEARER"} - with pytest.raises(ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster"): + with pytest.raises( + ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster" + ): AsyncSchemaRegistryClient(conf) diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index 0d4ce3e23..9c41b33c3 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -275,6 +275,7 @@ def test_static_token_comma_separated_pools(): def test_static_field_provider_optional_pool(): """Test that _StaticFieldProvider works with optional identity pool.""" + from confluent_kafka.schema_registry.common.schema_registry_client import _StaticFieldProvider def check_provider(): diff --git a/tests/test_Admin.py b/tests/test_Admin.py index a6d67628a..1e8f05146 100644 --- a/tests/test_Admin.py +++ b/tests/test_Admin.py @@ -153,7 +153,9 @@ def test_create_topics_api(): try: a.create_topics([NewTopic("mytopic")]) except Exception as err: - assert False, f"When none of the partitions, \ + assert ( + False + ), f"When none of the partitions, \ replication and assignment is present, the request should not fail, but it does with error {err}" fs = a.create_topics([NewTopic("mytopic", 3, 2)]) with pytest.raises(KafkaException): diff --git a/tests/test_unasync.py b/tests/test_unasync.py index e854ec547..8681c6b42 100644 --- a/tests/test_unasync.py +++ b/tests/test_unasync.py @@ -46,14 +46,18 @@ def test_unasync_file_check(temp_dirs): os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(async_file, "w") as f: - f.write("""async def test(): + f.write( + """async def test(): await asyncio.sleep(1) -""") +""" + ) with open(sync_file, "w") as f: - f.write("""def test(): + f.write( + """def test(): time.sleep(1) -""") +""" + ) # This should return True assert unasync_file_check(async_file, sync_file) is True @@ -64,15 +68,19 @@ def test_unasync_file_check(temp_dirs): os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(async_file, "w") as f: - f.write("""async def test(): + f.write( + """async def test(): await asyncio.sleep(1) -""") +""" + ) with open(sync_file, "w") as f: - f.write("""def test(): + f.write( + """def test(): # This is wrong asyncio.sleep(1) -""") +""" + ) # This should return False assert unasync_file_check(async_file, sync_file) is False @@ -83,15 +91,19 @@ def test_unasync_file_check(temp_dirs): os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(async_file, "w") as f: - f.write("""async def test(): + f.write( + """async def test(): await asyncio.sleep(1) return "test" -""") +""" + ) with open(sync_file, "w") as f: - f.write("""def test(): + f.write( + """def test(): time.sleep(1) -""") +""" + ) # This should return False assert unasync_file_check(async_file, sync_file) is False @@ -108,14 +120,16 @@ def test_unasync_generation(temp_dirs): # Create a test async file test_file = os.path.join(async_dir, "test.py") with open(test_file, "w") as f: - f.write("""async def test_func(): + f.write( + """async def test_func(): await asyncio.sleep(1) return "test" class AsyncTest: async def test_method(self): await self.some_async() -""") +""" + ) # Run unasync with test directories dir_pairs = [(async_dir, sync_dir)] @@ -149,20 +163,24 @@ def test_unasync_check(temp_dirs): # Create a test async file test_file = os.path.join(async_dir, "test.py") with open(test_file, "w") as f: - f.write("""async def test_func(): + f.write( + """async def test_func(): await asyncio.sleep(1) return "test" -""") +""" + ) # Create an incorrect sync file sync_file = os.path.join(sync_dir, "test.py") os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(sync_file, "w") as f: - f.write("""def test_func(): + f.write( + """def test_func(): time.sleep(1) return "test" # Extra line that shouldn't be here -""") +""" + ) # Run unasync check with test directories dir_pairs = [(async_dir, sync_dir)] @@ -178,10 +196,12 @@ def test_unasync_missing_sync_file(temp_dirs): # Create a test async file test_file = os.path.join(async_dir, "test.py") with open(test_file, "w") as f: - f.write("""async def test_func(): + f.write( + """async def test_func(): await asyncio.sleep(1) return "test" -""") +""" + ) # Run unasync check with test directories dir_pairs = [(async_dir, sync_dir)] From 7257b6e6fc93188852b8bb97f440a7428a80d406 Mon Sep 17 00:00:00 2001 From: tobiogunbi Date: Tue, 3 Feb 2026 18:10:01 -0500 Subject: [PATCH 11/11] test fix --- tests/schema_registry/_async/test_config.py | 2 +- tests/schema_registry/_sync/test_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/schema_registry/_async/test_config.py b/tests/schema_registry/_async/test_config.py index 4424a27e1..2a1111ac6 100644 --- a/tests/schema_registry/_async/test_config.py +++ b/tests/schema_registry/_async/test_config.py @@ -150,7 +150,7 @@ def test_oauth_bearer_config_invalid(): 'bearer.auth.identity.pool.id': 1, } - with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"): + with pytest.raises(TypeError, match=r"identity pool id must be a str or list, not (.*)"): AsyncSchemaRegistryClient(conf) conf = { diff --git a/tests/schema_registry/_sync/test_config.py b/tests/schema_registry/_sync/test_config.py index a927a0bbf..11d5fc278 100644 --- a/tests/schema_registry/_sync/test_config.py +++ b/tests/schema_registry/_sync/test_config.py @@ -150,7 +150,7 @@ def test_oauth_bearer_config_invalid(): 'bearer.auth.identity.pool.id': 1, } - with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"): + with pytest.raises(TypeError, match=r"identity pool id must be a str or list, not (.*)"): SchemaRegistryClient(conf) conf = {