|
12 | 12 | from mlcroissant._src.core.constants import EncodingFormat |
13 | 13 | from mlcroissant._src.core.git import download_git_lfs_file |
14 | 14 | from mlcroissant._src.core.git import is_git_lfs_file |
| 15 | +from mlcroissant._src.core.optional import deps |
15 | 16 | from mlcroissant._src.core.path import Path |
16 | 17 | from mlcroissant._src.operation_graph.base_operation import Operation |
17 | 18 | from mlcroissant._src.operation_graph.operations.download import is_url |
|
21 | 22 | from mlcroissant._src.structure_graph.nodes.file_set import FileSet |
22 | 23 | from mlcroissant._src.structure_graph.nodes.source import FileProperty |
23 | 24 |
|
| 25 | +try: |
| 26 | + scipy = deps.scipy |
| 27 | +except ModuleNotFoundError: |
| 28 | + scipy = None |
| 29 | +INSTALL_MESSAGE = "scipy is not installed and is a dependency." |
| 30 | + |
24 | 31 |
|
25 | 32 | class ReadingMethod(enum.Enum): |
26 | 33 | """Reading method derived from the fields that consume the FileObject/FileSet.""" |
@@ -82,65 +89,80 @@ class Read(Operation): |
82 | 89 | folder: epath.Path |
83 | 90 | fields: tuple[Field, ...] |
84 | 91 |
|
85 | | - def _read_file_content(self, encoding_format: str, file: Path) -> pd.DataFrame: |
| 92 | + def _read_file_content( |
| 93 | + self, encoding_formats: list[str], file: Path |
| 94 | + ) -> pd.DataFrame: |
86 | 95 | """Extracts the `source` file to `target`.""" |
87 | 96 | filepath = file.filepath |
88 | 97 | if is_git_lfs_file(filepath): |
89 | 98 | download_git_lfs_file(file) |
90 | 99 | reading_method = _reading_method(self.node, self.fields) |
| 100 | + if EncodingFormat.ARFF in encoding_formats: |
| 101 | + if scipy is None: |
| 102 | + raise NotImplementedError(INSTALL_MESSAGE) |
| 103 | + |
| 104 | + data = scipy.io.arff.loadarff(filepath) |
| 105 | + if not isinstance(data, list) or len(data) != 1: |
| 106 | + raise ValueError( |
| 107 | + "The loaded data from scipy.io.arff does not have the expected" |
| 108 | + " shape (a list with one element). Please ensure the ARFF file is" |
| 109 | + " valid." |
| 110 | + ) |
| 111 | + return pd.DataFrame(data[0]) |
91 | 112 |
|
92 | 113 | with filepath.open("rb") as file: |
93 | | - # TODO(https://github.com/mlcommons/croissant/issues/635). |
94 | | - if filepath.suffix == ".gz": |
95 | | - file = gzip.open(file, "rt", newline="") |
96 | | - if encoding_format == EncodingFormat.CSV: |
97 | | - return pd.read_csv(file) |
98 | | - elif encoding_format == EncodingFormat.TSV: |
99 | | - return pd.read_csv(file, sep="\t") |
100 | | - elif encoding_format == EncodingFormat.JSON: |
101 | | - json_content = json.load(file) |
102 | | - if reading_method == ReadingMethod.JSON: |
103 | | - return parse_json_content(json_content, self.fields) |
104 | | - else: |
105 | | - # Raw files are returned as a one-line pd.DataFrame. |
106 | | - return pd.DataFrame({ |
107 | | - FileProperty.content: [json_content], |
108 | | - }) |
109 | | - elif encoding_format == EncodingFormat.JSON_LINES: |
110 | | - return pd.read_json(file, lines=True) |
111 | | - elif encoding_format == EncodingFormat.PARQUET: |
112 | | - try: |
113 | | - df = pd.read_parquet(file) |
114 | | - # Sometimes the author already set an index in Parquet, so we want |
115 | | - # to reset it to always have the same format. |
116 | | - df.reset_index(inplace=True) |
117 | | - return df |
118 | | - except ImportError as e: |
119 | | - raise ImportError( |
120 | | - "Missing dependency to read Parquet files. pyarrow is not" |
121 | | - " installed. Please, install `pip install" |
122 | | - " mlcroissant[parquet]`." |
123 | | - ) from e |
124 | | - elif encoding_format == EncodingFormat.TEXT: |
125 | | - if reading_method == ReadingMethod.LINES: |
126 | | - return pd.read_csv( |
127 | | - filepath, header=None, names=[FileProperty.lines] |
128 | | - ) |
129 | | - else: |
| 114 | + for encoding_format in encoding_formats: |
| 115 | + # TODO(https://github.com/mlcommons/croissant/issues/635). |
| 116 | + if filepath.suffix == ".gz": |
| 117 | + file = gzip.open(file, "rt", newline="") |
| 118 | + if encoding_format == EncodingFormat.CSV: |
| 119 | + return pd.read_csv(file) |
| 120 | + elif encoding_format == EncodingFormat.TSV: |
| 121 | + return pd.read_csv(file, sep="\t") |
| 122 | + elif encoding_format == EncodingFormat.JSON: |
| 123 | + json_content = json.load(file) |
| 124 | + if reading_method == ReadingMethod.JSON: |
| 125 | + return parse_json_content(json_content, self.fields) |
| 126 | + else: |
| 127 | + # Raw files are returned as a one-line pd.DataFrame. |
| 128 | + return pd.DataFrame({ |
| 129 | + FileProperty.content: [json_content], |
| 130 | + }) |
| 131 | + elif encoding_format == EncodingFormat.JSON_LINES: |
| 132 | + return pd.read_json(file, lines=True) |
| 133 | + elif encoding_format == EncodingFormat.PARQUET: |
| 134 | + try: |
| 135 | + df = pd.read_parquet(file) |
| 136 | + # Sometimes the author already set an index in Parquet, so we |
| 137 | + # want to reset it to always have the same format. |
| 138 | + df.reset_index(inplace=True) |
| 139 | + return df |
| 140 | + except ImportError as e: |
| 141 | + raise ImportError( |
| 142 | + "Missing dependency to read Parquet files. pyarrow is not" |
| 143 | + " installed. Please, install `pip install" |
| 144 | + " mlcroissant[parquet]`." |
| 145 | + ) from e |
| 146 | + elif encoding_format == EncodingFormat.TEXT: |
| 147 | + if reading_method == ReadingMethod.LINES: |
| 148 | + return pd.read_csv( |
| 149 | + filepath, header=None, names=[FileProperty.lines] |
| 150 | + ) |
| 151 | + else: |
| 152 | + return pd.DataFrame({ |
| 153 | + FileProperty.content: [file.read()], |
| 154 | + }) |
| 155 | + elif ( |
| 156 | + encoding_format == EncodingFormat.MP3 |
| 157 | + or encoding_format == EncodingFormat.JPG |
| 158 | + ): |
130 | 159 | return pd.DataFrame({ |
131 | 160 | FileProperty.content: [file.read()], |
132 | 161 | }) |
133 | | - elif ( |
134 | | - encoding_format == EncodingFormat.MP3 |
135 | | - or encoding_format == EncodingFormat.JPG |
136 | | - ): |
137 | | - return pd.DataFrame({ |
138 | | - FileProperty.content: [file.read()], |
139 | | - }) |
140 | | - else: |
141 | | - raise ValueError( |
142 | | - f"Unsupported encoding format for file: {encoding_format}" |
143 | | - ) |
| 162 | + raise ValueError( |
| 163 | + f"None of the provided encoding formats: {encoding_format} for file" |
| 164 | + f" {filepath} returned a valid pandas dataframe." |
| 165 | + ) |
144 | 166 |
|
145 | 167 | def call(self, files: list[Path] | Path) -> pd.DataFrame: |
146 | 168 | """See class' docstring.""" |
@@ -170,8 +192,8 @@ def call(self, files: list[Path] | Path) -> pd.DataFrame: |
170 | 192 | f'In node "{self.node.uuid}", file "{self.node.content_url}" is' |
171 | 193 | " either an invalid URL or an invalid path." |
172 | 194 | ) |
173 | | - assert self.node.encoding_format, "Encoding format is not specified." |
174 | | - file_content = self._read_file_content(self.node.encoding_format, file) |
| 195 | + assert self.node.encoding_formats, "Encoding format is not specified." |
| 196 | + file_content = self._read_file_content(self.node.encoding_formats, file) |
175 | 197 | if _should_append_line_numbers(self.fields): |
176 | 198 | file_content[FileProperty.lineNumbers] = range(len(file_content)) |
177 | 199 | file_content[FileProperty.filepath] = file.filepath |
|
0 commit comments