Skip to content

Commit 5ef49cc

Browse files
committed
[owl] Validate row data before performing LLM generation (#836)
Backend - owl (API server) - Validate row data before performing LLM generation - Only log request start for mutating HTTP methods
1 parent 32188e2 commit 5ef49cc

File tree

6 files changed

+63
-107
lines changed

6 files changed

+63
-107
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ jobs:
127127
# Replace the org with the key in the .env file
128128
sed -i "s/$org=.*/$org=$key/g" .env
129129
done
130-
echo "OWL_DB_INIT=False" >> .env
130+
echo "OWL_DB_INIT=0" >> .env
131131
echo "OWL_COMPUTE_STORAGE_PERIOD_SEC=15" >> .env
132132
echo "OWL_STRIPE_WEBHOOK_SECRET_TEST=${OWL_STRIPE_WEBHOOK_SECRET_TEST}" >> .env
133133
echo "OWL_STRIPE_PUBLISHABLE_KEY_TEST=${OWL_STRIPE_PUBLISHABLE_KEY_TEST}" >> .env

services/api/src/owl/entrypoints/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ async def log_request(request: Request, call_next):
247247

248248
# Call request
249249
path = request.url.path
250-
if "api/health" not in path:
250+
if request.method in ("POST", "PATCH", "PUT", "DELETE"):
251251
logger.info(make_request_log_str(request))
252252
response = await call_next(request)
253253
response.headers["x-request-id"] = request_id

services/api/src/owl/routers/gen_table.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,7 @@ async def add_rows(
553553
billing.has_db_storage_quota()
554554
billing.has_egress_quota()
555555
# Validate data
556-
try:
557-
[table.validate_row_data(d) for d in body.data]
558-
except Exception as e:
559-
logger.info(
560-
(
561-
"Row data validation failed. "
562-
f'Table={table.schema_id}."{table.table_metadata.short_id}" '
563-
f"Org={org.id} "
564-
f"User={user.id} "
565-
f"Error={repr(e)}"
566-
)
567-
)
556+
[table.validate_row_data(d) for d in body.data]
568557
executor = MultiRowGenExecutor(
569558
request=request,
570559
table=table,
@@ -854,18 +843,7 @@ async def update_rows(
854843
billing.has_db_storage_quota()
855844
billing.has_egress_quota()
856845
# Validate data
857-
try:
858-
{row_id: table.validate_row_data(d) for row_id, d in body.data.items()}
859-
except Exception as e:
860-
logger.info(
861-
(
862-
"Row data validation failed. "
863-
f'Table={table.schema_id}."{table.table_metadata.short_id}" '
864-
f"Org={org.id} "
865-
f"User={user.id} "
866-
f"Error={repr(e)}"
867-
)
868-
)
846+
{row_id: table.validate_row_data(d) for row_id, d in body.data.items()}
869847
await table.update_rows(body.data)
870848
return OkResponse()
871849

services/api/src/owl/utils/lm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ async def _get_deployment(
384384
)
385385
try:
386386
logger.info(
387-
f'{self.id} - Request started for model "{self.config.id}" ({provider=}, {routing_id=}).'
387+
f'Request started for model "{self.config.id}" ({provider=}, {routing_id=}, {self.id}).'
388388
)
389389
t0 = perf_counter()
390390
self.request.state.model_start_time = t0
@@ -396,7 +396,9 @@ async def _get_deployment(
396396
is_reasoning_model=is_reasoning_model,
397397
)
398398
self.request.state.timing["external_call"] = perf_counter() - t0
399-
logger.info(f'{self.id} - Request completed for model "{self.config.id}".')
399+
logger.info(
400+
f'Request completed for model "{self.config.id}" ({provider=}, {routing_id=}, {self.id}).'
401+
)
400402
except Exception as e:
401403
mapped_e = self._map_and_log_exception(e, deployment, api_key, **hyperparams)
402404
if isinstance(mapped_e, (ModelOverloadError, RateLimitExceedError)):

services/api/tests/gen_table/test_row_ops.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -987,22 +987,17 @@ def test_add_row_wrong_dtype(
987987
assert "concept" in row.columns
988988

989989
# Test adding data with wrong dtype
990-
response = add_table_rows(
991-
client,
992-
table_type,
993-
table.id,
994-
[dict(good="dummy1", words="dummy2", stars="dummy3", inputs=TEXT)],
995-
stream=stream,
996-
)
997-
rows = list_table_rows(client, table_type, TABLE_ID_A)
998-
assert len(rows.items) == 2
999-
row = rows.items[-1]
1000-
assert row["good"]["value"] is None, row["good"]
1001-
assert row["good"]["original"] == "dummy1", row["good"]
1002-
assert row["words"]["value"] is None, row["words"]
1003-
assert row["words"]["original"] == "dummy2", row["words"]
1004-
assert row["stars"]["value"] is None, row["stars"]
1005-
assert row["stars"]["original"] == "dummy3", row["stars"]
990+
with pytest.raises(BadInputError) as e:
991+
add_table_rows(
992+
client,
993+
table_type,
994+
table.id,
995+
[dict(good="dummy1", words="dummy2", stars="dummy3", inputs=TEXT)],
996+
stream=stream,
997+
)
998+
assert 'Column "good": Input should be a valid boolean' in str(e.value)
999+
assert 'Column "words": Input should be a valid integer' in str(e.value)
1000+
assert 'Column "stars": Input should be a valid number' in str(e.value)
10061001

10071002

10081003
@pytest.mark.parametrize("table_type", TABLE_TYPES)
@@ -1047,7 +1042,7 @@ def test_add_row_missing_columns(
10471042
table_type,
10481043
stream,
10491044
TABLE_ID_A,
1050-
data=dict(good="dummy1", inputs=TEXT),
1045+
data=dict(good=True, inputs=TEXT),
10511046
)
10521047
if stream:
10531048
responses = [r for r in response]
@@ -1057,8 +1052,7 @@ def test_add_row_missing_columns(
10571052
rows = list_table_rows(client, table_type, TABLE_ID_A)
10581053
assert len(rows.items) == 2
10591054
row = rows.items[-1]
1060-
assert row["good"]["value"] is None, row["good"]
1061-
assert row["good"]["original"] == "dummy1", row["good"]
1055+
assert row["good"]["value"] is True, row["good"]
10621056
assert row["words"]["value"] is None, row["words"]
10631057
assert "original" not in row["words"], row["words"]
10641058
assert row["stars"]["value"] is None, row["stars"]

services/api/tests/gen_table/test_row_ops_v2.py

Lines changed: 42 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -617,26 +617,15 @@ def test_knowledge_table_embedding(
617617
row = rows.values[3]
618618
assert row["Title"] is None, row
619619
assert row["Text"] is None, row
620-
# If embedding with invalid length is added, it will be coerced to None
621-
# Original vector will be saved into state
622-
response = add_table_rows(
623-
client,
624-
table_type,
625-
table.id,
626-
[{"Title": "test", "Title Embed": [1, 2, 3]}],
627-
stream=stream,
628-
)
629-
# We currently dont return anything if LLM is not called
630-
assert len(response.rows) == 0 if stream else 1
631-
assert all(len(r.columns) == 0 for r in response.rows)
632-
# Check the vectors
633-
rows = list_table_rows(client, table_type, table.id)
634-
assert rows.total == 5
635-
row = rows.values[-1]
636-
assert row["Title"] == "test", f"{row['Title']=}"
637-
assert row["Title Embed"] is None, f"{row['Title Embed']=}"
638-
assert row["Text"] is None, f"{row['Title']=}"
639-
assert_is_vector_or_none(row["Text Embed"], allow_none=False)
620+
# Embedding with invalid length will be rejected
621+
with pytest.raises(BadInputError, match="Array input must have length 256"):
622+
add_table_rows(
623+
client,
624+
table_type,
625+
table.id,
626+
[{"Title": "test", "Title Embed": [1, 2, 3]}],
627+
stream=stream,
628+
)
640629

641630

642631
@flaky(max_runs=3, min_passes=1)
@@ -1697,6 +1686,7 @@ def test_public_web_image(setup: ServingContext):
16971686
- As input to model
16981687
- Has valid raw and thumbnail URLs
16991688
- Reject private URLs
1689+
- Reject malformed URL
17001690
- Empty input is OK
17011691
"""
17021692
table_type = TableType.ACTION
@@ -1729,18 +1719,21 @@ def test_public_web_image(setup: ServingContext):
17291719
response = client.file.get_thumbnail_urls([image_uri])
17301720
assert isinstance(response, GetURLResponse)
17311721
assert response.urls[0] == image_uri
1732-
# Private URLs should be rejected
1733-
row = add_table_rows(
1734-
data=[dict(image="https://host.docker.internal:8080")], **kwargs
1735-
).rows[0]
1736-
assert "cannot be opened" in row.columns["ocr"].content, row
1737-
row = add_table_rows(data=[dict(image="https://localhost")], **kwargs).rows[0]
1738-
assert "cannot be opened" in row.columns["ocr"].content, row
1739-
row = add_table_rows(data=[dict(image="https://192.168.0.1")], **kwargs).rows[0]
1740-
assert "cannot be opened" in row.columns["ocr"].content, row
17411722
# Empty is OK
17421723
row = add_table_rows(data=[dict()], **kwargs).rows[0]
17431724
assert len(row.columns["ocr"].content) > 0, row
1725+
# Private URLs should be rejected
1726+
with pytest.raises(BadInputError, match="URL .+ invalid"):
1727+
add_table_rows(data=[dict(image="https://host.docker.internal:8080")], **kwargs)
1728+
with pytest.raises(BadInputError, match="URL .+ invalid"):
1729+
add_table_rows(data=[dict(image="https://localhost")], **kwargs)
1730+
with pytest.raises(BadInputError, match="URL .+ invalid"):
1731+
add_table_rows(data=[dict(image="https://192.168.0.1")], **kwargs)
1732+
# Malformed URL
1733+
with pytest.raises(BadInputError, match="URL .+ invalid"):
1734+
add_table_rows(
1735+
data=[dict(image='{"url": "https://host.docker.internal:8080"}')], **kwargs
1736+
)
17441737

17451738

17461739
@pytest.mark.parametrize("table_type", TABLE_TYPES)
@@ -2130,20 +2123,15 @@ def test_update_row(
21302123

21312124
# Test updating data with wrong dtype
21322125
data = dict(ID="2", int="str", float="str", bool="str")
2133-
response = client.table.update_table_rows(
2134-
table_type,
2135-
MultiRowUpdateRequest(table_id=table.id, data={row["ID"]: data}),
2136-
)
2137-
assert isinstance(response, OkResponse)
2138-
_rows = list_table_rows(client, table_type, table.id)
2139-
assert len(_rows.items) == 1
2140-
_row = _rows.values[0]
2141-
t2 = datetime.fromisoformat(_row["Updated at"])
2142-
assert _row["int"] is None
2143-
assert _row["float"] is None
2144-
assert _row["bool"] is None
2145-
_assert_dict_equal(row, _row, exclude=["Updated at", "int", "float", "bool"])
2146-
assert t2 > t1
2126+
with pytest.raises(BadInputError) as e:
2127+
client.table.update_table_rows(
2128+
table_type,
2129+
MultiRowUpdateRequest(table_id=table.id, data={row["ID"]: data}),
2130+
)
2131+
assert 'Column "int": Input should be a valid integer' in str(e.value)
2132+
assert 'Column "float": Input should be a valid number' in str(e.value)
2133+
assert 'Column "bool": Input should be a valid boolean' in str(e.value)
2134+
_assert_dict_equal(_row, list_table_rows(client, table_type, table.id).values[0])
21472135

21482136
if table_type == TableType.KNOWLEDGE:
21492137
# Test updating embedding columns directly
@@ -2163,26 +2151,20 @@ def test_update_row(
21632151
_rows = list_table_rows(client, table_type, table.id)
21642152
assert len(_rows.items) == 1
21652153
_row = _rows.values[0]
2166-
t3 = datetime.fromisoformat(_row["Updated at"])
2154+
t2 = datetime.fromisoformat(_row["Updated at"])
21672155
assert sum(_row["Title Embed"]) == 0
21682156
assert sum(_row["Text Embed"]) == len(row["Text Embed"])
2169-
assert t3 > t2
2157+
assert t2 > t1
21702158
# Test updating embedding columns with wrong length
2171-
response = client.table.update_table_rows(
2172-
table_type,
2173-
MultiRowUpdateRequest(
2174-
table_id=table.id,
2175-
data={row["ID"]: {"Title Embed": [0], "Text Embed": [0]}},
2176-
),
2177-
)
2178-
assert isinstance(response, OkResponse)
2179-
_rows = list_table_rows(client, table_type, table.id)
2180-
assert len(_rows.items) == 1
2181-
_row = _rows.values[0]
2182-
t4 = datetime.fromisoformat(_row["Updated at"])
2183-
assert _row["Title Embed"] is None
2184-
assert _row["Text Embed"] is None
2185-
assert t4 > t3
2159+
with pytest.raises(BadInputError, match="Array input must have length 256"):
2160+
client.table.update_table_rows(
2161+
table_type,
2162+
MultiRowUpdateRequest(
2163+
table_id=table.id,
2164+
data={row["ID"]: {"Title Embed": [0], "Text Embed": [0]}},
2165+
),
2166+
)
2167+
_assert_dict_equal(_row, list_table_rows(client, table_type, table.id).values[0])
21862168

21872169

21882170
@pytest.mark.parametrize("stream", **STREAM_PARAMS)

0 commit comments

Comments
 (0)