|
29 | 29 |
|
30 | 30 |   下图是一个可视化示例,实现过程具体可参考[CNN Explainer](https://poloclub.github.io/cnn-explainer): |
31 | 31 |
|
32 | | - |
| 32 | + |
33 | 33 |
|
34 | 34 | ## 2.2 相关术语解读 |
35 | 35 |
|
|
49 | 49 |
|
50 | 50 | 下图是对一个3通道的图片做卷积操作: |
51 | 51 |
|
52 | | - |
| 52 | + |
53 | 53 |
|
54 | 54 |   其中,有三个卷积核(也被称为滤波器)通道,维度是 `3 × 3 × 3`,分别代表卷积核的高度、宽度及深度。该卷积操作首先对三个输入通道分别做卷积操作,然后将卷积的结果相加,最后输出一个特征图。 |
55 | 55 |
|
56 | 56 |   下面来看一个例子,因为3D数据难以可视化,所以所有的数据(输入数据体是蓝色,权重数据体是红色,输出数据体是绿色)都采取将深度切片按照列的方式排列展现。 |
57 | 57 |
|
58 | | - |
| 58 | + |
59 | 59 |
|
60 | 60 |   卷积运算本质上就是在滤波器和输入数据的局部区域间做点积。卷积层的常用实现方式就是利用这一点,将卷积层的前向传播变成一个巨大的矩阵乘法。 |
61 | 61 |
|
62 | 62 | 下面一起动手实践一个简单的CNN例子[Mnist手写数字识别](https://github.com/datawhalechina/awesome-compression/blob/main/docs/notebook/ch02/1.mnist_classify.ipynb),通过这个例子来加深对CNN的理解。 |
63 | 63 |
|
| 64 | +## 2.3 实践 |
| 65 | + |
| 66 | +首先导入必要的包,并加载数据集 |
| 67 | + |
| 68 | +```python |
| 69 | +import copy |
| 70 | +import math |
| 71 | +import time |
| 72 | +import random |
| 73 | +from collections import OrderedDict, defaultdict |
| 74 | +from typing import Union, List |
| 75 | + |
| 76 | +import numpy as np |
| 77 | +import torch |
| 78 | +from matplotlib import pyplot as plt |
| 79 | +from torch import nn |
| 80 | +from torch.optim import * |
| 81 | +from torch.optim.lr_scheduler import * |
| 82 | +from torch.utils.data import DataLoader |
| 83 | +from torchvision.transforms import * |
| 84 | +from tqdm.auto import tqdm |
| 85 | +import torch.nn.functional as F |
| 86 | +from torchvision import datasets |
| 87 | + |
| 88 | +random.seed(0) |
| 89 | +np.random.seed(0) |
| 90 | +torch.manual_seed(0) |
| 91 | + |
| 92 | +# 设置归一化 |
| 93 | +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) |
| 94 | + |
| 95 | +# 获取数据集 |
| 96 | +train_dataset = datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform) |
| 97 | +test_dataset = datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform) # train=True训练集,=False测试集 |
| 98 | + |
| 99 | +# 设置DataLoader |
| 100 | +batch_size = 64 |
| 101 | +train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| 102 | +test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
| 103 | +``` |
| 104 | + |
| 105 | +展示数据集,如下图所示: |
| 106 | +```python |
| 107 | +# 展示数据集 |
| 108 | +fig = plt.figure() |
| 109 | +for i in range(12): |
| 110 | + plt.subplot(3, 4, i+1) |
| 111 | + plt.tight_layout() |
| 112 | + plt.imshow(train_dataset.train_data[i], cmap='gray', interpolation='none') |
| 113 | + plt.title("Labels: {}".format(train_dataset.train_labels[i])) |
| 114 | + plt.xticks([]) |
| 115 | + plt.yticks([]) |
| 116 | +plt.show() |
| 117 | +``` |
| 118 | + |
| 119 | + |
| 120 | + |
| 121 | +定义一个LeNet网络,代码如下: |
| 122 | + |
| 123 | +```python |
| 124 | +# 定义一个LeNet网络 |
| 125 | +class LeNet(nn.Module): |
| 126 | + def __init__(self, num_classes=10): |
| 127 | + super(LeNet, self).__init__() |
| 128 | + self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5) |
| 129 | + self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5) |
| 130 | + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) |
| 131 | + self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120) |
| 132 | + self.fc2 = nn.Linear(in_features=120, out_features=84) |
| 133 | + self.fc3 = nn.Linear(in_features=84, out_features=num_classes) |
| 134 | + |
| 135 | + def forward(self, x): |
| 136 | + x = self.maxpool(F.relu(self.conv1(x))) |
| 137 | + x = self.maxpool(F.relu(self.conv2(x))) |
| 138 | + |
| 139 | + x = x.view(x.size()[0], -1) |
| 140 | + x = F.relu(self.fc1(x)) |
| 141 | + x = F.relu(self.fc2(x)) |
| 142 | + x = self.fc3(x) |
| 143 | + |
| 144 | + return x |
| 145 | +device = torch.device("cpu") |
| 146 | +model = LeNet().to(device=device) |
| 147 | +``` |
| 148 | +定义训练函数: |
| 149 | + |
| 150 | +```python |
| 151 | + |
| 152 | +def train( |
| 153 | + model: nn.Module, |
| 154 | + dataloader: DataLoader, |
| 155 | + criterion: nn.Module, |
| 156 | + optimizer: Optimizer, |
| 157 | + callbacks = None |
| 158 | +) -> None: |
| 159 | + model.train() |
| 160 | + |
| 161 | + for inputs, targets in tqdm(dataloader, desc='train', leave=False): |
| 162 | + inputs = inputs.to(device) |
| 163 | + targets = targets.to(device) |
| 164 | + # print(inputs.shape) |
| 165 | + # Reset the gradients (from the last iteration) |
| 166 | + optimizer.zero_grad() |
| 167 | + |
| 168 | + # Forward inference |
| 169 | + outputs = model(inputs).cpu() |
| 170 | + loss = criterion(outputs, targets) |
| 171 | + |
| 172 | + # Backward propagation |
| 173 | + loss.backward() |
| 174 | + |
| 175 | + # Update optimizer |
| 176 | + optimizer.step() |
| 177 | + |
| 178 | + if callbacks is not None: |
| 179 | + for callback in callbacks: |
| 180 | + callback() |
| 181 | +``` |
| 182 | + |
| 183 | +定义评估函数: |
| 184 | + |
| 185 | +```python |
| 186 | +@torch.inference_mode() |
| 187 | +def evaluate( |
| 188 | + model: nn.Module, |
| 189 | + dataloader: DataLoader, |
| 190 | + verbose=True, |
| 191 | +) -> float: |
| 192 | + model.eval() |
| 193 | + |
| 194 | + num_samples = 0 |
| 195 | + num_correct = 0 |
| 196 | + |
| 197 | + for inputs, targets in tqdm(dataloader, desc="eval", leave=False, |
| 198 | + disable=not verbose): |
| 199 | + inputs = inputs.to(device) |
| 200 | + targets = targets.to(device) |
| 201 | + |
| 202 | + # Inference |
| 203 | + outputs = model(inputs).cpu() |
| 204 | + |
| 205 | + # Convert logits to class indices |
| 206 | + outputs = outputs.argmax(dim=1) |
| 207 | + |
| 208 | + # Update metrics |
| 209 | + num_samples += targets.size(0) |
| 210 | + num_correct += (outputs == targets).sum() |
| 211 | + |
| 212 | + return (num_correct / num_samples * 100).item() |
| 213 | +``` |
| 214 | + |
| 215 | +训练模型,并保存最好的模型和梯度,并输出预测准确率,代码如下: |
| 216 | + |
| 217 | +```python |
| 218 | +lr = 0.01 |
| 219 | +momentum = 0.5 |
| 220 | +num_epoch = 5 |
| 221 | + |
| 222 | +optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) # lr学习率,momentum冲量 |
| 223 | +criterion = nn.CrossEntropyLoss() # 交叉熵损失 |
| 224 | + |
| 225 | + |
| 226 | +best_accuracy = 0 |
| 227 | +best_checkpoint = dict() |
| 228 | +gradients = dict() |
| 229 | +for epoch in range(num_epoch): |
| 230 | + train(model, train_loader, criterion, optimizer) |
| 231 | + accuracy = evaluate(model, test_loader) |
| 232 | + is_best = accuracy > best_accuracy |
| 233 | + if is_best: |
| 234 | + best_checkpoint['state_dict'] = copy.deepcopy(model.state_dict()) |
| 235 | + best_accuracy = accuracy |
| 236 | + |
| 237 | + # 将每个梯度保存到字典中 |
| 238 | + for name, parameter in model.named_parameters(): |
| 239 | + if parameter.grad is not None: |
| 240 | + # .clone()确保我们有梯度的复制,而非引用 |
| 241 | + gradients[name] = parameter.grad.clone() |
| 242 | + |
| 243 | + print(f'Epoch{epoch+1:>2d} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%') |
| 244 | + |
| 245 | + |
| 246 | +torch.save(best_checkpoint['state_dict'], './model.pt') |
| 247 | +torch.save(gradients, './model_gradients.pt') |
| 248 | + |
| 249 | +print(f"=> loading best checkpoint") |
| 250 | +model.load_state_dict(best_checkpoint['state_dict']) |
| 251 | +model_accuracy = evaluate(model, test_loader) |
| 252 | +print(f"Model has accuracy={model_accuracy:.2f}%") |
| 253 | +``` |
| 254 | + |
| 255 | +最后,加载模型并预测单张图像的标签: |
| 256 | + |
| 257 | +```python |
| 258 | +from torchvision import transforms |
| 259 | +from PIL import Image |
| 260 | +# Load the saved model |
| 261 | +model = LeNet() # Replace MyModel with your model's class |
| 262 | +model.load_state_dict(torch.load('./model.pt')) |
| 263 | +model.eval() # Set the model to evaluation mode |
| 264 | + |
| 265 | +# Preprocess the image (assuming input is grayscale 28x28 as in MNIST) |
| 266 | +def preprocess_image(image_path): |
| 267 | + transform = transforms.Compose([ |
| 268 | + transforms.Grayscale(num_output_channels=1), # Convert to grayscale if needed |
| 269 | + transforms.Resize((28, 28)), # Resize to match MNIST dimensions |
| 270 | + transforms.ToTensor(), # Convert image to tensor |
| 271 | + transforms.Normalize((0.1307,), (0.3081,)) # Normalize as per model's training |
| 272 | + ]) |
| 273 | + image = Image.open(image_path) |
| 274 | + image = transform(image).unsqueeze(0) # Add batch dimension |
| 275 | + return image |
| 276 | + |
| 277 | +# Perform prediction on a single image |
| 278 | +def predict_image(image_path): |
| 279 | + image = preprocess_image(image_path) |
| 280 | + with torch.no_grad(): |
| 281 | + output = model(image) |
| 282 | + prediction = output.argmax(dim=1, keepdim=True) # Get the predicted class |
| 283 | + return prediction.item() |
| 284 | + |
| 285 | +# Example usage |
| 286 | +image_path = 'test.png' # Replace with the actual image path |
| 287 | +predicted_label = predict_image(image_path) |
| 288 | +print(f'Predicted label: {predicted_label}') |
| 289 | +``` |
| 290 | + |
| 291 | +同时预测多张图片的代码如下: |
| 292 | + |
| 293 | +```python |
| 294 | +def show_images(images, labels, preds, num_rows=4, num_cols=4): |
| 295 | + fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 10)) |
| 296 | + axes = axes.flatten() # Flatten axes array for easy iteration |
| 297 | + for idx in range(num_rows * num_cols): |
| 298 | + if idx >= len(images): |
| 299 | + break |
| 300 | + ax = axes[idx] |
| 301 | + img = images[idx].cpu().numpy().squeeze() # Convert tensor to numpy and remove unnecessary dimensions |
| 302 | + ax.imshow(img, cmap='gray') |
| 303 | + ax.set_title(f'True: {labels[idx].item()}\nPred: {preds[idx].item()}') |
| 304 | + ax.axis('off') # Turn off axis labels |
| 305 | + plt.tight_layout() |
| 306 | + plt.show() |
| 307 | + |
| 308 | +# Load the saved model |
| 309 | +model = LeNet() # Replace MyModel with your model's class |
| 310 | +model.load_state_dict(torch.load('./model.pt')) |
| 311 | +model.eval() # Set the model to evaluation mode |
| 312 | +# Get a batch of test data |
| 313 | +test_iter = iter(test_loader) |
| 314 | +images, labels = next(test_iter) |
| 315 | + |
| 316 | +# Run the model to predict labels |
| 317 | +with torch.no_grad(): |
| 318 | + outputs = model(images) |
| 319 | + _, preds = torch.max(outputs, 1) # Get the predicted labels |
| 320 | + |
| 321 | +# Show images with true and predicted labels |
| 322 | +show_images(images.cpu(), labels.cpu(), preds.cpu()) |
| 323 | +``` |
| 324 | + |
| 325 | + |
64 | 326 | ## 引用资料 |
65 | 327 |
|
66 | 328 | - [卷积核(kernel)和过滤器(filter)的区别](https://blog.csdn.net/weixin_38481963/article/details/109906338) |
|
0 commit comments