Skip to content

Commit 5a3054a

Browse files
authored
Feature/orphaned screen cleanup (#49)
* Cleanup orphaned screen wcgw sessions automatically * Multiple commands parsing * Simplified bash parsing * Fixed the backup case for bash statement parsing * Bump up patch version * Reverted bad changes in tools.py
1 parent 77effe0 commit 5a3054a

File tree

5 files changed

+412
-4
lines changed

5 files changed

+412
-4
lines changed

Diff for: pyproject.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
authors = [{ name = "Aman Rusia", email = "[email protected]" }]
33
name = "wcgw"
4-
version = "4.1.0"
4+
version = "4.1.1"
55
description = "Shell and coding agent on claude and chatgpt"
66
readme = "README.md"
77
requires-python = ">=3.11"
@@ -23,6 +23,7 @@ dependencies = [
2323
"tokenizers>=0.21.0",
2424
"pygit2>=1.16.0",
2525
"syntax-checker>=0.3.0",
26+
"psutil>=7.0.0",
2627
]
2728

2829
[project.urls]
@@ -58,6 +59,9 @@ dev-dependencies = [
5859
"line-profiler>=4.2.0",
5960
"pytest-asyncio>=0.25.3",
6061
"types-pexpect>=4.9.0.20241208",
62+
"types-psutil>=7.0.0.20250218",
63+
"tree-sitter>=0.24.0",
64+
"tree-sitter-bash>=0.23.3",
6165
]
6266

6367
[tool.pytest.ini_options]

Diff for: src/wcgw/client/bash_state/bash_state.py

+134-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818
import pexpect
19+
import psutil
1920
import pyte
2021

2122
from ...types_ import (
@@ -30,6 +31,7 @@
3031
)
3132
from ..encoder import EncoderDecoder
3233
from ..modes import BashCommandMode, FileEditMode, WriteIfEmptyMode
34+
from .parser.bash_statement_parser import BashStatementParser
3335

3436
PROMPT_CONST = "wcgw→" + " "
3537
PROMPT_STATEMENT = "export GIT_PAGER=cat PAGER=cat PROMPT_COMMAND= PS1='wcgw→'' '"
@@ -87,6 +89,116 @@ def check_if_screen_command_available() -> bool:
8789
return False
8890

8991

92+
def get_wcgw_screen_sessions() -> list[str]:
93+
"""
94+
Get a list of all WCGW screen session IDs.
95+
96+
Returns:
97+
List of screen session IDs that match the wcgw pattern.
98+
"""
99+
screen_sessions = []
100+
101+
try:
102+
# Get list of all screen sessions
103+
result = subprocess.run(
104+
["screen", "-ls"],
105+
capture_output=True,
106+
text=True,
107+
check=False, # Don't raise exception on non-zero exit code
108+
timeout=0.5,
109+
)
110+
output = result.stdout or result.stderr or ""
111+
112+
# Parse screen output to get session IDs
113+
for line in output.splitlines():
114+
line = line.strip()
115+
if not line or not line[0].isdigit():
116+
continue
117+
118+
# Extract session info (e.g., "1234.wcgw.123456 (Detached)")
119+
session_parts = line.split()
120+
if not session_parts:
121+
continue
122+
123+
session_id = session_parts[0].strip()
124+
125+
# Check if it's a WCGW session
126+
if ".wcgw." in session_id:
127+
screen_sessions.append(session_id)
128+
except Exception:
129+
# If anything goes wrong, just return empty list
130+
pass
131+
132+
return screen_sessions
133+
134+
135+
def get_orphaned_wcgw_screens() -> list[str]:
136+
"""
137+
Identify orphaned WCGW screen sessions where the parent process has PID 1
138+
or doesn't exist.
139+
140+
Returns:
141+
List of screen session IDs that are orphaned and match the wcgw pattern.
142+
"""
143+
orphaned_screens = []
144+
145+
try:
146+
# Get list of all WCGW screen sessions
147+
screen_sessions = get_wcgw_screen_sessions()
148+
149+
for session_id in screen_sessions:
150+
# Extract PID from session ID (first part before the dot)
151+
try:
152+
pid = int(session_id.split(".")[0])
153+
154+
# Check if process exists and if its parent is PID 1
155+
try:
156+
process = psutil.Process(pid)
157+
parent_pid = process.ppid()
158+
159+
if parent_pid == 1:
160+
# This is an orphaned process
161+
orphaned_screens.append(session_id)
162+
except psutil.NoSuchProcess:
163+
# Process doesn't exist anymore, consider it orphaned
164+
orphaned_screens.append(session_id)
165+
except (ValueError, IndexError):
166+
# Couldn't parse PID, skip
167+
continue
168+
except Exception:
169+
# If anything goes wrong, just return empty list
170+
pass
171+
172+
return orphaned_screens
173+
174+
175+
def cleanup_orphaned_wcgw_screens(console: Console) -> None:
176+
"""
177+
Clean up all orphaned WCGW screen sessions.
178+
179+
Args:
180+
console: Console for logging.
181+
"""
182+
orphaned_sessions = get_orphaned_wcgw_screens()
183+
184+
if not orphaned_sessions:
185+
return
186+
187+
console.log(
188+
f"Found {len(orphaned_sessions)} orphaned WCGW screen sessions to clean up"
189+
)
190+
191+
for session in orphaned_sessions:
192+
try:
193+
subprocess.run(
194+
["screen", "-S", session, "-X", "quit"],
195+
check=False,
196+
timeout=CONFIG.timeout,
197+
)
198+
except Exception as e:
199+
console.log(f"Failed to kill orphaned screen session: {session}\n{e}")
200+
201+
90202
def cleanup_all_screens_with_name(name: str, console: Console) -> None:
91203
"""
92204
There could be in worst case multiple screens with same name, clear them if any.
@@ -269,6 +381,8 @@ def send(self, s: str | bytes, set_as_command: Optional[str]) -> int:
269381
self.close_bg_expect_thread()
270382
if set_as_command is not None:
271383
self._last_command = set_as_command
384+
# if s == "\n":
385+
# return self._shell.sendcontrol("m")
272386
output = self._shell.send(s)
273387
return output
274388

@@ -387,6 +501,11 @@ def _init_shell(self) -> None:
387501
self._last_command = ""
388502
# Ensure self._cwd exists
389503
os.makedirs(self._cwd, exist_ok=True)
504+
505+
# Clean up orphaned WCGW screen sessions
506+
if check_if_screen_command_available():
507+
cleanup_orphaned_wcgw_screens(self.console)
508+
390509
try:
391510
self._shell, self._shell_id = start_shell(
392511
self._bash_command_mode.bash_mode == "restricted_mode",
@@ -817,10 +936,22 @@ def _execute_bash(
817936
raise ValueError(WAITING_INPUT_MESSAGE)
818937

819938
command = command_data.command.strip()
939+
940+
# Check for multiple statements using the bash statement parser
820941
if "\n" in command:
821-
raise ValueError(
822-
"Command should not contain newline character in middle. Run only one command at a time."
823-
)
942+
try:
943+
parser = BashStatementParser()
944+
statements = parser.parse_string(command)
945+
if len(statements) > 1:
946+
return (
947+
"Error: Command contains multiple statements. Please run only one bash statement at a time.",
948+
0.0,
949+
)
950+
except Exception:
951+
# Fall back to simple newline check if something goes wrong
952+
raise ValueError(
953+
"Command should not contain newline character in middle. Run only one command at a time."
954+
)
824955

825956
for i in range(0, len(command), 128):
826957
bash_state.send(command[i : i + 128], set_as_command=None)

Diff for: src/wcgw/client/bash_state/parser/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
Parser for bash statements using tree-sitter.
3+
4+
This module provides functionality to parse and identify individual bash statements.
5+
"""
6+
7+
from .bash_statement_parser import BashStatementParser, Statement
+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Bash Statement Parser
4+
5+
This script parses bash scripts and identifies individual statements using tree-sitter.
6+
It correctly handles multi-line strings, command chains with && and ||, and semicolon-separated statements.
7+
"""
8+
9+
import sys
10+
from dataclasses import dataclass
11+
from typing import Any, List, Optional
12+
13+
import tree_sitter_bash
14+
from tree_sitter import Language, Parser
15+
16+
17+
@dataclass
18+
class Statement:
19+
"""A bash statement with its source code and position information."""
20+
21+
text: str
22+
start_line: int
23+
end_line: int
24+
start_byte: int
25+
end_byte: int
26+
node_type: str
27+
parent_type: Optional[str] = None
28+
29+
def __str__(self) -> str:
30+
return self.text.strip()
31+
32+
33+
class BashStatementParser:
34+
def __init__(self) -> None:
35+
# Use the precompiled bash language
36+
self.language = Language(tree_sitter_bash.language())
37+
self.parser = Parser(self.language)
38+
39+
def parse_file(self, file_path: str) -> List[Statement]:
40+
"""Parse a bash script file and return a list of statements."""
41+
with open(file_path, "r", encoding="utf-8") as f:
42+
content = f.read()
43+
return self.parse_string(content)
44+
45+
def parse_string(self, content: str) -> List[Statement]:
46+
"""Parse a string containing bash script and return a list of statements."""
47+
tree = self.parser.parse(bytes(content, "utf-8"))
48+
root_node = tree.root_node
49+
50+
# For debugging: Uncomment to print the tree structure
51+
# self._print_tree(root_node, content)
52+
53+
statements: List[Statement] = []
54+
self._extract_statements(root_node, content, statements, None)
55+
56+
# Post-process statements to handle multi-line statements correctly
57+
return self._post_process_statements(statements, content)
58+
59+
def _print_tree(self, node: Any, content: str, indent: str = "") -> None:
60+
"""Debug helper to print the entire syntax tree."""
61+
node_text = content[node.start_byte : node.end_byte]
62+
if len(node_text) > 40:
63+
node_text = node_text[:37] + "..."
64+
print(f"{indent}{node.type}: {repr(node_text)}")
65+
for child in node.children:
66+
self._print_tree(child, content, indent + " ")
67+
68+
def _extract_statements(
69+
self,
70+
node: Any,
71+
content: str,
72+
statements: List[Statement],
73+
parent_type: Optional[str],
74+
) -> None:
75+
"""Recursively extract statements from the syntax tree."""
76+
# Node types that represent bash statements
77+
statement_node_types = {
78+
# Basic statements
79+
"command",
80+
"variable_assignment",
81+
"declaration_command",
82+
"unset_command",
83+
# Control flow statements
84+
"for_statement",
85+
"c_style_for_statement",
86+
"while_statement",
87+
"if_statement",
88+
"case_statement",
89+
# Function definition
90+
"function_definition",
91+
# Command chains and groups
92+
"pipeline", # For command chains with | and |&
93+
"list", # For command chains with && and ||
94+
"compound_statement",
95+
"subshell",
96+
"redirected_statement",
97+
}
98+
99+
# Create a Statement object for this node if it's a recognized statement type
100+
if node.type in statement_node_types:
101+
# Get the text of this statement
102+
start_byte = node.start_byte
103+
end_byte = node.end_byte
104+
statement_text = content[start_byte:end_byte]
105+
106+
# Get line numbers
107+
start_line = (
108+
node.start_point[0] + 1
109+
) # tree-sitter uses 0-indexed line numbers
110+
end_line = node.end_point[0] + 1
111+
112+
statements.append(
113+
Statement(
114+
text=statement_text,
115+
start_line=start_line,
116+
end_line=end_line,
117+
start_byte=start_byte,
118+
end_byte=end_byte,
119+
node_type=node.type,
120+
parent_type=parent_type,
121+
)
122+
)
123+
124+
# Update parent type for children
125+
parent_type = node.type
126+
127+
# Recursively process all children
128+
for child in node.children:
129+
self._extract_statements(child, content, statements, parent_type)
130+
131+
def _post_process_statements(
132+
self, statements: List[Statement], content: str
133+
) -> List[Statement]:
134+
if not statements:
135+
return []
136+
137+
# Filter out list statements that have been split
138+
top_statements = []
139+
for stmt in statements:
140+
# Skip statements that are contained within others
141+
is_contained = False
142+
for other in statements:
143+
if other is stmt:
144+
continue
145+
146+
# Check if completely contained (except for lists we've split)
147+
if other.node_type != "list" or ";" not in other.text:
148+
if (
149+
other.start_line <= stmt.start_line
150+
and other.end_line >= stmt.end_line
151+
and len(other.text) > len(stmt.text)
152+
and stmt.text in other.text
153+
):
154+
is_contained = True
155+
break
156+
157+
if not is_contained:
158+
top_statements.append(stmt)
159+
160+
# Sort by position in file for consistent output
161+
top_statements.sort(key=lambda s: (s.start_line, s.text))
162+
163+
return top_statements
164+
165+
166+
def main() -> None:
167+
if len(sys.argv) < 2:
168+
print("Usage: python bash_statement_parser.py <bash_script_file>")
169+
sys.exit(1)
170+
171+
parser = BashStatementParser()
172+
statements = parser.parse_file(sys.argv[1])
173+
174+
print(f"Found {len(statements)} statements:")
175+
for i, stmt in enumerate(statements, 1):
176+
print(f"\n--- Statement {i} (Lines {stmt.start_line}-{stmt.end_line}) ---")
177+
print(stmt)
178+
179+
180+
if __name__ == "__main__":
181+
main()

0 commit comments

Comments
 (0)