diff --git a/core/src/main/java/google/registry/model/eppcommon/EppXmlTransformer.java b/core/src/main/java/google/registry/model/eppcommon/EppXmlTransformer.java index 165046612df..8fe731c6550 100644 --- a/core/src/main/java/google/registry/model/eppcommon/EppXmlTransformer.java +++ b/core/src/main/java/google/registry/model/eppcommon/EppXmlTransformer.java @@ -20,9 +20,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import google.registry.flows.FeeExtensionXmlTagNormalizer; import google.registry.model.ImmutableObject; +import google.registry.model.domain.fee.FeeCheckResponseExtension; +import google.registry.model.domain.fee.FeeTransformResponseExtension; +import google.registry.model.domain.fee06.FeeInfoResponseExtensionV06; import google.registry.model.eppinput.EppInput; import google.registry.model.eppoutput.EppOutput; +import google.registry.model.eppoutput.EppResponse; import google.registry.util.RegistryEnvironment; import google.registry.xml.ValidationMode; import google.registry.xml.XmlException; @@ -98,8 +103,31 @@ private static byte[] marshal( return byteArrayOutputStream.toByteArray(); } + private static boolean hasFeeExtension(EppOutput eppOutput) { + if (!eppOutput.isResponse()) { + return false; + } + return eppOutput.getResponse().getExtensions().stream() + .map(EppResponse.ResponseExtension::getClass) + .filter(EppXmlTransformer::isFeeExtension) + .findAny() + .isPresent(); + } + + @VisibleForTesting + static boolean isFeeExtension(Class clazz) { + return FeeCheckResponseExtension.class.isAssignableFrom(clazz) + || FeeTransformResponseExtension.class.isAssignableFrom(clazz) + || FeeInfoResponseExtensionV06.class.isAssignableFrom(clazz); + } + public static byte[] marshal(EppOutput root, ValidationMode validation) throws XmlException { - return marshal(OUTPUT_TRANSFORMER, root, validation); + byte[] bytes = marshal(OUTPUT_TRANSFORMER, root, validation); + if (!RegistryEnvironment.PRODUCTION.equals(RegistryEnvironment.get()) + && hasFeeExtension(root)) { + return FeeExtensionXmlTagNormalizer.normalize(new String(bytes, UTF_8)).getBytes(UTF_8); + } + return bytes; } @VisibleForTesting diff --git a/core/src/test/java/google/registry/flows/FeeExtensionXmlTagNormalizerTest.java b/core/src/test/java/google/registry/flows/FeeExtensionXmlTagNormalizerTest.java index 15545ec9283..eafb0e254d8 100644 --- a/core/src/test/java/google/registry/flows/FeeExtensionXmlTagNormalizerTest.java +++ b/core/src/test/java/google/registry/flows/FeeExtensionXmlTagNormalizerTest.java @@ -17,8 +17,10 @@ import static com.google.common.truth.Truth.assertThat; import static google.registry.flows.FeeExtensionXmlTagNormalizer.feeExtensionInUseRegex; import static google.registry.flows.FeeExtensionXmlTagNormalizer.normalize; +import static google.registry.flows.FlowTestCase.verifyFeeTagNormalized; import static google.registry.model.eppcommon.EppXmlTransformer.validateOutput; import static google.registry.testing.TestDataHelper.loadFile; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.stream.Stream; import org.junit.jupiter.api.Test; @@ -41,6 +43,13 @@ void normalize_noFeeExtensions() throws Exception { assertThat(normalized).isEqualTo(xml); } + @Test + void normalize_greetingUnchanged() throws Exception { + String xml = loadFile(getClass(), "greeting.xml"); + String normalized = normalize(xml); + assertThat(normalized).isEqualTo(xml); + } + @ParameterizedTest(name = "normalize_withFeeExtension-{0}") @MethodSource("provideTestCombinations") @SuppressWarnings("unused") // Parameter 'name' is part of test case name @@ -55,6 +64,28 @@ void normalize_withFeeExtension(String name, String inputXmlFilename, String exp assertThat(normalized).isEqualTo(expected); } + // Piggyback tests for FlowTestCase.verifyFeeTagNormalized here. + @ParameterizedTest(name = "verifyFeeTagNormalized-{0}") + @MethodSource("provideTestCombinations") + @SuppressWarnings("unused") // Parameter 'name' is part of test case name + void verifyFeeTagNormalized_success( + String name, String inputXmlFilename, String expectedXmlFilename) throws Exception { + String original = loadFile(getClass(), inputXmlFilename); + String expected = loadFile(getClass(), expectedXmlFilename); + + if (name.equals("v06")) { + // Fee-06 already uses 'fee'. Non-normalized tags only appear in header. + verifyFeeTagNormalized(original); + } else { + assertThrows( + AssertionError.class, + () -> { + verifyFeeTagNormalized(original); + }); + } + verifyFeeTagNormalized(expected); + } + @SuppressWarnings("unused") static Stream provideTestCombinations() { return Stream.of( diff --git a/core/src/test/java/google/registry/flows/FlowTestCase.java b/core/src/test/java/google/registry/flows/FlowTestCase.java index e27f5a62293..ba2b31f61af 100644 --- a/core/src/test/java/google/registry/flows/FlowTestCase.java +++ b/core/src/test/java/google/registry/flows/FlowTestCase.java @@ -19,6 +19,7 @@ import static com.google.common.collect.Sets.difference; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static google.registry.flows.FlowUtils.marshalWithLenientRetry; import static google.registry.model.eppcommon.EppXmlTransformer.marshal; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.testing.DatabaseHelper.stripBillingEventId; @@ -55,6 +56,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; import javax.annotation.Nullable; import org.joda.time.DateTime; import org.junit.jupiter.api.BeforeEach; @@ -284,6 +286,7 @@ public EppOutput runFlowAssertResponse( Arrays.toString(marshal(output, ValidationMode.LENIENT))), e); } + verifyFeeTagNormalized(new String(marshalWithLenientRetry(output), UTF_8)); return output; } @@ -298,4 +301,14 @@ public EppOutput dryRunFlowAssertResponse(String xml, String... ignoredPaths) th public EppOutput runFlowAssertResponse(String xml, String... ignoredPaths) throws Exception { return runFlowAssertResponse(CommitMode.LIVE, UserPrivileges.NORMAL, xml, ignoredPaths); } + + // Pattern for non-normalized tags in use. Occurrences in namespace declarations ignored. + private static final Pattern NON_NORMALIZED_FEE_TAGS = + Pattern.compile("\\bfee11:|\\bfee12:|\\bfee_1_00:"); + + static void verifyFeeTagNormalized(String xml) { + assertWithMessage("Unexpected un-normalized Fee tags found in message.") + .that(NON_NORMALIZED_FEE_TAGS.matcher(xml).find()) + .isFalse(); + } } diff --git a/core/src/test/java/google/registry/model/eppcommon/EppXmlTransformerTest.java b/core/src/test/java/google/registry/model/eppcommon/EppXmlTransformerTest.java index 5d77124b7c3..2054769e665 100644 --- a/core/src/test/java/google/registry/model/eppcommon/EppXmlTransformerTest.java +++ b/core/src/test/java/google/registry/model/eppcommon/EppXmlTransformerTest.java @@ -15,18 +15,48 @@ package google.registry.model.eppcommon; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static google.registry.model.eppcommon.EppXmlTransformer.isFeeExtension; import static google.registry.model.eppcommon.EppXmlTransformer.unmarshal; import static google.registry.testing.TestDataHelper.loadBytes; import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.common.collect.ImmutableSet; +import google.registry.model.domain.bulktoken.BulkTokenResponseExtension; +import google.registry.model.domain.launch.LaunchCheckResponseExtension; +import google.registry.model.domain.rgp.RgpInfoExtension; +import google.registry.model.domain.secdns.SecDnsInfoExtension; import google.registry.model.eppinput.EppInput; import google.registry.model.eppoutput.EppOutput; +import google.registry.model.eppoutput.EppResponse; import google.registry.util.RegistryEnvironment; +import jakarta.xml.bind.annotation.XmlElementRef; +import jakarta.xml.bind.annotation.XmlElementRefs; +import java.util.Arrays; import org.junit.jupiter.api.Test; /** Tests for {@link EppXmlTransformer}. */ class EppXmlTransformerTest { + // Non-fee extensions allowed in {@code Response.extensions}. + private static final ImmutableSet> NON_FEE_EXTENSIONS = + ImmutableSet.of( + BulkTokenResponseExtension.class, + LaunchCheckResponseExtension.class, + RgpInfoExtension.class, + SecDnsInfoExtension.class); + + @Test + void isFeeExtension_eppResponse() throws Exception { + var xmlRefs = + EppResponse.class.getDeclaredField("extensions").getAnnotation(XmlElementRefs.class); + Arrays.stream(xmlRefs.value()) + .map(XmlElementRef::type) + .filter(type -> !NON_FEE_EXTENSIONS.contains(type)) + .forEach( + type -> assertWithMessage(type.getSimpleName()).that(isFeeExtension(type)).isTrue()); + } + @Test void testUnmarshalingEppInput() throws Exception { EppInput input = unmarshal(EppInput.class, loadBytes(getClass(), "contact_info.xml").read());