-
Notifications
You must be signed in to change notification settings - Fork 630
Open
Description
I recently want to replicate the AlphaTensor based on AlphaZero. Tell a truth, the repo and its materials https://www.kdnuggets.com/2023/03/first-open-source-implementation-deepmind-alphatensor.html help me a lot. But when I want to train a model with the mutiplication of 2*2 matrix, the program call me a dismatching Error.
the code snippet below represents my config,
cardinality_vector = 5 # The actions can have values in range [-2, 2]
N_bar = 100 # parameter for smoothing the temperature while adjusting the probability distribution
matrix_size = 2
input_size = matrix_size**2
n_steps = 3
n_actions = cardinality_vector ** (3 * input_size // n_steps)
action_memory = 5
train_alpha_tensor(
tensor_length=action_memory + 1,
input_size=input_size,
scalars_size=1,
emb_dim=512,
n_steps=n_steps,
n_logits=n_actions,
n_samples=32,
device="cuda",
len_data=512,
n_synth_data=10000,
pct_synth=0.9,
batch_size=16,
epochs=6000,
lr=1e-4,
lr_decay_factor=0.1,
lr_decay_steps=50, ## change
weight_decay=1e-5,
optimizer_name="adamw",
loss_params=(1, 1),
limit_rank=8,
checkpoint_dir="github/nebuly/optimization/open_alpha_tensor/result/Checkpoint",
checkpoint_data_dir="github/nebuly/optimization/open_alpha_tensor/result/Data",
n_actors=1,
mc_n_sim=200,
n_cob=1000,
cob_prob=0.9983,
data_augmentation=True,
N_bar=N_bar,
random_seed=42,
extra_devices=None,
save_dir="github/nebuly/optimization/open_alpha_tensor/result/model",
)
I don't know what else I should notice but the readme doesn't give. And I think it shouldn't call this error.
Metadata
Metadata
Assignees
Labels
No labels
