@@ -41,13 +41,7 @@ PyTorch Geometric
4141 from torch_geometric.loader import DataLoader
4242
4343 # Initialize model
44- model = PyGLightningCrystalGraphClassifier(
45- node_features = 12 ,
46- edge_features = 6 ,
47- global_features = 0 ,
48- output_features = 1 ,
49- learning_rate = 0.001 ,
50- )
44+ model = PyGLightningCrystalGraphClassifier()
5145
5246 # Train
5347 trainer = pyl.Trainer(max_epochs = 50 )
@@ -56,29 +50,38 @@ PyTorch Geometric
5650 # Test
5751 trainer.test(model, test_loader)
5852
59- # Predict
60- predictions = trainer.predict(model, pred_loader)
61-
6253 Spektral
6354~~~~~~~~
6455
6556.. code-block :: python
6657
6758 from unravel.classifiers import CrystalGraphClassifier
6859
69- # Initialize model
70- model = CrystalGraphClassifier(
71- node_features = 12 ,
72- edge_features = 6 ,
73- output_features = 1 ,
74- )
60+ from tensorflow.keras.metrics import AUC , BinaryAccuracy
61+ from tensorflow.keras.losses import BinaryCrossentropy
62+ from tensorflow.keras.optimizers import Adam
63+ from tensorflow.keras.callbacks import EarlyStopping
64+
65+ model = CrystalGraphClassifier( )
7566
76- # Compile
7767 model.compile(
78- optimizer = ' adam' ,
79- loss = ' binary_crossentropy' ,
80- metrics = [' accuracy' ]
68+ loss = BinaryCrossentropy(), optimizer = Adam(), metrics = [AUC(), BinaryAccuracy()]
8169 )
8270
83- # Train
84- model.fit(x = train_data, y = train_labels, epochs = 50 , validation_data = (val_data, val_labels))
71+ model.fit(
72+ loader_tr.load(),
73+ steps_per_epoch = loader_tr.steps_per_epoch,
74+ epochs = 5 ,
75+ use_multiprocessing = True ,
76+ validation_data = loader_va.load(),
77+ callbacks = [EarlyStopping(monitor = " loss" , patience = 5 , restore_best_weights = True )],
78+ )
79+
80+ from tensorflow.keras.models import load_model
81+
82+ model_path = " models/my-first-graph-classifier"
83+ model.save(model_path)
84+ loaded_model = load_model(model_path)
85+
86+ loader_te = DisjointLoader(test, epochs = 1 , shuffle = False , batch_size = batch_size)
87+ results = model.evaluate(loader_te.load())
0 commit comments