Skip to content

Commit 138d8a6

Browse files
committed
Made suggested changes and optimizations
1 parent 43819cd commit 138d8a6

File tree

3 files changed

+91
-71
lines changed

3 files changed

+91
-71
lines changed

api/src/main/java/ai/djl/translate/NoopServingTranslatorFactory.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ private void initializeCsvTranslator() {
8484
if (csvTranslatorClass == null) {
8585
csvTranslatorClass =
8686
Class.forName("ai.djl.basicdataset.tabular.CsvTranslator");
87-
csvConstructor = csvTranslatorClass.getConstructor(Map.class);
87+
csvConstructor = csvTranslatorClass.getConstructor();
8888
csvProcessInputMethod =
8989
csvTranslatorClass.getMethod(
9090
"processInput", TranslatorContext.class, String.class);
@@ -94,7 +94,7 @@ private void initializeCsvTranslator() {
9494
}
9595
}
9696
}
97-
csvTranslator = csvConstructor.newInstance(Collections.emptyMap());
97+
csvTranslator = csvConstructor.newInstance();
9898
this.csvProcessInput = NoopServingTranslatorFactory.csvProcessInputMethod;
9999
this.csvProcessOutput = NoopServingTranslatorFactory.csvProcessOutputMethod;
100100
} catch (ReflectiveOperationException e) {
@@ -202,8 +202,6 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
202202
return output;
203203
}
204204

205-
// --- CSV helper methods ---
206-
207205
private NDList processCsvInput(TranslatorContext ctx, String csvData)
208206
throws TranslateException {
209207
try {

api/src/test/java/ai/djl/translate/NoopServingTranslatorFactoryTest.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,15 @@ public void testNoopTranslatorFactory() throws ModelException, IOException, Tran
113113
msg.contains("Non-numeric"), "Unexpected exception message: " + msg);
114114
}
115115

116-
// Header row should be skipped
117-
Input headerIn = new Input();
118-
headerIn.addProperty("Content-Type", "text/csv");
119-
headerIn.addProperty("Accept", "application/json");
120-
headerIn.add("feature1,feature2\n1.0,0.1\n2.0,0.2\n");
121-
Output headerOut = predictor.predict(headerIn);
116+
// CSV with header should skip header and return predictions
117+
Input csvWithHeaderIn = new Input();
118+
csvWithHeaderIn.addProperty("Content-Type", "text/csv");
119+
csvWithHeaderIn.addProperty("Accept", "application/json");
120+
csvWithHeaderIn.add("feature1,feature2\n1.0,0.1\n2.0,0.2\n");
121+
Output csvWithHeaderOut = predictor.predict(csvWithHeaderIn);
122122
Assert.assertEquals(
123-
headerOut.getData().getAsString(), "{\"predictions\":[[1.0,0.1],[2.0,0.2]]}");
123+
csvWithHeaderOut.getData().getAsString(),
124+
"{\"predictions\":[[1.0,0.1],[2.0,0.2]]}");
124125

125126
// Empty CSV should fail
126127
Input emptyIn = new Input();

basicdataset/src/main/java/ai/djl/basicdataset/tabular/CsvTranslator.java

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import ai.djl.ndarray.NDArray;
1616
import ai.djl.ndarray.NDList;
17+
import ai.djl.ndarray.NDManager;
18+
import ai.djl.ndarray.types.DataType;
1719
import ai.djl.translate.TranslateException;
1820
import ai.djl.translate.Translator;
1921
import ai.djl.translate.TranslatorContext;
@@ -26,61 +28,64 @@
2628
import java.io.IOException;
2729
import java.io.StringReader;
2830
import java.io.StringWriter;
29-
import java.util.ArrayList;
3031
import java.util.List;
31-
import java.util.Map;
3232

