Skip to content

feat: add derivation expression evaluator #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM mcr.microsoft.com/vscode/devcontainers/python:3.10-buster
USER vscode
RUN curl -s "https://get.sdkman.io" | bash
SHELL ["/bin/bash", "-c"]
RUN source "/home/vscode/.sdkman/bin/sdkman-init.sh" && sdk install java 20.0.2-graalce
RUN mkdir -p ~/lib && cd ~/lib && curl -L -O http://www.antlr.org/download/antlr-4.13.1-complete.jar
ENV ANTLR_JAR="~/lib/antlr-4.13.1-complete.jar"
USER root
24 changes: 24 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"name": "substrait-python-devcontainer",
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},

// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {
// "ghcr.io/devcontainers/features/nix:1": {}
// },

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "poetry install"

// Configure tool-specific properties.
// "customizations": {},

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
antlr:
java -jar ${ANTLR_JAR} -o src/substrait/gen/antlr -Dlanguage=Python3 SubstraitType.g4
209 changes: 209 additions & 0 deletions SubstraitType.g4
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
grammar SubstraitType;

//
fragment A : [aA];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we move this into core we should use:

options {
    caseInsensitive = true;
}

so we can eliminate all these fragments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can open an issue in substrait-java for this..

fragment B : [bB];
fragment C : [cC];
fragment D : [dD];
fragment E : [eE];
fragment F : [fF];
fragment G : [gG];
fragment H : [hH];
fragment I : [iI];
fragment J : [jJ];
fragment K : [kK];
fragment L : [lL];
fragment M : [mM];
fragment N : [nN];
fragment O : [oO];
fragment P : [pP];
fragment Q : [qQ];
fragment R : [rR];
fragment S : [sS];
fragment T : [tT];
fragment U : [uU];
fragment V : [vV];
fragment W : [wW];
fragment X : [xX];
fragment Y : [yY];
fragment Z : [zZ];


If : I F;
Then : T H E N;
Else : E L S E;

// TYPES
Boolean : B O O L E A N;
I8 : I '8';
I16 : I '16';
I32 : I '32';
I64 : I '64';
FP32 : F P '32';
FP64 : F P '64';
String : S T R I N G;
Binary : B I N A R Y;
Timestamp: T I M E S T A M P;
TimestampTZ: T I M E S T A M P '_' T Z;
Date : D A T E;
Time : T I M E;
IntervalYear: I N T E R V A L '_' Y E A R;
IntervalDay: I N T E R V A L '_' D A Y;
IntervalCompound: I N T E R V A L '_' C O M P O U N D;
UUID : U U I D;
Decimal : D E C I M A L;
PrecisionTimestamp: P R E C I S I O N '_' T I M E S T A M P;
PrecisionTimestampTZ: P R E C I S I O N '_' T I M E S T A M P '_' T Z;
FixedChar: F I X E D C H A R;
VarChar : V A R C H A R;
FixedBinary: F I X E D B I N A R Y;
Struct : S T R U C T;
NStruct : N S T R U C T;
List : L I S T;
Map : M A P;
ANY : A N Y;
UserDefined: U '!';


// OPERATIONS
And : A N D;
Or : O R;
Assign : ':=';

// COMPARE
Eq : '=';
NotEquals: '!=';
Gte : '>=';
Lte : '<=';
Gt : '>';
Lt : '<';
Bang : '!';


// MATH
Plus : '+';
Minus : '-';
Asterisk : '*';
ForwardSlash : '/';
Percent : '%';

// ORGANIZE
OBracket : '[';
CBracket : ']';
OParen : '(';
CParen : ')';
SColon : ';';
Comma : ',';
QMark : '?';
Colon : ':';
SingleQuote: '\'';


Number
: '-'? Int
;

Identifier
: ('a'..'z' | 'A'..'Z' | '_' | '$') ('a'..'z' | 'A'..'Z' | '_' | '$' | Digit)*
;

LineComment
: '//' ~[\r\n]* -> channel(HIDDEN)
;

