Skip to content

Commit 55b928d

Browse files
authored
chore: improve XXE attack prevention (#2008)
Reported by Aikido earlier today.
1 parent 1057837 commit 55b928d

File tree

8 files changed

+206
-287
lines changed

8 files changed

+206
-287
lines changed

benchmark/src/main/java/ai/timefold/solver/benchmark/impl/io/PlannerBenchmarkConfigIO.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ public PlannerBenchmarkConfig read(Reader reader) {
2828
* the solver element in benchmark configuration and thus the solver element's namespace needs to be overridden.
2929
*/
3030
return genericJaxbIO.readOverridingNamespace(document,
31-
ElementNamespaceOverride.of(SolverConfig.XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE));
31+
new ElementNamespaceOverride(SolverConfig.XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE));
3232
} else if (rootElementNamespace == null || rootElementNamespace.isEmpty()) {
3333
// If not, add the missing namespace to maintain backward compatibility.
3434
return genericJaxbIO.readOverridingNamespace(document,
35-
ElementNamespaceOverride.of(PlannerBenchmarkConfig.XML_ELEMENT_NAME, PlannerBenchmarkConfig.XML_NAMESPACE),
36-
ElementNamespaceOverride.of(SolverConfig.XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE));
35+
new ElementNamespaceOverride(PlannerBenchmarkConfig.XML_ELEMENT_NAME, PlannerBenchmarkConfig.XML_NAMESPACE),
36+
new ElementNamespaceOverride(SolverConfig.XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE));
3737
} else { // If there is an unexpected namespace, fail fast.
3838
String errorMessage = String.format("The <%s/> element belongs to a different namespace (%s) than expected (%s).",
3939
PlannerBenchmarkConfig.XML_ELEMENT_NAME, rootElementNamespace, PlannerBenchmarkConfig.XML_NAMESPACE);

benchmark/src/main/java/ai/timefold/solver/benchmark/impl/result/BenchmarkResultIO.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ protected PlannerBenchmarkResult readPlannerBenchmarkResult(File plannerBenchmar
9292

9393
protected PlannerBenchmarkResult read(Reader reader) {
9494
return genericJaxbIO.readOverridingNamespace(reader,
95-
ElementNamespaceOverride.of(SOLVER_CONFIG_XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE));
95+
new ElementNamespaceOverride(SOLVER_CONFIG_XML_ELEMENT_NAME, SolverConfig.XML_NAMESPACE));
9696
}
9797

9898
protected void write(PlannerBenchmarkResult plannerBenchmarkResult, Writer writer) {

benchmark/src/main/java/ai/timefold/solver/benchmark/impl/xsd/XsdAggregator.java

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,52 @@
88
import java.util.function.Predicate;
99
import java.util.function.UnaryOperator;
1010

11-
import javax.xml.XMLConstants;
1211
import javax.xml.parsers.DocumentBuilder;
1312
import javax.xml.parsers.DocumentBuilderFactory;
1413
import javax.xml.parsers.ParserConfigurationException;
1514
import javax.xml.transform.OutputKeys;
16-
import javax.xml.transform.Result;
1715
import javax.xml.transform.Transformer;
1816
import javax.xml.transform.TransformerConfigurationException;
1917
import javax.xml.transform.TransformerException;
20-
import javax.xml.transform.TransformerFactory;
2118
import javax.xml.transform.dom.DOMSource;
2219
import javax.xml.transform.stream.StreamResult;
2320

2421
import ai.timefold.solver.core.config.solver.SolverConfig;
22+
import ai.timefold.solver.core.impl.io.jaxb.GenericJaxbIO;
2523

24+
import org.jspecify.annotations.NullMarked;
2625
import org.w3c.dom.Attr;
2726
import org.w3c.dom.Document;
2827
import org.w3c.dom.Element;
2928
import org.w3c.dom.Node;
30-
import org.w3c.dom.NodeList;
3129
import org.xml.sax.SAXException;
3230

3331
/**
3432
* This class merges solver.xsd and benchmark.xsd into a single XML Schema file that contains both Solver and Benchmark XML
3533
* types under a single namespace of the benchmark.xsd.
36-
*
34+
* <p>
3735
* Both solver.xsd and benchmark.xsd declare its own namespace as they are supposed to be used for different purposes. As the
3836
* benchmark configuration contains solver configuration, the benchmark.xsd imports the solver.xsd. To avoid distributing
3937
* dependent schemas and using prefixes in users' XML configuration files, the types defined by solver.xsd are merged to
4038
* the benchmark.xsd under its namespace.
4139
*/
40+
@NullMarked
4241
public final class XsdAggregator {
4342

4443
private static final String TNS_PREFIX = "tns";
4544

4645
public static void main(String[] args) {
4746
if (args.length != 3) {
48-
String msg = "The XSD Aggregator expects 3 arguments:\n"
49-
+ "1) a path to the solver XSD file. \n"
50-
+ "2) a path to the benchmark XSD file. \n"
51-
+ "3) a path to an output file where the merged benchmark XSD should be saved to.";
47+
var msg = """
48+
The XSD Aggregator expects 3 arguments:
49+
1) a path to the solver XSD file.
50+
2) a path to the benchmark XSD file.
51+
3) a path to an output file where the merged benchmark XSD should be saved to.""";
5252
throw new IllegalArgumentException(msg);
5353
}
54-
File solverXsd = checkFileExists(new File(args[0]));
55-
File benchmarkXsd = checkFileExists(new File(args[1]));
56-
File outputXsd = new File(args[2]);
54+
var solverXsd = checkFileExists(new File(args[0]));
55+
var benchmarkXsd = checkFileExists(new File(args[1]));
56+
var outputXsd = new File(args[2]);
5757

5858
if (!outputXsd.getParentFile().exists()) {
5959
outputXsd.getParentFile().mkdirs();
@@ -71,53 +71,46 @@ private static File checkFileExists(File file) {
7171
}
7272

7373
private void mergeXmlSchemas(File solverSchemaFile, File benchmarkSchemaFile, File outputSchemaFile) {
74-
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
75-
Document solverSchema = parseXml(solverSchemaFile, factory);
76-
Element solverRootElement = solverSchema.getDocumentElement();
77-
Document benchmarkSchema = parseXml(benchmarkSchemaFile, factory);
74+
var factory = GenericJaxbIO.createDocumentBuilderFactory();
75+
var solverSchema = parseXml(solverSchemaFile, factory);
76+
var solverRootElement = solverSchema.getDocumentElement();
77+
var benchmarkSchema = parseXml(benchmarkSchemaFile, factory);
7878

7979
removeReferencesToSolverConfig(benchmarkSchema, benchmarkSchemaFile);
8080

8181
copySolverConfigTypes(benchmarkSchema, solverRootElement);
8282

83-
Transformer transformer = createTransformer();
84-
DOMSource source = new DOMSource(benchmarkSchema);
85-
Result result = new StreamResult(outputSchemaFile);
83+
var source = new DOMSource(benchmarkSchema);
84+
var result = new StreamResult(outputSchemaFile);
8685
try {
87-
transformer.transform(source, result);
86+
createTransformer().transform(source, result);
8887
} catch (TransformerException e) {
89-
throw new IllegalArgumentException(
90-
"Failed to write the resulting XSD to a file (" + outputSchemaFile.getAbsolutePath() + ").", e);
88+
throw new IllegalArgumentException("Failed to write the resulting XSD to a file (%s)."
89+
.formatted(outputSchemaFile.getAbsolutePath()), e);
9190
}
9291
}
9392

9493
private Document parseXml(File xmlFile, DocumentBuilderFactory documentBuilderFactory) {
95-
DocumentBuilder builder;
9694
try {
97-
builder = documentBuilderFactory.newDocumentBuilder();
95+
return documentBuilderFactory.newDocumentBuilder().parse(xmlFile);
9896
} catch (ParserConfigurationException e) {
99-
throw new IllegalArgumentException("Failed to create a " + DocumentBuilder.class.getName() + "instance.", e);
100-
}
101-
102-
try {
103-
return builder.parse(xmlFile);
104-
} catch (SAXException saxException) {
105-
throw new IllegalArgumentException("Failed to parse an XML file (" + xmlFile.getAbsolutePath() + ").",
106-
saxException);
107-
} catch (IOException ioException) {
108-
throw new IllegalArgumentException("Failed to open an XML file (" + xmlFile.getAbsolutePath() + ").", ioException);
97+
throw new IllegalArgumentException("Failed to create a %s instance."
98+
.formatted(DocumentBuilder.class.getSimpleName()), e);
99+
} catch (SAXException | IOException exception) {
100+
throw new IllegalArgumentException("Failed to parse an XML file (%s)."
101+
.formatted(xmlFile.getAbsolutePath()), exception);
109102
}
110103
}
111104

112105
private void removeReferencesToSolverConfig(Document benchmarkSchema, File benchmarkSchemaFile) {
113-
boolean solverNamespaceRemoved = false;
114-
boolean solverElementRefRemoved = false;
115-
boolean importRemoved = false;
106+
var solverNamespaceRemoved = false;
107+
var solverElementRefRemoved = false;
108+
var importRemoved = false;
116109

117-
NodeList nodeList = benchmarkSchema.getElementsByTagName("*");
118-
for (int i = 0; i < nodeList.getLength(); i++) {
119-
Node node = nodeList.item(i);
120-
Element element = (Element) node;
110+
var nodeList = benchmarkSchema.getElementsByTagName("*");
111+
for (var i = 0; i < nodeList.getLength(); i++) {
112+
var node = Objects.requireNonNull(nodeList.item(i));
113+
var element = (Element) node;
121114

122115
if ("xs:schema".equals(node.getNodeName())) { // Remove the solver namespace declaration.
123116
element.removeAttribute("xmlns:" + SOLVER_NAMESPACE_PREFIX);
@@ -148,31 +141,30 @@ private void removeReferencesToSolverConfig(Document benchmarkSchema, File bench
148141
* a successful validation by the resulting XML schema.
149142
*/
150143
if (!solverElementRefRemoved) {
151-
String msg = String.format("An expected reference to the solver element was not found. Check the content of (%s).",
144+
var msg = String.format("An expected reference to the solver element was not found. Check the content of (%s).",
152145
benchmarkSchemaFile);
153146
throw new AssertionError(msg);
154147
}
155148

156149
if (!solverNamespaceRemoved) {
157-
String msg = String.format("An expected namespace (%s) declaration was not found. Check the content of (%s).",
150+
var msg = String.format("An expected namespace (%s) declaration was not found. Check the content of (%s).",
158151
SolverConfig.XML_NAMESPACE, benchmarkSchemaFile);
159152
throw new AssertionError(msg);
160153
}
161154

162155
if (!importRemoved) {
163-
String msg = String.format("An expected import element was not found. Check the content of (%s).",
164-
benchmarkSchemaFile);
156+
var msg =
157+
String.format("An expected import element was not found. Check the content of (%s).", benchmarkSchemaFile);
165158
throw new AssertionError(msg);
166159
}
167160
}
168161

169162
private void copySolverConfigTypes(Document benchmarkSchema, Element solverSchemaRoot) {
170-
Element benchmarkSchemaRoot = benchmarkSchema.getDocumentElement();
171-
NodeList solverChildNodes = solverSchemaRoot.getChildNodes();
172-
for (int i = 0; i < solverChildNodes.getLength(); i++) {
173-
Node node = solverChildNodes.item(i);
174-
boolean isSolverElementDeclaration =
175-
isXsElement(node) && hasAttribute(node, "name", SolverConfig.XML_ELEMENT_NAME);
163+
var benchmarkSchemaRoot = benchmarkSchema.getDocumentElement();
164+
var solverChildNodes = solverSchemaRoot.getChildNodes();
165+
for (var i = 0; i < solverChildNodes.getLength(); i++) {
166+
var node = solverChildNodes.item(i);
167+
var isSolverElementDeclaration = isXsElement(node) && hasAttribute(node, "name", SolverConfig.XML_ELEMENT_NAME);
176168
if (!isSolverElementDeclaration) { // Skip the solver root element.
177169
benchmarkSchemaRoot.appendChild(benchmarkSchema.importNode(node, true));
178170
}
@@ -184,41 +176,31 @@ private boolean isXsElement(Node node) {
184176
}
185177

186178
private boolean hasAttribute(Node node, String attributeName, String attributeValue) {
187-
Objects.requireNonNull(node);
188-
Objects.requireNonNull(attributeName);
189-
Objects.requireNonNull(attributeValue);
190-
191-
Attr attribute = ((Element) node).getAttributeNode(attributeName);
179+
var attribute = ((Element) node).getAttributeNode(attributeName);
192180
return (attribute != null && attributeValue.equals(attribute.getValue()));
193181
}
194182

195183
private void updateNodeAttributes(Node node, Predicate<Attr> attributePredicate, UnaryOperator<String> valueFunction) {
196-
Objects.requireNonNull(node);
197-
Objects.requireNonNull(attributePredicate);
198-
Objects.requireNonNull(valueFunction);
199-
for (int i = 0; i < node.getAttributes().getLength(); i++) {
200-
Attr attribute = (Attr) node.getAttributes().item(i);
184+
for (var i = 0; i < node.getAttributes().getLength(); i++) {
185+
var attribute = (Attr) node.getAttributes().item(i);
201186
if (attributePredicate.test(attribute)) {
202187
attribute.setValue(valueFunction.apply(attribute.getValue()));
203188
}
204189
}
205190
}
206191

207192
private Transformer createTransformer() {
208-
TransformerFactory transformerFactory = TransformerFactory.newInstance();
209-
// Protect the Transformer from XXE attacks.
210-
transformerFactory.setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, "");
211-
transformerFactory.setAttribute(XMLConstants.ACCESS_EXTERNAL_STYLESHEET, "");
212-
Transformer transformer;
193+
var transformerFactory = GenericJaxbIO.createTransformerFactory();
213194
try {
214-
transformer = transformerFactory.newTransformer();
195+
var transformer = transformerFactory.newTransformer();
196+
transformer.setOutputProperty(OutputKeys.INDENT, "yes");
197+
transformer.setOutputProperty(OutputKeys.STANDALONE, "yes");
198+
transformer.setOutputProperty(OutputKeys.ENCODING, "utf-8");
199+
transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2");
200+
return transformer;
215201
} catch (TransformerConfigurationException e) {
216-
throw new IllegalArgumentException("Failed to create a " + Transformer.class.getName() + ".", e);
202+
throw new IllegalArgumentException("Failed to create a %s."
203+
.formatted(Transformer.class.getSimpleName()), e);
217204
}
218-
transformer.setOutputProperty(OutputKeys.INDENT, "yes");
219-
transformer.setOutputProperty(OutputKeys.STANDALONE, "yes");
220-
transformer.setOutputProperty(OutputKeys.ENCODING, "utf-8");
221-
transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "2");
222-
return transformer;
223205
}
224206
}

benchmark/src/test/java/ai/timefold/solver/benchmark/config/PlannerBenchmarkConfigTest.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import ai.timefold.solver.core.testdomain.TestdataSolution;
1717
import ai.timefold.solver.jackson.impl.domain.solution.JacksonSolutionFileIO;
1818
import ai.timefold.solver.persistence.common.api.domain.solution.RigidTestdataSolutionFileIO;
19-
import ai.timefold.solver.persistence.common.api.domain.solution.SolutionFileIO;
2019

2120
import org.apache.commons.io.IOUtils;
2221
import org.junit.jupiter.api.Test;
@@ -32,7 +31,7 @@ class PlannerBenchmarkConfigTest {
3231
@ParameterizedTest
3332
@ValueSource(strings = { TEST_PLANNER_BENCHMARK_CONFIG_WITHOUT_NAMESPACE, TEST_PLANNER_BENCHMARK_CONFIG_WITH_NAMESPACE })
3433
void xmlConfigFileRemainsSameAfterReadWrite(String xmlBenchmarkConfigResource) throws IOException {
35-
PlannerBenchmarkConfigIO xmlIO = new PlannerBenchmarkConfigIO();
34+
var xmlIO = new PlannerBenchmarkConfigIO();
3635
PlannerBenchmarkConfig jaxbBenchmarkConfig;
3736

3837
try (Reader reader = new InputStreamReader(
@@ -44,13 +43,13 @@ void xmlConfigFileRemainsSameAfterReadWrite(String xmlBenchmarkConfigResource) t
4443

4544
Writer stringWriter = new StringWriter();
4645
xmlIO.write(jaxbBenchmarkConfig, stringWriter);
47-
String jaxbString = stringWriter.toString();
46+
var jaxbString = stringWriter.toString();
4847

49-
String originalXml = IOUtils.toString(PlannerBenchmarkConfigTest.class.getResourceAsStream(xmlBenchmarkConfigResource),
48+
var originalXml = IOUtils.toString(PlannerBenchmarkConfigTest.class.getResourceAsStream(xmlBenchmarkConfigResource),
5049
StandardCharsets.UTF_8);
5150

5251
// During writing the benchmark config, the benchmark element's namespace is removed.
53-
String benchmarkElementWithNamespace =
52+
var benchmarkElementWithNamespace =
5453
PlannerBenchmarkConfig.XML_ELEMENT_NAME + " xmlns=\"" + PlannerBenchmarkConfig.XML_NAMESPACE + "\"";
5554
if (originalXml.contains(benchmarkElementWithNamespace)) {
5655
originalXml = originalXml.replace(benchmarkElementWithNamespace, PlannerBenchmarkConfig.XML_ELEMENT_NAME);
@@ -60,8 +59,8 @@ void xmlConfigFileRemainsSameAfterReadWrite(String xmlBenchmarkConfigResource) t
6059

6160
@Test
6261
void readAndValidateInvalidBenchmarkConfig_failsIndicatingTheIssue() {
63-
PlannerBenchmarkConfigIO xmlIO = new PlannerBenchmarkConfigIO();
64-
String benchmarkConfigXml = "<plannerBenchmark xmlns=\"https://timefold.ai/xsd/benchmark\">\n"
62+
var xmlIO = new PlannerBenchmarkConfigIO();
63+
var benchmarkConfigXml = "<plannerBenchmark xmlns=\"https://timefold.ai/xsd/benchmark\">\n"
6564
+ " <benchmarkDirectory>data</benchmarkDirectory>\n"
6665
+ " <parallelBenchmarkCount>AUTO</parallelBenchmarkCount>\n"
6766
+ " <solverBenchmark>\n"
@@ -78,19 +77,20 @@ void readAndValidateInvalidBenchmarkConfig_failsIndicatingTheIssue() {
7877
+ " </solverBenchmark>\n"
7978
+ "</plannerBenchmark>\n";
8079

81-
StringReader stringReader = new StringReader(benchmarkConfigXml);
80+
var stringReader = new StringReader(benchmarkConfigXml);
8281
assertThatExceptionOfType(TimefoldXmlSerializationException.class)
8382
.isThrownBy(() -> xmlIO.read(stringReader))
84-
.withRootCauseExactlyInstanceOf(SAXParseException.class)
83+
.havingRootCause()
84+
.isInstanceOf(SAXParseException.class)
8585
.withMessageContaining("solutionKlazz");
8686
}
8787

8888
@Test
8989
public void assignCustomSolutionIO() {
90-
ProblemBenchmarksConfig pbc = new ProblemBenchmarksConfig();
90+
var pbc = new ProblemBenchmarksConfig();
9191
pbc.setSolutionFileIOClass(RigidTestdataSolutionFileIO.class);
9292

93-
Class<? extends SolutionFileIO<?>> configured = pbc.getSolutionFileIOClass();
93+
var configured = pbc.getSolutionFileIOClass();
9494
assertThat(configured).isNotNull();
9595
}
9696

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,5 @@
11
package ai.timefold.solver.core.impl.io.jaxb;
22

3-
public class ElementNamespaceOverride {
3+
public record ElementNamespaceOverride(String elementLocalName, String namespaceOverride) {
44

5-
public static ElementNamespaceOverride of(String elementLocalName, String namespaceOverride) {
6-
return new ElementNamespaceOverride(elementLocalName, namespaceOverride);
7-
}
8-
9-
private final String elementLocalName;
10-
private final String namespaceOverride;
11-
12-
private ElementNamespaceOverride(String elementLocalName, String namespaceOverride) {
13-
this.elementLocalName = elementLocalName;
14-
this.namespaceOverride = namespaceOverride;
15-
}
16-
17-
public String getElementLocalName() {
18-
return elementLocalName;
19-
}
20-
21-
public String getNamespaceOverride() {
22-
return namespaceOverride;
23-
}
24-
25-
@Override
26-
public String toString() {
27-
return "ElementNamespaceOverride{" +
28-
"elementLocalName='" + elementLocalName + '\'' +
29-
", namespaceOverride='" + namespaceOverride + '\'' +
30-
'}';
31-
}
325
}

0 commit comments

Comments
 (0)