Skip to content

Commit 6ca7551

Browse files
amaloneyahuang11
andauthored
allow latin1 encoding for csv files (#1203)
Co-authored-by: Andrew <[email protected]>
1 parent c5df684 commit 6ca7551

File tree

7 files changed

+102
-10
lines changed

7 files changed

+102
-10
lines changed

lumen/ai/controls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from ..sources.duckdb import DuckDBSource
18+
from ..util import detect_file_encoding
1819
from .memory import _Memory, memory
1920

2021
TABLE_EXTENSIONS = ("csv", "parquet", "parq", "json", "xlsx", "geojson", "wkt", "zip")
@@ -197,7 +198,8 @@ def _generate_media_controls(self, event):
197198
self._upload_tabs.clear()
198199
self._media_controls.clear()
199200
for filename, file in self._file_input.value.items():
200-
file_obj = io.BytesIO(file) if isinstance(file, bytes) else io.StringIO(file)
201+
encoding = detect_file_encoding(file_obj=file)
202+
file_obj = io.BytesIO(file.decode(encoding).encode("utf-8")) if isinstance(file, bytes) else io.StringIO(file)
201203
if filename.lower().endswith(TABLE_EXTENSIONS):
202204
table_controls = TableControls(
203205
file_obj,

lumen/ai/ui.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..sources import Source
3333
from ..sources.duckdb import DuckDBSource
3434
from ..transforms.sql import SQLLimit
35-
from ..util import log
35+
from ..util import detect_file_encoding, log
3636
from .agents import (
3737
AnalysisAgent, AnalystAgent, ChatAgent, DocumentListAgent, SourceAgent,
3838
SQLAgent, TableListAgent, VegaLiteAgent,
@@ -257,7 +257,8 @@ def _resolve_data(self, data: DataT | list[DataT] | None):
257257
if src.endswith(('.parq', '.parquet')):
258258
table = f"read_parquet('{src}')"
259259
elif src.endswith(".csv"):
260-
table = f"read_csv('{src}')"
260+
encoding = detect_file_encoding(file_obj=src)
261+
table = f"read_csv('{src}', encoding='{encoding}')"
261262
elif src.endswith(".json"):
262263
table = f"read_json_auto('{src}')"
263264
else:

lumen/tests/transforms/test_sql.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ def test_sql_comments():
109109
assert result == expected
110110

111111

112+
def test_add_encoding_to_read_csv():
113+
expression: list = sqlglot.parse("READ_CSV('data/life-expectancy.csv')")
114+
result = SQLTransform(identify=True)._add_encoding_to_read_csv(expression[0])
115+
expected = "READ_CSV('data/life-expectancy.csv', encoding='utf-8')"
116+
assert result.sql() == expected
117+
118+
112119
def test_sql_error_level():
113120
with pytest.raises(
114121
sqlglot.errors.ParseError, match="Expected table name but got"

lumen/transforms/sql.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import datetime as dt
4+
import pathlib
45
import re
56

7+
from copy import deepcopy
68
from typing import ClassVar
79

810
import param # type: ignore
@@ -11,12 +13,13 @@
1113
from sqlglot import parse
1214
from sqlglot.expressions import (
1315
LT, Column, Expression, Identifier, Literal as SQLLiteral, Max, Min, Null,
14-
Select, Star, Table, TableSample, and_, func, or_, replace_placeholders,
15-
replace_tables, select,
16+
ReadCSV, Select, Star, Table, TableSample, and_, func, or_,
17+
replace_placeholders, replace_tables, select,
1618
)
1719
from sqlglot.optimizer import optimize
1820

1921
from ..config import SOURCE_TABLE_SEPARATOR
22+
from ..util import detect_file_encoding
2023
from .base import Transform
2124

2225

@@ -140,6 +143,30 @@ def parse_sql(self, sql_in: str) -> Expression:
140143
expression = expressions[0]
141144
return expression
142145

146+
def _add_encoding_to_read_csv(self, expression: Expression) -> Expression:
147+
"""
148+
Add file encoding when reading CSV files using DuckDB.
149+
150+
Parameters
151+
----------
152+
expression : Expression
153+
An sqlglot expression object.
154+
155+
Returns
156+
-------
157+
Expression
158+
A modified expression that includes the file encoding.
159+
"""
160+
expr = deepcopy(expression)
161+
if isinstance(expr, ReadCSV):
162+
read_csv = expr.find(ReadCSV) or ReadCSV()
163+
literal = read_csv.find(SQLLiteral) or SQLLiteral()
164+
if pathlib.Path(literal.this).suffix.lower() == ".csv" and "encoding" not in literal.this:
165+
encoding = detect_file_encoding(file_obj=literal.this)
166+
expr.find(ReadCSV).find(SQLLiteral).replace(Identifier(this=f"'{literal.this}', encoding='{encoding}'", is_string=literal.is_string))
167+
168+
return expr
169+
143170
def to_sql(self, expression: Expression) -> str:
144171
"""
145172
Convert sqlglot expression back to SQL string.
@@ -157,6 +184,8 @@ def to_sql(self, expression: Expression) -> str:
157184
if self.optimize:
158185
expression = optimize(expression, dialect=self.read)
159186

187+
expression = self._add_encoding_to_read_csv(expression=expression)
188+
160189
return expression.sql(
161190
comments=self.comments,
162191
dialect=self.write,
@@ -208,10 +237,12 @@ def apply(self, sql_in: str) -> str:
208237
sql_template = re.sub(r'\{(\w+)\}', r':\1', sql_in)
209238
expression = self.parse_sql(sql_template)
210239
if self.parameters:
211-
parameters = {
212-
k: Identifier(this=v, quoted=self.identify) if isinstance(v, str) else v
213-
for k, v in self.parameters.items()
214-
}
240+
parameters = {}
241+
for k, v in self.parameters.items():
242+
if isinstance(v, str):
243+
parameters[k] = Identifier(this=v, quoted=self.identify)
244+
else:
245+
parameters[k] = v
215246
replaced_expression = replace_placeholders(expression, **parameters)
216247
return self.to_sql(replaced_expression,)
217248

lumen/util.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22

33
import datetime as dt
44
import importlib
5+
import io
56
import os
67
import re
78
import sys
89
import unicodedata
910

1011
from functools import partial, wraps
1112
from logging import getLogger
13+
from pathlib import Path
1214
from subprocess import check_output
1315

1416
import bokeh
17+
import chardet
1518
import pandas as pd
1619
import panel as pn
1720
import param
@@ -349,3 +352,47 @@ def slugify(value, allow_unicode=False) -> str:
349352
)
350353
value = re.sub(r"[^\w\s-]", "", value.lower())
351354
return re.sub(r"[-\s]+", "-", value).strip("-_")
355+
356+
357+
def detect_file_encoding(file_obj: Path | io.BytesIO | io.StringIO) -> str:
358+
"""
359+
Detects the given file object's encoding.
360+
361+
Parameters
362+
----------
363+
file_obj : Path | io.BytesIO | io.StringIO
364+
File object or path object to detect encoding.
365+
366+
Returns
367+
-------
368+
str
369+
"""
370+
if isinstance(file_obj, str):
371+
try:
372+
path_exists = Path(file_obj).exists()
373+
if path_exists:
374+
file_obj = Path(file_obj)
375+
except OSError:
376+
pass
377+
378+
# Handle if a path is given.
379+
if isinstance(file_obj, Path):
380+
with file_obj.open("rb") as f:
381+
data = f.read()
382+
detected_encoding = chardet.detect(data)
383+
encoding = detected_encoding["encoding"]
384+
385+
# Handle if a string or bytes object is given.
386+
if isinstance(file_obj, bytes):
387+
detected_encoding = chardet.detect(file_obj)
388+
elif isinstance(file_obj, str):
389+
detected_encoding = chardet.detect(file_obj.encode())
390+
391+
encoding = detected_encoding["encoding"]
392+
393+
if encoding == "ISO-8859-1":
394+
encoding = "latin-1"
395+
elif encoding == "ascii":
396+
encoding = "utf-8"
397+
398+
return encoding.lower()

pixi.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ bq-dev = ["py313", "ai", "ai-local", "ai-llama", "bigquery", "lint", "sql", "tes
2121

2222
[dependencies]
2323
bokeh = "*"
24+
chardet = "*"
2425
holoviews = ">=1.17.0"
2526
hvplot = "*"
2627
intake = "<2"

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ HoloViz = "https://holoviz.org/"
4848
[project.optional-dependencies]
4949
tests = ['pytest', 'pytest-rerunfailures', 'pytest-asyncio']
5050
sql = ['duckdb', 'intake-sql', 'sqlalchemy']
51-
ai = ['griffe', 'nbformat', 'duckdb', 'pyarrow', 'instructor >=1.6.4', 'pydantic >=2.8.0', 'pydantic-extra-types', 'panel-graphic-walker[kernel] >=0.6.4', 'markitdown', 'semchunk', 'tiktoken']
51+
ai = [
52+
'griffe', 'nbformat', 'duckdb', 'pyarrow', 'instructor >=1.6.4', 'pydantic >=2.8.0', 'pydantic-extra-types', 'panel-graphic-walker[kernel] >=0.6.4',
53+
'markitdown', 'semchunk', 'tiktoken', 'chardet',
54+
]
5255
ai-local = ['lumen[ai]', 'huggingface_hub']
5356
ai-openai = ['lumen[ai]', 'openai']
5457
ai-mistralai = ['lumen[ai]', 'mistralai']

0 commit comments

Comments
 (0)