Skip to content

Commit 56daf03

Browse files
reid-spencerclaude
andcommitted
Add require statement to the RIDDL language
Adds require "condition" syntax as a precondition statement that generates an error if the boolean expression is false. Updated parser, EBNF grammar, AST, BAST serialization, validation, and prettify support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3610dc0 commit 56daf03

11 files changed

Lines changed: 123 additions & 19 deletions

File tree

language/shared/src/main/resources/riddl/grammar/ebnf-grammar.ebnf

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ message_origins = inlet_ref | processor_ref | user_ref | epic_ref ;
191191
(* Core statements available in all contexts *)
192192
statement = when_statement | match_statement | send_statement | tell_statement |
193193
the_set_statement | let_statement | prompt_statement | code_statement |
194-
error_statement | morph_statement | become_statement | comment ;
194+
error_statement | require_statement | morph_statement | become_statement | comment ;
195195
196196
(* Control flow *)
197197
when_statement = "when" literal_string "then" pseudo_code_block ["else" pseudo_code_block] "end" ;
@@ -203,14 +203,15 @@ send_statement = "send" message_ref "to" (outlet_ref | inlet_ref) ;
203203
tell_statement = "tell" message_ref "to" processor_ref ;
204204
205205
(* Variable operations *)
206-
the_set_statement = "set" field_ref "to" literal_string ;
206+
the_set_statement = "set" (field_ref | state_ref) "to" literal_string ;
207207
let_statement = "let" identifier [":" type_ref] "=" literal_string ;
208208
209209
(* General statements *)
210210
prompt_statement = ("prompt" | "do") literal_string ;
211211
code_statement = "```" ("scala" | "java" | "python" | "mojo") code_contents "```" ;
212212
code_contents = {any_char_except_triple_backtick} ;
213213
error_statement = "error" literal_string ;
214+
require_statement = "require" literal_string ;
214215
215216
(* Entity state transitions *)
216217
morph_statement = "morph" entity_ref "to" state_ref "with" message_ref ;

language/shared/src/main/scala/com/ossuminc/riddl/language/AST.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2342,6 +2342,23 @@ object AST:
23422342
def format: String = s"error ${message.format}"
23432343
}
23442344

2345+
/** A statement that requires a boolean condition to be true for execution
2346+
* to continue. If the condition is false, an error is generated.
2347+
*
2348+
* @param loc
2349+
* The location where the statement occurs in the source
2350+
* @param condition
2351+
* The boolean expression (as a string) that must be true
2352+
*/
2353+
@JSExportTopLevel("RequireStatement")
2354+
case class RequireStatement(
2355+
loc: At,
2356+
condition: LiteralString
2357+
) extends Statement {
2358+
override def kind: String = "Require Statement"
2359+
def format: String = s"require ${condition.format}"
2360+
}
2361+
23452362
/** A statement that sets a value of a field
23462363
*
23472364
* @param loc
@@ -2354,7 +2371,7 @@ object AST:
23542371
@JSExportTopLevel("SetStatement")
23552372
case class SetStatement(
23562373
loc: At,
2357-
field: FieldRef,
2374+
field: FieldRef | StateRef,
23582375
value: LiteralString
23592376
) extends Statement {
23602377
override def kind: String = "Set Statement"

language/shared/src/main/scala/com/ossuminc/riddl/language/bast/BASTReader.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,10 @@ class BASTReader(bytes: Array[Byte])(using pc: PlatformContext) {
784784
val message = readLiteralString()
785785
ErrorStatement(loc, message)
786786

787-
case 3 => // Set
788-
val field = readFieldRef()
787+
case 3 => // Set (field ref or state ref)
788+
val field: FieldRef | StateRef = reader.peekU8() match
789+
case tag if tag == NODE_FIELD_REF => readFieldRef()
790+
case tag if tag == NODE_STATE_REF => readStateRef()
789791
val value = readLiteralString()
790792
SetStatement(loc, field, value)
791793

@@ -871,6 +873,10 @@ class BASTReader(bytes: Array[Byte])(using pc: PlatformContext) {
871873
val body = readString()
872874
CodeStatement(loc, language, body)
873875

876+
case 14 => // Require
877+
val condition = readLiteralString()
878+
RequireStatement(loc, condition)
879+
874880
case _ =>
875881
PromptStatement(loc, LiteralString(loc, s"<unknown statement $stmtType>"))
876882
}

language/shared/src/main/scala/com/ossuminc/riddl/language/bast/BASTWriter.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,10 @@ class BASTWriter(val writer: ByteBufferWriter, val stringTable: StringTable) {
235235
case input: Input => writeInput(input)
236236
case output: Output => writeOutput(output)
237237

238-
// Statements (10 declarative statements per riddlsim spec)
238+
// Statements (11 declarative statements)
239239
case s: PromptStatement => writePromptStatement(s)
240240
case s: ErrorStatement => writeErrorStatement(s)
241+
case s: RequireStatement => writeRequireStatement(s)
241242
case s: SetStatement => writeSetStatement(s)
242243
case s: SendStatement => writeSendStatement(s)
243244
case s: MorphStatement => writeMorphStatement(s)
@@ -864,11 +865,21 @@ class BASTWriter(val writer: ByteBufferWriter, val stringTable: StringTable) {
864865
writeLiteralString(s.message)
865866
}
866867

868+
def writeRequireStatement(s: RequireStatement): Unit = {
869+
writer.writeU8(NODE_STATEMENT)
870+
writer.writeU8(14) // Require statement
871+
writeLocation(s.loc)
872+
writeLiteralString(s.condition)
873+
}
874+
867875
def writeSetStatement(s: SetStatement): Unit = {
868876
writer.writeU8(NODE_STATEMENT)
869877
writer.writeU8(3) // Set statement
870878
writeLocation(s.loc)
871-
writeFieldRef(s.field)
879+
s.field match {
880+
case fr: FieldRef => writeFieldRef(fr)
881+
case sr: StateRef => writeStateRef(sr)
882+
}
872883
writeLiteralString(s.value)
873884
}
874885

language/shared/src/main/scala/com/ossuminc/riddl/language/bast/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ package object bast {
5858
* with revision 0 (pre-check era) will be rejected with a
5959
* clear message.
6060
*/
61-
val FORMAT_REVISION: Short = 4
61+
val FORMAT_REVISION: Short = 5
6262

