1414
1515import ai .djl .ndarray .NDArray ;
1616import ai .djl .ndarray .NDList ;
17+ import ai .djl .ndarray .NDManager ;
18+ import ai .djl .ndarray .types .DataType ;
1719import ai .djl .translate .TranslateException ;
1820import ai .djl .translate .Translator ;
1921import ai .djl .translate .TranslatorContext ;
2628import java .io .IOException ;
2729import java .io .StringReader ;
2830import java .io .StringWriter ;
29- import java .util .ArrayList ;
3031import java .util .List ;
31- import java .util .Map ;
3232
3333/** A {@link Translator} that converts between CSV text and {@link NDList}. */
3434public class CsvTranslator implements Translator <String , String > {
3535
3636 private final CSVFormat csvFormat ;
3737
38- /**
39- * Constructs a CsvTranslator.
40- *
41- * @param arguments the arguments (unused but required for reflection)
42- */
43- @ SuppressWarnings ({"PMD.UnusedFormalParameter" , "deprecation" })
44- public CsvTranslator (Map <String , ?> arguments ) {
45- this .csvFormat = CSVFormat .newFormat (',' ).withRecordSeparator ("\n " );
38+ /** Constructs a CsvTranslator. */
39+ public CsvTranslator () {
40+ this .csvFormat = CSVFormat .INFORMIX_UNLOAD_CSV ;
4641 }
4742
4843 /** {@inheritDoc} */
4944 @ Override
5045 public NDList processInput (TranslatorContext ctx , String csvData ) throws TranslateException {
51- try (CSVParser parser = csvFormat .parse (new StringReader (csvData ))) {
52- List <float []> rows = new ArrayList <>();
53- int expectedCols = -1 ;
54- boolean headerSkipped = false ;
55-
56- for (CSVRecord record : parser ) {
57- if (!headerSkipped && isHeaderRow (record )) {
58- headerSkipped = true ;
59- continue ;
60- }
46+ StringReader reader = new StringReader (csvData );
47+
48+ try (CSVParser parser = csvFormat .parse (reader )) {
49+ List <CSVRecord > records = parser .getRecords ();
50+ if (records .isEmpty ()) {
51+ throw new TranslateException ("CSV data is empty" );
52+ }
53+
54+ int rowStart = 0 ;
55+
56+ // Skip header if present
57+ if (isHeaderRow (records .get (0 ))) {
58+ rowStart = 1 ;
59+ }
6160
62- if (expectedCols == -1 ) {
63- expectedCols = record .size ();
64- } else if (record .size () != expectedCols ) {
61+ int numRows = records .size () - rowStart ;
62+ int expectedCols =
63+ records .get (rowStart ).size (); // assume first data row sets column count
64+
65+ float [][] data = new float [numRows ][expectedCols ];
66+
67+ for (int i = 0 ; i < numRows ; i ++) {
68+ CSVRecord record = records .get (i + rowStart );
69+
70+ if (record .size () != expectedCols ) {
6571 throw new TranslateException (
66- String .format (
67- "Row %d has %d columns, expected %d" ,
68- record .getRecordNumber (), record .size (), expectedCols ));
72+ "Row "
73+ + record .getRecordNumber ()
74+ + " has "
75+ + record .size ()
76+ + " columns, expected "
77+ + expectedCols );
6978 }
7079
71- float [] row = new float [expectedCols ];
7280 for (int j = 0 ; j < expectedCols ; j ++) {
73- row [j ] = parseFloat (record .get (j ), record .getRecordNumber (), j );
81+ data [ i ] [j ] = parseFloat (record .get (j ), record .getRecordNumber (), j );
7482 }
75- rows .add (row );
7683 }
7784
78- if ( rows . isEmpty ()) {
79- throw new TranslateException ( "CSV data is empty" );
80- }
85+ NDManager manager = ctx . getNDManager ();
86+ NDArray ndArray = manager . create ( data );
87+ return new NDList ( ndArray );
8188
82- float [][] data = rows .toArray (new float [0 ][]);
83- return new NDList (ctx .getNDManager ().create (data ));
8489 } catch (IOException e ) {
8590 throw new TranslateException ("Failed to process CSV input" , e );
8691 }
@@ -99,49 +104,65 @@ private boolean isNumeric(String str) {
99104 if (str == null || str .isEmpty ()) {
100105 return false ;
101106 }
102- try {
103- Float .parseFloat (str );
104- return true ;
105- } catch (NumberFormatException e ) {
106- return false ;
107+ int len = str .length ();
108+ for (int i = 0 ; i < len ; i ++) {
109+ char c = str .charAt (i );
110+ if ((c < '0' || c > '9' ) && c != '-' && c != '.' && c != 'e' && c != 'E' && c != '+' ) {
111+ return false ;
112+ }
107113 }
114+ return true ;
108115 }
109116
110117 private float parseFloat (String value , long row , int col ) throws TranslateException {
118+ if (value == null || value .isEmpty ()) {
119+ return Float .NaN ;
120+ }
121+
122+ int len = value .length ();
123+ if (len > 0 && (value .charAt (0 ) <= ' ' || value .charAt (len - 1 ) <= ' ' )) {
124+ value = value .trim ();
125+ if (value .isEmpty ()) {
126+ return Float .NaN ;
127+ }
128+ }
129+
111130 try {
112- return Float .parseFloat (value . trim () );
131+ return Float .parseFloat (value );
113132 } catch (NumberFormatException e ) {
114133 throw new TranslateException (
115- String .format ("Non-numeric value '%s' at row %d, column %d" , value , row , col ),
116- e );
134+ "Non-numeric value '" + value + "' at row " + row + ", column " + col , e );
117135 }
118136 }
119137
120138 /** {@inheritDoc} */
121139 @ Override
122140 public String processOutput (TranslatorContext ctx , NDList list ) throws TranslateException {
141+ NDArray array = list .singletonOrThrow ();
142+
123143 try (StringWriter writer = new StringWriter ();
124144 CSVPrinter printer = new CSVPrinter (writer , csvFormat )) {
125145
126- for (NDArray array : list ) {
127- float [] data =
128- array .toType (ai .djl .ndarray .types .DataType .FLOAT32 , false ).toFloatArray ();
129- long [] shape = array .getShape ().getShape ();
130-
131- if (shape .length == 1 ) {
132- // 1D array → single row
133- printRow (printer , data , 0 , data .length );
134- } else if (shape .length == 2 ) {
135- int rows = (int ) shape [0 ];
136- int cols = (int ) shape [1 ];
137- for (int i = 0 ; i < rows ; i ++) {
138- printRow (printer , data , i * cols , cols );
139- }
140- } else {
141- throw new TranslateException (
142- "Only 1D or 2D arrays can be converted to CSV, found shape: "
143- + array .getShape ());
146+ // Extract name as column names
147+ if (array .getName () != null && !array .getName ().isEmpty ()) {
148+ printer .print (array .getName ());
149+ printer .println ();
150+ }
151+
152+ float [] data = array .toType (DataType .FLOAT32 , false ).toFloatArray ();
153+ long [] shape = array .getShape ().getShape ();
154+
155+ if (shape .length == 1 ) {
156+ printRow (printer , data , 0 , data .length );
157+ } else if (shape .length == 2 ) {
158+ int rows = (int ) shape [0 ];
159+ int cols = (int ) shape [1 ];
160+ for (int i = 0 ; i < rows ; i ++) {
161+ printRow (printer , data , i * cols , cols );
144162 }
163+ } else {
164+ throw new TranslateException (
165+ "Only 1D or 2D arrays supported, found shape: " + array .getShape ());
145166 }
146167
147168 return writer .toString ();
0 commit comments