BlockComment
: ( '/*'
( '/'* BlockComment
| ~[/*]
| '/'+ ~[/*]
| '*'+ ~[/*]
)*
'*'*
'*/'
) -> channel(HIDDEN)
;

Whitespace
: [ \t]+ -> channel(HIDDEN)
;

Newline
: ( '\r' '\n'?
| '\n'
)
;


fragment Int
: '1'..'9' Digit*
| '0'
;

fragment Digit
: '0'..'9'
;

start: expr EOF;

scalarType
: Boolean #Boolean
| I8 #i8
| I16 #i16
| I32 #i32
| I64 #i64
| FP32 #fp32
| FP64 #fp64
| String #string
| Binary #binary
| Timestamp #timestamp
| TimestampTZ #timestampTz
| Date #date
| Time #time
| IntervalYear #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
;

parameterizedType
: FixedChar isnull='?'? Lt len=numericParameter Gt #fixedChar
| VarChar isnull='?'? Lt len=numericParameter Gt #varChar
| FixedBinary isnull='?'? Lt len=numericParameter Gt #fixedBinary
| Decimal isnull='?'? Lt precision=numericParameter Comma scale=numericParameter Gt #decimal
| IntervalDay isnull='?'? Lt precision=numericParameter Gt #intervalDay
| IntervalCompound isnull='?'? Lt precision=numericParameter Gt #intervalCompound
| PrecisionTimestamp isnull='?'? Lt precision=numericParameter Gt #precisionTimestamp
| PrecisionTimestampTZ isnull='?'? Lt precision=numericParameter Gt #precisionTimestampTZ
| Struct isnull='?'? Lt expr (Comma expr)* Gt #struct
| NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct
| List isnull='?'? Lt expr Gt #list
| Map isnull='?'? Lt key=expr Comma value=expr Gt #map
;

numericParameter
: Number #numericLiteral
| Identifier #numericParameterName
| expr #numericExpression
;

anyType: ANY;

type
: scalarType isnull='?'?
| parameterizedType
| anyType isnull='?'?
;

// : (OParen innerExpr CParen | innerExpr)

expr
: OParen expr CParen #ParenExpression
| Identifier Eq expr Newline+ (Identifier Eq expr Newline+)* finalType=type Newline* #MultilineDefinition
| type #TypeLiteral
| number=Number #LiteralNumber
| identifier=Identifier isnull='?'? #TypeParam
| Identifier OParen (expr (Comma expr)*)? CParen #FunctionCall
| left=expr op=(And | Or | Plus | Minus | Lt | Gt | Eq | NotEquals | Lte | Gte | Asterisk | ForwardSlash) right=expr #BinaryExpr
| If ifExpr=expr Then thenExpr=expr Else elseExpr=expr #IfExpr
| (Bang) expr #NotExpr
| ifExpr=expr QMark thenExpr=expr Colon elseExpr=expr #Ternary
;
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ dynamic = ["version"]
write_to = "src/substrait/_version.py"

[project.optional-dependencies]
extensions = ["antlr4-python3-runtime"]
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
test = ["pytest >= 7.0.0"]
test = ["pytest >= 7.0.0", "antlr4-python3-runtime"]

[tool.pytest.ini_options]
pythonpath = "src"
Expand Down
102 changes: 102 additions & 0 deletions src/substrait/derivation_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Optional
from antlr4 import InputStream, CommonTokenStream
from substrait.gen.antlr.SubstraitTypeLexer import SubstraitTypeLexer
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
from substrait.gen.proto.type_pb2 import Type


def _evaluate(x, values: dict):
if type(x) == SubstraitTypeParser.BinaryExprContext:
left = _evaluate(x.left, values)
right = _evaluate(x.right, values)

if x.op.text == "+":
return left + right
elif x.op.text == "-":
return left - right
elif x.op.text == "*":
return left * right
elif x.op.text == ">":
return left > right
elif x.op.text == ">=":
return left >= right
elif x.op.text == "<":
return left < right
elif x.op.text == "<=":
return left <= right
else:
raise Exception(f"Unknown binary op {x.op.text}")
elif type(x) == SubstraitTypeParser.LiteralNumberContext:
return int(x.number.text)
elif type(x) == SubstraitTypeParser.TypeParamContext:
return values[x.identifier.text]
elif type(x) == SubstraitTypeParser.NumericParameterNameContext:
return values[x.Identifier().symbol.text]
elif type(x) == SubstraitTypeParser.ParenExpressionContext:
return _evaluate(x.expr(), values)
elif type(x) == SubstraitTypeParser.FunctionCallContext:
exprs = [_evaluate(e, values) for e in x.expr()]
func = x.Identifier().symbol.text

if func == "min":
return min(*exprs)
elif func == "max":
return max(*exprs)
else:
raise Exception(f"Unknown function {func}")
elif type(x) == SubstraitTypeParser.TypeContext:
scalar_type = x.scalarType()
parametrized_type = x.parameterizedType()
if scalar_type:
if isinstance(scalar_type, SubstraitTypeParser.I8Context):
return Type(i8=Type.I8())
elif isinstance(scalar_type, SubstraitTypeParser.I16Context):
return Type(i16=Type.I16())
elif isinstance(scalar_type, SubstraitTypeParser.I32Context):
return Type(i32=Type.I32())
elif isinstance(scalar_type, SubstraitTypeParser.I64Context):
return Type(i64=Type.I64())
elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context):
return Type(fp32=Type.FP32())
elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context):
return Type(fp64=Type.FP64())
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
return Type(bool=Type.Boolean())
else:
raise Exception(f"Unknown scalar type {type(scalar_type)}")
elif parametrized_type:
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
precision = _evaluate(parametrized_type.precision, values)
scale = _evaluate(parametrized_type.scale, values)
return Type(decimal=Type.Decimal(precision=precision, scale=scale))
raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
else:
raise Exception("either scalar_type or parametrized_type is required")
elif type(x) == SubstraitTypeParser.NumericExpressionContext:
return _evaluate(x.expr(), values)
elif type(x) == SubstraitTypeParser.TernaryContext:
ifExpr = _evaluate(x.ifExpr, values)
thenExpr = _evaluate(x.thenExpr, values)
elseExpr = _evaluate(x.elseExpr, values)

return thenExpr if ifExpr else elseExpr
elif type(x) == SubstraitTypeParser.MultilineDefinitionContext:
lines = zip(x.Identifier(), x.expr())

for i, e in lines:
identifier = i.symbol.text
expr_eval = _evaluate(e, values)
values[identifier] = expr_eval

return _evaluate(x.finalType, values)
elif type(x) == SubstraitTypeParser.TypeLiteralContext:
return _evaluate(x.type_(), values)
else:
raise Exception(f"Unknown token type {type(x)}")


def evaluate(x: str, values: Optional[dict] = None):
lexer = SubstraitTypeLexer(InputStream(x))
stream = CommonTokenStream(lexer)
parser = SubstraitTypeParser(stream)
return _evaluate(parser.expr(), values)
Loading
Loading