6363
/** Magic bytes for BAST file identification: "BAST" */
6464
val MAGIC_BYTES: Array[Byte] = Array('B'.toByte, 'A'.toByte, 'S'.toByte, 'T'.toByte)

language/shared/src/main/scala/com/ossuminc/riddl/language/parsing/Keywords.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ object Keywords {
258258

259259
def repository[u: P]: P[Unit] = keyword(Keyword.repository)
260260

261+
def require[u: P]: P[Unit] = keyword(Keyword.require_)
262+
261263
def requires[u: P]: P[Unit] = keyword(Keyword.requires)
262264

263265
def required[u: P]: P[Unit] = keyword(Keyword.required)
@@ -434,6 +436,7 @@ object Keywords {
434436
Keyword.replica,
435437
Keyword.reply,
436438
Keyword.repository,
439+
Keyword.require_,
437440
Keyword.requires,
438441
Keyword.required,
439442
Keyword.record,
@@ -583,6 +586,7 @@ object Keyword {
583586
final val replica = "replica"
584587
final val reply = "reply"
585588
final val repository = "repository"
589+
final val require_ = "require"
586590
final val requires = "requires"
587591
final val required = "required"
588592
final val record = "record"

language/shared/src/main/scala/com/ossuminc/riddl/language/parsing/StatementParser.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,19 @@ private[parsing] trait StatementParser {
3030
)./.map { case (start, str, end) => ErrorStatement(at(start, end), str) }
3131
}
3232

33+
private def requireStatement[u: P]: P[RequireStatement] = {
34+
P(
35+
Index ~ Keywords.require ~ literalString ~/ Index
36+
)./.map { case (start, str, end) => RequireStatement(at(start, end), str) }
37+
}
38+
3339
private def theSetStatement[u: P]: P[SetStatement] = {
3440
P(
35-
Index ~ Keywords.set ~/ fieldRef ~ to ~/ literalString ~/ Index
36-
)./.map { (start, ref, str, end) => SetStatement(at(start, end), ref, str) }
41+
Index ~ Keywords.set ~/ (fieldRef|stateRef) ~ to ~/ literalString ~/ Index
42+
)./.map {
43+
case (start, ref: FieldRef, str, end) => SetStatement(at(start, end), ref, str)
44+
case (start, ref: StateRef, str, end) => SetStatement(at(start,end), ref, str)
45+
}
3746
}
3847

3948
private def sendStatement[u: P]: P[SendStatement] = {
@@ -121,11 +130,11 @@ private[parsing] trait StatementParser {
121130
}
122131
}
123132

124-
private def backTickElipsis[u: P]: P[Unit] = { P("```") }
133+
private def backTickEllipsis[u: P]: P[Unit] = { P("```") }
125134

126135
private def codeStatement[u: P]: P[CodeStatement] = {
127136
P(
128-
Index ~ backTickElipsis ~ Index ~ StringIn("scala", "java", "python", "mojo").! ~ Index ~
137+
Index ~ backTickEllipsis ~ Index ~ StringIn("scala", "java", "python", "mojo").! ~ Index ~
129138
until3('`', '`', '`') ~ Index
130139
).map { case (at1, at2, lang, at3, contents, at4) =>
131140
CodeStatement(at(at1, at4), LiteralString(at(at2, at3), lang), contents)
@@ -142,8 +151,8 @@ private[parsing] trait StatementParser {
142151
theSetStatement | letStatement |
143152
// GROUP 4: General statements
144153
promptStatement | codeStatement |
145-
// GROUP 5: Error handling
146-
errorStatement | comment
154+
// GROUP 5: Error handling and preconditions
155+
errorStatement | requireStatement | comment
147156
).asInstanceOf[P[Statements]]
148157
}
149158

language/shared/src/test/scala/com/ossuminc/riddl/language/parsing/HandlerTest.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
package com.ossuminc.riddl.language.parsing
88

9-
import com.ossuminc.riddl.language.AST.{Context, Entity}
9+
import com.ossuminc.riddl.language.AST.{Context, Entity, RequireStatement}
10+
import com.ossuminc.riddl.language.Finder
1011
import com.ossuminc.riddl.language.parsing.AbstractParsingTest
1112
import com.ossuminc.riddl.utils.PlatformContext
1213
import org.scalatest.TestData
@@ -225,5 +226,44 @@ abstract class HandlerTest(using PlatformContext) extends AbstractParsingTest {
225226
case Right(_) => succeed
226227
}
227228
}
229+
"accept require statements" in { (td: TestData) =>
230+
val input = RiddlParserInput(
231+
"""entity Account is {
232+
| type AccountState is { balance: Number }
233+
| state Active of Account.AccountState
234+
| handler Transactions is {
235+
| on command Withdraw {
236+
| require "balance >= amount"
237+
| set field Account.balance to "balance - amount"
238+
| }
239+
| on command Transfer {
240+
| require "balance >= amount"
241+
| require "recipient != sender"
242+
| prompt "execute transfer"
243+
| }
244+
| }
245+
|}
246+
|""".stripMargin,
247+
td
248+
)
249+
parseDefinition[Entity](input) match {
250+
case Left(errors) =>
251+
val msg = errors.map(_.format).mkString("\n")
252+
fail(msg)
253+
case Right((entity, _)) =>
254+
val handler = entity.handlers.head
255+
val clause = handler.clauses.head
256+
val finder = Finder(clause.contents)
257+
val requires = finder.findByType[RequireStatement]
258+
requires.size must be(1)
259+
requires.head.condition.s must be("balance >= amount")
260+
// Second clause has two requires
261+
val clause2 = handler.clauses(1)
262+
val finder2 = Finder(clause2.contents)
263+
val requires2 = finder2.findByType[RequireStatement]
264+
requires2.size must be(2)
265+
succeed
266+
}
267+
}
228268
}
229269
}

