Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions android_world/agents/midscene.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


from android_world.agents import base_agent
from android_world.env import adb_utils
from android_world.env import interface
from android_world.env import representation_utils

Expand Down Expand Up @@ -150,34 +151,81 @@ def _send_rpc_request(self, method: str, params: dict) -> dict:
return result

def _start_dom_server(self):
"""Starts a background HTTP server that serves the current a11y tree as raw XML."""
"""Starts a background HTTP server that serves page XML."""
agent_ref = self

class DomHandler(BaseHTTPRequestHandler):
def _write_response(self, status: int, content_type: str, body: str):
self.send_response(status)
self.send_header('Content-Type', content_type)
self.end_headers()
self.wfile.write(body.encode('utf-8'))

def _page_xml_from_adb_dump(self) -> str:
try:
page_xml = adb_utils.uiautomator_dump(agent_ref.env.controller)
if page_xml:
return page_xml
except Exception as e:
agent_ref._formatted_console(
'DOM provider adb dump fallback failed: ' + str(e)
)
return ''

def _page_xml_from_state(self, state) -> str:
if state.forest is None:
agent_ref._formatted_console(
'DOM provider got empty AccessibilityForwarder forest; falling back to adb dump'
)
return self._page_xml_from_adb_dump()

page_xml = representation_utils.forest_to_raw_xml(state.forest)
if page_xml:
return page_xml

agent_ref._formatted_console(
'DOM provider got empty AccessibilityForwarder XML; falling back to adb dump'
)
return self._page_xml_from_adb_dump()

def _get_state_with_retry(self):
state = agent_ref.env.get_state()
if state.forest is not None:
return state

time.sleep(0.2)
retry_state = agent_ref.env.get_state()
if retry_state.forest is None:
agent_ref._formatted_console(
'DOM provider retry still returned empty AccessibilityForwarder forest'
)
return retry_state

def _get_page_xml(self) -> str:
return self._page_xml_from_state(self._get_state_with_retry())

def do_GET(self):
try:
state = agent_ref.env.get_state()
if state.forest is not None:
raw_xml = representation_utils.forest_to_raw_xml(state.forest)
if self.path == '/context':
# Page structure and cursor metadata come from the same
# AccessibilityForwarder snapshot. Cursor state is embedded on the
# focused editable node as XML attributes.
state = self._get_state_with_retry()
page_xml = self._page_xml_from_state(state)
self._write_response(200, 'text/xml; charset=utf-8', page_xml)
else:
raw_xml = ''
self.send_response(200)
self.send_header('Content-Type', 'text/xml; charset=utf-8')
self.end_headers()
self.wfile.write(raw_xml.encode('utf-8'))
page_xml = self._get_page_xml()
self._write_response(200, 'text/xml; charset=utf-8', page_xml)
except Exception as e:
self.send_response(500)
self.send_header('Content-Type', 'text/plain')
self.end_headers()
self.wfile.write(str(e).encode('utf-8'))
self._write_response(500, 'text/plain', str(e))

def log_message(self, format, *args):
pass # Suppress default access logs

server = HTTPServer(('127.0.0.1', 0), DomHandler)
port = server.server_address[1]
self._dom_server = server
self._dom_server_url = f'http://127.0.0.1:{port}/dom'
self._dom_server_url = f'http://127.0.0.1:{port}/context'
self._formatted_console(f"DOM provider server started at {self._dom_server_url}")

thread = threading.Thread(target=server.serve_forever, daemon=True)
Expand Down
97 changes: 74 additions & 23 deletions android_world/env/android_world_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
# Throttle check_airplane_mode: at most once per interval to reduce ADB load
_last_airplane_check: dict[int, float] = {}
_AIRPLANE_CHECK_INTERVAL = 30.0
_A11Y_FORWARDER_SERVICE = (
'com.google.androidenv.accessibilityforwarder/'
'com.google.androidenv.accessibilityforwarder.AccessibilityForwarder'
)
_A11Y_FORWARDER_FLAGS_RECEIVER = (
'com.google.androidenv.accessibilityforwarder/'
'com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
)


