Skip to content

Commit

Permalink
[ysh] Fix crash when comparing non-comparable types (#1754)
Browse files Browse the repository at this point in the history
  • Loading branch information
PossiblyAShrub authored Oct 31, 2023
1 parent e4257e0 commit 015766e
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 9 deletions.
2 changes: 1 addition & 1 deletion osh/cmd_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ def _DoCase(self, node):
expr_val = self.expr_ev.EvalExpr(
pat_expr, case_arm.left)

if val_ops.ExactlyEqual(expr_val, to_match):
if val_ops.ExactlyEqual(expr_val, to_match, case_arm.left):
status = self._ExecuteList(case_arm.action)
matched = True
break
Expand Down
37 changes: 37 additions & 0 deletions spec/ysh-expr-compare.test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,40 @@ y=42
x=42
## END

#### Undefined comparisons
shopt -s ysh:all

func f() { true }
var mydict = {}
var myexpr = ^[123]

var unimpl = [
/ [a-z]+ /, # Eggex
myexpr, # Expr
^(echo hello), # Block
f, # Func
mydict->keys, # BoundFunc
# These cannot be constructed
# - Proc
# - Slice
# - Range
]

for val in (unimpl) {
try { :: val === val }
if (_status !== 3) {
exit 1
}
}
## STDOUT:
## END

#### Non-comparable types in case arms
var myexpr = ^[123]

case (myexpr) {
(myexpr) { echo 123; }
}
## status: 3
## STDOUT:
## END
8 changes: 8 additions & 0 deletions test/parse-errors.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,14 @@ ysh_case() {
}
'

_ysh-should-parse '
var myexpr = ^[123]
case (123) {
(myexpr) { echo 1 }
}
'

_ysh-should-parse '
case (x) {
(else) { echo 1 }
Expand Down
17 changes: 17 additions & 0 deletions test/ysh-runtime-errors.sh
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,23 @@ test-read-builtin() {
_error-case-X 2 'echo hi | read --line x y'
}

test-equality() {
_expr-error-case '
= ^[42] === ^[43]
'

_expr-error-case '
= ^(echo hi) === ^(echo yo)
'

return

# Hm it's kind of weird you can do this -- it's False
_expr-error-case '
= ^[42] === "hi"
'
}

soil-run() {
# This is like run-test-funcs, except errexit is off here
run-test-funcs
Expand Down
4 changes: 2 additions & 2 deletions ysh/expr_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,12 @@ def _EvalCompare(self, node):
if left.tag() != right.tag():
result = False
else:
result = val_ops.ExactlyEqual(left, right)
result = val_ops.ExactlyEqual(left, right, op)
elif op.id == Id.Expr_NotDEqual:
if left.tag() != right.tag():
result = True
else:
result = not val_ops.ExactlyEqual(left, right)
result = not val_ops.ExactlyEqual(left, right, op)

elif op.id == Id.Expr_In:
result = val_ops.Contains(left, right)
Expand Down
14 changes: 8 additions & 6 deletions ysh/val_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from _devbuild.gen.syntax_asdl import loc, loc_t, command_t
from _devbuild.gen.value_asdl import (value, value_e, value_t)
from core import error
from core import ui
from mycpp.mylib import tagswitch
from ysh import regex_translate

Expand Down Expand Up @@ -323,8 +324,8 @@ def ToBool(val):
return True # all other types are Truthy


def ExactlyEqual(left, right):
# type: (value_t, value_t) -> bool
def ExactlyEqual(left, right, blame_loc):
# type: (value_t, value_t, loc_t) -> bool
if left.tag() != right.tag():
return False

Expand All @@ -351,7 +352,7 @@ def ExactlyEqual(left, right):
# Note: could provide floatEquals(), and suggest it
# Suggested idiom is abs(f1 - f2) < 0.1
raise error.TypeErrVerbose("Equality isn't defined on Float",
loc.Missing)
blame_loc)

elif case(value_e.Str):
left = cast(value.Str, UP_left)
Expand All @@ -377,7 +378,7 @@ def ExactlyEqual(left, right):
return False

for i in xrange(0, len(left.items)):
if not ExactlyEqual(left.items[i], right.items[i]):
if not ExactlyEqual(left.items[i], right.items[i], blame_loc):
return False

return True
Expand All @@ -401,12 +402,13 @@ def ExactlyEqual(left, right):
return False

for k in left.d.keys():
if k not in right.d or not ExactlyEqual(right.d[k], left.d[k]):
if k not in right.d or not ExactlyEqual(right.d[k], left.d[k], blame_loc):
return False

return True

raise NotImplementedError(left)
raise error.TypeErrVerbose(
"Can't compare two values of type %s" % ui.ValType(left), blame_loc)


def Contains(needle, haystack):
Expand Down

0 comments on commit 015766e

Please sign in to comment.