Skip to content

Dimension dismatch on AlphaTensor #356

@ywsslr

Description

@ywsslr

I recently want to replicate the AlphaTensor based on AlphaZero. Tell a truth, the repo and its materials https://www.kdnuggets.com/2023/03/first-open-source-implementation-deepmind-alphatensor.html help me a lot. But when I want to train a model with the mutiplication of 2*2 matrix, the program call me a dismatching Error.

the code snippet below represents my config,

    cardinality_vector = 5  # The actions can have values in range [-2, 2]
    N_bar = 100  # parameter for smoothing the temperature while adjusting the probability distribution
    matrix_size = 2
    input_size = matrix_size**2
    n_steps = 3
    n_actions = cardinality_vector ** (3 * input_size // n_steps)
    action_memory = 5

    train_alpha_tensor(
        tensor_length=action_memory + 1,
        input_size=input_size,
        scalars_size=1,
        emb_dim=512,
        n_steps=n_steps,
        n_logits=n_actions,
        n_samples=32,
        device="cuda",
        len_data=512,
        n_synth_data=10000,
        pct_synth=0.9,
        batch_size=16,
        epochs=6000,
        lr=1e-4,
        lr_decay_factor=0.1,
        lr_decay_steps=50,  ## change
        weight_decay=1e-5,
        optimizer_name="adamw",
        loss_params=(1, 1),
        limit_rank=8,
        checkpoint_dir="github/nebuly/optimization/open_alpha_tensor/result/Checkpoint",
        checkpoint_data_dir="github/nebuly/optimization/open_alpha_tensor/result/Data",
        n_actors=1,
        mc_n_sim=200,
        n_cob=1000,
        cob_prob=0.9983,
        data_augmentation=True,
        N_bar=N_bar,
        random_seed=42,
        extra_devices=None,
        save_dir="github/nebuly/optimization/open_alpha_tensor/result/model",
    )

And the error is :
image

I don't know what else I should notice but the readme doesn't give. And I think it shouldn't call this error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions