Skip to content

Commit 2938c2b

Browse files
fix incorrect comment
1 parent 917ad10 commit 2938c2b

4 files changed

Lines changed: 5 additions & 5 deletions

File tree

examples/indexBatching/DCRNN/chicken_pox_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def train(train_dataloader, val_dataloader, edge_index,edge_weight, epochs, seq_
7070
# Forward pass
7171
outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
7272

73-
# Calculate loss (use only the first output channel, assuming it's the target)
73+
# Calculate loss
7474
loss = masked_mae_loss(outputs,y_batch )
7575

7676
# Backward pass

examples/indexBatching/DCRNN/pems_allLA_main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epoc
6868
# Forward pass
6969
outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
7070

71-
# Calculate loss (use only the first output channel, assuming it's the target)
71+
72+
# Calculate loss
7273
loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
7374

7475
# Backward pass

examples/indexBatching/DCRNN/pems_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def train(args=None, epochs=None, batch_size=None, allGPU=False, debug=False, lo
112112
# Forward pass
113113
outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
114114

115-
# Calculate loss (use only the first output channel, assuming it's the target)
115+
# Calculate loss
116116
loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
117117

118118
# Backward pass

examples/indexBatching/DCRNN/pems_main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,13 @@ def train(train_dataloader, val_dataloader, mean, std, edges, edge_weights, epoc
6161
X_batch, y_batch = batch
6262

6363
if allGPU == False:
64-
# print("casting")
6564
X_batch = X_batch.to(device).float()
6665
y_batch = y_batch.to(device).float()
6766

6867
# Forward pass
6968
outputs = model(X_batch, edge_index, edge_weight) # Shape: (batch_size, seq_length, num_nodes, out_channels)
7069

71-
# Calculate loss (use only the first output channel, assuming it's the target)
70+
# Calculate loss
7271
loss = masked_mae_loss((outputs * std) + mean, (y_batch * std) + mean)
7372

7473
# Backward pass

0 commit comments

Comments
 (0)