Skip to content

Commit 7f32288

Browse files
fix: data type for static schema (#1235)
if string parse able to int, consider it valid if string parse able to float, consider it valid
1 parent 7d9b9ab commit 7f32288

File tree

4 files changed

+114
-48
lines changed

4 files changed

+114
-48
lines changed

src/event/format/json.rs

+105-46
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use anyhow::anyhow;
2323
use arrow_array::RecordBatch;
2424
use arrow_json::reader::{infer_json_schema_from_iterator, ReaderBuilder};
2525
use arrow_schema::{DataType, Field, Fields, Schema};
26-
use chrono::{DateTime, NaiveDateTime, Utc};
26+
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
2727
use datafusion::arrow::util::bit_util::round_upto_multiple_of_64;
2828
use itertools::Itertools;
2929
use serde_json::Value;
@@ -62,6 +62,7 @@ impl EventFormat for Event {
6262
schema: &HashMap<String, Arc<Field>>,
6363
time_partition: Option<&String>,
6464
schema_version: SchemaVersion,
65+
static_schema_flag: bool,
6566
) -> Result<(Self::Data, Vec<Arc<Field>>, bool), anyhow::Error> {
6667
let stream_schema = schema;
6768

@@ -111,7 +112,7 @@ impl EventFormat for Event {
111112

112113
if value_arr
113114
.iter()
114-
.any(|value| fields_mismatch(&schema, value, schema_version))
115+
.any(|value| fields_mismatch(&schema, value, schema_version, static_schema_flag))
115116
{
116117
return Err(anyhow!(
117118
"Could not process this event due to mismatch in datatype"
@@ -253,73 +254,131 @@ fn collect_keys<'a>(values: impl Iterator<Item = &'a Value>) -> Result<Vec<&'a s
253254
Ok(keys)
254255
}
255256

256-
fn fields_mismatch(schema: &[Arc<Field>], body: &Value, schema_version: SchemaVersion) -> bool {
257+
fn fields_mismatch(
258+
schema: &[Arc<Field>],
259+
body: &Value,
260+
schema_version: SchemaVersion,
261+
static_schema_flag: bool,
262+
) -> bool {
257263
for (name, val) in body.as_object().expect("body is of object variant") {
258264
if val.is_null() {
259265
continue;
260266
}
261267
let Some(field) = get_field(schema, name) else {
262268
return true;
263269
};
264-
if !valid_type(field.data_type(), val, schema_version) {
270+
if !valid_type(field, val, schema_version, static_schema_flag) {
265271
return true;
266272
}
267273
}
268274
false
269275
}
270276

271-
fn valid_type(data_type: &DataType, value: &Value, schema_version: SchemaVersion) -> bool {
272-
match data_type {
277+
fn valid_type(
278+
field: &Field,
279+
value: &Value,
280+
schema_version: SchemaVersion,
281+
static_schema_flag: bool,
282+
) -> bool {
283+
match field.data_type() {
273284
DataType::Boolean => value.is_boolean(),
274-
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => value.is_i64(),
285+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
286+
validate_int(value, static_schema_flag)
287+
}
275288
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => value.is_u64(),
276289
DataType::Float16 | DataType::Float32 => value.is_f64(),
277-
// All numbers can be cast as Float64 from schema version v1
278-
DataType::Float64 if schema_version == SchemaVersion::V1 => value.is_number(),
279-
DataType::Float64 if schema_version != SchemaVersion::V1 => value.is_f64(),
290+
DataType::Float64 => validate_float(value, schema_version, static_schema_flag),
280291
DataType::Utf8 => value.is_string(),
281-
DataType::List(field) => {
282-
let data_type = field.data_type();
283-
if let Value::Array(arr) = value {
284-
for elem in arr {
285-
if elem.is_null() {
286-
continue;
287-
}
288-
if !valid_type(data_type, elem, schema_version) {
289-
return false;
290-
}
291-
}
292-
}
293-
true
294-
}
292+
DataType::List(field) => validate_list(field, value, schema_version, static_schema_flag),
295293
DataType::Struct(fields) => {
296-
if let Value::Object(val) = value {
297-
for (key, value) in val {
298-
let field = (0..fields.len())
299-
.find(|idx| fields[*idx].name() == key)
300-
.map(|idx| &fields[idx]);
301-
302-
if let Some(field) = field {
303-
if value.is_null() {
304-
continue;
305-
}
306-
if !valid_type(field.data_type(), value, schema_version) {
307-
return false;
308-
}
309-
} else {
310-
return false;
311-
}
312-
}
313-
true
314-
} else {
315-
false
294+
validate_struct(fields, value, schema_version, static_schema_flag)
295+
}
296+
DataType::Date32 => {
297+
if let Value::String(s) = value {
298+
return NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok();
316299
}
300+
false
317301
}
318302
DataType::Timestamp(_, _) => value.is_string() || value.is_number(),
319303
_ => {
320-
error!("Unsupported datatype {:?}, value {:?}", data_type, value);
321-
unreachable!()
304+
error!(
305+
"Unsupported datatype {:?}, value {:?}",
306+
field.data_type(),
307+
value
308+
);
309+
false
310+
}
311+
}
312+
}
313+
314+
fn validate_int(value: &Value, static_schema_flag: bool) -> bool {
315+
// allow casting string to int for static schema
316+
if static_schema_flag {
317+
if let Value::String(s) = value {
318+
return s.trim().parse::<i64>().is_ok();
319+
}
320+
}
321+
value.is_i64()
322+
}
323+
324+
fn validate_float(value: &Value, schema_version: SchemaVersion, static_schema_flag: bool) -> bool {
325+
// allow casting string to int for static schema
326+
if static_schema_flag {
327+
if let Value::String(s) = value.clone() {
328+
let trimmed = s.trim();
329+
return trimmed.parse::<f64>().is_ok() || trimmed.parse::<i64>().is_ok();
330+
}
331+
return value.is_number();
332+
}
333+
match schema_version {
334+
SchemaVersion::V1 => value.is_number(),
335+
_ => value.is_f64(),
336+
}
337+
}
338+
339+
fn validate_list(
340+
field: &Field,
341+
value: &Value,
342+
schema_version: SchemaVersion,
343+
static_schema_flag: bool,
344+
) -> bool {
345+
if let Value::Array(arr) = value {
346+
for elem in arr {
347+
if elem.is_null() {
348+
continue;
349+
}
350+
if !valid_type(field, elem, schema_version, static_schema_flag) {
351+
return false;
352+
}
353+
}
354+
}
355+
true
356+
}
357+
358+
fn validate_struct(
359+
fields: &Fields,
360+
value: &Value,
361+
schema_version: SchemaVersion,
362+
static_schema_flag: bool,
363+
) -> bool {
364+
if let Value::Object(val) = value {
365+
for (key, value) in val {
366+
let field = fields.iter().find(|f| f.name() == key);
367+
368+
if let Some(field) = field {
369+
if value.is_null() {
370+
continue;
371+
}
372+
if !valid_type(field, value, schema_version, static_schema_flag) {
373+
return false;
374+
}
375+
} else {
376+
return false;
377+
}
322378
}
379+
true
380+
} else {
381+
false
323382
}
324383
}
325384

src/event/format/mod.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ pub trait EventFormat: Sized {
102102
schema: &HashMap<String, Arc<Field>>,
103103
time_partition: Option<&String>,
104104
schema_version: SchemaVersion,
105+
static_schema_flag: bool,
105106
) -> Result<(Self::Data, EventSchema, bool), AnyError>;
106107

107108
fn decode(data: Self::Data, schema: Arc<Schema>) -> Result<RecordBatch, AnyError>;
@@ -117,8 +118,12 @@ pub trait EventFormat: Sized {
117118
schema_version: SchemaVersion,
118119
) -> Result<(RecordBatch, bool), AnyError> {
119120
let p_timestamp = self.get_p_timestamp();
120-
let (data, mut schema, is_first) =
121-
self.to_data(storage_schema, time_partition, schema_version)?;
121+
let (data, mut schema, is_first) = self.to_data(
122+
storage_schema,
123+
time_partition,
124+
schema_version,
125+
static_schema_flag,
126+
)?;
122127

123128
if get_field(&schema, DEFAULT_TIMESTAMP_KEY).is_some() {
124129
return Err(anyhow!(

src/query/stream_schema_provider.rs

+1
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,7 @@ fn cast_or_none(scalar: &ScalarValue) -> Option<CastRes<'_>> {
967967
ScalarValue::UInt32(val) => val.map(|val| CastRes::Int(val as i64)),
968968
ScalarValue::UInt64(val) => val.map(|val| CastRes::Int(val as i64)),
969969
ScalarValue::Utf8(val) => val.as_ref().map(|val| CastRes::String(val)),
970+
ScalarValue::Date32(val) => val.map(|val| CastRes::Int(val as i64)),
970971
ScalarValue::TimestampMillisecond(val, _) => val.map(CastRes::Int),
971972
_ => None,
972973
}

src/static_schema.rs

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ pub fn convert_static_schema_to_arrow_schema(
111111
"boolean" => DataType::Boolean,
112112
"string" => DataType::Utf8,
113113
"datetime" => DataType::Timestamp(TimeUnit::Millisecond, None),
114+
"date" => DataType::Date32,
114115
"string_list" => {
115116
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))
116117
}

0 commit comments

Comments
 (0)