-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprojection.py
More file actions
49 lines (39 loc) · 1.67 KB
/
projection.py
File metadata and controls
49 lines (39 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import List
import re
def webshop_projection(actions: List[str]):
"""
A function to process the actions.
actions: the list of actions to be processed, it is a list of strings.
Expected format:
<think>some reasoning...</think><action>up/down/left/right/still</action>
"""
valids = [0] * len(actions)
for i in range(len(actions)):
original_str = actions[i] # keep the original string
actions[i] = actions[i].lower()
# Attempt to extract the substring within <action>...</action>
start_tag = "<action>"
end_tag = "</action>"
start_idx = actions[i].find(start_tag)
end_idx = actions[i].find(end_tag)
try:
if start_idx == -1 or end_idx == -1:
# If we can't find a valid <action>...</action> block, mark as invalid
actions[i] = actions[i][-20:] # 0 is invalid action for Sokoban
continue
# Extract just the content between the tags
extracted_action = actions[i][start_idx + len(start_tag):end_idx].strip().lower()
actions[i] = extracted_action
valids[i] = 1
except:
# randomly choose an action from the action list if illegal
actions[i] = actions[i][-20:]
# check <think>...</think>
think_start_idx = original_str.find("<think>")
think_end_idx = original_str.find("</think>")
if think_start_idx == -1 or think_end_idx == -1:
valids[i] = 0
# check if contains any Chinese characters
if re.search(r'[\u4e00-\u9fff]', original_str):
valids[i] = 0
return actions, valids