Skip to content

Commit 98074f6

Browse files
authored
Merge pull request #36 from simular-ai/no_torch
fix box_iou, remove unnecessary libraries: torch,torchvision,transformers
2 parents 9098898 + 606a421 commit 98074f6

File tree

7 files changed

+257
-127
lines changed

7 files changed

+257
-127
lines changed

gui_agents/aci/LinuxOSACI.py

Lines changed: 215 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import time
55
import xml.etree.ElementTree as ET
66
from typing import Dict, List, Optional, Tuple, Any, Sequence
7-
7+
import numpy as np
88
import requests
9-
import torch
10-
import torchvision
119

1210
from gui_agents.aci.ACI import ACI
11+
from gui_agents.utils.common_utils import box_iou
1312

1413
import platform
1514

@@ -28,41 +27,21 @@
2827
import lxml.etree
2928
import concurrent.futures
3029

31-
logger = logging.getLogger("desktopenv.agent")
32-
33-
install_tmux_cmd = """import subprocess
34-
35-
install_command = f"echo 'password' | sudo -S apt install -y tmux"
36-
subprocess.run(install_command, shell=True, check=True)
37-
"""
30+
_accessibility_ns_map_ubuntu = {
31+
"st": "https://accessibility.ubuntu.example.org/ns/state",
32+
"attr": "https://accessibility.ubuntu.example.org/ns/attributes",
33+
"cp": "https://accessibility.ubuntu.example.org/ns/component",
34+
"doc": "https://accessibility.ubuntu.example.org/ns/document",
35+
"docattr": "https://accessibility.ubuntu.example.org/ns/document/attributes",
36+
"txt": "https://accessibility.ubuntu.example.org/ns/text",
37+
"val": "https://accessibility.ubuntu.example.org/ns/value",
38+
"act": "https://accessibility.ubuntu.example.org/ns/action",
39+
}
3840

39-
create_tmux_session_cmd = """import subprocess
41+
MAX_DEPTH = 50
42+
MAX_WIDTH = 1024
4043

41-
cmd = "tmux new-session -d -s my_background_session"
42-
subprocess.run(cmd, shell=True, check=True)
43-
"""
44-
45-
kill_tmux_session_cmd = """import subprocess
46-
47-
cmd = "tmux kill-session -t my_background_session"
48-
subprocess.run(cmd, shell=True, check=True)
49-
"""
50-
51-
run_tmux_cmd = """import subprocess, time
52-
53-
try:
54-
tmux_cmd = "tmux send-keys -t my_background_session '{cmd}' C-m"
55-
process = subprocess.Popen(
56-
tmux_cmd,
57-
shell=True,
58-
text=True
59-
)
60-
time.sleep(0.5)
61-
capture_cmd = "tmux capture-pane -t my_background_session -p"
62-
output = subprocess.check_output(capture_cmd, shell=True, text=True, timeout={timeout})
63-
except subprocess.TimeoutExpired:
64-
output = f"Command '{cmd}' timed out after {timeout} seconds"
65-
"""
44+
logger = logging.getLogger("desktopenv.agent")
6645

6746

6847
# Agent action decorator
@@ -291,13 +270,10 @@ def add_ocr_elements(
291270
int(box.get("right", 0)),
292271
int(box.get("bottom", 0)),
293272
)
294-
iou = (
295-
torchvision.ops.box_iou(
296-
torch.tensor(tree_bboxes), torch.tensor([[x1, y1, x2, y2]])
297-
)
298-
.numpy()
299-
.flatten()
300-
)
273+
iou = box_iou(
274+
np.array(tree_bboxes, dtype=np.float32),
275+
np.array([[x1, y1, x2, y2]], dtype=np.float32),
276+
).flatten()
301277

302278
if max(iou) < 0.1:
303279
# Add the element to the linearized accessibility tree
@@ -674,6 +650,184 @@ def fail(self):
674650
return """FAIL"""
675651

676652

