Skip to content

Commit 43819cd

Browse files
committed
Add CSV translator
1 parent e7627a5 commit 43819cd

File tree

5 files changed

+319
-1
lines changed

5 files changed

+319
-1
lines changed

api/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies {
1313

1414
testImplementation(project(":testing"))
1515
testImplementation(libs.testng)
16+
runtimeOnly(project(":basicdataset"))
17+
testImplementation(project(":basicdataset"))
1618
testImplementation(libs.slf4j.simple)
1719
testRuntimeOnly(project(":engines:pytorch:pytorch-model-zoo"))
1820
testRuntimeOnly(project(":engines:pytorch:pytorch-jni"))

api/src/main/java/ai/djl/translate/NoopServingTranslatorFactory.java

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.google.gson.JsonElement;
2626
import com.google.gson.JsonObject;
2727

28+
import java.lang.reflect.Constructor;
29+
import java.lang.reflect.Method;
2830
import java.lang.reflect.Type;
2931
import java.util.ArrayList;
3032
import java.util.Arrays;
@@ -38,6 +40,12 @@
3840
/** A {@link TranslatorFactory} that creates a {@code RawTranslator} instance. */
3941
public class NoopServingTranslatorFactory implements TranslatorFactory {
4042

43+
private static final Object LOCK = new Object();
44+
private static Class<?> csvTranslatorClass;
45+
private static Constructor<?> csvConstructor;
46+
private static Method csvProcessInputMethod;
47+
private static Method csvProcessOutputMethod;
48+
4149
/** {@inheritDoc} */
4250
@Override
4351
public Set<Pair<Type, Type>> getSupportedTypes() {
@@ -59,9 +67,40 @@ public <I, O> Translator<I, O> newInstance(
5967
static final class NoopServingTranslator implements Translator<Input, Output> {
6068

6169
private Batchifier batchifier;
70+
private Object csvTranslator;
71+
private Method csvProcessInput;
72+
private Method csvProcessOutput;
6273

6374
NoopServingTranslator(Batchifier batchifier) {
6475
this.batchifier = batchifier;
76+
initializeCsvTranslator();
77+
}
78+
79+
private void initializeCsvTranslator() {
80+
try {
81+
// Use cached reflection objects if available
82+
if (csvTranslatorClass == null) {
83+
synchronized (LOCK) {
84+
if (csvTranslatorClass == null) {
85+
csvTranslatorClass =
86+
Class.forName("ai.djl.basicdataset.tabular.CsvTranslator");
87+
csvConstructor = csvTranslatorClass.getConstructor(Map.class);
88+
csvProcessInputMethod =
89+
csvTranslatorClass.getMethod(
90+
"processInput", TranslatorContext.class, String.class);
91+
csvProcessOutputMethod =
92+
csvTranslatorClass.getMethod(
93+
"processOutput", TranslatorContext.class, NDList.class);
94+
}
95+
}
96+
}
97+
csvTranslator = csvConstructor.newInstance(Collections.emptyMap());
98+
this.csvProcessInput = NoopServingTranslatorFactory.csvProcessInputMethod;
99+
this.csvProcessOutput = NoopServingTranslatorFactory.csvProcessOutputMethod;
100+
} catch (ReflectiveOperationException e) {
101+
// CSV translator not available - silently continue without CSV support
102+
csvTranslator = null;
103+
}
65104
}
66105

67106
/** {@inheritDoc} */
@@ -82,7 +121,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE
82121
if (pos > 0) {
83122
contentType = contentType.substring(0, pos);
84123
}
85-
if ("application/json".equals(contentType)) {
124+
if ("application/json".equalsIgnoreCase(contentType)) {
86125
String data = input.getData().getAsString();
87126
JsonElement element = JsonUtils.GSON.fromJson(data, JsonElement.class);
88127
if (element.isJsonObject()) {
@@ -97,6 +136,12 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE
97136
} else {
98137
throw new TranslateException("Input is not a supported json format");
99138
}
139+
} else if ("text/csv".equalsIgnoreCase(contentType)) {
140+
if (csvTranslator == null) {
141+
throw new TranslateException(
142+
"CSV support not available. Add basicdataset dependency.");
143+
}
144+
return processCsvInput(ctx, input.getData().getAsString());
100145
}
101146
}
102147

@@ -127,6 +172,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
127172
|| "tensor/safetensors".equalsIgnoreCase(contentType)) {
128173
output.add(list.encode(NDList.Encoding.SAFETENSORS));
129174
output.addProperty("Content-Type", "tensor/safetensors");
175+
} else if ("text/csv".equalsIgnoreCase(accept)) {
176+
if (csvTranslator == null) {
177+
throw new IllegalArgumentException(
178+
"CSV support not available. Add basicdataset dependency.");
179+
}
180+
String csvOutput = processCsvOutput(ctx, list);
181+
output.add(csvOutput);
182+
output.addProperty("Content-Type", "text/csv");
130183
} else if ("application/json".equalsIgnoreCase(accept)
131184
|| "application/json".equalsIgnoreCase(contentType)) {
132185
List<Object> ret;
@@ -141,13 +194,39 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
141194
Map<String, List<Object>> map = new ConcurrentHashMap<>();
142195
map.put("predictions", ret);
143196
output.add("predictions", BytesSupplier.wrapAsJson(map));
197+
144198
} else {
145199
output.add(list.encode());
146200
output.addProperty("Content-Type", "tensor/ndlist");
147201
}
148202
return output;
149203
}
150204

205+
// --- CSV helper methods ---
206+
207+
private NDList processCsvInput(TranslatorContext ctx, String csvData)
208+
throws TranslateException {
209+
try {
210+
return (NDList) csvProcessInput.invoke(csvTranslator, ctx, csvData);
211+
} catch (ReflectiveOperationException e) {
212+
Throwable cause = e.getCause();
213+
if (cause instanceof TranslateException) {
214+
TranslateException te = (TranslateException) cause;
215+
te.addSuppressed(e);
216+
throw te;
217+
}
218+
throw new TranslateException("Failed to process CSV input", e);
219+
}
220+
}
221+
222+
private String processCsvOutput(TranslatorContext ctx, NDList list) {
223+
try {
224+
return (String) csvProcessOutput.invoke(csvTranslator, ctx, list);
225+
} catch (ReflectiveOperationException e) {
226+
throw new IllegalStateException("Failed to process CSV output", e);
227+
}
228+
}
229+
151230
private NDList toNDList(NDManager manager, JsonElement element) {
152231
JsonElement e = element.getAsJsonArray().get(0);
153232
if (e.isJsonArray()) {

api/src/test/java/ai/djl/translate/NoopServingTranslatorFactoryTest.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public void testNoopTranslatorFactory() throws ModelException, IOException, Tran
5151

5252
try (ZooModel<Input, Output> model = criteria.loadModel();
5353
Predictor<Input, Output> predictor = model.newPredictor()) {
54+
5455
Input in = new Input();
5556
in.addProperty("Content-Type", "application/json; charset=UTF-8");
5657
Map<String, List<List<Number>>> data = new ConcurrentHashMap<>();
@@ -62,6 +63,81 @@ public void testNoopTranslatorFactory() throws ModelException, IOException, Tran
6263
Output out = predictor.predict(in);
6364
BytesSupplier outData = out.getData();
6465
Assert.assertEquals(outData.getAsString(), "{\"predictions\":[[1.0,0.1],[2.0,0.2]]}");
66+
67+
// CSV input / JSON output
68+
Input csvJsonIn = new Input();
69+
csvJsonIn.addProperty("Content-Type", "text/csv");
70+
csvJsonIn.addProperty("Accept", "application/json");
71+
csvJsonIn.add("1.0,0.1\n2.0,0.2\n");
72+
Output csvJsonOut = predictor.predict(csvJsonIn);
73+
Assert.assertEquals(
74+
csvJsonOut.getData().getAsString(), "{\"predictions\":[[1.0,0.1],[2.0,0.2]]}");
75+
76+
// CSV input / CSV output
77+
Input csvCsvIn = new Input();
78+
csvCsvIn.addProperty("Content-Type", "text/csv");
79+
csvCsvIn.addProperty("Accept", "text/csv");
80+
csvCsvIn.add("1.0,0.1\n2.0,0.2\n");
81+
Output csvCsvOut = predictor.predict(csvCsvIn);
82+
Assert.assertEquals(csvCsvOut.getData().getAsString(), "1.0,0.1\n2.0,0.2\n");
83+
84+
// Uneven rows should fail
85+
Input unevenRowsIn = new Input();
86+
unevenRowsIn.addProperty("Content-Type", "text/csv");
87+
unevenRowsIn.add("1.0,0.1\n2.0,0.2,0.3\n");
88+
try {
89+
predictor.predict(unevenRowsIn);
90+
Assert.fail("Should have thrown exception for uneven rows");
91+
} catch (Exception e) {
92+
String msg = e.getMessage();
93+
if (e.getCause() != null) {
94+
msg += " | cause: " + e.getCause().getMessage();
95+
}
96+
Assert.assertTrue(
97+
msg.contains("columns, expected"), "Unexpected exception message: " + msg);
98+
}
99+
100+
// Non-numeric should fail
101+
Input nonNumericIn = new Input();
102+
nonNumericIn.addProperty("Content-Type", "text/csv");
103+
nonNumericIn.add("1.0,hello\n2.0,0.2\n");
104+
try {
105+
predictor.predict(nonNumericIn);
106+
Assert.fail("Should have thrown exception for non-numeric data");
107+
} catch (Exception e) {
108+
String msg = e.getMessage();
109+
if (e.getCause() != null) {
110+
msg += " | cause: " + e.getCause().getMessage();
111+
}
112+
Assert.assertTrue(
113+
msg.contains("Non-numeric"), "Unexpected exception message: " + msg);
114+
}
115+
116+
// Header row should be skipped
117+
Input headerIn = new Input();
118+
headerIn.addProperty("Content-Type", "text/csv");
119+
headerIn.addProperty("Accept", "application/json");
120+
headerIn.add("feature1,feature2\n1.0,0.1\n2.0,0.2\n");
121+
Output headerOut = predictor.predict(headerIn);
122+
Assert.assertEquals(
123+
headerOut.getData().getAsString(), "{\"predictions\":[[1.0,0.1],[2.0,0.2]]}");
124+
125+
// Empty CSV should fail
126+
Input emptyIn = new Input();
127+
emptyIn.addProperty("Content-Type", "text/csv");
128+
emptyIn.add("");
129+
try {
130+
predictor.predict(emptyIn);
131+
Assert.fail("Should have thrown exception for empty CSV");
132+
} catch (Exception e) {
133+
String msg = e.getMessage();
134+
if (e.getCause() != null) {
135+
msg += " | cause: " + e.getCause().getMessage();
136+
}
137+
Assert.assertTrue(
138+
msg.toLowerCase().contains("csv") && msg.toLowerCase().contains("empty"),
139+
"Unexpected exception message: " + msg);
140+
}
65141
}
66142
}
67143
}

basicdataset/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ plugins {
88
dependencies {
99
api(project(":api"))
1010
api(libs.apache.commons.csv)
11+
compileOnly("com.github.spotbugs:spotbugs-annotations:4.7.3")
1112

1213
// Add following dependency to your project for COCO dataset
1314
// runtimeOnly(libs.twelvemonkeys.imageio)

0 commit comments

Comments
 (0)