diff --git a/android_world/agents/midscene.py b/android_world/agents/midscene.py index 8d0acf37..13a25eeb 100644 --- a/android_world/agents/midscene.py +++ b/android_world/agents/midscene.py @@ -3,6 +3,7 @@ from android_world.agents import base_agent +from android_world.env import adb_utils from android_world.env import interface from android_world.env import representation_utils @@ -150,26 +151,73 @@ def _send_rpc_request(self, method: str, params: dict) -> dict: return result def _start_dom_server(self): - """Starts a background HTTP server that serves the current a11y tree as raw XML.""" + """Starts a background HTTP server that serves page XML.""" agent_ref = self class DomHandler(BaseHTTPRequestHandler): + def _write_response(self, status: int, content_type: str, body: str): + self.send_response(status) + self.send_header('Content-Type', content_type) + self.end_headers() + self.wfile.write(body.encode('utf-8')) + + def _page_xml_from_adb_dump(self) -> str: + try: + page_xml = adb_utils.uiautomator_dump(agent_ref.env.controller) + if page_xml: + return page_xml + except Exception as e: + agent_ref._formatted_console( + 'DOM provider adb dump fallback failed: ' + str(e) + ) + return '' + + def _page_xml_from_state(self, state) -> str: + if state.forest is None: + agent_ref._formatted_console( + 'DOM provider got empty AccessibilityForwarder forest; falling back to adb dump' + ) + return self._page_xml_from_adb_dump() + + page_xml = representation_utils.forest_to_raw_xml(state.forest) + if page_xml: + return page_xml + + agent_ref._formatted_console( + 'DOM provider got empty AccessibilityForwarder XML; falling back to adb dump' + ) + return self._page_xml_from_adb_dump() + + def _get_state_with_retry(self): + state = agent_ref.env.get_state() + if state.forest is not None: + return state + + time.sleep(0.2) + retry_state = agent_ref.env.get_state() + if retry_state.forest is None: + agent_ref._formatted_console( + 'DOM provider retry still returned empty AccessibilityForwarder forest' + ) + return retry_state + + def _get_page_xml(self) -> str: + return self._page_xml_from_state(self._get_state_with_retry()) + def do_GET(self): try: - state = agent_ref.env.get_state() - if state.forest is not None: - raw_xml = representation_utils.forest_to_raw_xml(state.forest) + if self.path == '/context': + # Page structure and cursor metadata come from the same + # AccessibilityForwarder snapshot. Cursor state is embedded on the + # focused editable node as XML attributes. + state = self._get_state_with_retry() + page_xml = self._page_xml_from_state(state) + self._write_response(200, 'text/xml; charset=utf-8', page_xml) else: - raw_xml = '' - self.send_response(200) - self.send_header('Content-Type', 'text/xml; charset=utf-8') - self.end_headers() - self.wfile.write(raw_xml.encode('utf-8')) + page_xml = self._get_page_xml() + self._write_response(200, 'text/xml; charset=utf-8', page_xml) except Exception as e: - self.send_response(500) - self.send_header('Content-Type', 'text/plain') - self.end_headers() - self.wfile.write(str(e).encode('utf-8')) + self._write_response(500, 'text/plain', str(e)) def log_message(self, format, *args): pass # Suppress default access logs @@ -177,7 +225,7 @@ def log_message(self, format, *args): server = HTTPServer(('127.0.0.1', 0), DomHandler) port = server.server_address[1] self._dom_server = server - self._dom_server_url = f'http://127.0.0.1:{port}/dom' + self._dom_server_url = f'http://127.0.0.1:{port}/context' self._formatted_console(f"DOM provider server started at {self._dom_server_url}") thread = threading.Thread(target=server.serve_forever, daemon=True) diff --git a/android_world/env/android_world_controller.py b/android_world/env/android_world_controller.py index f044253e..135169d8 100644 --- a/android_world/env/android_world_controller.py +++ b/android_world/env/android_world_controller.py @@ -42,6 +42,14 @@ # Throttle check_airplane_mode: at most once per interval to reduce ADB load _last_airplane_check: dict[int, float] = {} _AIRPLANE_CHECK_INTERVAL = 30.0 +_A11Y_FORWARDER_SERVICE = ( + 'com.google.androidenv.accessibilityforwarder/' + 'com.google.androidenv.accessibilityforwarder.AccessibilityForwarder' +) +_A11Y_FORWARDER_FLAGS_RECEIVER = ( + 'com.google.androidenv.accessibilityforwarder/' + 'com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' +) def _has_wrapper( @@ -109,7 +117,7 @@ def get_a11y_tree( try: forest = env.accumulate_new_extras()['accessibility_tree'][-1] # pytype:disable=attribute-error return forest - except KeyError: + except (KeyError, IndexError): logging.warning('Could not get a11y tree, retrying.') time.sleep(sleep_duration) @@ -454,18 +462,34 @@ def _restart_a11y_forwarder(self) -> bool: server_port = str(self._adb_server_port) device_args = ['-s', self._device_name] if self._device_name else [] - # Step 1: Re-enable the accessibility service - cmd = [adb_path, '-P', server_port] + device_args + [ - 'shell', 'settings', 'put', 'secure', - 'enabled_accessibility_services', - 'com.google.androidenv.accessibilityforwarder/' - 'com.google.androidenv.accessibilityforwarder.AccessibilityForwarder', - ] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) - if result.returncode != 0: - logging.warning('Failed to re-enable a11y service: %s', result.stderr) + if self._is_remote and not self.ensure_adb_connection(): return False + def run_shell_command(args: list[str]) -> subprocess.CompletedProcess: + cmd = [adb_path, '-P', server_port] + device_args + ['shell'] + args + return subprocess.run(cmd, capture_output=True, text=True, timeout=30) + + # Step 1: Re-enable Android accessibility and the forwarder service. + # Some failures leave the service listed but global accessibility off. + for shell_args in ( + ['settings', 'put', 'secure', 'accessibility_enabled', '1'], + [ + 'settings', + 'put', + 'secure', + 'enabled_accessibility_services', + _A11Y_FORWARDER_SERVICE, + ], + ): + result = run_shell_command(shell_args) + if result.returncode != 0: + logging.warning( + 'Failed to update a11y setting %s: %s', + ' '.join(shell_args), + result.stderr, + ) + return False + logging.info('Re-enabled AccessibilityForwarder service') time.sleep(2.0) # Give the service time to start @@ -480,8 +504,7 @@ def _restart_a11y_forwarder(self) -> bool: 'shell', 'am', 'broadcast', '-a', 'accessibility_forwarder.intent.action.SET_GRPC', '--ei', 'port', str(self._a11y_port), - '-n', 'com.google.androidenv.accessibilityforwarder/' - 'com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver', + '-n', _A11Y_FORWARDER_FLAGS_RECEIVER, ] subprocess.run(cmd, capture_output=True, text=True, timeout=30) @@ -501,6 +524,8 @@ def get_a11y_forest( 2. Restart AccessibilityForwarder service (handles uiautomator disruption) 3. Full environment refresh (handles ADB disconnection / deep failures) """ + self.ensure_adb_connection() + try: return self._get_a11y_forest() except RuntimeError: @@ -535,25 +560,51 @@ def get_ui_elements(self) -> list[representation_utils.UIElement]: self.ensure_adb_connection() if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP: - return representation_utils.forest_to_ui_elements( - self.get_a11y_forest(), - exclude_invisible_elements=True, - ) + try: + return representation_utils.forest_to_ui_elements( + self.get_a11y_forest(), + exclude_invisible_elements=True, + ) + except RuntimeError as e: + logging.warning( + 'A11y tree unavailable after recovery; falling back to ' + 'uiautomator UI elements: %s', + e, + ) + return self._get_uiautomator_ui_elements() elif self._a11y_method == A11yMethod.UIAUTOMATOR: + return self._get_uiautomator_ui_elements() + else: + return [] + + def _get_uiautomator_ui_elements(self) -> list[representation_utils.UIElement]: + """Returns UI elements from uiautomator, or an empty list if it fails.""" + try: return representation_utils.xml_dump_to_ui_elements( adb_utils.uiautomator_dump(self._env) ) - else: + except Exception as e: + logging.warning('Failed to get UI elements via uiautomator: %s', e) return [] def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: """Adds a11y tree info to the observation.""" if self._a11y_method == A11yMethod.A11Y_FORWARDER_APP: - forest = self.get_a11y_forest() - ui_elements = representation_utils.forest_to_ui_elements( - forest, - exclude_invisible_elements=True, - ) + try: + forest = self.get_a11y_forest() + except RuntimeError as e: + logging.warning( + 'A11y tree unavailable after recovery; falling back to ' + 'uiautomator UI elements: %s', + e, + ) + forest = None + ui_elements = self._get_uiautomator_ui_elements() + else: + ui_elements = representation_utils.forest_to_ui_elements( + forest, + exclude_invisible_elements=True, + ) else: forest = None ui_elements = self.get_ui_elements() diff --git a/android_world/env/android_world_controller_test.py b/android_world/env/android_world_controller_test.py index e065dc71..022dae35 100644 --- a/android_world/env/android_world_controller_test.py +++ b/android_world/env/android_world_controller_test.py @@ -130,6 +130,36 @@ def test_process_timestep( exclude_invisible_elements=True, ) + @mock.patch.object(adb_utils, 'check_airplane_mode') + @mock.patch.object(android_world_controller, 'get_controller') + @mock.patch.object(android_world_controller, '_has_wrapper') + @mock.patch.object(representation_utils, 'forest_to_ui_elements') + def test_process_timestep_continues_without_a11y_tree( + self, + mock_forest_to_ui, + mock_has_wrapper, + mock_get_controller, + mock_check_airplane_mode, + ): + del mock_has_wrapper, mock_get_controller, mock_check_airplane_mode + mock_base_env = mock.Mock(spec=env_interface.AndroidEnvInterface) + env = android_world_controller.AndroidWorldController(mock_base_env) + timestep = dm_env.TimeStep( + observation={}, reward=None, discount=None, step_type=None + ) + + with mock.patch.object( + env, 'get_a11y_forest', side_effect=RuntimeError('a11y unavailable') + ), mock.patch.object( + env, '_get_uiautomator_ui_elements', return_value=['fallback'] + ) as mock_uiautomator_fallback: + processed_timestep = env._process_timestep(timestep) + + self.assertIsNone(processed_timestep.observation['forest']) + self.assertEqual(processed_timestep.observation['ui_elements'], ['fallback']) + mock_uiautomator_fallback.assert_called_once() + mock_forest_to_ui.assert_not_called() + @mock.patch.object(adb_utils, 'check_airplane_mode') @mock.patch.object(android_world_controller, 'get_controller') @mock.patch.object(android_world_controller, '_has_wrapper') diff --git a/android_world/env/representation_utils.py b/android_world/env/representation_utils.py index f425e9d5..bececc5e 100644 --- a/android_world/env/representation_utils.py +++ b/android_world/env/representation_utils.py @@ -70,6 +70,8 @@ class UIElement: is_scrollable: Optional[bool] = None is_selected: Optional[bool] = None is_visible: Optional[bool] = None + text_selection_start: Optional[int] = None + text_selection_end: Optional[int] = None package_name: Optional[str] = None resource_name: Optional[str] = None tooltip: Optional[str] = None @@ -115,6 +117,8 @@ def text_or_none(text: Optional[str]) -> Optional[str]: is_scrollable=node.is_scrollable, is_selected=node.is_selected, is_visible=node.is_visible_to_user, + text_selection_start=node.text_selection_start, + text_selection_end=node.text_selection_end, package_name=text_or_none(node.package_name), resource_name=text_or_none(node.view_id_resource_name), ) @@ -191,6 +195,35 @@ def _bounds_str(node: Any) -> str: return f'[{b.left},{b.top}][{b.right},{b.bottom}]' +def _selection_attrs(node: Any) -> str: + start = getattr(node, 'text_selection_start', 0) + end = getattr(node, 'text_selection_end', 0) + + if start < 0 or end < 0: + return '' + + # For proto3 fields, 0 is also the default for nodes without a selection. + # Emit a zero offset only when an editable node is focused; otherwise require + # Android to report a non-default range. + if ( + start == 0 + and end == 0 + and not ( + getattr(node, 'is_focused', False) + and getattr(node, 'is_editable', False) + ) + ): + return '' + + attrs = ( + f' text-selection-start="{start}"' + f' text-selection-end="{end}"' + ) + if start == end and start >= 0: + attrs += f' cursor-position="{start}"' + return attrs + + def _raw_xml_node(node: Any, children_by_id: dict[int, list[Any]], indent: int) -> str: """Format a protobuf node as uiautomator-dump-compatible XML.""" @@ -204,7 +237,12 @@ def _raw_xml_node(node: Any, children_by_id: dict[int, list[Any]], f'scrollable="{_bool_str(node.is_scrollable)}" ' f'selected="{_bool_str(node.is_selected)}" ' f'checked="{_bool_str(node.is_checked)}" ' + f'enabled="{_bool_str(node.is_enabled)}" ' + f'focusable="{_bool_str(node.is_focusable)}" ' + f'focused="{_bool_str(node.is_focused)}" ' + f'editable="{_bool_str(node.is_editable)}" ' f'bounds="{_bounds_str(node)}"' + f'{_selection_attrs(node)}' ) kids = children_by_id.get(node.unique_id, []) @@ -222,10 +260,9 @@ def forest_to_raw_xml( ) -> str: """Convert accessibility forest to uiautomator-dump-compatible raw XML. - Produces XML with the same attribute names as uiautomator dump - (text, resource-id, class, content-desc, clickable, scrollable, selected, - checked, bounds), so midscene's parseXmlToFormatTree can process it - identically. + Produces XML with the same core attribute names as uiautomator dump + plus Android AccessibilityForwarder cursor attributes on reportable text + fields. Args: forest: The accessibility forest protobuf. @@ -294,6 +331,7 @@ def process_node(node, is_root): is_checked=node.get('checked') == 'true', is_checkable=node.get('checkable') == 'true', is_clickable=node.get('clickable') == 'true', + is_editable=node.get('editable') == 'true', is_enabled=node.get('enabled') == 'true', is_focused=node.get('focused') == 'true', is_focusable=node.get('focusable') == 'true', @@ -303,6 +341,16 @@ def process_node(node, is_root): package_name=node.get('package'), resource_id=node.get('resource-id'), is_visible=True, + text_selection_start=( + int(node['text-selection-start']) + if node.get('text-selection-start') + else None + ), + text_selection_end=( + int(node['text-selection-end']) + if node.get('text-selection-end') + else None + ), ) if not is_root: ui_elements.append(ui_element) diff --git a/android_world/env/representation_utils_test.py b/android_world/env/representation_utils_test.py index d8cf5465..afe296a0 100644 --- a/android_world/env/representation_utils_test.py +++ b/android_world/env/representation_utils_test.py @@ -28,6 +28,39 @@ class BoundsInScreen: bottom: int +@dataclasses.dataclass +class FakeNode: + text: str = '' + view_id_resource_name: str = '' + class_name: str = '' + content_description: str = '' + bounds_in_screen: BoundsInScreen = dataclasses.field( + default_factory=lambda: BoundsInScreen(1, 3, 2, 4) + ) + unique_id: int = 1 + child_ids: list[int] = dataclasses.field(default_factory=list) + is_clickable: bool = False + is_scrollable: bool = False + is_selected: bool = False + is_checked: bool = False + is_enabled: bool = True + is_focusable: bool = False + is_focused: bool = False + is_editable: bool = False + text_selection_start: int = 0 + text_selection_end: int = 0 + + +def _fake_forest(nodes: list[FakeNode]): + tree = mock.MagicMock() + tree.nodes = nodes + window = mock.MagicMock() + window.tree = tree + forest = mock.MagicMock() + forest.windows = [window] + return forest + + class TestAccessibilityNodeToUIElement(parameterized.TestCase): @parameterized.named_parameters( @@ -115,5 +148,65 @@ def test_normalize_bboxes( self.assertEqual(ui_element.bbox, expected_normalized_bbox) +class TestForestToRawXml(absltest.TestCase): + + def test_focused_editable_zero_cursor_is_embedded_on_node(self): + forest = _fake_forest([ + FakeNode( + text='hello & bye', + view_id_resource_name='pkg:id/editor', + class_name='android.widget.EditText', + is_focusable=True, + is_focused=True, + is_editable=True, + text_selection_start=0, + text_selection_end=0, + ) + ]) + + self.assertEqual( + representation_utils.forest_to_raw_xml(forest), + '\n' + ' \n' + '', + ) + + def test_unfocused_default_selection_is_not_embedded(self): + forest = _fake_forest([ + FakeNode( + text='hello', + class_name='android.widget.TextView', + text_selection_start=0, + text_selection_end=0, + ) + ]) + + xml = representation_utils.forest_to_raw_xml(forest) + + self.assertIn('bounds="[1,2][3,4]"', xml) + self.assertNotIn('text-selection-start', xml) + self.assertNotIn('cursor-position', xml) + + def test_negative_selection_is_not_embedded(self): + forest = _fake_forest([ + FakeNode( + text='hello', + class_name='android.widget.TextView', + text_selection_start=-1, + text_selection_end=-1, + ) + ]) + + xml = representation_utils.forest_to_raw_xml(forest) + + self.assertIn('bounds="[1,2][3,4]"', xml) + self.assertNotIn('text-selection-start', xml) + self.assertNotIn('cursor-position', xml) + + if __name__ == '__main__': absltest.main()