3333
/** A {@link Translator} that converts between CSV text and {@link NDList}. */
3434
public class CsvTranslator implements Translator<String, String> {
3535

3636
private final CSVFormat csvFormat;
3737

38-
/**
39-
* Constructs a CsvTranslator.
40-
*
41-
* @param arguments the arguments (unused but required for reflection)
42-
*/
43-
@SuppressWarnings({"PMD.UnusedFormalParameter", "deprecation"})
44-
public CsvTranslator(Map<String, ?> arguments) {
45-
this.csvFormat = CSVFormat.newFormat(',').withRecordSeparator("\n");
38+
/** Constructs a CsvTranslator. */
39+
public CsvTranslator() {
40+
this.csvFormat = CSVFormat.INFORMIX_UNLOAD_CSV;
4641
}
4742

4843
/** {@inheritDoc} */
4944
@Override
5045
public NDList processInput(TranslatorContext ctx, String csvData) throws TranslateException {
51-
try (CSVParser parser = csvFormat.parse(new StringReader(csvData))) {
52-
List<float[]> rows = new ArrayList<>();
53-
int expectedCols = -1;
54-
boolean headerSkipped = false;
55-
56-
for (CSVRecord record : parser) {
57-
if (!headerSkipped && isHeaderRow(record)) {
58-
headerSkipped = true;
59-
continue;
60-
}
46+
StringReader reader = new StringReader(csvData);
47+
48+
try (CSVParser parser = csvFormat.parse(reader)) {
49+
List<CSVRecord> records = parser.getRecords();
50+
if (records.isEmpty()) {
51+
throw new TranslateException("CSV data is empty");
52+
}
53+
54+
int rowStart = 0;
55+
56+
// Skip header if present
57+
if (isHeaderRow(records.get(0))) {
58+
rowStart = 1;
59+
}
6160

62-
if (expectedCols == -1) {
63-
expectedCols = record.size();
64-
} else if (record.size() != expectedCols) {
61+
int numRows = records.size() - rowStart;
62+
int expectedCols =
63+
records.get(rowStart).size(); // assume first data row sets column count
64+
65+
float[][] data = new float[numRows][expectedCols];
66+
67+
for (int i = 0; i < numRows; i++) {
68+
CSVRecord record = records.get(i + rowStart);
69+
70+
if (record.size() != expectedCols) {
6571
throw new TranslateException(
66-
String.format(
67-
"Row %d has %d columns, expected %d",
68-
record.getRecordNumber(), record.size(), expectedCols));
72+
"Row "
73+
+ record.getRecordNumber()
74+
+ " has "
75+
+ record.size()
76+
+ " columns, expected "
77+
+ expectedCols);
6978
}
7079

71-
float[] row = new float[expectedCols];
7280
for (int j = 0; j < expectedCols; j++) {
73-
row[j] = parseFloat(record.get(j), record.getRecordNumber(), j);
81+
data[i][j] = parseFloat(record.get(j), record.getRecordNumber(), j);
7482
}
75-
rows.add(row);
7683
}
7784

78-
if (rows.isEmpty()) {
79-
throw new TranslateException("CSV data is empty");
80-
}
85+
NDManager manager = ctx.getNDManager();
86+
NDArray ndArray = manager.create(data);
87+
return new NDList(ndArray);
8188

82-
float[][] data = rows.toArray(new float[0][]);
83-
return new NDList(ctx.getNDManager().create(data));
8489
} catch (IOException e) {
8590
throw new TranslateException("Failed to process CSV input", e);
8691
}
@@ -99,49 +104,65 @@ private boolean isNumeric(String str) {
99104
if (str == null || str.isEmpty()) {
100105
return false;
101106
}
102-
try {
103-
Float.parseFloat(str);
104-
return true;
105-
} catch (NumberFormatException e) {
106-
return false;
107+
int len = str.length();
108+
for (int i = 0; i < len; i++) {
109+
char c = str.charAt(i);
110+
if ((c < '0' || c > '9') && c != '-' && c != '.' && c != 'e' && c != 'E' && c != '+') {
111+
return false;
112+
}
107113
}
114+
return true;
108115
}
109116

110117
private float parseFloat(String value, long row, int col) throws TranslateException {
118+
if (value == null || value.isEmpty()) {
119+
return Float.NaN;
120+
}
121+
122+
int len = value.length();
123+
if (len > 0 && (value.charAt(0) <= ' ' || value.charAt(len - 1) <= ' ')) {
124+
value = value.trim();
125+
if (value.isEmpty()) {
126+
return Float.NaN;
127+
}
128+
}
129+
111130
try {
112-
return Float.parseFloat(value.trim());
131+
return Float.parseFloat(value);
113132
} catch (NumberFormatException e) {
114133
throw new TranslateException(
115-
String.format("Non-numeric value '%s' at row %d, column %d", value, row, col),
116-
e);
134+
"Non-numeric value '" + value + "' at row " + row + ", column " + col, e);
117135
}
118136
}
119137

120138
/** {@inheritDoc} */
121139
@Override
122140
public String processOutput(TranslatorContext ctx, NDList list) throws TranslateException {
141+
NDArray array = list.singletonOrThrow();
142+
123143
try (StringWriter writer = new StringWriter();
124144
CSVPrinter printer = new CSVPrinter(writer, csvFormat)) {
125145

126-
for (NDArray array : list) {
127-
float[] data =
128-
array.toType(ai.djl.ndarray.types.DataType.FLOAT32, false).toFloatArray();
129-
long[] shape = array.getShape().getShape();
130-
131-
if (shape.length == 1) {
132-
// 1D array → single row
133-
printRow(printer, data, 0, data.length);
134-
} else if (shape.length == 2) {
135-
int rows = (int) shape[0];
136-
int cols = (int) shape[1];
137-
for (int i = 0; i < rows; i++) {
138-
printRow(printer, data, i * cols, cols);
139-
}
140-
} else {
141-
throw new TranslateException(
142-
"Only 1D or 2D arrays can be converted to CSV, found shape: "
143-
+ array.getShape());
146+
// Extract name as column names
147+
if (array.getName() != null && !array.getName().isEmpty()) {
148+
printer.print(array.getName());
149+
printer.println();
150+
}
151+
152+
float[] data = array.toType(DataType.FLOAT32, false).toFloatArray();
153+
long[] shape = array.getShape().getShape();
154+
155+
if (shape.length == 1) {
156+
printRow(printer, data, 0, data.length);
157+
} else if (shape.length == 2) {
158+
int rows = (int) shape[0];
159+
int cols = (int) shape[1];
160+
for (int i = 0; i < rows; i++) {
161+
printRow(printer, data, i * cols, cols);
144162
}
163+
} else {
164+
throw new TranslateException(
165+
"Only 1D or 2D arrays supported, found shape: " + array.getShape());
145166
}
146167

147168
return writer.toString();

0 commit comments

Comments
 (0)