Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions timesketch/lib/aggregators/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,26 @@ def to_chart(
raise RuntimeError(f"No such chart type: {chart_name:s}")

try:
# We need to check if there is an encoding.
encoding = self.encoding
values_dataframe = self.to_pandas()

if not encoding:
logger.warning(
"No encoding found for chart [%s] with title [%s]. "
"Skipping chart generation.",
chart_name,
chart_title,
)
if as_html:
return ""
if as_chart:
return None
return {}

chart_data = {"values": values_dataframe, "encoding": encoding}
chart_object = chart_class(
self.to_pandas(),
chart_data,
title=chart_title,
sketch_url=self._sketch_url,
field=self.field,
Expand Down Expand Up @@ -192,19 +210,25 @@ def __init__(self, sketch_id=None, indices=None, timeline_ids=None):

self.opensearch = OpenSearchDataStore()

self._sketch_url = f"/sketch/{sketch_id:d}/explore"
if sketch_id:
self._sketch_url = f"/sketch/{sketch_id:d}/explore"
self.sketch = SQLSketch.get_by_id(sketch_id)
else:
self._sketch_url = ""
self.sketch = None

self.field = ""
self.indices = indices
self.sketch = SQLSketch.get_by_id(sketch_id)
self.timeline_ids = None

active_timelines = self.sketch.active_timelines
if not self.indices:
self.indices = [t.searchindex.index_name for t in active_timelines]
if self.sketch:
active_timelines = self.sketch.active_timelines
if not self.indices:
self.indices = [t.searchindex.index_name for t in active_timelines]

if timeline_ids:
valid_ids = [t.id for t in active_timelines]
self.timeline_ids = [t for t in timeline_ids if t in valid_ids]
if timeline_ids:
valid_ids = [t.id for t in active_timelines]
self.timeline_ids = [t for t in timeline_ids if t in valid_ids]

@property
def chart_title(self):
Expand Down
148 changes: 148 additions & 0 deletions timesketch/lib/aggregators/interface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2026 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for aggregator interface."""

import unittest
from unittest import mock

from timesketch.lib.aggregators import interface
from timesketch.lib.charts import interface as chart_interface


class MockChart(chart_interface.BaseChart):
"""Mock chart class."""

NAME = "mockchart"

def generate(self):
return mock.Mock(
to_dict=lambda: {"chart": "data"}, to_html=lambda: "<html></html>"
)


class TestAggregationResult(unittest.TestCase):
"""Tests for AggregationResult."""

def setUp(self):
self.values = [{"field": "foo", "count": 10}, {"field": "bar", "count": 20}]
self.encoding = {
"x": {"field": "field", "type": "nominal"},
"y": {"field": "count", "type": "quantitative"},
}
self.result = interface.AggregationResult(
encoding=self.encoding, values=self.values
)

@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
def test_to_chart_valid(self, mock_get_chart):
"""Test to_chart with valid data."""
mock_get_chart.return_value = MockChart

# Call to_chart
result = self.result.to_chart(chart_name="mockchart")

# Assert chart was generated and returned as dict by default
self.assertEqual(result, {"chart": "data"})

@mock.patch("timesketch.lib.aggregators.interface.logger")
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
def test_to_chart_missing_encoding(self, mock_get_chart, mock_logger):
"""Test to_chart with missing encoding (should skip generation)."""
mock_get_chart.return_value = MockChart

# Result without encoding
result_obj = interface.AggregationResult(encoding=None, values=self.values)

result = result_obj.to_chart(chart_name="mockchart")

# Should return empty dict
self.assertEqual(result, {})

# Verify logger.warning was called
mock_logger.warning.assert_called_with(
"No encoding found for chart [%s] with title [%s]. "
"Skipping chart generation.",
"mockchart",
"",
)

@mock.patch("timesketch.lib.aggregators.interface.logger")
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
def test_to_chart_missing_values_and_encoding(self, mock_get_chart, mock_logger):
"""Test to_chart with missing values and encoding."""
mock_get_chart.return_value = MockChart

# Result with empty values and no encoding
result_obj = interface.AggregationResult(encoding=None, values=[])

result = result_obj.to_chart(chart_name="mockchart")

# Should return empty dict
self.assertEqual(result, {})

# Verify logger.warning was called
mock_logger.warning.assert_called_with(
"No encoding found for chart [%s] with title [%s]. "
"Skipping chart generation.",
"mockchart",
"",
)

@mock.patch("timesketch.lib.aggregators.interface.logger")
@mock.patch("timesketch.lib.charts.manager.ChartManager.get_chart")
def test_to_chart_missing_encoding_unable_to_guess(
self, mock_get_chart, mock_logger
):
"""Test to_chart when unable to guess encoding (e.g. 1 column)."""
mock_get_chart.return_value = MockChart

values = [{"field": "foo"}] # Only 1 column
result_obj = interface.AggregationResult(encoding=None, values=values)

result = result_obj.to_chart(chart_name="mockchart")
self.assertEqual(result, {})

# Verify logger.warning was called
mock_logger.warning.assert_called()


class TestBaseAggregator(unittest.TestCase):
"""Tests for BaseAggregator."""

@mock.patch("timesketch.lib.aggregators.interface.OpenSearchDataStore")
@mock.patch("timesketch.lib.aggregators.interface.SQLSketch")
def test_init_no_sketch_id(self, _mock_sketch, _mock_ds):
"""Test initialization without sketch_id."""
agg = interface.BaseAggregator(indices=["index1"])
self.assertIsNone(agg.sketch)
self.assertEqual(agg.indices, ["index1"])
# pylint: disable=protected-access
self.assertEqual(agg._sketch_url, "")

@mock.patch("timesketch.lib.aggregators.interface.OpenSearchDataStore")
@mock.patch("timesketch.lib.aggregators.interface.SQLSketch")
def test_init_with_sketch_id(self, mock_sketch, _mock_ds):
"""Test initialization with sketch_id."""
mock_sketch_obj = mock.Mock()
mock_t1 = mock.Mock()
mock_t1.searchindex.index_name = "index1"
mock_t1.id = 1
mock_sketch_obj.active_timelines = [mock_t1]
mock_sketch.get_by_id.return_value = mock_sketch_obj

agg = interface.BaseAggregator(sketch_id=1)
self.assertEqual(agg.sketch, mock_sketch_obj)
self.assertEqual(agg.indices, ["index1"])
# pylint: disable=protected-access
self.assertEqual(agg._sketch_url, "/sketch/1/explore")
24 changes: 20 additions & 4 deletions timesketch/lib/charts/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,34 @@ def _get_chart_with_transform(self):
A LayerChart object, either with added transform or not, depending
on whether sketch URL and field are set.
"""
chart = alt.Chart(self.values)
if not self._sketch_url:
return chart
data = self.values
# We need to convert the dataframe to a dict to avoid issues with
# newer versions of pandas and older versions of altair.
# See https://github.com/altair-viz/altair/issues/2763
if isinstance(data, pd.DataFrame):
data = data.to_dict(orient="records")

if not isinstance(data, list):
logger.error(
"Chart data is not in a supported format. "
"Expected pandas DataFrame or list, got %s",
type(data),
)
data = []

if not self._field:
chart = alt.Chart(data)

if not self._sketch_url or not self._field:
return chart

datum = getattr(alt.datum, self._field)
if self._aggregation_id:
agg_string = f"a={self._aggregation_id:d}&"
else:
agg_string = ""

# Construct Vega-Lite expression
# usage: url + datum.field + '" ' + extra
url = f'{self._sketch_url:s}?{agg_string:s}q={self._field:s}:"'
return chart.transform_calculate(
url=url + datum + '" ' + self._extra_query_url
Expand Down
70 changes: 70 additions & 0 deletions timesketch/lib/charts/interface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2026 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for chart interface."""

import unittest
from unittest import mock
import pandas as pd

from timesketch.lib.charts import interface


class TestBaseChart(unittest.TestCase):
"""Tests for BaseChart."""

def test_get_chart_with_transform_converts_to_dict(self):
"""Test that _get_chart_with_transform converts DataFrame to dict."""
data = {
"values": pd.DataFrame([{"a": 1, "b": 2}]),
"encoding": {"x": "a", "y": "b"},
}
chart = interface.BaseChart(data)

# We want to verify that alt.Chart is called with a list of dicts,
# not a DataFrame.
with mock.patch("timesketch.lib.charts.interface.alt.Chart") as mock_chart:
# pylint: disable=protected-access
chart._get_chart_with_transform()

# Get the argument passed to alt.Chart
args, _ = mock_chart.call_args
passed_data = args[0]

# Assert it is a list (result of to_dict)
self.assertIsInstance(passed_data, list)
self.assertEqual(passed_data, [{"a": 1, "b": 2}])

def test_get_chart_with_transform_invalid_data(self):
"""Test that _get_chart_with_transform handles invalid data gracefully."""
# Setup mock data where values is not a DataFrame or list
# Since init forces DataFrame, we have to mock self.values
# on the instance.

data = {"values": pd.DataFrame(), "encoding": {"x": "a", "y": "b"}}
chart = interface.BaseChart(data)

# Override values with invalid type
chart.values = "invalid_string"

with mock.patch("timesketch.lib.charts.interface.logger") as mock_logger:
with mock.patch("timesketch.lib.charts.interface.alt.Chart") as mock_chart:
# pylint: disable=protected-access
chart._get_chart_with_transform()

# Check error logged
mock_logger.error.assert_called()

# Check alt.Chart called with empty list
args, _ = mock_chart.call_args
self.assertEqual(args[0], [])