653+
def _create_atspi_node(
654+
node: Accessible, depth: int = 0, flag: Optional[str] = None
655+
) -> _Element:
656+
node_name = node.name
657+
attribute_dict: Dict[str, Any] = {"name": node_name}
658+
659+
# States
660+
states: List[StateType] = node.getState().get_states()
661+
for st in states:
662+
state_name: str = StateType._enum_lookup[st]
663+
state_name: str = state_name.split("_", maxsplit=1)[1].lower()
664+
if len(state_name) == 0:
665+
continue
666+
attribute_dict[
667+
"{{{:}}}{:}".format(_accessibility_ns_map_ubuntu["st"], state_name)
668+
] = "true"
669+
670+
# Attributes
671+
attributes: Dict[str, str] = node.get_attributes()
672+
for attribute_name, attribute_value in attributes.items():
673+
if len(attribute_name) == 0:
674+
continue
675+
attribute_dict[
676+
"{{{:}}}{:}".format(_accessibility_ns_map_ubuntu["attr"], attribute_name)
677+
] = attribute_value
678+
679+
# Component
680+
if (
681+
attribute_dict.get(
682+
"{{{:}}}visible".format(_accessibility_ns_map_ubuntu["st"]), "false"
683+
)
684+
== "true"
685+
and attribute_dict.get(
686+
"{{{:}}}showing".format(_accessibility_ns_map_ubuntu["st"]), "false"
687+
)
688+
== "true"
689+
):
690+
try:
691+
component: Component = node.queryComponent()
692+
except NotImplementedError:
693+
pass
694+
else:
695+
bbox: Sequence[int] = component.getExtents(pyatspi.XY_SCREEN)
696+
attribute_dict[
697+
"{{{:}}}screencoord".format(_accessibility_ns_map_ubuntu["cp"])
698+
] = str(tuple(bbox[0:2]))
699+
attribute_dict["{{{:}}}size".format(_accessibility_ns_map_ubuntu["cp"])] = (
700+
str(tuple(bbox[2:]))
701+
)
702+
703+
text = ""
704+
# Text
705+
try:
706+
text_obj: ATText = node.queryText()
707+
# only text shown on current screen is available
708+
# attribute_dict["txt:text"] = text_obj.getText(0, text_obj.characterCount)
709+
text: str = text_obj.getText(0, text_obj.characterCount)
710+
# if flag=="thunderbird":
711+
# appeared in thunderbird (uFFFC) (not only in thunderbird), "Object
712+
# Replacement Character" in Unicode, "used as placeholder in text for
713+
# an otherwise unspecified object; uFFFD is another "Replacement
714+
# Character", just in case
715+
text = text.replace("\ufffc", "").replace("\ufffd", "")
716+
except NotImplementedError:
717+
pass
718+
719+
# Image, Selection, Value, Action
720+
try:
721+
node.queryImage()
722+
attribute_dict["image"] = "true"
723+
except NotImplementedError:
724+
pass
725+
726+
try:
727+
node.querySelection()
728+
attribute_dict["selection"] = "true"
729+
except NotImplementedError:
730+
pass
731+
732+
try:
733+
value: ATValue = node.queryValue()
734+
value_key = f"{{{_accessibility_ns_map_ubuntu['val']}}}"
735+
736+
for attr_name, attr_func in [
737+
("value", lambda: value.currentValue),
738+
("min", lambda: value.minimumValue),
739+
("max", lambda: value.maximumValue),
740+
("step", lambda: value.minimumIncrement),
741+
]:
742+
try:
743+
attribute_dict[f"{value_key}{attr_name}"] = str(attr_func())
744+
except:
745+
pass
746+
except NotImplementedError:
747+
pass
748+
749+
try:
750+
action: ATAction = node.queryAction()
751+
for i in range(action.nActions):
752+
action_name: str = action.getName(i).replace(" ", "-")
753+
attribute_dict[
754+
"{{{:}}}{:}_desc".format(
755+
_accessibility_ns_map_ubuntu["act"], action_name
756+
)
757+
] = action.getDescription(i)
758+
attribute_dict[
759+
"{{{:}}}{:}_kb".format(_accessibility_ns_map_ubuntu["act"], action_name)
760+
] = action.getKeyBinding(i)
761+
except NotImplementedError:
762+
pass
763+
764+
# Add from here if we need more attributes in the future...
765+
766+
raw_role_name: str = node.getRoleName().strip()
767+
node_role_name = (raw_role_name or "unknown").replace(" ", "-")
768+
769+
if not flag:
770+
if raw_role_name == "document spreadsheet":
771+
flag = "calc"
772+
if raw_role_name == "application" and node.name == "Thunderbird":
773+
flag = "thunderbird"
774+
775+
xml_node = lxml.etree.Element(
776+
node_role_name, attrib=attribute_dict, nsmap=_accessibility_ns_map_ubuntu
777+
)
778+
779+
if len(text) > 0:
780+
xml_node.text = text
781+
782+
if depth == MAX_DEPTH:
783+
logger.warning("Max depth reached")
784+
return xml_node
785+
786+
if flag == "calc" and node_role_name == "table":
787+
# Maximum column: 1024 if ver<=7.3 else 16384
788+
# Maximum row: 104 8576
789+
# Maximun sheet: 1 0000
790+
791+
global libreoffice_version_tuple
792+
MAXIMUN_COLUMN = 1024 if libreoffice_version_tuple < (7, 4) else 16384
793+
MAX_ROW = 104_8576
794+
795+
index_base = 0
796+
first_showing = False
797+
column_base = None
798+
for r in range(MAX_ROW):
799+
for clm in range(column_base or 0, MAXIMUN_COLUMN):
800+
child_node: Accessible = node[index_base + clm]
801+
showing: bool = child_node.getState().contains(STATE_SHOWING)
802+
if showing:
803+
child_node: _Element = _create_atspi_node(
804+
child_node, depth + 1, flag
805+
)
806+
if not first_showing:
807+
column_base = clm
808+
first_showing = True
809+
xml_node.append(child_node)
810+
elif first_showing and column_base is not None or clm >= 500:
811+
break
812+
if first_showing and clm == column_base or not first_showing and r >= 500:
813+
break
814+
index_base += MAXIMUN_COLUMN
815+
return xml_node
816+
else:
817+
try:
818+
for i, ch in enumerate(node):
819+
if i == MAX_WIDTH:
820+
logger.warning("Max width reached")
821+
break
822+
xml_node.append(_create_atspi_node(ch, depth + 1, flag))
823+
except:
824+
logger.warning(
825+
"Error occurred during children traversing. Has Ignored. Node: %s",
826+
lxml.etree.tostring(xml_node, encoding="unicode"),
827+
)
828+
return xml_node
829+
830+
677831
class UIElement(object):
678832
def __init__(self, node):
679833
self.node = node
@@ -683,12 +837,24 @@ def getAttributeNames(self):
683837

