Skip to content

Commit 7bd91ce

Browse files
jaegeraljkppr
andauthored
Fix: chart generation robustness (#3661)
* Fix chart generation robustness --------- Co-authored-by: Janosch <[email protected]>
1 parent ee96b46 commit 7bd91ce

File tree

4 files changed

+271
-13
lines changed

4 files changed

+271
-13
lines changed

timesketch/lib/aggregators/interface.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,26 @@ def to_chart(
121121
raise RuntimeError(f"No such chart type: {chart_name:s}")
122122

123123
try:
124+
# We need to check if there is an encoding.
125+
encoding = self.encoding
126+
values_dataframe = self.to_pandas()
127+
128+
if not encoding:
129+
logger.warning(
130+
"No encoding found for chart [%s] with title [%s]. "
131+
"Skipping chart generation.",
132+
chart_name,
133+
chart_title,
134+
)
135+
if as_html:
136+
return ""
137+
if as_chart:
138+
return None
139+
return {}
140+
141+
chart_data = {"values": values_dataframe, "encoding": encoding}
124142
chart_object = chart_class(
125-
self.to_pandas(),
143+
chart_data,
126144
title=chart_title,
127145
sketch_url=self._sketch_url,
128146
field=self.field,
@@ -192,19 +210,25 @@ def __init__(self, sketch_id=None, indices=None, timeline_ids=None):
192210

193211
self.opensearch = OpenSearchDataStore()
194212

195-
self._sketch_url = f"/sketch/{sketch_id:d}/explore"
213+
if sketch_id:
214+
self._sketch_url = f"/sketch/{sketch_id:d}/explore"
215+
self.sketch = SQLSketch.get_by_id(sketch_id)
216+
else:
217+
self._sketch_url = ""
218+
self.sketch = None
219+
196220
self.field = ""
197221
self.indices = indices
198-
self.sketch = SQLSketch.get_by_id(sketch_id)
199222
self.timeline_ids = None
200223

201-
active_timelines = self.sketch.active_timelines
202-
if not self.indices:
203-
self.indices = [t.searchindex.index_name for t in active_timelines]
224+
if self.sketch:
225+
active_timelines = self.sketch.active_timelines
226+
if not self.indices:
227+
self.indices = [t.searchindex.index_name for t in active_timelines]
204228

205-
if timeline_ids:
206-
valid_ids = [t.id for t in active_timelines]
207-
self.timeline_ids = [t for t in timeline_ids if t in valid_ids]
229+
if timeline_ids:
230+
valid_ids = [t.id for t in active_timelines]
231+
self.timeline_ids = [t for t in timeline_ids if t in valid_ids]
208232

209233
@property
210234
def chart_title(self):
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2026 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for aggregator interface."""
15+
16+
import unittest
17+
from unittest import mock
18+
19+
from timesketch.lib.aggregators import interface
20+
from timesketch.lib.charts import interface as chart_interface
21+
22+
23+
class MockChart(chart_interface.BaseChart):
24+
"""Mock chart class."""
25+
26+
NAME = "mockchart"
27+
28+
def generate(self):
29+
return mock.Mock(
30+
to_dict=lambda: {"chart": "data"}, to_html=lambda: "<html></html>"
31+
)
32+
33+
34+
class TestAggregationResult(unittest.TestCase):
35+
"""Tests for AggregationResult."""
36+
37+
def setUp(self):
38+
self.values = [{"field": "foo", "count": 10}, {"field": "bar", "count": 20}]
39+
self.encoding = {
40+
"x": {"field": "field", "type": "nominal"},
41+
"y": {"field": "count", "type": "quantitative"},
42+
}
43+
self.result = interface.AggregationResult(
44+
encoding=self.encoding, values=self.values
45+
)
46+
47+
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
48+
def test_to_chart_valid(self, mock_get_chart):
49+
"""Test to_chart with valid data."""
50+
mock_get_chart.return_value = MockChart
51+
52+
# Call to_chart
53+
result = self.result.to_chart(chart_name="mockchart")
54+
55+
# Assert chart was generated and returned as dict by default
56+
self.assertEqual(result, {"chart": "data"})
57+
58+
@mock.patch("timesketch.lib.aggregators.interface.logger")
59+
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
60+
def test_to_chart_missing_encoding(self, mock_get_chart, mock_logger):
61+
"""Test to_chart with missing encoding (should skip generation)."""
62+
mock_get_chart.return_value = MockChart
63+
64+
# Result without encoding
65+
result_obj = interface.AggregationResult(encoding=None, values=self.values)
66+
67+
result = result_obj.to_chart(chart_name="mockchart")
68+
69+
# Should return empty dict
70+
self.assertEqual(result, {})
71+
72+
# Verify logger.warning was called
73+
mock_logger.warning.assert_called_with(
74+
"No encoding found for chart [%s] with title [%s]. "
75+
"Skipping chart generation.",
76+
"mockchart",
77+
"",
78+
)
79+
80+
@mock.patch("timesketch.lib.aggregators.interface.logger")
81+
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
82+
def test_to_chart_missing_values_and_encoding(self, mock_get_chart, mock_logger):
83+
"""Test to_chart with missing values and encoding."""
84+
mock_get_chart.return_value = MockChart
85+
86+
# Result with empty values and no encoding
87+
result_obj = interface.AggregationResult(encoding=None, values=[])
88+
89+
result = result_obj.to_chart(chart_name="mockchart")
90+
91+
# Should return empty dict
92+
self.assertEqual(result, {})
93+
94+
# Verify logger.warning was called
95+
mock_logger.warning.assert_called_with(
96+
"No encoding found for chart [%s] with title [%s]. "
97+
"Skipping chart generation.",
98+
"mockchart",
99+
"",
100+
)
101+
102+
@mock.patch("timesketch.lib.aggregators.interface.logger")
103+
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
104+
def test_to_chart_missing_encoding_unable_to_guess(
105+
self, mock_get_chart, mock_logger
106+
):
107+
"""Test to_chart when unable to guess encoding (e.g. 1 column)."""
108+
mock_get_chart.return_value = MockChart
109+
110+
values = [{"field": "foo"}] # Only 1 column
111+
result_obj = interface.AggregationResult(encoding=None, values=values)
112+
113+
result = result_obj.to_chart(chart_name="mockchart")
114+
self.assertEqual(result, {})
115+
116+
# Verify logger.warning was called
117+
mock_logger.warning.assert_called()
118+
119+
120+
class TestBaseAggregator(unittest.TestCase):
121+
"""Tests for BaseAggregator."""
122+
123+
@mock.patch("timesketch.lib.aggregators.interface.OpenSearchDataStore")
124+
@mock.patch("timesketch.lib.aggregators.interface.SQLSketch")
125+
def test_init_no_sketch_id(self, _mock_sketch, _mock_ds):
126+
"""Test initialization without sketch_id."""
127+
agg = interface.BaseAggregator(indices=["index1"])
128+
self.assertIsNone(agg.sketch)
129+
self.assertEqual(agg.indices, ["index1"])
130+
# pylint: disable=protected-access
131+
self.assertEqual(agg._sketch_url, "")
132+
133+
@mock.patch("timesketch.lib.aggregators.interface.OpenSearchDataStore")
134+
@mock.patch("timesketch.lib.aggregators.interface.SQLSketch")
135+
def test_init_with_sketch_id(self, mock_sketch, _mock_ds):
136+
"""Test initialization with sketch_id."""
137+
mock_sketch_obj = mock.Mock()
138+
mock_t1 = mock.Mock()
139+
mock_t1.searchindex.index_name = "index1"
140+
mock_t1.id = 1
141+
mock_sketch_obj.active_timelines = [mock_t1]
142+
mock_sketch.get_by_id.return_value = mock_sketch_obj
143+
144+
agg = interface.BaseAggregator(sketch_id=1)
145+
self.assertEqual(agg.sketch, mock_sketch_obj)
146+
self.assertEqual(agg.indices, ["index1"])
147+
# pylint: disable=protected-access
148+
self.assertEqual(agg._sketch_url, "/sketch/1/explore")

timesketch/lib/charts/interface.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,34 @@ def _get_chart_with_transform(self):
105105
A LayerChart object, either with added transform or not, depending
106106
on whether sketch URL and field are set.
107107
"""
108-
chart = alt.Chart(self.values)
109-
if not self._sketch_url:
110-
return chart
108+
data = self.values
109+
# We need to convert the dataframe to a dict to avoid issues with
110+
# newer versions of pandas and older versions of altair.
111+
# See https://github.com/altair-viz/altair/issues/2763
112+
if isinstance(data, pd.DataFrame):
113+
data = data.to_dict(orient="records")
114+
115+
if not isinstance(data, list):
116+
logger.error(
117+
"Chart data is not in a supported format. "
118+
"Expected pandas DataFrame or list, got %s",
119+
type(data),
120+
)
121+
data = []
111122

112-
if not self._field:
123+
chart = alt.Chart(data)
124+
125+
if not self._sketch_url or not self._field:
113126
return chart
114127

115128
datum = getattr(alt.datum, self._field)
116129
if self._aggregation_id:
117130
agg_string = f"a={self._aggregation_id:d}&"
118131
else:
119132
agg_string = ""
133+
134+
# Construct Vega-Lite expression
135+
# usage: url + datum.field + '" ' + extra
120136
url = f'{self._sketch_url:s}?{agg_string:s}q={self._field:s}:"'
121137
return chart.transform_calculate(
122138
url=url + datum + '" ' + self._extra_query_url
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2026 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for chart interface."""
15+
16+
import unittest
17+
from unittest import mock
18+
import pandas as pd
19+
20+
from timesketch.lib.charts import interface
21+
22+
23+
class TestBaseChart(unittest.TestCase):
24+
"""Tests for BaseChart."""
25+
26+
def test_get_chart_with_transform_converts_to_dict(self):
27+
"""Test that _get_chart_with_transform converts DataFrame to dict."""
28+
data = {
29+
"values": pd.DataFrame([{"a": 1, "b": 2}]),
30+
"encoding": {"x": "a", "y": "b"},
31+
}
32+
chart = interface.BaseChart(data)
33+
34+
# We want to verify that alt.Chart is called with a list of dicts,
35+
# not a DataFrame.
36+
with mock.patch("timesketch.lib.charts.interface.alt.Chart") as mock_chart:
37+
# pylint: disable=protected-access
38+
chart._get_chart_with_transform()
39+
40+
# Get the argument passed to alt.Chart
41+
args, _ = mock_chart.call_args
42+
passed_data = args[0]
43+
44+
# Assert it is a list (result of to_dict)
45+
self.assertIsInstance(passed_data, list)
46+
self.assertEqual(passed_data, [{"a": 1, "b": 2}])
47+
48+
def test_get_chart_with_transform_invalid_data(self):
49+
"""Test that _get_chart_with_transform handles invalid data gracefully."""
50+
# Setup mock data where values is not a DataFrame or list
51+
# Since init forces DataFrame, we have to mock self.values
52+
# on the instance.
53+
54+
data = {"values": pd.DataFrame(), "encoding": {"x": "a", "y": "b"}}
55+
chart = interface.BaseChart(data)
56+
57+
# Override values with invalid type
58+
chart.values = "invalid_string"
59+
60+
with mock.patch("timesketch.lib.charts.interface.logger") as mock_logger:
61+
with mock.patch("timesketch.lib.charts.interface.alt.Chart") as mock_chart:
62+
# pylint: disable=protected-access
63+
chart._get_chart_with_transform()
64+
65+
# Check error logged
66+
mock_logger.error.assert_called()
67+
68+
# Check alt.Chart called with empty list
69+
args, _ = mock_chart.call_args
70+
self.assertEqual(args[0], [])

0 commit comments

Comments
 (0)