diff --git a/examples/oauth_schema_registry.py b/examples/oauth_schema_registry.py index fd1115b5c..10856ac48 100644 --- a/examples/oauth_schema_registry.py +++ b/examples/oauth_schema_registry.py @@ -71,6 +71,33 @@ def custom_oauth_function(config): custom_sr_client = SchemaRegistryClient(custom_sr_config) print(custom_sr_client.get_subjects()) + # Example: Using union-of-pools with comma-separated pool IDs + union_of_pools_config = { + 'url': 'https://psrc-123456.us-east-1.aws.confluent.cloud', + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': 'multi-pool-token', + 'bearer.auth.logical.cluster': 'lsrc-12345', + 'bearer.auth.identity.pool.id': 'pool-abc,pool-def,pool-ghi', + } + + union_sr_client = SchemaRegistryClient(union_of_pools_config) + print(union_sr_client.get_subjects()) + + # Example: Omitting identity pool for auto pool mapping + auto_pool_config = { + 'url': 'https://psrc-123456.us-east-1.aws.confluent.cloud', + 'bearer.auth.credentials.source': 'OAUTHBEARER', + 'bearer.auth.client.id': 'client-id', + 'bearer.auth.client.secret': 'client-secret', + 'bearer.auth.scope': 'schema_registry', + 'bearer.auth.issuer.endpoint.url': 'https://yourauthprovider.com/v1/token', + 'bearer.auth.logical.cluster': 'lsrc-12345', + # bearer.auth.identity.pool.id is omitted - SR will use auto pool mapping + } + + auto_pool_sr_client = SchemaRegistryClient(auto_pool_config) + print(auto_pool_sr_client.get_subjects()) + if __name__ == '__main__': main() diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index e0efe351e..0564c2e22 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -44,6 +44,7 @@ full_jitter, is_retriable, is_success, + normalize_identity_pool, ) from confluent_kafka.schema_registry.error import OAuthTokenError, SchemaRegistryError @@ -97,10 +98,10 @@ def __init__( scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int, + identity_pool: Optional[str] = None, ): self.token = None self.logical_cluster = logical_cluster @@ -113,11 +114,13 @@ def __init__( self.token_expiry_threshold = 0.8 async def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': await self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def token_expired(self) -> bool: if self.token is None: @@ -283,20 +286,14 @@ def __init__(self, conf: dict): self.auth = None if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: - raise ValueError( - "Missing required bearer configuration properties: {}".format(", ".join(missing_headers)) - ) + if 'bearer.auth.logical.cluster' not in conf_copy: + raise ValueError("Missing required bearer configuration property: bearer.auth.logical.cluster") logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + identity_pool = normalize_identity_pool(conf_copy.pop('bearer.auth.identity.pool.id', None)) if self.bearer_auth_credentials_source == 'OAUTHBEARER': properties_list = [ @@ -335,10 +332,10 @@ def __init__(self, conf: dict): self.scope, self.token_endpoint, logical_cluster, - identity_pool, self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms, + identity_pool, ) else: # STATIC_TOKEN if 'bearer.auth.token' not in conf_copy: @@ -412,7 +409,8 @@ async def handle_bearer_auth(self, headers: dict) -> None: if self.bearer_field_provider is None: raise ValueError("Bearer field provider is not set") bearer_fields = await self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + # Note: bearer.auth.identity.pool.id is optional; only token and logical.cluster are required + required_fields = ['bearer.auth.token', 'bearer.auth.logical.cluster'] missing_fields = [] for field in required_fields: @@ -427,9 +425,11 @@ async def handle_bearer_auth(self, headers: dict) -> None: ) headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + if 'bearer.auth.identity.pool.id' in bearer_fields: + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + async def get(self, url: str, query: Optional[dict] = None) -> Any: return await self.send_request(url, method='GET', query=query) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 74435e108..ee9847863 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -44,6 +44,7 @@ full_jitter, is_retriable, is_success, + normalize_identity_pool, ) from confluent_kafka.schema_registry.error import OAuthTokenError, SchemaRegistryError @@ -97,10 +98,10 @@ def __init__( scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int, + identity_pool: Optional[str] = None, ): self.token = None self.logical_cluster = logical_cluster @@ -113,11 +114,13 @@ def __init__( self.token_expiry_threshold = 0.8 def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def token_expired(self) -> bool: if self.token is None: @@ -283,20 +286,14 @@ def __init__(self, conf: dict): self.auth = None if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: - raise ValueError( - "Missing required bearer configuration properties: {}".format(", ".join(missing_headers)) - ) + if 'bearer.auth.logical.cluster' not in conf_copy: + raise ValueError("Missing required bearer configuration property: bearer.auth.logical.cluster") logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') if not isinstance(logical_cluster, str): raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + identity_pool = normalize_identity_pool(conf_copy.pop('bearer.auth.identity.pool.id', None)) if self.bearer_auth_credentials_source == 'OAUTHBEARER': properties_list = [ @@ -335,10 +332,10 @@ def __init__(self, conf: dict): self.scope, self.token_endpoint, logical_cluster, - identity_pool, self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms, + identity_pool, ) else: # STATIC_TOKEN if 'bearer.auth.token' not in conf_copy: @@ -412,7 +409,8 @@ def handle_bearer_auth(self, headers: dict) -> None: if self.bearer_field_provider is None: raise ValueError("Bearer field provider is not set") bearer_fields = self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + # Note: bearer.auth.identity.pool.id is optional; only token and logical.cluster are required + required_fields = ['bearer.auth.token', 'bearer.auth.logical.cluster'] missing_fields = [] for field in required_fields: @@ -427,9 +425,11 @@ def handle_bearer_auth(self, headers: dict) -> None: ) headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + if 'bearer.auth.identity.pool.id' in bearer_fields: + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + def get(self, url: str, query: Optional[dict] = None) -> Any: return self.send_request(url, method='GET', query=query) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index afdefd08f..82ae27d89 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -31,6 +31,7 @@ 'is_success', 'is_retriable', 'full_jitter', + 'normalize_identity_pool', '_StaticFieldProvider', '_AsyncStaticFieldProvider', '_SchemaCache', @@ -71,33 +72,37 @@ async def get_bearer_fields(self) -> dict: class _StaticFieldProvider(_BearerFieldProvider): """Synchronous static token bearer field provider.""" - def __init__(self, token: str, logical_cluster: str, identity_pool: str): + def __init__(self, token: str, logical_cluster: str, identity_pool: Optional[str] = None): self.token = token self.logical_cluster = logical_cluster self.identity_pool = identity_pool def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields class _AsyncStaticFieldProvider(_AsyncBearerFieldProvider): """Asynchronous static token bearer field provider.""" - def __init__(self, token: str, logical_cluster: str, identity_pool: str): + def __init__(self, token: str, logical_cluster: str, identity_pool: Optional[str] = None): self.token = token self.logical_cluster = logical_cluster self.identity_pool = identity_pool async def get_bearer_fields(self) -> dict: - return { + fields = { 'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool, } + if self.identity_pool is not None: + fields['bearer.auth.identity.pool.id'] = self.identity_pool + return fields def is_success(status_code: int) -> bool: @@ -113,6 +118,35 @@ def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) - return random.random() * min(no_jitter_delay, max_delay_ms) +def normalize_identity_pool(identity_pool_raw: Any) -> Optional[str]: + """ + Normalize identity pool configuration to a comma-separated string. + + Identity pool can be provided as: + - None: Returns None (no identity pool configured) + - str: Returns as-is (single pool ID or already comma-separated) + - list[str]: Joins with commas (multiple pool IDs) + + Args: + identity_pool_raw: The raw identity pool configuration value. + + Returns: + A comma-separated string of identity pool IDs, or None. + + Raises: + TypeError: If identity_pool_raw is not None, str, or list of strings. + """ + if identity_pool_raw is None: + return None + if isinstance(identity_pool_raw, str): + return identity_pool_raw + if isinstance(identity_pool_raw, list): + if not all(isinstance(item, str) for item in identity_pool_raw): + raise TypeError("All items in identity pool list must be strings") + return ",".join(identity_pool_raw) + raise TypeError("identity pool id must be a str or list, not " + str(type(identity_pool_raw))) + + class _SchemaCache(object): """ Thread-safe cache for use with the Schema Registry Client. diff --git a/tests/schema_registry/_async/test_bearer_field_provider.py b/tests/schema_registry/_async/test_bearer_field_provider.py index ffedc15b6..61162165d 100644 --- a/tests/schema_registry/_async/test_bearer_field_provider.py +++ b/tests/schema_registry/_async/test_bearer_field_provider.py @@ -51,7 +51,7 @@ async def custom_oauth_function(config: dict) -> dict: def test_expiry(): - oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) # Use consistent test data: expires_at and expires_in should match # Token expires in 2 seconds, with 0.8 threshold, should refresh after 1.6 seconds (when 0.4s remaining) oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2} @@ -61,7 +61,7 @@ def test_expiry(): async def test_get_token(): - oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) def update_token1(): oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} @@ -86,7 +86,7 @@ def update_token2(): async def test_generate_token_retry_logic(): - oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000) + oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 5, 1000, 20000, TEST_POOL) with ( patch("confluent_kafka.schema_registry._async.schema_registry_client.asyncio.sleep") as mock_sleep, @@ -158,3 +158,133 @@ async def test_bearer_field_headers_valid(): assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token']) assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id'] assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster'] + + +async def test_bearer_field_headers_optional_identity_pool(): + """Test that identity pool is optional and header is omitted when not provided.""" + + async def custom_oauth_no_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_no_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert headers['Authorization'] == "Bearer {}".format(TEST_TOKEN) + assert headers['target-sr-cluster'] == TEST_CLUSTER + # Confluent-Identity-Pool-Id should NOT be present when identity pool is omitted + assert 'Confluent-Identity-Pool-Id' not in headers + + +async def test_bearer_field_headers_comma_separated_pools(): + """Test that comma-separated pool IDs are passed through as-is in the header.""" + comma_separated_pools = 'pool-1,pool-2,pool-3' + + async def custom_oauth_multi_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_multi_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'Confluent-Identity-Pool-Id' in headers + assert 'target-sr-cluster' in headers + # Verify comma-separated value is passed through unchanged + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +async def test_static_token_optional_identity_pool(): + """Test that STATIC_TOKEN credential source works without identity pool.""" + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert 'Confluent-Identity-Pool-Id' not in headers + + +async def test_static_token_comma_separated_pools(): + """Test that STATIC_TOKEN credential source supports comma-separated pool IDs.""" + comma_separated_pools = 'pool-abc,pool-def,pool-ghi' + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + client = AsyncSchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + await client._rest_client.handle_bearer_auth(headers) + + assert 'Confluent-Identity-Pool-Id' in headers + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +def test_static_field_provider_optional_pool(): + """Test that _StaticFieldProvider works with optional identity pool.""" + import asyncio + + from confluent_kafka.schema_registry.common.schema_registry_client import _AsyncStaticFieldProvider + + async def check_provider(): + static_field_provider = _AsyncStaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) + bearer_fields = await static_field_provider.get_bearer_fields() + + assert bearer_fields['bearer.auth.token'] == TEST_TOKEN + assert bearer_fields['bearer.auth.logical.cluster'] == TEST_CLUSTER + assert 'bearer.auth.identity.pool.id' not in bearer_fields + + asyncio.run(check_provider()) diff --git a/tests/schema_registry/_async/test_config.py b/tests/schema_registry/_async/test_config.py index 98dd79c66..2a1111ac6 100644 --- a/tests/schema_registry/_async/test_config.py +++ b/tests/schema_registry/_async/test_config.py @@ -124,7 +124,9 @@ def test_config_auth_userinfo_invalid(): def test_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': "OAUTHBEARER"} - with pytest.raises(ValueError, match=r"Missing required bearer configuration properties: (.*)"): + with pytest.raises( + ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster" + ): AsyncSchemaRegistryClient(conf) @@ -148,7 +150,7 @@ def test_oauth_bearer_config_invalid(): 'bearer.auth.identity.pool.id': 1, } - with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"): + with pytest.raises(TypeError, match=r"identity pool id must be a str or list, not (.*)"): AsyncSchemaRegistryClient(conf) conf = { diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index ebeceaa7e..9c41b33c3 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -51,7 +51,7 @@ def custom_oauth_function(config: dict) -> dict: def test_expiry(): - oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) # Use consistent test data: expires_at and expires_in should match # Token expires in 2 seconds, with 0.8 threshold, should refresh after 1.6 seconds (when 0.4s remaining) oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 2} @@ -61,7 +61,7 @@ def test_expiry(): def test_get_token(): - oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 2, 1000, 20000, TEST_POOL) def update_token1(): oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} @@ -86,7 +86,7 @@ def update_token2(): def test_generate_token_retry_logic(): - oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000) + oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, 5, 1000, 20000, TEST_POOL) with ( patch("confluent_kafka.schema_registry._sync.schema_registry_client.time.sleep") as mock_sleep, @@ -158,3 +158,132 @@ def test_bearer_field_headers_valid(): assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token']) assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id'] assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster'] + + +def test_bearer_field_headers_optional_identity_pool(): + """Test that identity pool is optional and header is omitted when not provided.""" + + def custom_oauth_no_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_no_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert headers['Authorization'] == "Bearer {}".format(TEST_TOKEN) + assert headers['target-sr-cluster'] == TEST_CLUSTER + # Confluent-Identity-Pool-Id should NOT be present when identity pool is omitted + assert 'Confluent-Identity-Pool-Id' not in headers + + +def test_bearer_field_headers_comma_separated_pools(): + """Test that comma-separated pool IDs are passed through as-is in the header.""" + comma_separated_pools = 'pool-1,pool-2,pool-3' + + def custom_oauth_multi_pool(config: dict) -> dict: + return { + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'CUSTOM', + 'bearer.auth.custom.provider.function': custom_oauth_multi_pool, + 'bearer.auth.custom.provider.config': {}, + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'Confluent-Identity-Pool-Id' in headers + assert 'target-sr-cluster' in headers + # Verify comma-separated value is passed through unchanged + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +def test_static_token_optional_identity_pool(): + """Test that STATIC_TOKEN credential source works without identity pool.""" + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + # bearer.auth.identity.pool.id is intentionally omitted + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Authorization' in headers + assert 'target-sr-cluster' in headers + assert 'Confluent-Identity-Pool-Id' not in headers + + +def test_static_token_comma_separated_pools(): + """Test that STATIC_TOKEN credential source supports comma-separated pool IDs.""" + comma_separated_pools = 'pool-abc,pool-def,pool-ghi' + + conf = { + 'url': TEST_URL, + 'bearer.auth.credentials.source': 'STATIC_TOKEN', + 'bearer.auth.token': TEST_TOKEN, + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': comma_separated_pools, + } + + client = SchemaRegistryClient(conf) + + headers = { + 'Accept': "application/vnd.schemaregistry.v1+json," " application/vnd.schemaregistry+json," " application/json" + } + + client._rest_client.handle_bearer_auth(headers) + + assert 'Confluent-Identity-Pool-Id' in headers + assert headers['Confluent-Identity-Pool-Id'] == comma_separated_pools + + +def test_static_field_provider_optional_pool(): + """Test that _StaticFieldProvider works with optional identity pool.""" + + from confluent_kafka.schema_registry.common.schema_registry_client import _StaticFieldProvider + + def check_provider(): + static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, None) + bearer_fields = static_field_provider.get_bearer_fields() + + assert bearer_fields['bearer.auth.token'] == TEST_TOKEN + assert bearer_fields['bearer.auth.logical.cluster'] == TEST_CLUSTER + assert 'bearer.auth.identity.pool.id' not in bearer_fields + + check_provider() diff --git a/tests/schema_registry/_sync/test_config.py b/tests/schema_registry/_sync/test_config.py index e5b74aeda..11d5fc278 100644 --- a/tests/schema_registry/_sync/test_config.py +++ b/tests/schema_registry/_sync/test_config.py @@ -124,7 +124,9 @@ def test_config_auth_userinfo_invalid(): def test_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': "OAUTHBEARER"} - with pytest.raises(ValueError, match=r"Missing required bearer configuration properties: (.*)"): + with pytest.raises( + ValueError, match=r"Missing required bearer configuration property: bearer.auth.logical.cluster" + ): SchemaRegistryClient(conf) @@ -148,7 +150,7 @@ def test_oauth_bearer_config_invalid(): 'bearer.auth.identity.pool.id': 1, } - with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"): + with pytest.raises(TypeError, match=r"identity pool id must be a str or list, not (.*)"): SchemaRegistryClient(conf) conf = { diff --git a/tests/test_Admin.py b/tests/test_Admin.py index a6d67628a..1e8f05146 100644 --- a/tests/test_Admin.py +++ b/tests/test_Admin.py @@ -153,7 +153,9 @@ def test_create_topics_api(): try: a.create_topics([NewTopic("mytopic")]) except Exception as err: - assert False, f"When none of the partitions, \ + assert ( + False + ), f"When none of the partitions, \ replication and assignment is present, the request should not fail, but it does with error {err}" fs = a.create_topics([NewTopic("mytopic", 3, 2)]) with pytest.raises(KafkaException): diff --git a/tests/test_unasync.py b/tests/test_unasync.py index e854ec547..8681c6b42 100644 --- a/tests/test_unasync.py +++ b/tests/test_unasync.py @@ -46,14 +46,18 @@ def test_unasync_file_check(temp_dirs): os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(async_file, "w") as f: - f.write("""async def test(): + f.write( + """async def test(): await asyncio.sleep(1) -""") +""" + ) with open(sync_file, "w") as f: - f.write("""def test(): + f.write( + """def test(): time.sleep(1) -""") +""" + ) # This should return True assert unasync_file_check(async_file, sync_file) is True @@ -64,15 +68,19 @@ def test_unasync_file_check(temp_dirs): os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(async_file, "w") as f: - f.write("""async def test(): + f.write( + """async def test(): await asyncio.sleep(1) -""") +""" + ) with open(sync_file, "w") as f: - f.write("""def test(): + f.write( + """def test(): # This is wrong asyncio.sleep(1) -""") +""" + ) # This should return False assert unasync_file_check(async_file, sync_file) is False @@ -83,15 +91,19 @@ def test_unasync_file_check(temp_dirs): os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(async_file, "w") as f: - f.write("""async def test(): + f.write( + """async def test(): await asyncio.sleep(1) return "test" -""") +""" + ) with open(sync_file, "w") as f: - f.write("""def test(): + f.write( + """def test(): time.sleep(1) -""") +""" + ) # This should return False assert unasync_file_check(async_file, sync_file) is False @@ -108,14 +120,16 @@ def test_unasync_generation(temp_dirs): # Create a test async file test_file = os.path.join(async_dir, "test.py") with open(test_file, "w") as f: - f.write("""async def test_func(): + f.write( + """async def test_func(): await asyncio.sleep(1) return "test" class AsyncTest: async def test_method(self): await self.some_async() -""") +""" + ) # Run unasync with test directories dir_pairs = [(async_dir, sync_dir)] @@ -149,20 +163,24 @@ def test_unasync_check(temp_dirs): # Create a test async file test_file = os.path.join(async_dir, "test.py") with open(test_file, "w") as f: - f.write("""async def test_func(): + f.write( + """async def test_func(): await asyncio.sleep(1) return "test" -""") +""" + ) # Create an incorrect sync file sync_file = os.path.join(sync_dir, "test.py") os.makedirs(os.path.dirname(sync_file), exist_ok=True) with open(sync_file, "w") as f: - f.write("""def test_func(): + f.write( + """def test_func(): time.sleep(1) return "test" # Extra line that shouldn't be here -""") +""" + ) # Run unasync check with test directories dir_pairs = [(async_dir, sync_dir)] @@ -178,10 +196,12 @@ def test_unasync_missing_sync_file(temp_dirs): # Create a test async file test_file = os.path.join(async_dir, "test.py") with open(test_file, "w") as f: - f.write("""async def test_func(): + f.write( + """async def test_func(): await asyncio.sleep(1) return "test" -""") +""" + ) # Run unasync check with test directories dir_pairs = [(async_dir, sync_dir)]