diff --git a/compiler/ast/ast_types.nim b/compiler/ast/ast_types.nim index 493b379e06..ac5298fbb6 100644 --- a/compiler/ast/ast_types.nim +++ b/compiler/ast/ast_types.nim @@ -820,6 +820,7 @@ type mSymIsInstantiationOf, mNodeId, mPrivateAccess mEvalToAst + mSuspend # magics only used internally: mStrToCStr diff --git a/compiler/backend/cgirgen.nim b/compiler/backend/cgirgen.nim index 08e19fa4c4..c1061c7c79 100644 --- a/compiler/backend/cgirgen.nim +++ b/compiler/backend/cgirgen.nim @@ -620,7 +620,7 @@ proc stmtToIr(tree: MirBody, env: MirEnv, cl: var TranslateCl, scopeToIr(tree, env, cl, cr, stmts) of mnkDestroy: unreachable("a 'destroy' that wasn't lowered") - of AllNodeKinds - StmtNodes + {mnkEndScope}: + of AllNodeKinds - StmtNodes + {mnkEndScope, mnkFork, mnkLand}: unreachable(n.kind) proc setElementToIr(tree: MirBody, cl: var TranslateCl, diff --git a/compiler/front/condsyms.nim b/compiler/front/condsyms.nim index 3d88114bee..4c324485cd 100644 --- a/compiler/front/condsyms.nim +++ b/compiler/front/condsyms.nim @@ -78,3 +78,4 @@ proc initDefines*(symbols: StringTableRef) = defineSymbol("nimskullNoNkStmtListTypeAndNkBlockType") defineSymbol("nimskullNoNkNone") defineSymbol("nimskullHasSupportsZeroMem") + defineSymbol("nimskullHasSuspend") diff --git a/compiler/mir/continuations.nim b/compiler/mir/continuations.nim new file mode 100644 index 0000000000..42f2966f49 --- /dev/null +++ b/compiler/mir/continuations.nim @@ -0,0 +1,43 @@ +## Implements the lowering of `fork`/`land` pairs. For every static `fork`, the +## continuation of the fork (i.e., the code following the land) is reified +## into a standalone procedure -- the context (i.e., all active locals) are +## saved into a separate context object, which must be passed to the reified +## procedure. +## +## A rough summary of the reification process is that the procedure body is +## duplicated and all basic-blocks not reachable from the resumption point +## are removed. +## +## Forks in loops make this a bit trickier in practice. Consider: +## +## loop: +## def _1 = ... +## if cond: +## goto L2 +## fork a, b, L1 +## ... +## land L1: +## ... +## L2: +## ... +## +## Here, the loop must also be part of the reified procedure, but a naive +## removal of all unused basic blocks would leave it at the start, which is +## wrong. +## +## The solution is to split the continuation into multiple subroutines, one +## for each such problematic join point. +## +## There are two possible strategies to implement subroutines: +## 1. via separate procedure that tail call each other, passing along extra +## locals crossing the border as parameter +## 2. or, use a case statement dispatcher in a loop, with each target +## corresponding to a subroutine. Invoking a subroutine means changing the +## selector value and jumping back to the loop start -- local variables +## living across subroutine boundaries are lifted to the top-level scope +## +## Number 2 is chosen because it's slightly simpler to implement and doesn't +## put pressure on the available tailcall argument sizes (and thus the size +## of the context object). +## +## A reified continuation is created for each fork/land pair. diff --git a/compiler/mir/mirgen.nim b/compiler/mir/mirgen.nim index 5858f416a6..ad315b395e 100644 --- a/compiler/mir/mirgen.nim +++ b/compiler/mir/mirgen.nim @@ -1027,6 +1027,53 @@ proc genMagic(c: var TCtx, n: PNode; m: TMagic) = c.buildMagicCall m, rtyp: # skip the surrounding typedesc c.emitByVal typeLit(c.typeToMir(n[1].typ.skipTypes({tyTypeDesc}))) + of mSuspend: + let label = c.allocLabel() + # emit a definition of the local storing the continuation: + discard c.addLocal(n[2].sym) + let tmp = c.nameNode(n[2].sym) + + # treat the code in the suspend context as if was at the top level of the + # procedure + let saved = c.blocks.saveContext() + withFront c.builder: + c.buildStmt mnkScope: discard + discard c.blocks.startScope() + + c.buildStmt mnkDef: + c.add tmp + c.add MirNode(kind: mnkNone) + c.buildStmt mnkFork: + c.add tmp + c.add labelNode(label) + + if c.owner.typ[0].isEmptyType() or n[3].typ == c.graph.noreturnType: + c.genCall(n[3]) + else: + let v = c.wrapTemp c.typeToMir(n.typ): + c.genCall(n[3]) + c.buildStmt mnkAsgn: + c.add nameNode(c, c.owner.ast[resultPos].sym) + c.use v + + # emit a return, close the scope, and restore the original context + blockExit(c.blocks, c.graph, c.env, c.builder, 0) + c.blocks.closeScope(c.builder, 0, false) + c.buildStmt mnkEndScope: discard + c.blocks.restoreContext(saved) + + if rtyp == VoidType: + # the land receives nothing + c.buildStmt mnkLand: + c.add labelNode(label) + else: + # the land receives an owning value + let res = c.allocTemp(rtyp) + c.buildStmt mnkLand: + c.add labelNode(label) + c.use res + c.buildTree mnkMove, rtyp: + c.use res # arithmetic operations: of mAddI, mSubI, mMulI, mDivI, mModI, mPred, mSucc: diff --git a/compiler/mir/mirgen_blocks.nim b/compiler/mir/mirgen_blocks.nim index f6d6416621..09691c34b6 100644 --- a/compiler/mir/mirgen_blocks.nim +++ b/compiler/mir/mirgen_blocks.nim @@ -184,6 +184,20 @@ proc tailExit*(c; bu) = bu.subTree mnkGoto: bu.add labelNode(bu.requestLabel(c.blocks[0])) +proc saveContext*(c): BlockCtx = + ## Saves the current context and replaces it with one that only contains the + ## top-level block. + result = BlockCtx(blocks: @[c.blocks[0]]) + swap(c, result) + +proc restoreContext*(c; with: sink BlockCtx) = + ## Restores the block context `with`. If the top-level block had a label + ## registered since the `saveContext` `with` was saved with, the label is + ## kept. + swap(c, with) + if c.blocks[0].id.isNone and with.blocks[0].id.isSome: + c.blocks[0].id = with.blocks[0].id + template add*(c: var BlockCtx; b: Block) = c.blocks.add b diff --git a/compiler/mir/mirtrees.nim b/compiler/mir/mirtrees.nim index fda07027d0..c4a4f97917 100644 --- a/compiler/mir/mirtrees.nim +++ b/compiler/mir/mirtrees.nim @@ -164,6 +164,12 @@ type ## kind is deliberately named "tail call" for the sake of ## discoverability + mnkFork ## saves the current context and creates a delimited continuation + ## starting at the specified `mnkLand` + mnkLand ## a special join that marks the start of a continuation. May only + ## be targeted by a `mnkFork` + # TODO: rename to resume + # unary arithmetic operations: mnkNeg ## signed integer and float negation (for ints, overflow is UB) # binary arithmetic operations: @@ -346,7 +352,7 @@ const StmtNodes* = {mnkScope, mnkGoto, mnkIf, mnkCase, mnkLoop, mnkJoin, mnkLoopJoin, mnkExcept, mnkFinally, mnkContinue, mnkEndStruct, mnkInit, mnkAsgn, mnkSwitch, mnkVoid, mnkRaise, mnkDestroy, - mnkEmit, mnkAsm, mnkEndScope} + DefNodes + mnkEmit, mnkAsm, mnkEndScope, mnkFork, mnkLand} + DefNodes ## Nodes that are treated like statements, in terms of syntax. # --- semantics-focused sets: diff --git a/compiler/mir/utils.nim b/compiler/mir/utils.nim index 66512d0c26..bea0bba5d5 100644 --- a/compiler/mir/utils.nim +++ b/compiler/mir/utils.nim @@ -633,6 +633,19 @@ proc stmtToStr(nodes: MirTree, i: var int, indent: var int, result: var string, tree "": labelToStr(nodes, i, result) result.add ":\n" + of mnkFork: + tree "fork ": + valueToStr() + result.add " " + targetToStr() + result.add "\n" + of mnkLand: + tree "land ": + labelToStr(nodes, i, result) + if n.len == 2: + result.add " " + valueToStr() + result.add ":\n" of mnkContinue: tree "continue ": targetToStr(nodes, i, result) diff --git a/compiler/sem/mirexec.nim b/compiler/sem/mirexec.nim index a23f0cbd13..54d13deedd 100644 --- a/compiler/sem/mirexec.nim +++ b/compiler/sem/mirexec.nim @@ -439,6 +439,12 @@ func computeDfg*(tree: MirTree): DataFlowGraph = # emit a join at the end of an 'if' if ifs.len > 0 and tree[i, 0].label == ifs[^1]: join i, ifs.pop() + of mnkFork: + fork i, tree[i, 1].label + emitLvalueOp(env, opMutate, tree, i, OpValue tree.child(i, 0)) + of mnkLand: + join i, tree[i, 0].label + emitLvalueOp(env, opDef, tree, i, OpValue tree.child(i, 1)) of mnkDef, mnkDefCursor, mnkAsgn, mnkInit: emitForDef(env, tree, i) diff --git a/compiler/sem/semexprs.nim b/compiler/sem/semexprs.nim index 1985fa97f4..ab6c09a0c8 100644 --- a/compiler/sem/semexprs.nim +++ b/compiler/sem/semexprs.nim @@ -2898,6 +2898,8 @@ proc semMagic(c: PContext, n: PNode, s: PSym, flags: TExprFlags): PNode = of mSizeOf: markUsed(c, n.info, s) result = semSizeOf(c, setMs(n, s)) + of mSuspend: + result = semSuspend(c, n, s, flags) else: result = semDirectOp(c, n, flags) diff --git a/compiler/sem/semmagic.nim b/compiler/sem/semmagic.nim index 3412a2210d..262a2ed9eb 100644 --- a/compiler/sem/semmagic.nim +++ b/compiler/sem/semmagic.nim @@ -410,6 +410,100 @@ proc semPrivateAccess(c: PContext, n: PNode): PNode = c.currentScope.allowPrivateAccess.add t.sym result = newNodeIT(nkEmpty, n.info, getSysType(c.graph, n.info, tyVoid)) +proc semSuspend(c: PContext, n: PNode, s: PSym, flags: TExprFlags): PNode = + ## Analyzes a 'suspend' magic call, producing a typed AST or an error. If + ## the call doesn't have the right shape, analysis fall back to overload + ## resolution. + addInNimDebugUtils(c.config, "semSuspend", n, result) + if n.len != 4: + # could be some other call + return semDirectOp(c, n, flags) + + result = shallowCopy(n) + result[0] = newSymNode(s, n[0].info) + result[1] = semExprWithType(c, n[1]) + + var paramType = result[1].typ + if paramType.kind != tyError: + if paramType.kind == tyTypeDesc: + paramType = paramType.lastSon + else: + result[1] = c.config.newError(result[1], PAstDiag(kind: adSemTypeExpected)) + paramType = result[1].typ + + let hasResult = paramType.skipTypes({tyAlias}).kind != tyVoid + + # create an new object for the context. It's populated at a (much) later stage + let objSym = newSym(skType, c.cache.getIdent("Ctx"), nextSymId(c.idgen), + getCurrOwner(c), n.info) + # enable special name mangling: + objSym.flags.incl sfFromGeneric + + let obj = newTypeS(tyObject, c) + obj.rawAddSon(nil) # the base type + obj.size = szUnknownSize + obj.align = szUnknownSize + obj.n = newTree(nkRecList) + obj.flags.incl tfHasAsgn # the object has custom copy logic + objSym.linkTo(obj) + + proc addParam(prc: PType, name: string, typ: PType, info: TLineInfo, + c: PContext) = + let p = newSym(skParam, c.cache.getIdent(name), nextSymId(c.idgen), + getCurrOwner(c), info) + p.typ = typ + prc.rawAddSon(typ, propagateHasAsgn=false) + prc.n.add newSymNode(p) + + # create the type of the continuation procedure: + let prc = newProcType(n.info, nextTypeId(c.idgen), getCurrOwner(c)) + prc.callConv = ccNimCall # TODO: use tailcall + # TODO: handle the "unresolved auto return type" case. The easiest solution + # is just reporting an error + prc[0] = c.p.owner.typ[0] # use the enclosing routine's return type + if hasResult: + prc.addParam("arg", newTypeWithSons(c, tySink, @[paramType]), n.info, c) + prc.addParam("c", newTypeWithSons(c, tySink, @[obj]), n.info, c) + + # set up the type to use for the local: + let tup = newTypeS(tyTuple, c) + tup.rawAddSon(obj) + tup.rawAddSon(prc) + + c.openScope() + # create a let section and type that. This makes sure the symbol is properly + # registered everywhere, and retyping is also taken care. The initializer + # needs to be some well-formed, non empty expression for the analysis to + # succeed -- we use a correctly typed but gramatically incorrect node as + # the expression + let cons = newNodeIT(nkType, n.info, tup) + cons.flags.incl nfSem # prevent the expression from being analyzed + + let + ls = nkLetSection.newTree( + nkIdentDefs.newTree(n[2], newNodeIT(nkType, n.info, tup), cons)) + tmp = semNormalizedLetOrVar(c, ls, skLet) + if tmp.kind == nkError: + # place the erroneous identifier node back into the call + result[2] = tmp.diag.wrongNode[0][0] + else: + result[2] = tmp[0][0] + + var call = semExprWithType(c, n[3]) + # TODO: noreturn handling... + call = fitNode(c, c.p.owner.typ[0], call, n[3].info) + c.closeScope() + + result[3] = call + if hasResult: + result.typ = paramType + + if nkError in {result[1].kind, result[2].kind, result[3].kind}: + result = c.config.wrapError(result) + elif ecfStatic in c.executionCons[^1].flags: + # TODO: report an error + discard + proc magicsAfterOverloadResolution(c: PContext, n: PNode, flags: TExprFlags): PNode = ## This is the preferred code point to implement magics. diff --git a/compiler/sem/tailcall_analysis.nim b/compiler/sem/tailcall_analysis.nim index e5589f2c5c..16253e1856 100644 --- a/compiler/sem/tailcall_analysis.nim +++ b/compiler/sem/tailcall_analysis.nim @@ -97,6 +97,10 @@ proc verifyTailCalls(g: ModuleGraph, owner: PSym, n, next, problem: PNode) = # the body was not typed properly, nor is it relevant for the # analysis; skip return + elif n[0].kind == nkSym and n[0].sym.magic == mSuspend: + # the suspended-to-expression appears in a tailing position + recurse(n[3], nil, nil) + return for it in n.items: recurse(it) # arguments are not tailing expressions diff --git a/doc/manual.rst b/doc/manual.rst index 87ea379522..79338a2f76 100644 --- a/doc/manual.rst +++ b/doc/manual.rst @@ -4166,6 +4166,85 @@ For the above `openArray` and `var` rules, whether the expressions refers to a location derived from a parameter or global must be visible directly from the argument expression, no indirection through locals is allowed. +Delimited Continuations +----------------------- + +A delimited continuation is a continuation extending only up to a certain +point. In the context of NimSkull, a delimited continuation extends to the +end of the enclosing routine in which the continuation is created. + +A continuation is created via the `system.suspend` procedure, which +accepts three arguments. + +Syntactically, the second argument must be an identifier. + +1. let `T` be the type of the first argument +2. let `R` be the return type of the caller +3. let `Ctx` be a unique type such that: + * `Ctx is object` evaluates to true + * `supportsCopyMem(Ctx)` evaluates to true + * the size and alignment of `Ctx` are not queriable in a compile-time context + * an instance of `Ctx` is always copyable +4. let `P` be a type: + * if `T is void`, then let `P` be `proc(e: sink Ctx): R` + * if `T isnot void`, then let `P` be `proc(p: sink T, e: sink Ctx): R` +5. let `cont` refer to the *continuation* (a procedure of type `P`) +6. let `local` refer to the identifier appearing as the second argument +7. let `ctx` refer to the saved context (a value of type `Ctx`) +8. let `call` refer to the third argument expression + +`call` is typed as if would the expression were the following: + +.. code-block:: nim + + block: + let local: (Ctx, P) # has the + call + + +The following requirements must be met for the `suspend` call: +* `R2` must be equal to `R` +* `call` must be (after template/macro expansion) a call expression such that: + * the callee is a static procedure not using the `.closure` calling + convention + * each run-time argument expression is a value identifier, or a built-in + projection thereof. The same goes for index operands in projections +* no local with a disabled copy operator must *potentially* store a value +* no `var`:idx: or `openArray`:idx: parameter (or a borrow thereof) must be live +* the `suspend` call is not statically located within an `except`:idx: or + `finally`:idx: clause +* the `suspend` call does not appear in top-level code + +Evaluation of `suspend`:idx: works as follows: +1. the current local context is saved +2. `local` is intialized with the tuple `(ctx, prc)` +3. all local variables (except `local`) or sink parameters of the caller used + in `call` are copied -- the usages within `call` refer to the copies form + here on +4. `call` is evaluated as if it were the only expression within the caller + +Resume +~~~~~~ + +Upon evaluating a call `x(y)` where `x` dynamically evaluates to `cont` and `y` +dynamically evaluates to `ctx` (or a copy thereof), execution resumes in the +suspended caller as if `suspend` returned. + +Upon evaluating a call `x(v, y)` where `x` dynamically evaluates to `cont` and +`y` dynamically evaluates to `ctx` (or a copy thereof), execution resumes in +the suspended caller as if `suspend` returned with value `v`. + +The value returned by a resumed caller is returned by the continuation +invocation -- the same goes for raised exceptions. + +Cleanup +~~~~~~~ + +If `ctx` (or a copy thereof) goes out scope without having been consumed by a +call to the continuation, cleanup happens as if unwinding would take place +right the `suspend` call, but without `finally`:idx: clauses being visited. + + Methods ============= @@ -6246,7 +6325,6 @@ does not export these identifiers. The `import` statement is only allowed at the top level. - Include statement ----------------- diff --git a/lib/system.nim b/lib/system.nim index 1effd48c5f..228a539b59 100644 --- a/lib/system.nim +++ b/lib/system.nim @@ -3070,6 +3070,12 @@ when defined(nimDebugUtils): proc `not`*[T: ref or ptr](a: typedesc[T], b: typeof(nil)): typedesc {.magic: "TypeTrait", noSideEffect.} ## Constructs a `not nil` type. +when defined(nimskullHasSuspend): + proc suspend*[T](with: typedesc[T], name, call: untyped): T {.magic: "Suspend", noSideEffect.} = + ## Saves the current local context, stores context + continuation in a + ## local named `name`, and invokes `call` as if it were the last remaining + ## expression in the caller's body. + type ParamBlob = object ## The type to use for the storage of parameters. diff --git a/showcase.nim b/showcase.nim new file mode 100644 index 0000000000..00dfd8e718 --- /dev/null +++ b/showcase.nim @@ -0,0 +1,178 @@ +## Provides some examples of possible use cases for delimited continuations. + +import std/[strutils, macros] + +{.experimental: "callOperator".} + +# some helper routines to allow for "erasing" the context parameter, so that +# different continuation with the same shape can be stored in the same location + +type + CellBase = ref object of RootObj + copy: proc(x: CellBase): CellBase {.nimcall, raises: [].} + Cell[T] {.final.} = ref object of CellBase + val: T + CellPtr = object + ## A unique managed pointer. + cell: CellBase + + Cont[P, R] = tuple + prc: proc(p: sink P, cell: sink CellPtr): R {.nimcall.} + env: CellPtr + +proc `=copy`(x: var CellPtr, y: CellPtr) = + if y.cell.isNil: + x.cell = nil + elif x.cell != y.cell: + x.cell = y.cell.copy(y.cell) + +proc copyCell[T](x: CellBase): CellBase = + Cell[T](val: Cell[T](x).val, copy: x.copy) + +proc newCell[T](x: sink T): CellPtr = + CellPtr(cell: Cell[T](val: x, copy: copyCell[T])) + +proc take[T](x: sink CellPtr): T = + move Cell[T](x.cell).val + +proc newCont[R, T](env: sink T, prc: proc(x: sink T): R): Cont[void, R] = + (proc(env: sink CellPtr): auto {.nimcall.} = + let (prc, env) = take[(typeof(prc), T)](env) + prc(env) + , newCell((prc, env))) + +proc newCont[P, R, T](env: sink T, prc: proc(x: sink P, y: sink T): R): Cont[P, R] = + (proc(param: sink P, env: sink CellPtr): auto {.nimcall.} = + let (prc, env) = take[(typeof(prc), T)](env) + prc(param, env) + , newCell((prc, env))) + +proc newCont[T](x: sink T): auto {.inline.} = + let (env, prc) = x + newCont(env, prc) + +template newContP[P, R](body: proc(p: sink P): R): Cont[P, R] = + (proc (p: sink P, env: sink CellPtr): R {.nimcall.} = body(p), + CellPtr()) + +proc `()`[P, R](c: sink Cont[P, R], a: sink P): R = + # TODO: make this a tailcall procedure + let (prc, env) = c + prc(a, env) + +# ----- interpreter example ------ + +type + NodeKind = enum + nkAdd, nkEq, nkIf, nkConst + + Node = object + case kind: NodeKind + of nkAdd, nkEq, nkIf: + sub: seq[Node] + of nkConst: + val: int + +template `[]`(n: Node, i: untyped): Node = + n.sub[i] + +template tree(k: NodeKind, args: varargs[Node]): Node = + Node(kind: k, sub: @args) + +template cnst(v: int): Node = + Node(kind: nkConst, val: v) + +proc interpret(n: Node, then: sink Cont[Node, Node]): Node {.tailcall.} = + template eval(n: Node): Node = + suspend(Node, cont, interpret(n, newCont(cont))) + + proc take(x: sink Node): int = x.val + + case n.kind + of nkConst: + then n + of nkAdd: + then cnst(take(eval(n[0])) + take(eval(n[1]))) + of nkEq: + then cnst(ord(take(eval(n[0])) == take(eval(n[1])))) + of nkIf: + if take(eval n[0]) == 0: + eval(n[2]) # else branch + else: + eval(n[1]) # then branch + +proc interpret(n: Node): Node {.tailcall.} = + interpret(n, newContP(proc(x: sink Node): Node = x)) + +assert interpret( + tree(nkIf, + tree(nkEq, cnst(1), cnst(2)), + cnst(0), + cnst(1))).val == 1 + +# ----- async/await ----- + +type + Future[T] = ref object + content: ref T + callback: proc(x: Future[T]) + +proc awaitImpl[T, U](f: Future[U], with: sink Cont[U, Future[T]], + on: sink Future[T]): Future[T] = + f.callback = proc(x: Future[U]) = + discard with(x.content[]) + return on + +macro async(p: untyped) = + let prev = p.body + p.body = quote do: + result = typeof(result)() + template await[T](x: Future[T]): T {.used.} = + suspend(T, cont, awaitImpl(x, newCont(cont), result)) + `prev` + result = p + +proc resolve[T](x: Future[T], val: sink T) = + x.content = new T + x.content[] = val + +var inner: Future[string] + ## an unresolved future + +proc jield[T, U](on: sink Future[T], with: sink Cont[T, Future[U]]): Future[T] = + inner = Future[T]() + inner.callback = proc(x: Future[T]) = + discard with(x.content[]) + return on + +proc readFileAsync(): Future[string] {.async.} = + # simulate a wait: + let val = suspend(string, cont, jield(result, newCont(cont))) + result.resolve val + +proc parseFileContent(): Future[int] {.async.} = + let content = await readFileAsync() + result.resolve parseInt(content) + +var res = parseFileContent() +assert res.content == nil +# simulate the read file operation finishing: +inner.resolve("123") +# now the future is done: +assert res.content != nil +assert res.content[] == 123 + +# ------ fun things ------ + +# some non-sensical things that are fun to write + +proc repeat(): int = + var i = suspend(int, cont, (proc(cont: Cont[int, int]): int = + var val = 0 + for _ in 0..<10: + val = cont(val) + val + )(newCont(cont))) + return i + 1 + +assert repeat() == 10