Skip to content
7 changes: 5 additions & 2 deletions src/Core/Resolvers/SqlMutationEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,12 @@ await queryExecutor.ExecuteQueryAsync(
case EntityActionOperation.Insert:

HttpContext httpContext = GetHttpContext();
// Use scheme/host from X-Forwarded-* headers if present, else fallback to request values
string scheme = SqlPaginationUtil.ResolveRequestScheme(httpContext.Request);
string host = SqlPaginationUtil.ResolveRequestHost(httpContext.Request);
string locationHeaderURL = UriHelper.BuildAbsolute(
scheme: httpContext.Request.Scheme,
host: httpContext.Request.Host,
scheme: scheme,
host: new HostString(host),
pathBase: GetBaseRouteFromConfig(_runtimeConfigProvider.GetConfig()),
path: httpContext.Request.Path);

Expand Down
10 changes: 5 additions & 5 deletions src/Core/Resolvers/SqlPaginationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -751,12 +751,12 @@ public static string FormatQueryString(NameValueCollection? queryStringParameter
}

/// <summary>
/// Extracts and request scheme from "X-Forwarded-Proto" or falls back to the request scheme.
/// Extracts the request scheme from "X-Forwarded-Proto" or falls back to the request scheme.
/// Invalid forwarded values are ignored.
/// </summary>
/// <param name="req">The HTTP request.</param>
/// <returns>The scheme string ("http" or "https").</returns>
/// <exception cref="DataApiBuilderException">Thrown when client explicitly sets an invalid scheme.</exception>
private static string ResolveRequestScheme(HttpRequest req)
internal static string ResolveRequestScheme(HttpRequest req)
{
string? rawScheme = req.Headers["X-Forwarded-Proto"].FirstOrDefault();
string? normalized = rawScheme?.Trim().ToLowerInvariant();
Expand All @@ -776,11 +776,11 @@ private static string ResolveRequestScheme(HttpRequest req)

/// <summary>
/// Extracts the request host from "X-Forwarded-Host" or falls back to the request host.
/// Invalid forwarded values are ignored.
/// </summary>
/// <param name="req">The HTTP request.</param>
/// <returns>The host string.</returns>
/// <exception cref="DataApiBuilderException">Thrown when client explicitly sets an invalid host.</exception>
private static string ResolveRequestHost(HttpRequest req)
internal static string ResolveRequestHost(HttpRequest req)
{
string? rawHost = req.Headers["X-Forwarded-Host"].FirstOrDefault();
string? trimmed = rawHost?.Trim();
Expand Down
7 changes: 5 additions & 2 deletions src/Core/Resolvers/SqlResponseHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,12 @@ HttpContext httpContext
// The third part is the computed primary key route.
if (operationType is EntityActionOperation.Insert && !string.IsNullOrEmpty(primaryKeyRoute))
{
// Use scheme/host from X-Forwarded-* headers if present, else fallback to request values
string scheme = SqlPaginationUtil.ResolveRequestScheme(httpContext.Request);
string host = SqlPaginationUtil.ResolveRequestHost(httpContext.Request);
locationHeaderURL = UriHelper.BuildAbsolute(
scheme: httpContext.Request.Scheme,
host: httpContext.Request.Host,
scheme: scheme,
host: new HostString(host),
pathBase: baseRoute,
path: httpContext.Request.Path);

Expand Down
130 changes: 130 additions & 0 deletions src/Service.Tests/Configuration/ConfigurationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
using Azure.DataApiBuilder.Service.Tests.SqlTests;
using HotChocolate;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.IdentityModel.Tokens;
Expand Down Expand Up @@ -3448,6 +3451,133 @@ public async Task ValidateLocationHeaderWhenBaseRouteIsConfigured(
}
}

/// <summary>
/// Validates that the Location header returned for POST requests respects X-Forwarded-Host and X-Forwarded-Proto.
/// This covers both table and stored procedure insert endpoints.
/// </summary>
/// <param name="entityType">Type of entity under test.</param>
/// <param name="requestPath">REST endpoint path for POST request.</param>
/// <param name="forwardedHost">Value for X-Forwarded-Host header.</param>
/// <param name="forwardedProto">Value for X-Forwarded-Proto header.</param>
/// <param name="expectedScheme">Expected scheme in Location header.</param>
[DataTestMethod]
[TestCategory(TestCategory.MSSQL)]
[DataRow(EntitySourceType.Table, "/api/Book", null, null, "http", DisplayName = "Location header uses local http scheme when no forwarded headers are present for table POST")]
[DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", null, null, "http", DisplayName = "Location header uses local http scheme when no forwarded headers are present for stored procedure POST")]
[DataRow(EntitySourceType.Table, "/api/Book", "api.contoso.com", "http", "http", DisplayName = "Location header uses forwarded http scheme for table POST")]
[DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", "api.contoso.com", "http", "http", DisplayName = "Location header uses forwarded http scheme for stored procedure POST")]
[DataRow(EntitySourceType.Table, "/api/Book", "api.contoso.com", "https", "https", DisplayName = "Location header uses forwarded https scheme/host for table POST")]
[DataRow(EntitySourceType.StoredProcedure, "/api/GetBooks", "api.contoso.com", "https", "https", DisplayName = "Location header uses forwarded https scheme/host for stored procedure POST")]
public async Task ValidateLocationHeaderRespectsXForwardedHostAndProto(
EntitySourceType entityType,
string requestPath,
string forwardedHost,
string forwardedProto,
string expectedScheme)
{
TestHelper.SetupDatabaseEnvironment(MSSQL_ENVIRONMENT);

GraphQLRuntimeOptions graphqlOptions = new(Enabled: false);
RestRuntimeOptions restRuntimeOptions = new(Enabled: true);
McpRuntimeOptions mcpRuntimeOptions = new(Enabled: false);

SqlConnectionStringBuilder connectionStringBuilder = new(GetConnectionStringFromEnvironmentConfig(environment: TestCategory.MSSQL))
{
TrustServerCertificate = true
};

DataSource dataSource = new(DatabaseType.MSSQL,
connectionStringBuilder.ConnectionString, Options: null);

RuntimeConfig configuration;
if (entityType is EntitySourceType.StoredProcedure)
{
Entity entity = new(Source: new("get_books", EntitySourceType.StoredProcedure, null, null),
Fields: null,
Rest: new(new SupportedHttpVerb[] { SupportedHttpVerb.Get, SupportedHttpVerb.Post }),
GraphQL: null,
Permissions: new[] { GetMinimalPermissionConfig(AuthorizationResolver.ROLE_ANONYMOUS) },
Relationships: null,
Mappings: null
);

configuration = InitMinimalRuntimeConfig(dataSource, graphqlOptions, restRuntimeOptions, mcpRuntimeOptions, entity, entityName: "GetBooks");
}
else
{
configuration = InitMinimalRuntimeConfig(dataSource, graphqlOptions, restRuntimeOptions, mcpRuntimeOptions);
}

const string CUSTOM_CONFIG = "custom-config.json";
File.WriteAllText(CUSTOM_CONFIG, configuration.ToJson());
string[] args = new[] { $"--ConfigFileName={CUSTOM_CONFIG}" };

// Intentionally bind HTTP to simulate the proxy-to-app internal hop.
using IWebHost host = Program.CreateWebHostBuilder(args)
.UseUrls("http://127.0.0.1:0")
.Build();
await host.StartAsync();

IServerAddressesFeature addresses = host.ServerFeatures.Get<IServerAddressesFeature>();
Assert.IsNotNull(addresses);

string baseAddress = addresses.Addresses.FirstOrDefault();
Assert.IsFalse(string.IsNullOrEmpty(baseAddress));

using HttpClient client = new()
{
BaseAddress = new Uri(baseAddress)
};

HttpRequestMessage request = new(HttpMethod.Post, requestPath);
if (!string.IsNullOrEmpty(forwardedHost))
{
request.Headers.Add("X-Forwarded-Host", forwardedHost);
}

if (!string.IsNullOrEmpty(forwardedProto))
{
request.Headers.Add("X-Forwarded-Proto", forwardedProto);
}

if (entityType is EntitySourceType.Table)
{
JsonElement requestBodyElement = JsonDocument.Parse(@"{
""title"": ""Forwarded Header Location Test"",
""publisher_id"": 1234
}").RootElement.Clone();

request.Content = JsonContent.Create(requestBodyElement);
}

HttpResponseMessage response = await client.SendAsync(request);

Assert.AreEqual(HttpStatusCode.Created, response.StatusCode);
Assert.IsNotNull(response.Headers.Location, "Location header should be present for successful POST create.");

Uri location = response.Headers.Location;
Assert.AreEqual(expectedScheme, location.Scheme, $"Expected Location scheme '{expectedScheme}', got '{location.Scheme}'.");

if (!string.IsNullOrEmpty(forwardedHost))
{
Assert.AreEqual(forwardedHost, location.Host, $"Expected Location host '{forwardedHost}', got '{location.Host}'.");
}

// Since forwarded host is external, validate follow-up using local path only.
string localPathAndQuery = string.IsNullOrEmpty(location.Query) ? location.AbsolutePath : location.AbsolutePath + location.Query;
HttpRequestMessage followUpRequest = new(HttpMethod.Get, localPathAndQuery);
HttpResponseMessage followUpResponse = await client.SendAsync(followUpRequest);
Assert.AreEqual(HttpStatusCode.OK, followUpResponse.StatusCode);

if (entityType is EntitySourceType.Table)
{
HttpRequestMessage cleanupRequest = new(HttpMethod.Delete, localPathAndQuery);
await client.SendAsync(cleanupRequest);
}

await host.StopAsync();
}

/// <summary>
/// Test to validate that when the property rest.request-body-strict is absent from the rest runtime section in config file, DAB runs in strict mode.
/// In strict mode, presence of extra fields in the request body is not permitted and leads to HTTP 400 - BadRequest error.
Expand Down