@@ -41,7 +41,7 @@ class Args:
41
41
"""the language to use"""
42
42
max_options : int = 24
43
43
"""the maximum number of options"""
44
- n_history_actions : int = 8
44
+ n_history_actions : int = 16
45
45
"""the number of history actions to use"""
46
46
47
47
player : int = - 1
@@ -71,7 +71,7 @@ class Args:
71
71
"""the number of channels for the agent"""
72
72
checkpoint : str = "checkpoints/agent.pt"
73
73
"""the checkpoint to load"""
74
- embedding_file : str = "embeddings_en.npy"
74
+ embedding_file : Optional [ str ] = "embeddings_en.npy"
75
75
"""the embedding file for card embeddings"""
76
76
77
77
compile : bool = False
@@ -130,9 +130,13 @@ class Args:
130
130
envs = RecordEpisodeStatistics (envs )
131
131
132
132
if args .agent :
133
- embeddings = np .load (args .embedding_file )
133
+ if args .embedding_file :
134
+ embeddings = np .load (args .embedding_file )
135
+ embedding_shape = embeddings .shape
136
+ else :
137
+ embedding_shape = None
134
138
L = args .num_layers
135
- agent = Agent (args .num_channels , L , L , 1 , embeddings . shape ).to (device )
139
+ agent = Agent (args .num_channels , L , L , 1 , embedding_shape ).to (device )
136
140
agent = agent .eval ()
137
141
state_dict = torch .load (args .checkpoint , map_location = device )
138
142
0 commit comments