Skip to content

Commit 6c79259

Browse files
fullynet code review and update with small improvement
1 parent ae581a6 commit 6c79259

File tree

1 file changed

+50
-10
lines changed

1 file changed

+50
-10
lines changed

Diff for: ML/Pytorch/Basics/pytorch_simple_fullynet.py

+50-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* 2020-04-08: Initial coding
1010
* 2021-03-24: Added more detailed comments also removed part of
1111
check_accuracy which would only work specifically on MNIST.
12-
12+
* 2022-09-23: Updated with more detailed comments, docstrings to functions, and checked code still functions as intended.
1313
"""
1414

1515
# Imports
@@ -27,9 +27,19 @@
2727
# inheriting from nn.Module, this is the most general way to create your networks and
2828
# allows for more flexibility. I encourage you to also check out nn.Sequential which
2929
# would be easier to use in this scenario but I wanted to show you something that
30-
# "always" works.
30+
# "always" works and is a general approach.
3131
class NN(nn.Module):
3232
def __init__(self, input_size, num_classes):
33+
"""
34+
Here we define the layers of the network. We create two fully connected layers
35+
36+
Parameters:
37+
input_size: the size of the input, in this case 784 (28x28)
38+
num_classes: the number of classes we want to predict, in this case 10 (0-9)
39+
40+
Returns:
41+
None
42+
"""
3343
super(NN, self).__init__()
3444
# Our first linear layer take input_size, in this case 784 nodes to 50
3545
# and our second linear layer takes 50 to the num_classes we have, in
@@ -42,6 +52,12 @@ def forward(self, x):
4252
x here is the mnist images and we run it through fc1, fc2 that we created above.
4353
we also add a ReLU activation function in between and for that (since it has no parameters)
4454
I recommend using nn.functional (F)
55+
56+
Parameters:
57+
x: mnist images
58+
59+
Returns:
60+
out: the output of the network
4561
"""
4662

4763
x = F.relu(self.fc1(x))
@@ -52,15 +68,14 @@ def forward(self, x):
5268
# Set device cuda for GPU if it's available otherwise run on the CPU
5369
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5470

55-
# Hyperparameters of our neural network which depends on the dataset, and
56-
# also just experimenting to see what works well (learning rate for example).
71+
# Hyperparameters
5772
input_size = 784
5873
num_classes = 10
5974
learning_rate = 0.001
6075
batch_size = 64
6176
num_epochs = 3
6277

63-
# Load Training and Test data
78+
# Load Data
6479
train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
6580
test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
6681
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
@@ -83,38 +98,63 @@ def forward(self, x):
8398
# Get to correct shape
8499
data = data.reshape(data.shape[0], -1)
85100

86-
# forward
101+
# Forward
87102
scores = model(data)
88103
loss = criterion(scores, targets)
89104

90-
# backward
105+
# Backward
91106
optimizer.zero_grad()
92107
loss.backward()
93108

94-
# gradient descent or adam step
109+
# Gradient descent or adam step
95110
optimizer.step()
96111

97112

98113
# Check accuracy on training & test to see how good our model
99114
def check_accuracy(loader, model):
115+
"""
116+
Check accuracy of our trained model given a loader and a model
117+
118+
Parameters:
119+
loader: torch.utils.data.DataLoader
120+
A loader for the dataset you want to check accuracy on
121+
model: nn.Module
122+
The model you want to check accuracy on
123+
124+
Returns:
125+
acc: float
126+
The accuracy of the model on the dataset given by the loader
127+
"""
128+
100129
num_correct = 0
101130
num_samples = 0
102131
model.eval()
103132

133+
# We don't need to keep track of gradients here so we wrap it in torch.no_grad()
104134
with torch.no_grad():
135+
# Loop through the data
105136
for x, y in loader:
137+
138+
# Move data to device
106139
x = x.to(device=device)
107140
y = y.to(device=device)
141+
142+
# Get to correct shape
108143
x = x.reshape(x.shape[0], -1)
109144

145+
# Forward pass
110146
scores = model(x)
111147
_, predictions = scores.max(1)
148+
149+
# Check how many we got correct
112150
num_correct += (predictions == y).sum()
151+
152+
# Keep track of number of samples
113153
num_samples += predictions.size(0)
114154

115155
model.train()
116156
return num_correct/num_samples
117157

118-
158+
# Check accuracy on training & test to see how good our model
119159
print(f"Accuracy on training set: {check_accuracy(train_loader, model)*100:.2f}")
120-
print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")
160+
print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")

0 commit comments

Comments
 (0)