Skip to content

Commit 14769af

Browse files
committed
fix: make ST.Ref.ptrEq behave as stated in the docs
1 parent 1ce05b2 commit 14769af

File tree

2 files changed

+87
-4
lines changed

2 files changed

+87
-4
lines changed

src/runtime/io.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,10 +1495,8 @@ extern "C" LEAN_EXPORT obj_res lean_st_ref_swap(b_obj_arg ref, obj_arg a) {
14951495
}
14961496
}
14971497

1498-
extern "C" LEAN_EXPORT obj_res lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) {
1499-
// TODO(Leo): ref_maybe_mt
1500-
bool r = lean_to_ref(ref1)->m_value == lean_to_ref(ref2)->m_value;
1501-
return box(r);
1498+
extern "C" LEAN_EXPORT uint8_t lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) {
1499+
return lean_to_ref(ref1) == lean_to_ref(ref2);
15021500
}
15031501

15041502
/* {α : Type} (act : BaseIO α) : α */

tests/lean/run/st_test.lean

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/-!
2+
Some basic tests for the ST monad.
3+
-/
4+
5+
namespace STTest
6+
7+
def ptrEq : IO Unit := do
8+
let ref1 ← IO.mkRef 0
9+
let ref2 ← IO.mkRef 0
10+
IO.println (← ref1.ptrEq ref1)
11+
IO.println (← ref1.ptrEq ref2)
12+
13+
/--
14+
info: true
15+
false
16+
-/
17+
#guard_msgs in
18+
#eval ptrEq
19+
20+
def readWriteRegister : IO Unit := do
21+
let ref1 ← IO.mkRef 0
22+
IO.println (← ref1.get)
23+
ref1.set 1
24+
IO.println (← ref1.get)
25+
26+
/--
27+
info: 0
28+
1
29+
-/
30+
#guard_msgs in
31+
#eval readWriteRegister
32+
33+
def swapRegister : IO Unit := do
34+
let ref1 ← IO.mkRef 0
35+
IO.println (← ref1.swap 5)
36+
IO.println (← ref1.get)
37+
38+
/--
39+
info: 0
40+
5
41+
-/
42+
#guard_msgs in
43+
#eval swapRegister
44+
45+
unsafe def takeRegister : IO Unit := do
46+
let ref1 ← IO.mkRef 0
47+
IO.println (← ref1.take)
48+
ref1.set 5
49+
IO.println (← ref1.get)
50+
51+
/--
52+
info: 0
53+
5
54+
-/
55+
#guard_msgs in
56+
#eval takeRegister
57+
58+
def modifyRegister : IO Unit := do
59+
let ref1 ← IO.mkRef 1
60+
IO.println (← ref1.get)
61+
ref1.modify (fun x => 2 * x)
62+
IO.println (← ref1.get)
63+
64+
/--
65+
info: 1
66+
2
67+
-/
68+
#guard_msgs in
69+
#eval modifyRegister
70+
71+
def modifyGetRegister : IO Unit := do
72+
let ref1 ← IO.mkRef 1
73+
IO.println (← ref1.get)
74+
IO.println (← ref1.modifyGet (fun x => (x, 2 * x)))
75+
IO.println (← ref1.get)
76+
77+
/--
78+
info: 1
79+
1
80+
2
81+
-/
82+
#guard_msgs in
83+
#eval modifyGetRegister
84+
85+
end STTest

0 commit comments

Comments
 (0)