Skip to content

Commit 60e1d64

Browse files
committed
engine: support for some array builtins, like SUM
1 parent 84f106b commit 60e1d64

File tree

9 files changed

+340
-99
lines changed

9 files changed

+340
-99
lines changed

src/engine/error_codes.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export function errorCodeDescription(code: ErrorCode): string {
3535
case ErrorCode.UnknownBuiltin:
3636
return 'Reference to unknown or unimplemented builtin';
3737
case ErrorCode.BadBuiltinArgs:
38-
return 'Builtin function arguments';
38+
return 'Incorrect arguments to a builtin function (e.g. too many, too few)';
3939
case ErrorCode.EmptyEquation:
4040
return 'Variable has empty equation';
4141
case ErrorCode.BadModuleInputDst:

src/simlin-engine/src/ast.rs

+54-4
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,18 @@ impl Expr {
329329
let a = args.remove(0);
330330
BuiltinFn::$builtin_fn(Box::new(a), Box::new(b))
331331
}};
332+
($builtin_fn:tt, 1, 2) => {{
333+
if args.len() == 1 {
334+
let a = args.remove(0);
335+
BuiltinFn::$builtin_fn(Box::new(a), None)
336+
} else if args.len() == 2 {
337+
let b = args.remove(1);
338+
let a = args.remove(0);
339+
BuiltinFn::$builtin_fn(Box::new(a), Some(Box::new(b)))
340+
} else {
341+
return eqn_err!(BadBuiltinArgs, loc.start, loc.end);
342+
}
343+
}};
332344
($builtin_fn:tt, 3) => {{
333345
if args.len() != 3 {
334346
return eqn_err!(BadBuiltinArgs, loc.start, loc.end);
@@ -339,6 +351,26 @@ impl Expr {
339351
let a = args.remove(0);
340352
BuiltinFn::$builtin_fn(Box::new(a), Box::new(b), Box::new(c))
341353
}};
354+
($builtin_fn:tt, 1, 3) => {{
355+
if args.len() == 1 {
356+
let a = args.remove(0);
357+
BuiltinFn::$builtin_fn(Box::new(a), None)
358+
} else if args.len() == 2 {
359+
let b = args.remove(1);
360+
let a = args.remove(0);
361+
BuiltinFn::$builtin_fn(Box::new(a), Some((Box::new(b), None)))
362+
} else if args.len() == 3 {
363+
let c = args.remove(2);
364+
let b = args.remove(1);
365+
let a = args.remove(0);
366+
BuiltinFn::$builtin_fn(
367+
Box::new(a),
368+
Some((Box::new(b), Some(Box::new(c)))),
369+
)
370+
} else {
371+
return eqn_err!(BadBuiltinArgs, loc.start, loc.end);
372+
}
373+
}};
342374
($builtin_fn:tt, 2, 3) => {{
343375
if args.len() == 2 {
344376
let b = args.remove(1);
@@ -381,8 +413,8 @@ impl Expr {
381413
}
382414
"ln" => check_arity!(Ln, 1),
383415
"log10" => check_arity!(Log10, 1),
384-
"max" => check_arity!(Max, 2),
385-
"min" => check_arity!(Min, 2),
416+
"max" => check_arity!(Max, 1, 2),
417+
"min" => check_arity!(Min, 1, 2),
386418
"pi" => check_arity!(Pi, 0),
387419
"pulse" => check_arity!(Pulse, 2, 3),
388420
"ramp" => check_arity!(Ramp, 2, 3),
@@ -395,6 +427,10 @@ impl Expr {
395427
"time_step" | "dt" => check_arity!(TimeStep, 0),
396428
"initial_time" => check_arity!(StartTime, 0),
397429
"final_time" => check_arity!(FinalTime, 0),
430+
"rank" => check_arity!(Rank, 1, 3),
431+
"size" => check_arity!(Size, 1),
432+
"stddev" => check_arity!(Stddev, 1),
433+
"sum" => check_arity!(Sum, 1),
398434
_ => {
399435
// TODO: this could be a table reference, array reference,
400436
// or module instantiation according to 3.3.2 of the spec
@@ -468,11 +504,11 @@ impl Expr {
468504
),
469505
BuiltinFn::Max(a, b) => BuiltinFn::Max(
470506
Box::new(a.constify_dimensions(scope)),
471-
Box::new(b.constify_dimensions(scope)),
507+
b.map(|expr| Box::new(expr.constify_dimensions(scope))),
472508
),
473509
BuiltinFn::Min(a, b) => BuiltinFn::Min(
474510
Box::new(a.constify_dimensions(scope)),
475-
Box::new(b.constify_dimensions(scope)),
511+
b.map(|expr| Box::new(expr.constify_dimensions(scope))),
476512
),
477513
BuiltinFn::Step(a, b) => BuiltinFn::Step(
478514
Box::new(a.constify_dimensions(scope)),
@@ -497,6 +533,20 @@ impl Expr {
497533
Box::new(b.constify_dimensions(scope)),
498534
c.map(|arg| Box::new(arg.constify_dimensions(scope))),
499535
),
536+
BuiltinFn::Rank(a, rest) => BuiltinFn::Rank(
537+
Box::new(a.constify_dimensions(scope)),
538+
rest.map(|(b, c)| {
539+
(
540+
Box::new(b.constify_dimensions(scope)),
541+
c.map(|c| Box::new(c.constify_dimensions(scope))),
542+
)
543+
}),
544+
),
545+
BuiltinFn::Size(a) => BuiltinFn::Size(Box::new(a.constify_dimensions(scope))),
546+
BuiltinFn::Stddev(a) => {
547+
BuiltinFn::Stddev(Box::new(a.constify_dimensions(scope)))
548+
}
549+
BuiltinFn::Sum(a) => BuiltinFn::Sum(Box::new(a.constify_dimensions(scope))),
500550
};
501551
Expr::App(func, loc)
502552
}

src/simlin-engine/src/builtins.rs

+92-51
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ pub enum BuiltinFn<Expr> {
6565
IsModuleInput(String, Loc),
6666
Ln(Box<Expr>),
6767
Log10(Box<Expr>),
68-
Max(Box<Expr>, Box<Expr>),
68+
// max takes 2 scalar args OR 1-2 args for an array
69+
Max(Box<Expr>, Option<Box<Expr>>),
6970
Mean(Vec<Expr>),
70-
Min(Box<Expr>, Box<Expr>),
71+
// max takes 2 scalar args OR 1-2 args for an array
72+
Min(Box<Expr>, Option<Box<Expr>>),
7173
Pi,
7274
Pulse(Box<Expr>, Box<Expr>, Option<Box<Expr>>),
7375
Ramp(Box<Expr>, Box<Expr>, Option<Box<Expr>>),
@@ -80,38 +82,49 @@ pub enum BuiltinFn<Expr> {
8082
TimeStep,
8183
StartTime,
8284
FinalTime,
85+
// array-only builtins
86+
Rank(Box<Expr>, Option<(Box<Expr>, Option<Box<Expr>>)>),
87+
Size(Box<Expr>),
88+
Stddev(Box<Expr>),
89+
Sum(Box<Expr>),
8390
}
8491

8592
impl<Expr> BuiltinFn<Expr> {
8693
pub fn name(&self) -> &'static str {
94+
use BuiltinFn::*;
8795
match self {
88-
BuiltinFn::Lookup(_, _, _) => "lookup",
89-
BuiltinFn::Abs(_) => "abs",
90-
BuiltinFn::Arccos(_) => "arccos",
91-
BuiltinFn::Arcsin(_) => "arcsin",
92-
BuiltinFn::Arctan(_) => "arctan",
93-
BuiltinFn::Cos(_) => "cos",
94-
BuiltinFn::Exp(_) => "exp",
95-
BuiltinFn::Inf => "inf",
96-
BuiltinFn::Int(_) => "int",
97-
BuiltinFn::IsModuleInput(_, _) => "ismoduleinput",
98-
BuiltinFn::Ln(_) => "ln",
99-
BuiltinFn::Log10(_) => "log10",
100-
BuiltinFn::Max(_, _) => "max",
101-
BuiltinFn::Mean(_) => "mean",
102-
BuiltinFn::Min(_, _) => "min",
103-
BuiltinFn::Pi => "pi",
104-
BuiltinFn::Pulse(_, _, _) => "pulse",
105-
BuiltinFn::Ramp(_, _, _) => "ramp",
106-
BuiltinFn::SafeDiv(_, _, _) => "safediv",
107-
BuiltinFn::Sin(_) => "sin",
108-
BuiltinFn::Sqrt(_) => "sqrt",
109-
BuiltinFn::Step(_, _) => "step",
110-
BuiltinFn::Tan(_) => "tan",
111-
BuiltinFn::Time => "time",
112-
BuiltinFn::TimeStep => "time_step",
113-
BuiltinFn::StartTime => "initial_time",
114-
BuiltinFn::FinalTime => "final_time",
96+
Lookup(_, _, _) => "lookup",
97+
Abs(_) => "abs",
98+
Arccos(_) => "arccos",
99+
Arcsin(_) => "arcsin",
100+
Arctan(_) => "arctan",
101+
Cos(_) => "cos",
102+
Exp(_) => "exp",
103+
Inf => "inf",
104+
Int(_) => "int",
105+
IsModuleInput(_, _) => "ismoduleinput",
106+
Ln(_) => "ln",
107+
Log10(_) => "log10",
108+
Max(_, _) => "max",
109+
Mean(_) => "mean",
110+
Min(_, _) => "min",
111+
Pi => "pi",
112+
Pulse(_, _, _) => "pulse",
113+
Ramp(_, _, _) => "ramp",
114+
SafeDiv(_, _, _) => "safediv",
115+
Sin(_) => "sin",
116+
Sqrt(_) => "sqrt",
117+
Step(_, _) => "step",
118+
Tan(_) => "tan",
119+
Time => "time",
120+
TimeStep => "time_step",
121+
StartTime => "initial_time",
122+
FinalTime => "final_time",
123+
// array only builtins
124+
Rank(_, _) => "rank",
125+
Size(_) => "size",
126+
Stddev(_) => "stddev",
127+
Sum(_) => "sum",
115128
}
116129
}
117130
}
@@ -127,27 +140,33 @@ pub fn is_builtin_fn(name: &str) -> bool {
127140
is_0_arity_builtin_fn(name)
128141
|| matches!(
129142
name,
143+
// scalar builtins
130144
"lookup"
131-
| "abs"
132-
| "arccos"
133-
| "arcsin"
134-
| "arctan"
135-
| "cos"
136-
| "exp"
137-
| "int"
138-
| "ismoduleinput"
139-
| "ln"
140-
| "log10"
141-
| "max"
142-
| "mean"
143-
| "min"
144-
| "pulse"
145-
| "ramp"
146-
| "safediv"
147-
| "sin"
148-
| "sqrt"
149-
| "step"
150-
| "tan"
145+
| "abs"
146+
| "arccos"
147+
| "arcsin"
148+
| "arctan"
149+
| "cos"
150+
| "exp"
151+
| "int"
152+
| "ismoduleinput"
153+
| "ln"
154+
| "log10"
155+
| "max"
156+
| "mean"
157+
| "min"
158+
| "pulse"
159+
| "ramp"
160+
| "safediv"
161+
| "sin"
162+
| "sqrt"
163+
| "step"
164+
| "tan"
165+
// array-only builtins
166+
| "rank"
167+
| "size"
168+
| "stddev"
169+
| "sum"
151170
)
152171
}
153172

@@ -183,21 +202,39 @@ where
183202
| BuiltinFn::Log10(a)
184203
| BuiltinFn::Sin(a)
185204
| BuiltinFn::Sqrt(a)
186-
| BuiltinFn::Tan(a) => cb(BuiltinContents::Expr(a)),
205+
| BuiltinFn::Tan(a)
206+
| BuiltinFn::Size(a)
207+
| BuiltinFn::Stddev(a)
208+
| BuiltinFn::Sum(a) => cb(BuiltinContents::Expr(a)),
187209
BuiltinFn::Mean(args) => {
188210
args.iter().for_each(|a| cb(BuiltinContents::Expr(a)));
189211
}
190-
BuiltinFn::Max(a, b) | BuiltinFn::Min(a, b) | BuiltinFn::Step(a, b) => {
212+
BuiltinFn::Step(a, b) => {
191213
cb(BuiltinContents::Expr(a));
192214
cb(BuiltinContents::Expr(b));
193215
}
216+
BuiltinFn::Max(a, b) | BuiltinFn::Min(a, b) => {
217+
cb(BuiltinContents::Expr(a));
218+
if let Some(b) = b {
219+
cb(BuiltinContents::Expr(b));
220+
}
221+
}
194222
BuiltinFn::Pulse(a, b, c) | BuiltinFn::Ramp(a, b, c) | BuiltinFn::SafeDiv(a, b, c) => {
195223
cb(BuiltinContents::Expr(a));
196224
cb(BuiltinContents::Expr(b));
197225
if let Some(c) = c {
198226
cb(BuiltinContents::Expr(c))
199227
}
200228
}
229+
BuiltinFn::Rank(a, rest) => {
230+
cb(BuiltinContents::Expr(a));
231+
if let Some((b, c)) = rest {
232+
cb(BuiltinContents::Expr(b));
233+
if let Some(c) = c {
234+
cb(BuiltinContents::Expr(c));
235+
}
236+
}
237+
}
201238
}
202239
}
203240

@@ -206,6 +243,10 @@ fn test_is_builtin_fn() {
206243
assert!(is_builtin_fn("lookup"));
207244
assert!(!is_builtin_fn("lookupz"));
208245
assert!(is_builtin_fn("log10"));
246+
assert!(is_builtin_fn("sum"));
247+
assert!(is_builtin_fn("rank"));
248+
assert!(is_builtin_fn("size"));
249+
assert!(is_builtin_fn("stddev"));
209250
}
210251

211252
#[test]

src/simlin-engine/src/bytecode.rs

+6
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,17 @@ pub struct ModuleDeclaration {
8080
pub(crate) off: usize, // offset within the parent module
8181
}
8282

83+
#[derive(Clone, Debug)]
84+
pub struct ArrayDefinition {
85+
pub(crate) dimensions: Vec<usize>,
86+
}
87+
8388
// these are things that will be shared across bytecode runlists
8489
#[derive(Clone, Debug)]
8590
pub struct ByteCodeContext {
8691
pub(crate) graphical_functions: Vec<Vec<(f64, f64)>>,
8792
pub(crate) modules: Vec<ModuleDeclaration>,
93+
pub(crate) arrays: Vec<ArrayDefinition>,
8894
}
8995

9096
#[derive(Clone, Debug, Default)]

src/simlin-engine/src/common.rs

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub enum ErrorCode {
7171
TodoWildcard,
7272
TodoStarRange,
7373
TodoRange,
74+
TodoArrayBuiltin,
7475
}
7576

7677
impl fmt::Display for ErrorCode {
@@ -125,6 +126,7 @@ impl fmt::Display for ErrorCode {
125126
TodoWildcard => "todo_wildcard",
126127
TodoStarRange => "todo_star_range",
127128
TodoRange => "todo_range",
129+
TodoArrayBuiltin => "todo_array_builtin",
128130
};
129131

130132
write!(f, "{}", name)

0 commit comments

Comments
 (0)