|
| 1 | +import hydra |
| 2 | +from graphium.cli.train_finetune import run_training_finetuning |
| 3 | +import pytest |
| 4 | + |
| 5 | +FINETUNING_CONFIG_KEY = "finetuning" |
| 6 | + |
| 7 | + |
| 8 | +@pytest.mark.parametrize("acc_type, acc_prec", [("cpu", 32), ("ipu", 16)]) |
| 9 | +@pytest.mark.ipu |
| 10 | +def test_cli(acc_type, acc_prec) -> None: |
| 11 | + """ |
| 12 | + The main CLI endpoint for training and fine-tuning Graphium models. |
| 13 | + """ |
| 14 | + with hydra.initialize(version_base=None, config_path="../expts/hydra-configs"): |
| 15 | + # config is relative to a module |
| 16 | + cfg = hydra.compose( |
| 17 | + config_name="main", |
| 18 | + overrides=[ |
| 19 | + f"accelerator={acc_type}", |
| 20 | + "tasks=toymix", |
| 21 | + "training=toymix", |
| 22 | + # Reducing number of parameters in the toymix architecture |
| 23 | + "architecture=toymix", |
| 24 | + "architecture.pe_encoders.encoders.la_pos.hidden_dim=16", |
| 25 | + "architecture.pe_encoders.encoders.la_pos.num_layers=1", |
| 26 | + "architecture.pe_encoders.encoders.rw_pos.hidden_dim=16", |
| 27 | + "architecture.pe_encoders.encoders.rw_pos.num_layers=1", |
| 28 | + "architecture.pre_nn.hidden_dims=32", |
| 29 | + "architecture.pre_nn.depth=1", |
| 30 | + "architecture.pre_nn.out_dim=16", |
| 31 | + "architecture.gnn.in_dim=16", |
| 32 | + "architecture.gnn.out_dim=16", |
| 33 | + "architecture.gnn.depth=2", |
| 34 | + "architecture.task_heads.qm9.depth=1", |
| 35 | + "architecture.task_heads.tox21.depth=1", |
| 36 | + "architecture.task_heads.zinc.depth=1", |
| 37 | + # Set the number of epochs |
| 38 | + "constants.max_epochs=2", |
| 39 | + "+datamodule.args.task_specific_args.qm9.sample_size=1000", |
| 40 | + "+datamodule.args.task_specific_args.tox21.sample_size=1000", |
| 41 | + "+datamodule.args.task_specific_args.zinc.sample_size=1000", |
| 42 | + "trainer.trainer.check_val_every_n_epoch=1", |
| 43 | + f"trainer.trainer.precision={acc_prec}", # perhaps you can make this 32 for CPU and 16 for IPU |
| 44 | + ], |
| 45 | + ) |
| 46 | + if acc_type == "ipu": |
| 47 | + cfg["accelerator"]["ipu_config"].append("useIpuModel(True)") |
| 48 | + cfg["accelerator"]["ipu_inference_config"].append("useIpuModel(True)") |
| 49 | + |
| 50 | + run_training_finetuning(cfg) |
0 commit comments