Skip to content

Commit a93e7a8

Browse files
author
biluo.shen
committed
Fix eval
1 parent 29e1a24 commit a93e7a8

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

scripts/eval.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Args:
4141
"""the language to use"""
4242
max_options: int = 24
4343
"""the maximum number of options"""
44-
n_history_actions: int = 8
44+
n_history_actions: int = 16
4545
"""the number of history actions to use"""
4646

4747
player: int = -1
@@ -71,7 +71,7 @@ class Args:
7171
"""the number of channels for the agent"""
7272
checkpoint: str = "checkpoints/agent.pt"
7373
"""the checkpoint to load"""
74-
embedding_file: str = "embeddings_en.npy"
74+
embedding_file: Optional[str] = "embeddings_en.npy"
7575
"""the embedding file for card embeddings"""
7676

7777
compile: bool = False
@@ -130,9 +130,13 @@ class Args:
130130
envs = RecordEpisodeStatistics(envs)
131131

132132
if args.agent:
133-
embeddings = np.load(args.embedding_file)
133+
if args.embedding_file:
134+
embeddings = np.load(args.embedding_file)
135+
embedding_shape = embeddings.shape
136+
else:
137+
embedding_shape = None
134138
L = args.num_layers
135-
agent = Agent(args.num_channels, L, L, 1, embeddings.shape).to(device)
139+
agent = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
136140
agent = agent.eval()
137141
state_dict = torch.load(args.checkpoint, map_location=device)
138142

0 commit comments

Comments
 (0)