Skip to content

Allow more column types to be interpolated #421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions python/tempo/interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pyspark.sql.dataframe import DataFrame
import pyspark.sql.functions as sfn
from pyspark.sql.types import NumericType
from pyspark.sql.window import Window

import tempo.resample as t_resample
Expand All @@ -12,7 +13,6 @@

# Interpolation fill options
method_options = ["zero", "null", "bfill", "ffill", "linear"]
supported_target_col_types = ["int", "bigint", "float", "double"]


class Interpolation:
Expand Down Expand Up @@ -58,10 +58,6 @@ def __validate_col(
raise ValueError(
f"Target Column: '{column}' does not exist in DataFrame."
)
if df.select(column).dtypes[0][1] not in supported_target_col_types:
raise TypeError(
f"Target Column needs to be one of the following types: {supported_target_col_types}"
)

if ts_col not in str(df.columns):
raise ValueError(
Expand Down Expand Up @@ -105,6 +101,17 @@ def __calc_linear_spark(
# Preserve column order
return interpolated.select(*df.columns)

def _is_valid_method_for_column(
self, series: DataFrame, method: str, col_name: str
) -> bool:
"""
zero and linear interpolation are only valid for numeric columns
"""
if method in ["linear", "zero"]:
return isinstance(series.schema[col_name].dataType, NumericType)
else:
return True

def __interpolate_column(
self,
series: DataFrame,
Expand All @@ -120,6 +127,12 @@ def __interpolate_column(
:param target_col: column to interpolate
:param method: interpolation function to fill missing values
"""

if not self._is_valid_method_for_column(series, method, target_col):
raise ValueError(
f"Interpolation method '{method}' is not supported for column '{target_col}' of type '{series.schema[target_col].dataType}'"
)

output_df: DataFrame = series

# create new column for if target column is interpolated
Expand Down
34 changes: 6 additions & 28 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class TSDF:
This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType).
"""

summarizable_types = ["int", "bigint", "float", "double"]

def __init__(
self,
df: DataFrame,
Expand Down Expand Up @@ -1136,14 +1138,12 @@ def withRangeStats(
prohibited_cols = [self.ts_col.lower()]
if self.partitionCols:
prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
# types that can be summarized
summarizable_types = ["int", "bigint", "float", "double"]
# filter columns to find summarizable columns
colsToSummarize = [
datatype[0]
for datatype in self.df.dtypes
if (
(datatype[1] in summarizable_types)
(datatype[1] in self.summarizable_types)
and (datatype[0].lower() not in prohibited_cols)
)
]
Expand Down Expand Up @@ -1202,14 +1202,12 @@ def withGroupedStats(
prohibited_cols = [self.ts_col.lower()]
if self.partitionCols:
prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
# types that can be summarized
summarizable_types = ["int", "bigint", "float", "double"]
# filter columns to find summarizable columns
metricCols = [
datatype[0]
for datatype in self.df.dtypes
if (
(datatype[1] in summarizable_types)
(datatype[1] in self.summarizable_types)
and (datatype[0].lower() not in prohibited_cols)
)
]
Expand Down Expand Up @@ -1332,17 +1330,7 @@ def interpolate(
partition_cols = self.partitionCols
if target_cols is None:
prohibited_cols: List[str] = partition_cols + [ts_col]
summarizable_types = ["int", "bigint", "float", "double"]

# get summarizable find summarizable columns
target_cols = [
datatype[0]
for datatype in self.df.dtypes
if (
(datatype[1] in summarizable_types)
and (datatype[0].lower() not in prohibited_cols)
)
]
target_cols = [col for col in self.df.columns if col not in prohibited_cols]

interpolate_service = t_interpolation.Interpolation(is_resampled=False)
tsdf_input = TSDF(self.df, ts_col=ts_col, partition_cols=partition_cols)
Expand Down Expand Up @@ -1673,17 +1661,7 @@ def interpolate(
# Set defaults for target columns, timestamp column and partition columns when not provided
if target_cols is None:
prohibited_cols: List[str] = self.partitionCols + [self.ts_col]
summarizable_types = ["int", "bigint", "float", "double"]

# get summarizable find summarizable columns
target_cols = [
datatype[0]
for datatype in self.df.dtypes
if (
(datatype[1] in summarizable_types)
and (datatype[0].lower() not in prohibited_cols)
)
]
target_cols = [col for col in self.df.columns if col not in prohibited_cols]

interpolate_service = t_interpolation.Interpolation(is_resampled=True)
tsdf_input = TSDF(
Expand Down
23 changes: 23 additions & 0 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ def as_sdf(self) -> DataFrame:
)
else:
df = df.withColumn(ts_col, sfn.to_timestamp(ts_col))
if "ts_convert_ntz" in self.df:
for ts_col in self.df["ts_convert_ntz"]:
# handle nested columns
if "." in ts_col:
col, field = ts_col.split(".")
convert_field_expr = sfn.to_timestamp_ntz(sfn.col(col).getField(field))
df = df.withColumn(
col, sfn.col(col).withField(field, convert_field_expr)
)
else:
df = df.withColumn(ts_col, sfn.to_timestamp_ntz(ts_col))
# convert date columns
if "date_convert" in self.df:
for date_col in self.df["date_convert"]:
Expand All @@ -154,8 +165,20 @@ def as_sdf(self) -> DataFrame:
else:
df = df.withColumn(date_col, sfn.to_date(date_col))

if "decimal_convert" in self.df:
for decimal_col in self.df["decimal_convert"]:
if "." in date_col:
col, field = date_col.split(".")
convert_field_expr = sfn.col(col).getField(field).cast("decimal")
df = df.withColumn(
col, sfn.col(col).withField(field, convert_field_expr)
)
else:
df = df.withColumn(decimal_col, sfn.col(decimal_col).cast("decimal"))

return df


def as_tsdf(self) -> TSDF:
"""
Constructs a TSDF from the test data
Expand Down
88 changes: 76 additions & 12 deletions python/tests/interpol_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,6 @@ def test_validate_col_exist_in_df(self):
"wrongly_named",
)

def test_validate_col_target_cols_data_type(self):
input_df: DataFrame = self.get_test_df_builder("init").as_sdf()

self.assertRaises(
TypeError,
self.interpolate_helper._Interpolation__validate_col,
input_df,
["partition_a", "partition_b"],
["string_target", "float_target"],
"event_ts",
)

def test_fill_validation(self):
"""Test fill parameter is valid."""

Expand Down Expand Up @@ -450,6 +438,82 @@ def test_interpolation_freq_is_not_supported_type(self):
True,
)

def test_non_numeric_forward_fill(self):
"""Verify that forward fill interpolation works on non-numeric columns."""

# load test data
simple_input_tsdf: TSDF = self.get_test_df_builder("non_numeric_init").as_tsdf()
expected_df: DataFrame = self.get_test_df_builder("expected").as_sdf()

actual_df: DataFrame = simple_input_tsdf.interpolate(
freq="30 seconds", func="ceil", method="ffill", ts_col="event_ts",
partition_cols=["partition_a", "partition_b"]
).df

self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_non_numeric_back_fill(self):
"""Verify that backward fill interpolation works on non-numeric columns."""

# load test data
simple_input_tsdf: TSDF = self.get_test_df_builder("non_numeric_init").as_tsdf()
expected_df: DataFrame = self.get_test_df_builder("expected").as_sdf()

actual_df: DataFrame = simple_input_tsdf.interpolate(
freq="30 seconds", func="ceil", method="bfill", ts_col="event_ts",
partition_cols=["partition_a", "partition_b"]
).df

self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_non_numeric_null_fill(self):
"""Verify that null method interpolation works on non-numeric columns."""

# load test data
simple_input_tsdf: TSDF = self.get_test_df_builder("non_numeric_init").as_tsdf()
expected_df: DataFrame = self.get_test_df_builder("expected").as_sdf()

actual_df: DataFrame = simple_input_tsdf.interpolate(
freq="30 seconds", func="ceil", method="null", ts_col="event_ts",
partition_cols=["partition_a", "partition_b"]
).df

self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_non_numeric_linear(self):
"""Verify that linear interpolation is prohibited for non-numeric columns."""

# load test data
simple_input_tsdf: TSDF = self.get_test_df_builder("non_numeric_init").as_tsdf()

self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
simple_input_tsdf,
freq="30 seconds", func="ceil", method="linear", ts_col="event_ts",
partition_cols=["partition_a", "partition_b"], target_cols=["string_col", "timestamp_col"],
show_interpolated=False
)

def test_non_numeric_zero(self):
"""Verify that zero interpolation is prohibited for non-numeric columns."""

# load test data
simple_input_tsdf: TSDF = self.get_test_df_builder("non_numeric_init").as_tsdf()

self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
simple_input_tsdf,
freq="30 seconds", func="ceil", method="linear", ts_col="event_ts",
partition_cols=["partition_a", "partition_b"], target_cols=["string_col", "timestamp_col"],
show_interpolated=False
)






class InterpolationIntegrationTest(SparkTest):
def test_interpolation_using_default_tsdf_params(self):
Expand Down
Loading
Loading