passes/shared/src/main/scala/com/ossuminc/riddl/passes/prettify/RiddlFileEmitter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ case class RiddlFileEmitter(url: URL)(using PlatformContext) extends FileBuilder
352352
addLine(s"tell ${msg.format} to ${to.format}")
353353
case CodeStatement(_, lang, body) =>
354354
addIndent(s"```${lang.s}").add(body).nl.addIndent("```")
355+
case RequireStatement(_, condition) =>
356+
addLine(s"require ${condition.format}")
355357
case statement: Statement => addLine(statement.format)
356358
case comment: Comment => emitComment(comment)
357359
end match

passes/shared/src/main/scala/com/ossuminc/riddl/passes/resolve/ResolutionPass.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ case class ResolutionPass(input: PassInput, outputs: PassesOutput)(using io: Pla
234234
private def resolveStatement(statement: Statement, parents: Parents): Unit = {
235235
statement match {
236236
case SetStatement(_, field, _) =>
237-
associateUsage[Field](parents.head, resolveARef[Field](field, parents))
237+
field match
238+
case fr: FieldRef => associateUsage[Field](parents.head, resolveARef[Field](fr, parents))
239+
case sr: StateRef => associateUsage[State](parents.head, resolveARef[State](sr, parents))
238240
case BecomeStatement(_, entity, handler) =>
239241
associateUsage[Entity](parents.head, resolveARef[Entity](entity, parents))
240242
associateUsage[Handler](parents.head, resolveARef[Handler](handler, parents))
@@ -248,8 +250,9 @@ case class ResolutionPass(input: PassInput, outputs: PassesOutput)(using io: Pla
248250
case TellStatement(_, msg, processorRef) =>
249251
associateUsage[Type](parents.head, resolveARef[Type](msg, parents))
250252
associateUsage(parents.head, resolveARef[Processor[?]](processorRef, parents))
251-
case _: PromptStatement => () // no references
252-
case _: ErrorStatement => () // no references
253+
case _: PromptStatement => () // no references
254+
case _: ErrorStatement => () // no references
255+
case _: RequireStatement => () // no references
253256
case _: WhenStatement => () // no references (condition is a literal string)
254257
case _: MatchStatement => () // no references (expression/patterns are literal strings)
255258
case ls: LetStatement =>

0 commit comments

Comments
 (0)