Skip to content

Commit dbd1752

Browse files
committed
Support comparison of heterogenous values
Fixes #94.
1 parent 8151c60 commit dbd1752

File tree

2 files changed

+91
-42
lines changed

2 files changed

+91
-42
lines changed

src/main/java/org/squiddev/cobalt/OperationHelper.java

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -215,54 +215,46 @@ public static LuaValue concatNonStrings(LuaState state, LuaValue left, LuaValue
215215
//endregion
216216

217217
//region Compare
218+
private static LuaValue getComparisonMetatable(LuaState state, LuaValue tag, LuaValue left, LuaValue right) throws LuaError {
219+
LuaValue h = left.metatag(state, tag);
220+
if (!h.isNil()) return h;
221+
222+
return right.metatag(state, tag);
223+
}
224+
218225
public static boolean lt(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
219-
int tLeft = left.type();
220-
if (tLeft != right.type()) {
221-
throw ErrorFactory.compareError(left, right);
222-
}
223-
switch (tLeft) {
224-
case TNUMBER:
225-
return left.toDouble() < right.toDouble();
226-
case TSTRING:
227-
return left.checkLuaString().compareTo(right.checkLuaString()) < 0;
228-
default:
229-
LuaValue h = left.metatag(state, Constants.LT);
230-
if (!h.isNil() && h == right.metatag(state, Constants.LT)) {
231-
return Dispatch.call(state, h, left, right).toBoolean();
232-
} else {
233-
throw ErrorFactory.compareError(left, right);
234-
}
235-
}
226+
int tLeft = left.type(), tRight = right.type();
227+
228+
if (tLeft == TNUMBER && tRight == TNUMBER) return left.toDouble() < right.toDouble();
229+
if (tLeft == TSTRING && tRight == TSTRING) return left.checkLuaString().compareTo(right.checkLuaString()) < 0;
230+
231+
var mt = getComparisonMetatable(state, LT, left, right);
232+
if (mt.isNil()) throw ErrorFactory.compareError(left, right);
233+
234+
return Dispatch.call(state, mt, left, right).toBoolean();
236235
}
237236

238237
public static boolean le(LuaState state, LuaValue left, LuaValue right) throws LuaError, UnwindThrowable {
239-
int tLeft = left.type();
240-
if (tLeft != right.type()) {
241-
throw ErrorFactory.compareError(left, right);
238+
int tLeft = left.type(), tRight = right.type();
239+
240+
if (tLeft == TNUMBER && tRight == TNUMBER) return left.toDouble() <= right.toDouble();
241+
if (tLeft == TSTRING && tRight == TSTRING) return left.checkLuaString().compareTo(right.checkLuaString()) <= 0;
242+
243+
{ // Prefer __le.
244+
var leMt = getComparisonMetatable(state, LE, left, right);
245+
if (!leMt.isNil()) return Dispatch.call(state, leMt, left, right).toBoolean();
242246
}
243-
switch (tLeft) {
244-
case TNUMBER:
245-
return left.toDouble() <= right.toDouble();
246-
case TSTRING:
247-
return left.checkLuaString().compareTo(right.checkLuaString()) <= 0;
248-
default:
249-
LuaValue h = left.metatag(state, Constants.LE);
250-
if (h.isNil()) {
251-
h = left.metatag(state, Constants.LT);
252-
if (!h.isNil() && h == right.metatag(state, Constants.LT)) {
253-
DebugFrame frame = DebugState.get(state).getStackUnsafe();
254-
255-
frame.flags |= FLAG_LEQ;
256-
boolean result = !Dispatch.call(state, h, right, left).toBoolean();
257-
frame.flags ^= FLAG_LEQ;
258-
259-
return result;
260-
}
261-
} else if (h == right.metatag(state, Constants.LE)) {
262-
return Dispatch.call(state, h, left, right).toBoolean();
263-
}
264247

265-
throw ErrorFactory.compareError(left, right);
248+
{ // If unavailable, fall back to __lt.
249+
var ltMt = getComparisonMetatable(state, LT, left, right);
250+
if (ltMt.isNil()) throw ErrorFactory.compareError(left, right);
251+
DebugFrame frame = DebugState.get(state).getStackUnsafe();
252+
253+
frame.flags |= FLAG_LEQ;
254+
boolean result = !Dispatch.call(state, ltMt, right, left).toBoolean();
255+
frame.flags ^= FLAG_LEQ;
256+
257+
return result;
266258
}
267259
}
268260

src/test/resources/spec/operation_spec.lua

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,61 @@ describe("Lua's base operators", function()
4040
expect.error(adder, "hello"):strip_context():eq("attempt to add a 'number' with a 'string'")
4141
end)
4242
end)
43+
44+
describe("comparison", function()
45+
local function value(x) if type(x) == "number" then return x else return x.value end end
46+
local comparable_mt = {
47+
__lt = function(x, y) return value(x) < value(y) end,
48+
__le = function(x, y) return value(x) <= value(y) end,
49+
}
50+
51+
local function mk(x) return setmetatable({ value = x }, comparable_mt) end
52+
53+
it("compare homogenous values", function()
54+
expect(mk(1) < mk(2)):eq(true)
55+
expect(mk(2) < mk(2)):eq(false)
56+
expect(mk(3) < mk(2)):eq(false)
57+
58+
expect(mk(1) <= mk(2)):eq(true)
59+
expect(mk(2) <= mk(2)):eq(true)
60+
expect(mk(3) <= mk(2)):eq(false)
61+
end)
62+
63+
it("cannot compare heterogenous values :lua<=5.1 :!cobalt", function()
64+
expect.error(function() return mk(1) < 2 end)
65+
:str_match(": attempt to compare table with number$")
66+
end)
67+
68+
it("cannot compare heterogenous values", function()
69+
expect.error(function() return "2.0" < 2 end)
70+
:str_match(": attempt to compare string with number$")
71+
end)
72+
73+
it("compare heterogenous values :lua>=5.2", function()
74+
expect(mk(1) < 2):eq(true)
75+
expect(mk(2) < 2):eq(false)
76+
expect(mk(3) < 2):eq(false)
77+
78+
expect(1 < mk(2)):eq(true)
79+
expect(2 < mk(2)):eq(false)
80+
expect(3 < mk(2)):eq(false)
81+
82+
expect(mk(1) <= 2):eq(true)
83+
expect(mk(2) <= 2):eq(true)
84+
expect(mk(3) <= 2):eq(false)
85+
86+
expect(1 <= mk(2)):eq(true)
87+
expect(2 <= mk(2)):eq(true)
88+
expect(3 <= mk(2)):eq(false)
89+
end)
90+
91+
it("<= falls back to __lt", function()
92+
local comparable_mt = { __lt = function(x, y) return value(x) < value(y) end }
93+
local function mk(x) return setmetatable({ value = x }, comparable_mt) end
94+
95+
expect(mk(1) <= mk(2)):eq(true)
96+
expect(mk(2) <= mk(2)):eq(true)
97+
expect(mk(3) <= mk(2)):eq(false)
98+
end)
99+
end)
43100
end)

0 commit comments

Comments
 (0)