Skip to content

Commit 015766e

Browse files
[ysh] Fix crash when comparing non-comparable types (#1754)
1 parent e4257e0 commit 015766e

File tree

6 files changed

+73
-9
lines changed

6 files changed

+73
-9
lines changed

osh/cmd_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,7 @@ def _DoCase(self, node):
14701470
expr_val = self.expr_ev.EvalExpr(
14711471
pat_expr, case_arm.left)
14721472

1473-
if val_ops.ExactlyEqual(expr_val, to_match):
1473+
if val_ops.ExactlyEqual(expr_val, to_match, case_arm.left):
14741474
status = self._ExecuteList(case_arm.action)
14751475
matched = True
14761476
break

spec/ysh-expr-compare.test.sh

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,40 @@ y=42
343343
x=42
344344
## END
345345

346+
#### Undefined comparisons
347+
shopt -s ysh:all
348+
349+
func f() { true }
350+
var mydict = {}
351+
var myexpr = ^[123]
352+
353+
var unimpl = [
354+
/ [a-z]+ /, # Eggex
355+
myexpr, # Expr
356+
^(echo hello), # Block
357+
f, # Func
358+
mydict->keys, # BoundFunc
359+
# These cannot be constructed
360+
# - Proc
361+
# - Slice
362+
# - Range
363+
]
364+
365+
for val in (unimpl) {
366+
try { :: val === val }
367+
if (_status !== 3) {
368+
exit 1
369+
}
370+
}
371+
## STDOUT:
372+
## END
373+
374+
#### Non-comparable types in case arms
375+
var myexpr = ^[123]
376+
377+
case (myexpr) {
378+
(myexpr) { echo 123; }
379+
}
380+
## status: 3
381+
## STDOUT:
382+
## END

test/parse-errors.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,14 @@ ysh_case() {
10111011
}
10121012
'
10131013

1014+
_ysh-should-parse '
1015+
var myexpr = ^[123]
1016+
1017+
case (123) {
1018+
(myexpr) { echo 1 }
1019+
}
1020+
'
1021+
10141022
_ysh-should-parse '
10151023
case (x) {
10161024
(else) { echo 1 }

test/ysh-runtime-errors.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,23 @@ test-read-builtin() {
743743
_error-case-X 2 'echo hi | read --line x y'
744744
}
745745

746+
test-equality() {
747+
_expr-error-case '
748+
= ^[42] === ^[43]
749+
'
750+
751+
_expr-error-case '
752+
= ^(echo hi) === ^(echo yo)
753+
'
754+
755+
return
756+
757+
# Hm it's kind of weird you can do this -- it's False
758+
_expr-error-case '
759+
= ^[42] === "hi"
760+
'
761+
}
762+
746763
soil-run() {
747764
# This is like run-test-funcs, except errexit is off here
748765
run-test-funcs

ysh/expr_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,12 +622,12 @@ def _EvalCompare(self, node):
622622
if left.tag() != right.tag():
623623
result = False
624624
else:
625-
result = val_ops.ExactlyEqual(left, right)
625+
result = val_ops.ExactlyEqual(left, right, op)
626626
elif op.id == Id.Expr_NotDEqual:
627627
if left.tag() != right.tag():
628628
result = True
629629
else:
630-
result = not val_ops.ExactlyEqual(left, right)
630+
result = not val_ops.ExactlyEqual(left, right, op)
631631

632632
elif op.id == Id.Expr_In:
633633
result = val_ops.Contains(left, right)

ysh/val_ops.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from _devbuild.gen.syntax_asdl import loc, loc_t, command_t
88
from _devbuild.gen.value_asdl import (value, value_e, value_t)
99
from core import error
10+
from core import ui
1011
from mycpp.mylib import tagswitch
1112
from ysh import regex_translate
1213

@@ -323,8 +324,8 @@ def ToBool(val):
323324
return True # all other types are Truthy
324325

325326

326-
def ExactlyEqual(left, right):
327-
# type: (value_t, value_t) -> bool
327+
def ExactlyEqual(left, right, blame_loc):
328+
# type: (value_t, value_t, loc_t) -> bool
328329
if left.tag() != right.tag():
329330
return False
330331

@@ -351,7 +352,7 @@ def ExactlyEqual(left, right):
351352
# Note: could provide floatEquals(), and suggest it
352353
# Suggested idiom is abs(f1 - f2) < 0.1
353354
raise error.TypeErrVerbose("Equality isn't defined on Float",
354-
loc.Missing)
355+
blame_loc)
355356

356357
elif case(value_e.Str):
357358
left = cast(value.Str, UP_left)
@@ -377,7 +378,7 @@ def ExactlyEqual(left, right):
377378
return False
378379

379380
for i in xrange(0, len(left.items)):
380-
if not ExactlyEqual(left.items[i], right.items[i]):
381+
if not ExactlyEqual(left.items[i], right.items[i], blame_loc):
381382
return False
382383

383384
return True
@@ -401,12 +402,13 @@ def ExactlyEqual(left, right):
401402
return False
402403

403404
for k in left.d.keys():
404-
if k not in right.d or not ExactlyEqual(right.d[k], left.d[k]):
405+
if k not in right.d or not ExactlyEqual(right.d[k], left.d[k], blame_loc):
405406
return False
406407

407408
return True
408409

409-
raise NotImplementedError(left)
410+
raise error.TypeErrVerbose(
411+
"Can't compare two values of type %s" % ui.ValType(left), blame_loc)
410412

411413

412414
def Contains(needle, haystack):

0 commit comments

Comments
 (0)