def _has_wrapper(
Expand Down Expand Up @@ -109,7 +117,7 @@ def get_a11y_tree(
try:
forest = env.accumulate_new_extras()['accessibility_tree'][-1] # pytype:disable=attribute-error
return forest
except KeyError:
except (KeyError, IndexError):
logging.warning('Could not get a11y tree, retrying.')
time.sleep(sleep_duration)

Expand Down Expand Up @@ -454,18 +462,34 @@ def _restart_a11y_forwarder(self) -> bool:
server_port = str(self._adb_server_port)
device_args = ['-s', self._device_name] if self._device_name else []

# Step 1: Re-enable the accessibility service
cmd = [adb_path, '-P', server_port] + device_args + [
'shell', 'settings', 'put', 'secure',
'enabled_accessibility_services',
'com.google.androidenv.accessibilityforwarder/'
'com.google.androidenv.accessibilityforwarder.AccessibilityForwarder',
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
logging.warning('Failed to re-enable a11y service: %s', result.stderr)
if self._is_remote and not self.ensure_adb_connection():
return False

def run_shell_command(args: list[str]) -> subprocess.CompletedProcess:
cmd = [adb_path, '-P', server_port] + device_args + ['shell'] + args
return subprocess.run(cmd, capture_output=True, text=True, timeout=30)

# Step 1: Re-enable Android accessibility and the forwarder service.
# Some failures leave the service listed but global accessibility off.
for shell_args in (
['settings', 'put', 'secure', 'accessibility_enabled', '1'],
[
'settings',
'put',
'secure',
'enabled_accessibility_services',
_A11Y_FORWARDER_SERVICE,
],
):
result = run_shell_command(shell_args)
if result.returncode != 0:
logging.warning(
'Failed to update a11y setting %s: %s',
' '.join(shell_args),
result.stderr,
)
return False

logging.info('Re-enabled AccessibilityForwarder service')
time.sleep(2.0) # Give the service time to start

Expand All @@ -480,8 +504,7 @@ def _restart_a11y_forwarder(self) -> bool:
'shell', 'am', 'broadcast',
'-a', 'accessibility_forwarder.intent.action.SET_GRPC',
'--ei', 'port', str(self._a11y_port),
'-n', 'com.google.androidenv.accessibilityforwarder/'
'com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver',
'-n', _A11Y_FORWARDER_FLAGS_RECEIVER,
]
subprocess.run(cmd, capture_output=True, text=True, timeout=30)

Expand All @@ -501,6 +524,8 @@ def get_a11y_forest(
2. Restart AccessibilityForwarder service (handles uiautomator disruption)
3. Full environment refresh (handles ADB disconnection / deep failures)
"""
self.ensure_adb_connection()

try:
return self._get_a11y_forest()
except RuntimeError:
Expand Down Expand Up @@ -535,25 +560,51 @@ def get_ui_elements(self) -> list[representation_utils.UIElement]:
self.ensure_adb_connection()

if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP:
return representation_utils.forest_to_ui_elements(
self.get_a11y_forest(),
exclude_invisible_elements=True,
)
try:
return representation_utils.forest_to_ui_elements(
self.get_a11y_forest(),
exclude_invisible_elements=True,
)
except RuntimeError as e:
logging.warning(
'A11y tree unavailable after recovery; falling back to '
'uiautomator UI elements: %s',
e,
)
return self._get_uiautomator_ui_elements()
elif self._a11y_method == A11yMethod.UIAUTOMATOR:
return self._get_uiautomator_ui_elements()
else:
return []

def _get_uiautomator_ui_elements(self) -> list[representation_utils.UIElement]:
"""Returns UI elements from uiautomator, or an empty list if it fails."""
try:
return representation_utils.xml_dump_to_ui_elements(
adb_utils.uiautomator_dump(self._env)
)
else:
except Exception as e:
logging.warning('Failed to get UI elements via uiautomator: %s', e)
return []

def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
"""Adds a11y tree info to the observation."""
if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP:
forest = self.get_a11y_forest()
ui_elements = representation_utils.forest_to_ui_elements(
forest,
exclude_invisible_elements=True,
)
try:
forest = self.get_a11y_forest()
except RuntimeError as e:
logging.warning(
'A11y tree unavailable after recovery; falling back to '
'uiautomator UI elements: %s',
e,
)
forest = None
ui_elements = self._get_uiautomator_ui_elements()
else:
ui_elements = representation_utils.forest_to_ui_elements(
forest,
exclude_invisible_elements=True,
)
else:
forest = None
ui_elements = self.get_ui_elements()
Expand Down
30 changes: 30 additions & 0 deletions android_world/env/android_world_controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,36 @@ def test_process_timestep(
exclude_invisible_elements=True,
)

@mock.patch.object(adb_utils, 'check_airplane_mode')
@mock.patch.object(android_world_controller, 'get_controller')
@mock.patch.object(android_world_controller, '_has_wrapper')
@mock.patch.object(representation_utils, 'forest_to_ui_elements')
def test_process_timestep_continues_without_a11y_tree(
self,
mock_forest_to_ui,
mock_has_wrapper,
mock_get_controller,
mock_check_airplane_mode,
):
del mock_has_wrapper, mock_get_controller, mock_check_airplane_mode
mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface)
env = android_world_controller.AndroidWorldController(mock_base_env)
timestep = dm_env.TimeStep(
observation={}, reward=None, discount=None, step_type=None
)

with mock.patch.object(
env, 'get_a11y_forest', side_effect=RuntimeError('a11y unavailable')
), mock.patch.object(
env, '_get_uiautomator_ui_elements', return_value=['fallback']
) as mock_uiautomator_fallback:
processed_timestep = env._process_timestep(timestep)

self.assertIsNone(processed_timestep.observation['forest'])
self.assertEqual(processed_timestep.observation['ui_elements'], ['fallback'])
mock_uiautomator_fallback.assert_called_once()
mock_forest_to_ui.assert_not_called()

@mock.patch.object(adb_utils, 'check_airplane_mode')
@mock.patch.object(android_world_controller, 'get_controller')
@mock.patch.object(android_world_controller, '_has_wrapper')
Expand Down
56 changes: 52 additions & 4 deletions android_world/env/representation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class UIElement:
is_scrollable: Optional[bool] = None
is_selected: Optional[bool] = None
is_visible: Optional[bool] = None
text_selection_start: Optional[int] = None
text_selection_end: Optional[int] = None
package_name: Optional[str] = None
resource_name: Optional[str] = None
tooltip: Optional[str] = None
Expand Down Expand Up @@ -115,6 +117,8 @@ def text_or_none(text: Optional[str]) -> Optional[str]:
is_scrollable=node.is_scrollable,
is_selected=node.is_selected,
is_visible=node.is_visible_to_user,
text_selection_start=node.text_selection_start,
text_selection_end=node.text_selection_end,
Comment on lines +120 to +121
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize unset selection offsets to None

accessibility_node_to_ui_element now copies node.text_selection_start/end directly, but for proto3 scalar fields an unset value is indistinguishable from 0. This causes most nodes (including non-editable ones) to appear as if they have a cursor/selection at offset 0, which can mislead any downstream logic that interprets non-None selection metadata as real cursor state. Please apply the same default-handling logic used in _selection_attrs (or equivalent presence checks) before populating these fields.

Useful? React with 👍 / 👎.

package_name=text_or_none(node.package_name),
resource_name=text_or_none(node.view_id_resource_name),
)
Expand Down Expand Up @@ -191,6 +195,35 @@ def _bounds_str(node: Any) -> str:
return f'[{b.left},{b.top}][{b.right},{b.bottom}]'


def _selection_attrs(node: Any) -> str:
start = getattr(node, 'text_selection_start', 0)
end = getattr(node, 'text_selection_end', 0)

if start < 0 or end < 0:
return ''

# For proto3 fields, 0 is also the default for nodes without a selection.
# Emit a zero offset only when an editable node is focused; otherwise require
# Android to report a non-default range.
if (
start == 0
and end == 0
and not (
getattr(node, 'is_focused', False)
and getattr(node, 'is_editable', False)
)
):
return ''

attrs = (
f' text-selection-start="{start}"'
f' text-selection-end="{end}"'
)
if start == end and start >= 0:
attrs += f' cursor-position="{start}"'
return attrs


def _raw_xml_node(node: Any, children_by_id: dict[int, list[Any]],
indent: int) -> str:
"""Format a protobuf node as uiautomator-dump-compatible <node> XML."""
Expand All @@ -204,7 +237,12 @@ def _raw_xml_node(node: Any, children_by_id: dict[int, list[Any]],
f'scrollable="{_bool_str(node.is_scrollable)}" '
f'selected="{_bool_str(node.is_selected)}" '
f'checked="{_bool_str(node.is_checked)}" '
f'enabled="{_bool_str(node.is_enabled)}" '
f'focusable="{_bool_str(node.is_focusable)}" '
f'focused="{_bool_str(node.is_focused)}" '
f'editable="{_bool_str(node.is_editable)}" '
f'bounds="{_bounds_str(node)}"'
f'{_selection_attrs(node)}'
)

kids = children_by_id.get(node.unique_id, [])
Expand All @@ -222,10 +260,9 @@ def forest_to_raw_xml(
) -> str:
"""Convert accessibility forest to uiautomator-dump-compatible raw XML.

Produces <node> XML with the same attribute names as uiautomator dump
(text, resource-id, class, content-desc, clickable, scrollable, selected,
checked, bounds), so midscene's parseXmlToFormatTree can process it
identically.
Produces <node> XML with the same core attribute names as uiautomator dump
plus Android AccessibilityForwarder cursor attributes on reportable text
fields.

Args:
forest: The accessibility forest protobuf.
Expand Down Expand Up @@ -294,6 +331,7 @@ def process_node(node, is_root):
is_checked=node.get('checked') == 'true',
is_checkable=node.get('checkable') == 'true',
is_clickable=node.get('clickable') == 'true',
is_editable=node.get('editable') == 'true',
is_enabled=node.get('enabled') == 'true',
is_focused=node.get('focused') == 'true',
is_focusable=node.get('focusable') == 'true',
Expand All @@ -303,6 +341,16 @@ def process_node(node, is_root):
package_name=node.get('package'),
resource_id=node.get('resource-id'),
is_visible=True,
text_selection_start=(
int(node['text-selection-start'])
if node.get('text-selection-start')
else None
),
text_selection_end=(
int(node['text-selection-end'])
if node.get('text-selection-end')
else None
),
)
if not is_root:
ui_elements.append(ui_element)
Expand Down
Loading