684838
@staticmethod
685839
def systemWideElement():
686-
desktop = pyatspi.Registry.getDesktop(0)
687-
for app in desktop:
688-
for window in app:
689-
if window.getState().contains(pyatspi.STATE_ACTIVE):
690-
active_node = app
691-
return UIElement(active_node)
840+
# desktop = pyatspi.Registry.getDesktop(0)
841+
# for app in desktop:
842+
# for window in app:
843+
# if window.getState().contains(pyatspi.STATE_ACTIVE):
844+
# active_node = app
845+
# return UIElement(active_node)
846+
desktop: Accessible = pyatspi.Registry.getDesktop(0)
847+
xml_node = lxml.etree.Element(
848+
"desktop-frame", nsmap=_accessibility_ns_map_ubuntu
849+
)
850+
with concurrent.futures.ThreadPoolExecutor() as executor:
851+
futures = [
852+
executor.submit(_create_atspi_node, app_node, 1) for app_node in desktop
853+
]
854+
for future in concurrent.futures.as_completed(futures):
855+
xml_tree = future.result()
856+
xml_node.append(xml_tree)
857+
return lxml.etree.tostring(xml_node, encoding="unicode")
692858

693859
@property
694860
def states(self):

gui_agents/aci/MacOSACI.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import requests
77
import platform
8+
from gui_agents.utils.common_utils import box_iou
89

910
if platform.system() == "Darwin":
1011
from AppKit import *
@@ -33,35 +34,6 @@ def list_apps_in_directories(directories):
3334
return apps
3435

3536

36-
def box_iou(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
37-
"""
38-
Fast vectorized IOU implementation using only NumPy
39-
boxes1: [N, 4] array of boxes
40-
boxes2: [M, 4] array of boxes
41-
Returns: [N, M] array of IOU values
42-
"""
43-
# Calculate areas of boxes1
44-
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
45-
46-
# Calculate areas of boxes2
47-
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
48-
49-
# Get intersections using broadcasting
50-
lt = np.maximum(boxes1[:, None, :2], boxes2[None, :, :2]) # [N,M,2]
51-
rb = np.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:]) # [N,M,2]
52-
53-
# Calculate intersection areas
54-
wh = np.clip(rb - lt, 0, None) # [N,M,2]
55-
intersection = wh[:, :, 0] * wh[:, :, 1] # [N,M]
56-
57-
# Calculate union areas
58-
union = area1[:, None] + area2[None, :] - intersection
59-
60-
# Calculate IOU
61-
iou = np.where(union > 0, intersection / union, 0)
62-
return iou
63-
64-
6537
class MacOSACI(ACI):
6638
def __init__(self, top_app_only: bool = True, ocr: bool = False):
6739
super().__init__(top_app_only=top_app_only, ocr=ocr)

gui_agents/aci/WindowsOSACI.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import psutil
88
import requests
9+
from gui_agents.utils.common_utils import box_iou
910

1011
if platform.system() == "Windows":
1112
import pywinauto
@@ -37,35 +38,6 @@ def list_apps_in_directories():
3738
return apps
3839

3940

40-
def box_iou(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
41-
"""
42-
Fast vectorized IOU implementation using only NumPy
43-
boxes1: [N, 4] array of boxes
44-
boxes2: [M, 4] array of boxes
45-
Returns: [N, M] array of IOU values
46-
"""
47-
# Calculate areas of boxes1
48-
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
49-
50-
# Calculate areas of boxes2
51-
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
52-
53-
# Get intersections using broadcasting
54-
lt = np.maximum(boxes1[:, None, :2], boxes2[None, :, :2]) # [N,M,2]
55-
rb = np.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:]) # [N,M,2]
56-
57-
# Calculate intersection areas
58-
wh = np.clip(rb - lt, 0, None) # [N,M,2]
59-
intersection = wh[:, :, 0] * wh[:, :, 1] # [N,M]
60-
61-
# Calculate union areas
62-
union = area1[:, None] + area2[None, :] - intersection
63-
64-
# Calculate IOU
65-
iou = np.where(union > 0, intersection / union, 0)
66-
return iou
67-
68-
6941
# WindowsACI Class
7042
class WindowsACI(ACI):
7143
def __init__(self, top_app_only: bool = True, ocr: bool = False):

0 commit comments

Comments
 (0)