diff --git a/camel/interpreters/internal_python_interpreter.py b/camel/interpreters/internal_python_interpreter.py index c75556abe3..5c0b47e0ad 100644 --- a/camel/interpreters/internal_python_interpreter.py +++ b/camel/interpreters/internal_python_interpreter.py @@ -294,6 +294,10 @@ def _execute_ast(self, expression: ast.AST) -> Any: # update the state. We return the variable assigned as it may # be used to determine the final result. return self._execute_assign(expression) + elif isinstance(expression, ast.AugAssign): + # Augmented assignment (+=, -=, *=, etc.) -> update the state + # and return the new value + return self._execute_augassign(expression) elif isinstance(expression, ast.Attribute): value = self._execute_ast(expression.value) return getattr(value, expression.attr) @@ -396,6 +400,55 @@ def _assign(self, target: ast.expr, value: Any): f"{target.__class__.__name__} instead." ) + def _execute_augassign(self, augassign: ast.AugAssign) -> Any: + # Get the current value of the target variable + target = augassign.target + if not isinstance(target, ast.Name): + raise InterpreterError( + f"Unsupported target for augmented assignment. " + f"Expected ast.Name, got {target.__class__.__name__} instead." + ) + + current_value = self._get_value_from_state(target.id) + operator = augassign.op + right_value = self._execute_ast(augassign.value) + + # Apply the operation based on the operator type + if isinstance(operator, ast.Add): + result = current_value + right_value + elif isinstance(operator, ast.Sub): + result = current_value - right_value + elif isinstance(operator, ast.Mult): + result = current_value * right_value + elif isinstance(operator, ast.Div): + result = current_value / right_value + elif isinstance(operator, ast.FloorDiv): + result = current_value // right_value + elif isinstance(operator, ast.Mod): + result = current_value % right_value + elif isinstance(operator, ast.Pow): + result = current_value ** right_value + elif isinstance(operator, ast.LShift): + result = current_value << right_value + elif isinstance(operator, ast.RShift): + result = current_value >> right_value + elif isinstance(operator, ast.BitOr): + result = current_value | right_value + elif isinstance(operator, ast.BitXor): + result = current_value ^ right_value + elif isinstance(operator, ast.BitAnd): + result = current_value & right_value + elif isinstance(operator, ast.MatMult): + result = current_value @ right_value + else: + raise InterpreterError( + f"Unsupported augmented assignment operator: {operator}" + ) + + # Update the state with the new value + self.state[target.id] = result + return result + def _execute_call(self, call: ast.Call) -> Any: callable_func = self._execute_ast(call.func) diff --git a/test/interpreters/test_python_interpreter.py b/test/interpreters/test_python_interpreter.py index f92cbfef37..bd00ffd9a3 100644 --- a/test/interpreters/test_python_interpreter.py +++ b/test/interpreters/test_python_interpreter.py @@ -286,16 +286,42 @@ def test_joined_str(interpreter: InternalPythonInterpreter): assert execution_res == "2,3,5,7,11" -def test_expression_not_support(interpreter: InternalPythonInterpreter): - code = """x = 1 -x += 1""" - with pytest.raises(InterpreterError) as e: - interpreter.execute(code, keep_state=False) - exec_msg = e.value.args[0] - assert exec_msg == ( - "Evaluation of the code stopped at node 1. See:" - "\nAugAssign is not supported." - ) +def test_augassign_operations(interpreter: InternalPythonInterpreter): + # Test various augmented assignment operations + code = """x = 10 +x += 5""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 15 + + code = """x = 10 +x -= 3""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 7 + + code = """x = 4 +x *= 3""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 12 + + code = """x = 20 +x /= 4""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 5.0 + + code = """x = 17 +x //= 5""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 3 + + code = """x = 17 +x %= 5""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 2 + + code = """x = 2 +x **= 3""" + execution_res = interpreter.execute(code, keep_state=False) + assert execution_res == 8 def test_allow_builtins(interpreter: InternalPythonInterpreter):