Skip to content

Commit 56371a4

Browse files
committed
[CELEBORN-2309] Introduce JavaDeserializerFilter to prevent deserialization attacks of CWE-502
1 parent a56f69a commit 56371a4

4 files changed

Lines changed: 554 additions & 14 deletions

File tree

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.celeborn.common.serializer;
19+
20+
import java.io.IOException;
21+
import java.io.InputStream;
22+
import java.io.InvalidClassException;
23+
import java.io.ObjectInputStream;
24+
import java.io.ObjectStreamClass;
25+
import java.lang.reflect.Method;
26+
import java.lang.reflect.Proxy;
27+
28+
import org.slf4j.Logger;
29+
import org.slf4j.LoggerFactory;
30+
31+
/**
32+
* Allowlist-based deserialization filter to prevent CWE-502 (Deserialization of Untrusted Data)
33+
* attacks on Celeborn's internal RPC channel.
34+
*
35+
* <p>Provides dual-layer defense:
36+
*
37+
* <ul>
38+
* <li><b>resolveClass allowlist</b> — enforced on all JDK versions via {@link
39+
* #createValidatingInputStream}. Blocks any class whose name does not match an allowed
40+
* package prefix.
41+
* <li><b>JVM-level ObjectInputFilter</b> — enforced on JDK 9+ via {@link #apply}. Adds resource
42+
* limits (maxdepth, maxarray, maxrefs, maxbytes) and logs rejected classes. Accessed through
43+
* reflection to maintain JDK 8 compatibility; gracefully degrades to no-op on older JDKs.
44+
* </ul>
45+
*
46+
* @see <a href="https://cwe.mitre.org/data/definitions/502.html">CWE-502</a>
47+
*/
48+
public class JavaDeserializerFilter {
49+
private static final Logger LOG = LoggerFactory.getLogger(JavaDeserializerFilter.class);
50+
51+
private static final String[] DEFAULT_ALLOWED_PACKAGES = {
52+
"java.", "scala.", "org.apache.celeborn.", "com.google.protobuf.", "["
53+
};
54+
55+
// JDK 9+ ObjectInputFilter API handles, resolved via reflection for JDK 8 compatibility.
56+
// All fields are null when running on JDK 8.
57+
private static final Method SET_FILTER_METHOD; // ObjectInputStream.setObjectInputFilter
58+
private static final Method CREATE_FILTER_METHOD; // ObjectInputFilter.Config.createFilter
59+
private static final Object STATUS_REJECTED; // ObjectInputFilter.Status.REJECTED
60+
private static final Object STATUS_ALLOWED; // ObjectInputFilter.Status.ALLOWED
61+
private static final Method CHECK_METHOD; // ObjectInputFilter.checkInput
62+
private static final Method SERIAL_CLASS_METHOD; // ObjectInputFilter.FilterInfo.serialClass
63+
private static final Class<?>[] FILTER_CLASSES; // [ObjectInputFilter.class] for Proxy
64+
private static final ClassLoader FILTER_CLASS_LOADER;
65+
66+
// All-or-nothing initialization: if any lookup fails, all handles remain null (JDK 8 path).
67+
static {
68+
Method setMethod;
69+
Method createMethod;
70+
Object rejected;
71+
Object allowed;
72+
Method checkMtd;
73+
Method serialClassMtd;
74+
Class<?>[] filterClasses;
75+
ClassLoader filterCl;
76+
try {
77+
Class<?> filterCls = Class.forName("java.io.ObjectInputFilter");
78+
Class<?> filterInfoCls = Class.forName("java.io.ObjectInputFilter$FilterInfo");
79+
setMethod = ObjectInputStream.class.getMethod("setObjectInputFilter", filterCls);
80+
createMethod =
81+
Class.forName("java.io.ObjectInputFilter$Config").getMethod("createFilter", String.class);
82+
Class<?> statusCls = Class.forName("java.io.ObjectInputFilter$Status");
83+
rejected = statusCls.getField("REJECTED").get(null);
84+
allowed = statusCls.getField("ALLOWED").get(null);
85+
checkMtd = filterCls.getMethod("checkInput", filterInfoCls);
86+
serialClassMtd = filterInfoCls.getMethod("serialClass");
87+
filterClasses = new Class<?>[] {filterCls};
88+
filterCl = filterCls.getClassLoader();
89+
if (filterCl == null) {
90+
filterCl = JavaDeserializerFilter.class.getClassLoader();
91+
}
92+
} catch (Exception ignored) {
93+
// JDK 8: ObjectInputFilter does not exist — force all handles to null.
94+
setMethod = null;
95+
createMethod = null;
96+
rejected = null;
97+
allowed = null;
98+
checkMtd = null;
99+
serialClassMtd = null;
100+
filterClasses = null;
101+
filterCl = null;
102+
}
103+
SET_FILTER_METHOD = setMethod;
104+
CREATE_FILTER_METHOD = createMethod;
105+
STATUS_REJECTED = rejected;
106+
STATUS_ALLOWED = allowed;
107+
CHECK_METHOD = checkMtd;
108+
SERIAL_CLASS_METHOD = serialClassMtd;
109+
FILTER_CLASSES = filterClasses;
110+
FILTER_CLASS_LOADER = filterCl;
111+
}
112+
113+
private final String[] allowedPackages;
114+
/**
115+
* Cached JDK 9+ ObjectInputFilter proxy that delegates to the base filter and logs rejections.
116+
*/
117+
private final Object loggingFilter;
118+
119+
private JavaDeserializerFilter(
120+
String[] allowedPackages,
121+
int maxDepth,
122+
int maxArrayLength,
123+
long maxReferences,
124+
long maxStreamBytes) {
125+
if (allowedPackages == null || allowedPackages.length == 0) {
126+
throw new IllegalArgumentException("allowedPackages must not be null or empty");
127+
}
128+
for (String allowedPackage : allowedPackages) {
129+
if (allowedPackage == null || allowedPackage.isEmpty()) {
130+
throw new IllegalArgumentException("allowedPackages entry must not be null or empty");
131+
}
132+
}
133+
this.allowedPackages = allowedPackages.clone();
134+
this.loggingFilter =
135+
CREATE_FILTER_METHOD != null
136+
? createLoggingFilter(
137+
buildFilterPattern(
138+
allowedPackages, maxDepth, maxArrayLength, maxReferences, maxStreamBytes))
139+
: null;
140+
}
141+
142+
/** Wraps the JDK base filter in a Proxy that logs REJECTED classes at WARN level. */
143+
private Object createLoggingFilter(String filterPattern) {
144+
try {
145+
Object baseFilter = CREATE_FILTER_METHOD.invoke(null, filterPattern);
146+
Object proxy =
147+
Proxy.newProxyInstance(
148+
FILTER_CLASS_LOADER,
149+
FILTER_CLASSES,
150+
(Object p, Method method, Object[] args) -> {
151+
Object result = CHECK_METHOD.invoke(baseFilter, args);
152+
if (STATUS_REJECTED.equals(result)) {
153+
try {
154+
Class<?> serialClass = (Class<?>) SERIAL_CLASS_METHOD.invoke(args[0]);
155+
if (serialClass != null) {
156+
// JDK ObjectInputFilter unwraps arrays to their primitive component type,
157+
// which then gets rejected by "!*" since primitives match no pattern.
158+
// Override REJECTED for arrays whose base component type is primitive
159+
// or whose component class name is in the allowlist.
160+
if (serialClass.isArray()) {
161+
Class<?> component = serialClass;
162+
while (component.isArray()) {
163+
component = component.getComponentType();
164+
}
165+
if (component.isPrimitive() || isAllowed(component.getName())) {
166+
return STATUS_ALLOWED;
167+
}
168+
}
169+
LOG.error("ObjectInputFilter REJECTED class: {}.", serialClass.getName());
170+
}
171+
} catch (Exception exception) {
172+
LOG.error("Error logging rejected class.", exception);
173+
}
174+
}
175+
return result;
176+
});
177+
LOG.debug("Created deserialization filter with pattern: {}.", filterPattern);
178+
return proxy;
179+
} catch (Exception exception) {
180+
LOG.error("Failed to create deserialization filter.", exception);
181+
return null;
182+
}
183+
}
184+
185+
private boolean isAllowed(String className) {
186+
for (String allowedPackage : allowedPackages) {
187+
if (className.startsWith(allowedPackage)) {
188+
return true;
189+
}
190+
}
191+
return false;
192+
}
193+
194+
/**
195+
* Builds a JDK ObjectInputFilter pattern string. Uses "**" for package prefixes to match
196+
* subpackages at any depth, and ends with "!*" to reject everything not explicitly allowed. A
197+
* limit value of 0 means "no limit" per the JDK ObjectInputFilter specification.
198+
*/
199+
private static String buildFilterPattern(
200+
String[] allowedPackages,
201+
int maxDepth,
202+
int maxArrayLength,
203+
long maxReferences,
204+
long maxStreamBytes) {
205+
StringBuilder pattern = new StringBuilder();
206+
pattern.append("maxdepth=").append(maxDepth);
207+
pattern.append(";maxarray=").append(maxArrayLength);
208+
pattern.append(";maxrefs=").append(maxReferences);
209+
pattern.append(";maxbytes=").append(maxStreamBytes).append(';');
210+
for (String allowedPackage : allowedPackages) {
211+
// Skip "[" — array types are handled by the logging filter proxy which
212+
// checks component types. The JDK filter unwraps arrays to their primitive
213+
// component type, making pattern-based matching ineffective for arrays.
214+
if (allowedPackage.equals("[")) {
215+
continue;
216+
}
217+
pattern.append(allowedPackage).append("**;");
218+
}
219+
pattern.append("!*");
220+
return pattern.toString();
221+
}
222+
223+
/** Creates a filter with custom allowed packages and resource limits. */
224+
public static JavaDeserializerFilter create(
225+
String[] allowedPackages,
226+
int maxDepth,
227+
int maxArrayLength,
228+
long maxReferences,
229+
long maxStreamBytes) {
230+
return new JavaDeserializerFilter(
231+
allowedPackages, maxDepth, maxArrayLength, maxReferences, maxStreamBytes);
232+
}
233+
234+
/** Creates a filter with default allowed packages and no resource limits. */
235+
public static JavaDeserializerFilter create() {
236+
return new JavaDeserializerFilter(DEFAULT_ALLOWED_PACKAGES, 0, 0, 0, 0);
237+
}
238+
239+
/** Applies the JDK 9+ ObjectInputFilter to the stream. No-op on JDK 8. */
240+
public void apply(ObjectInputStream inputStream) {
241+
if (loggingFilter == null) {
242+
return;
243+
}
244+
try {
245+
SET_FILTER_METHOD.invoke(inputStream, loggingFilter);
246+
} catch (Exception exception) {
247+
LOG.error("Failed to apply logging filter.", exception);
248+
}
249+
}
250+
251+
/**
252+
* Creates an ObjectInputStream that checks each class against the allowlist before resolving.
253+
*
254+
* @param classLoader the class loader for resolving classes; null uses the bootstrap loader.
255+
*/
256+
public ObjectInputStream createValidatingInputStream(
257+
InputStream inputStream, ClassLoader classLoader) throws IOException {
258+
return new ObjectInputStream(inputStream) {
259+
@Override
260+
protected Class<?> resolveClass(ObjectStreamClass desc)
261+
throws IOException, ClassNotFoundException {
262+
String className = desc.getName();
263+
Class<?> primitive = resolvePrimitiveClass(className);
264+
if (primitive != null) {
265+
return primitive;
266+
}
267+
if (!isClassAllowed(className)) {
268+
LOG.error("REJECTED class during deserialization: {}.", className);
269+
throw new InvalidClassException(className, "Blocked");
270+
}
271+
return Class.forName(className, false, classLoader);
272+
}
273+
274+
@Override
275+
protected Class<?> resolveProxyClass(String[] interfaces)
276+
throws IOException, ClassNotFoundException {
277+
ClassLoader cl =
278+
classLoader != null ? classLoader : Thread.currentThread().getContextClassLoader();
279+
Class<?>[] ifaceClasses = new Class<?>[interfaces.length];
280+
for (int i = 0; i < interfaces.length; i++) {
281+
if (!isClassAllowed(interfaces[i])) {
282+
LOG.error("REJECTED proxy interface during deserialization: {}.", interfaces[i]);
283+
throw new InvalidClassException(interfaces[i], "Blocked proxy interface");
284+
}
285+
ifaceClasses[i] = Class.forName(interfaces[i], false, cl);
286+
}
287+
return Proxy.getProxyClass(cl, ifaceClasses);
288+
}
289+
};
290+
}
291+
292+
private static Class<?> resolvePrimitiveClass(String name) {
293+
switch (name) {
294+
case "boolean":
295+
return boolean.class;
296+
case "byte":
297+
return byte.class;
298+
case "char":
299+
return char.class;
300+
case "short":
301+
return short.class;
302+
case "int":
303+
return int.class;
304+
case "long":
305+
return long.class;
306+
case "float":
307+
return float.class;
308+
case "double":
309+
return double.class;
310+
case "void":
311+
return void.class;
312+
default:
313+
return null;
314+
}
315+
}
316+
317+
public boolean isClassAllowed(String className) {
318+
for (String allowedPackage : allowedPackages) {
319+
if (className.startsWith(allowedPackage)) {
320+
return true;
321+
}
322+
}
323+
return false;
324+
}
325+
}

0 commit comments

Comments
 (0)