diff --git a/src/Core/Resolvers/SqlMutationEngine.cs b/src/Core/Resolvers/SqlMutationEngine.cs index 69fefe4341..970a58bb51 100644 --- a/src/Core/Resolvers/SqlMutationEngine.cs +++ b/src/Core/Resolvers/SqlMutationEngine.cs @@ -397,9 +397,12 @@ await queryExecutor.ExecuteQueryAsync( case EntityActionOperation.Insert: HttpContext httpContext = GetHttpContext(); + // Use scheme/host from X-Forwarded-* headers if present, else fallback to request values + string scheme = SqlPaginationUtil.ResolveRequestScheme(httpContext.Request); + string host = SqlPaginationUtil.ResolveRequestHost(httpContext.Request); string locationHeaderURL = UriHelper.BuildAbsolute( - scheme: httpContext.Request.Scheme, - host: httpContext.Request.Host, + scheme: scheme, + host: new HostString(host), pathBase: GetBaseRouteFromConfig(_runtimeConfigProvider.GetConfig()), path: httpContext.Request.Path); diff --git a/src/Core/Resolvers/SqlPaginationUtil.cs b/src/Core/Resolvers/SqlPaginationUtil.cs index bb9362015b..852138ef09 100644 --- a/src/Core/Resolvers/SqlPaginationUtil.cs +++ b/src/Core/Resolvers/SqlPaginationUtil.cs @@ -751,12 +751,12 @@ public static string FormatQueryString(NameValueCollection? queryStringParameter } /// - /// Extracts and request scheme from "X-Forwarded-Proto" or falls back to the request scheme. + /// Extracts the request scheme from "X-Forwarded-Proto" or falls back to the request scheme. + /// Invalid forwarded values are ignored. /// /// The HTTP request. /// The scheme string ("http" or "https"). - /// Thrown when client explicitly sets an invalid scheme. - private static string ResolveRequestScheme(HttpRequest req) + internal static string ResolveRequestScheme(HttpRequest req) { string? rawScheme = req.Headers["X-Forwarded-Proto"].FirstOrDefault(); string? normalized = rawScheme?.Trim().ToLowerInvariant(); @@ -776,11 +776,11 @@ private static string ResolveRequestScheme(HttpRequest req) /// /// Extracts the request host from "X-Forwarded-Host" or falls back to the request host. + /// Invalid forwarded values are ignored. /// /// The HTTP request. /// The host string. - /// Thrown when client explicitly sets an invalid host. - private static string ResolveRequestHost(HttpRequest req) + internal static string ResolveRequestHost(HttpRequest req) { string? rawHost = req.Headers["X-Forwarded-Host"].FirstOrDefault(); string? trimmed = rawHost?.Trim(); diff --git a/src/Core/Resolvers/SqlResponseHelpers.cs b/src/Core/Resolvers/SqlResponseHelpers.cs index d0bf768281..d955a7a36b 100644 --- a/src/Core/Resolvers/SqlResponseHelpers.cs +++ b/src/Core/Resolvers/SqlResponseHelpers.cs @@ -381,9 +381,12 @@ HttpContext httpContext // The third part is the computed primary key route. if (operationType is EntityActionOperation.Insert && !string.IsNullOrEmpty(primaryKeyRoute)) { + // Use scheme/host from X-Forwarded-* headers if present, else fallback to request values + string scheme = SqlPaginationUtil.ResolveRequestScheme(httpContext.Request); + string host = SqlPaginationUtil.ResolveRequestHost(httpContext.Request); locationHeaderURL = UriHelper.BuildAbsolute( - scheme: httpContext.Request.Scheme, - host: httpContext.Request.Host, + scheme: scheme, + host: new HostString(host), pathBase: baseRoute, path: httpContext.Request.Path); diff --git a/src/Service.Tests/Configuration/ConfigurationTests.cs b/src/Service.Tests/Configuration/ConfigurationTests.cs index 0ef9b67a4b..536d7d87f0 100644 --- a/src/Service.Tests/Configuration/ConfigurationTests.cs +++ b/src/Service.Tests/Configuration/ConfigurationTests.cs @@ -40,7 +40,10 @@ using Azure.DataApiBuilder.Service.Tests.SqlTests; using HotChocolate; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.TestHost; +using Microsoft.Data.SqlClient; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.IdentityModel.Tokens; @@ -3448,6 +3451,133 @@ public async Task ValidateLocationHeaderWhenBaseRouteIsConfigured( } } + /// + /// Validates that the Location header returned for POST requests respects X-Forwarded-Host and X-Forwarded-Proto. + /// This covers both table and stored procedure insert endpoints. + /// + /// Type of entity under test. + /// REST endpoint path for POST request. + /// Value for X-Forwarded-Host header. + /// Value for X-Forwarded-Proto header. + /// Expected scheme in Location header. + [DataTestMethod] + [TestCategory(TestCategory.MSSQL)] + [DataRow(EntitySourceType.Table, "/api/Book", null, null, "http", DisplayName = "Location header uses local http scheme when no forwarded headers are present for table POST")] + [DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", null, null, "http", DisplayName = "Location header uses local http scheme when no forwarded headers are present for stored procedure POST")] + [DataRow(EntitySourceType.Table, "/api/Book", "api.contoso.com", "http", "http", DisplayName = "Location header uses forwarded http scheme for table POST")] + [DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", "api.contoso.com", "http", "http", DisplayName = "Location header uses forwarded http scheme for stored procedure POST")] + [DataRow(EntitySourceType.Table, "/api/Book", "api.contoso.com", "https", "https", DisplayName = "Location header uses forwarded https scheme/host for table POST")] + [DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", "api.contoso.com", "https", "https", DisplayName = "Location header uses forwarded https scheme/host for stored procedure POST")] + public async Task ValidateLocationHeaderRespectsXForwardedHostAndProto( + EntitySourceType entityType, + string requestPath, + string forwardedHost, + string forwardedProto, + string expectedScheme) + { + TestHelper.SetupDatabaseEnvironment(MSSQL_ENVIRONMENT); + + GraphQLRuntimeOptions graphqlOptions = new(Enabled: false); + RestRuntimeOptions restRuntimeOptions = new(Enabled: true); + McpRuntimeOptions mcpRuntimeOptions = new(Enabled: false); + + SqlConnectionStringBuilder connectionStringBuilder = new(GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL)) + { + TrustServerCertificate = true + }; + + DataSource dataSource = new(DatabaseType.MSSQL, + connectionStringBuilder.ConnectionString, Options: null); + + RuntimeConfig configuration; + if (entityType is EntitySourceType.StoredProcedure) + { + Entity entity = new(Source: new("get_books", EntitySourceType.StoredProcedure, null, null), + Fields: null, + Rest: new(new SupportedHttpVerb[] { SupportedHttpVerb.Get, SupportedHttpVerb.Post }), + GraphQL: null, + Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) }, + Relationships: null, + Mappings: null + ); + + configuration = InitMinimalRuntimeConfig(dataSource, graphqlOptions, restRuntimeOptions, mcpRuntimeOptions, entity, entityName: "GetBooks"); + } + else + { + configuration = InitMinimalRuntimeConfig(dataSource, graphqlOptions, restRuntimeOptions, mcpRuntimeOptions); + } + + const string CUSTOM_CONFIG = "custom-config.json"; + File.WriteAllText(CUSTOM_CONFIG, configuration.ToJson()); + string[] args = new[] { $"--ConfigFileName={CUSTOM_CONFIG}" }; + + // Intentionally bind HTTP to simulate the proxy-to-app internal hop. + using IWebHost host = Program.CreateWebHostBuilder(args) + .UseUrls("http://127.0.0.1:0") + .Build(); + await host.StartAsync(); + + IServerAddressesFeature addresses = host.ServerFeatures.Get(); + Assert.IsNotNull(addresses); + + string baseAddress = addresses.Addresses.FirstOrDefault(); + Assert.IsFalse(string.IsNullOrEmpty(baseAddress)); + + using HttpClient client = new() + { + BaseAddress = new Uri(baseAddress) + }; + + HttpRequestMessage request = new(HttpMethod.Post, requestPath); + if (!string.IsNullOrEmpty(forwardedHost)) + { + request.Headers.Add("X-Forwarded-Host", forwardedHost); + } + + if (!string.IsNullOrEmpty(forwardedProto)) + { + request.Headers.Add("X-Forwarded-Proto", forwardedProto); + } + + if (entityType is EntitySourceType.Table) + { + JsonElement requestBodyElement = JsonDocument.Parse(@"{ + ""title"": ""Forwarded Header Location Test"", + ""publisher_id"": 1234 + }").RootElement.Clone(); + + request.Content = JsonContent.Create(requestBodyElement); + } + + HttpResponseMessage response = await client.SendAsync(request); + + Assert.AreEqual(HttpStatusCode.Created, response.StatusCode); + Assert.IsNotNull(response.Headers.Location, "Location header should be present for successful POST create."); + + Uri location = response.Headers.Location; + Assert.AreEqual(expectedScheme, location.Scheme, $"Expected Location scheme '{expectedScheme}', got '{location.Scheme}'."); + + if (!string.IsNullOrEmpty(forwardedHost)) + { + Assert.AreEqual(forwardedHost, location.Host, $"Expected Location host '{forwardedHost}', got '{location.Host}'."); + } + + // Since forwarded host is external, validate follow-up using local path only. + string localPathAndQuery = string.IsNullOrEmpty(location.Query) ? location.AbsolutePath : location.AbsolutePath + location.Query; + HttpRequestMessage followUpRequest = new(HttpMethod.Get, localPathAndQuery); + HttpResponseMessage followUpResponse = await client.SendAsync(followUpRequest); + Assert.AreEqual(HttpStatusCode.OK, followUpResponse.StatusCode); + + if (entityType is EntitySourceType.Table) + { + HttpRequestMessage cleanupRequest = new(HttpMethod.Delete, localPathAndQuery); + await client.SendAsync(cleanupRequest); + } + + await host.StopAsync(); + } + /// /// Test to validate that when the property rest.request-body-strict is absent from the rest runtime section in config file, DAB runs in strict mode. /// In strict mode, presence of extra fields in the request body is not permitted and leads to HTTP 400 - BadRequest error.