-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory.py
More file actions
85 lines (73 loc) · 3.04 KB
/
memory.py
File metadata and controls
85 lines (73 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import List, Dict, Any, Tuple
from .base import BaseMemory
class SimpleMemory(BaseMemory):
"""
Memory manager: responsible for storing & fetching per‑environment history records.
"""
def __init__(self):
self._data = None
self.keys = None
self.batch_size = 0
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
return self._data[idx]
def reset(self, batch_size: int):
if self._data is not None:
self._data.clear()
self._data = [[] for _ in range(batch_size)]
self.batch_size = batch_size
self.keys = None
def store(self, record: Dict[str, List[Any]]):
"""
Store a new record (one step of history) for each environment instance.
Args:
record (Dict[str, List[Any]]):
A dictionary where each key corresponds to a type of data
(e.g., 'text_obs', 'action'), and each value is a list of
length `batch_size`, containing the data for each environment.
"""
if self.keys is None:
self.keys = list(record.keys())
assert self.keys == list(record.keys())
for env_idx in range(self.batch_size):
self._data[env_idx].append({k: record[k][env_idx] for k in self.keys})
def fetch(
self,
history_length: int,
obs_key: str = "text_obs",
action_key: str = "action",
) -> Tuple[List[str], List[int]]:
"""
Fetch and format recent interaction history for each environment instance.
Args:
history_length (int):
Maximum number of past steps to retrieve per environment.
obs_key (str, default="text_obs"):
The key name used to access the observation in stored records.
For example: "text_obs" or "Observation", depending on the environment.
action_key (str, default="action"):
The key name used to access the action in stored records.
For example: "action" or "Action".
Returns:
memory_contexts : List[str]
A list of formatted action history strings for each environment.
valid_lengths : List[int]
A list of the actual number of valid history steps per environment.
"""
memory_contexts, valid_lengths = [], []
for env_idx in range(self.batch_size):
recent = self._data[env_idx][-history_length:]
valid_len = len(recent)
start_idx = len(self._data[env_idx]) - valid_len
lines = []
for j, rec in enumerate(recent):
step_num = start_idx + j + 1
act = rec[action_key]
obs = rec[obs_key]
lines.append(
f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']"
)
memory_contexts.append("\n".join(lines))
valid_lengths.append(valid_len)
return memory_contexts, valid_lengths