Skip to content

Commit 3423b94

Browse files
committed
fix: [Rust] Fix curried application
1 parent e9616db commit 3423b94

3 files changed

Lines changed: 94 additions & 10 deletions

File tree

src/Fable.Transforms/Rust/Fable2Rust.fs

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,18 +3041,52 @@ module Util =
30413041
let transformCurry (com: IRustCompiler) (ctx: Context) arity (expr: Fable.Expr) : Rust.Expr =
30423042
com.TransformExpr(ctx, Replacements.Api.curryExprAtRuntime com arity expr)
30433043

3044+
let tryGetCurriedApplyArgAndReturnTypes typ =
3045+
match typ with
3046+
| Fable.LambdaType(argType, returnType) -> Some(argType, returnType)
3047+
| Fable.DelegateType([ argType ], returnType) -> Some(argType, returnType)
3048+
| _ -> None
3049+
3050+
let isErasedUnitClosureType typ =
3051+
match typ with
3052+
| Fable.LambdaType(Fable.Unit, _)
3053+
| Fable.DelegateType([ Fable.Unit ], _) -> true
3054+
| _ -> false
3055+
30443056
let transformCurriedApply (com: IRustCompiler) ctx r typ calleeExpr args =
30453057
match ctx.TailCallOpportunity with
30463058
| Some tc when tc.IsRecursiveRef(calleeExpr) && List.length tc.Args = List.length args ->
30473059
optimizeTailCall com ctx r tc args
30483060
| _ ->
30493061
let callee = transformCallee com ctx calleeExpr
30503062

3051-
(callee, args)
3052-
||> List.fold (fun expr arg ->
3053-
let args = FSharp2Fable.Util.dropUnitCallArg com [ arg ] [] None
3054-
callFunction com ctx r expr args
3063+
((callee, Some calleeExpr, calleeExpr.Type), args)
3064+
||> List.fold (fun (expr, currentExpr, currentType) arg ->
3065+
let expectedArgType, nextType =
3066+
match tryGetCurriedApplyArgAndReturnTypes currentType with
3067+
| Some(argType, returnType) -> argType, returnType
3068+
| None -> arg.Type, typ
3069+
3070+
if arg.Type = Fable.Unit then
3071+
let appliedExpr =
3072+
match currentExpr with
3073+
| Some(Fable.IdentExpr ident) when isFuncScoped ctx ident.Name -> mkCallExpr expr []
3074+
| _ -> makeLibCall com ctx None "Native" "applyUnit" [ expr |> makeClone ]
3075+
3076+
appliedExpr, None, nextType
3077+
else
3078+
let argExpr =
3079+
transformCallArgs com ctx [ arg ] [ expectedArgType ] [] |> List.exactlyOne
3080+
3081+
let argExpr =
3082+
if isErasedUnitClosureType expectedArgType then
3083+
makeLibCall com ctx None "Native" "eraseUnitArg" [ argExpr |> makeClone ]
3084+
else
3085+
argExpr
3086+
3087+
mkCallExpr expr [ argExpr ], None, nextType
30553088
)
3089+
|> fun (expr, _, _) -> expr
30563090

30573091
let makeUnionCasePat unionCaseName fields =
30583092
if List.isEmpty fields then

src/fable-library-rust/src/Native.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,56 @@ pub mod Native_ {
319319
}
320320
}
321321

322+
// -----------------------------------------------------------
323+
// Unit arg curried applications
324+
// -----------------------------------------------------------
325+
326+
pub trait ApplyUnit<R> {
327+
fn apply_unit(self) -> R;
328+
}
329+
330+
pub trait EraseUnitArg<R> {
331+
fn erase_unit_arg(self) -> Func0<R>;
332+
}
333+
334+
impl<R: 'static> ApplyUnit<R> for Func0<R> {
335+
#[inline]
336+
fn apply_unit(self) -> R {
337+
self()
338+
}
339+
}
340+
341+
impl<R: 'static> ApplyUnit<R> for Func1<(), R> {
342+
#[inline]
343+
fn apply_unit(self) -> R {
344+
self(())
345+
}
346+
}
347+
348+
impl<R: 'static> EraseUnitArg<R> for Func0<R> {
349+
#[inline]
350+
fn erase_unit_arg(self) -> Func0<R> {
351+
self
352+
}
353+
}
354+
355+
impl<R: 'static> EraseUnitArg<R> for Func1<(), R> {
356+
#[inline]
357+
fn erase_unit_arg(self) -> Func0<R> {
358+
Func0::new(move || self(()))
359+
}
360+
}
361+
362+
#[inline]
363+
pub fn applyUnit<R, F: ApplyUnit<R>>(f: F) -> R {
364+
f.apply_unit()
365+
}
366+
367+
#[inline]
368+
pub fn eraseUnitArg<R, F: EraseUnitArg<R>>(f: F) -> Func0<R> {
369+
f.erase_unit_arg()
370+
}
371+
322372
// -----------------------------------------------------------
323373
// Fixed-point combinators
324374
// -----------------------------------------------------------

tests/Rust/tests/src/ApplicativeTests.fs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,12 +1116,12 @@ let ``SRTP with ActivePattern works`` () =
11161116
// equal (1,5) baz
11171117
// equal (1,5) baz2
11181118

1119-
// [<Fact>]
1120-
// let ``Applying to a function returned by a local function works`` () =
1121-
// let foo a b c d = a , b + c d
1122-
// let bar a = foo 1 a
1123-
// let baz = bar 2 (fun _ -> 3) ()
1124-
// equal (1,5) baz
1119+
[<Fact>]
1120+
let ``Applying to a function returned by a local function works`` () =
1121+
let foo a b c d = a , b + c d
1122+
let bar a = foo 1 a
1123+
let baz = bar 2 (fun _ -> 3) ()
1124+
equal (1,5) baz
11251125

11261126
[<Fact>]
11271127
let ``Partially applied functions don't duplicate side effects`` () = // See #1156

0 commit comments

Comments
 (0)