Skip to content

Commit 1494e56

Browse files
authored
Merge branch 'datawhalechina:main' into main
2 parents 74b10a1 + a6c367b commit 1494e56

File tree

7 files changed

+23
-15
lines changed

7 files changed

+23
-15
lines changed

docs/ch03/ch03.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ $$
193193
$$
194194
其中 $h_{ii}$ 是对应的Hessian矩阵的对角元素,定义为:
195195
$$
196-
h_{ii} = \frac{\partial^2 L}{\partial w_i \partial w_j}
196+
h_{i i}=\frac{\partial^2 L}{\partial w_i^2}
197197
$$
198198

199199
- **剪枝原则**

docs/ch04/ch04.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ $$
183183
Z = round(q_{max}-\frac{r_{max}}{S})
184184
$$
185185

186-
  其中,$r_{max}$ 和 $r_{max}$分别表示浮点数中的最小值和最大值,$q_{max}$ 和 $q_{min}$分别表示定点数中的最小值和最大值。
186+
  其中,$r_{min}$ 和 $r_{max}$分别表示浮点数中的最小值和最大值,$q_{min}$ 和 $q_{max}$分别表示定点数中的最小值和最大值。
187187

188188
![图4-10 线性量化](images/linear.png)
189189

@@ -547,4 +547,4 @@ $$
547547

548548
- [Model Quantization 1: Basic Concepts](https://medium.com/@florian_algo/model-quantization-1-basic-concepts-860547ec6aa9)
549549
- [Model Quantization 3: Timing and Granularity](https://blog.gopenai.com/model-quantization-3-timing-and-granularity-a0978c6e58d4)
550-
- [A Visual Guide to Quantization](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-quantization#%C2%A7symmetric-quantization)
550+
- [A Visual Guide to Quantization](https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-quantization#%C2%A7symmetric-quantization)

docs/ch07/ch07.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# 第7章 项目实践
22

33
  在前面的章节里,我们分别学习了剪枝、量化、神经网络架构搜索及知识蒸馏等模型压缩技术,那么你能融合两种以上的技术对模型进行压缩吗?
4-
> 任务:根据前面已学内容,选择一种实际应用场景,使用两种及以上技术对模型进行压缩并对比效果~
4+
> 任务:基于前面已学知识,选择一种实际应用场景,不限框架和方法,使用两种及以上技术对模型进行压缩并对比前后效果~
55
66
## 7.1 总结
77

docs/notebook/INSTALL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
```
66
conda create -n compression python=3.10
77
conda activate compression
8-
pip install - r requirements.txt
8+
pip install -r requirements.txt
99
```
1010

docs/notebook/ch02/1.mnist_classify.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@
511511
" transforms.Grayscale(num_output_channels=1), # Convert to grayscale if needed\n",
512512
" transforms.Resize((28, 28)), # Resize to match MNIST dimensions\n",
513513
" transforms.ToTensor(), # Convert image to tensor\n",
514-
" transforms.Normalize((0.5,), (0.5,)) # Normalize as per model's training\n",
514+
" transforms.Normalize((0.1307,), (0.3081,)) # Normalize as per model's training\n",
515515
" ])\n",
516516
" image = Image.open(image_path)\n",
517517
" image = transform(image).unsqueeze(0) # Add batch dimension\n",

docs/notebook/ch03/2.pruning_criteria.ipynb

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,11 @@
389389
},
390390
{
391391
"cell_type": "code",
392-
"execution_count": 49,
392+
"execution_count": null,
393393
"metadata": {},
394394
"outputs": [],
395395
"source": [
396-
"# 定义一个LeNet网络\n",
396+
"# 为避免前面的操作影响后续结果,重新定义一个LeNet网络,和前面一致\n",
397397
"class LeNet(nn.Module):\n",
398398
" def __init__(self, num_classes=10):\n",
399399
" super(LeNet, self).__init__()\n",
@@ -451,13 +451,20 @@
451451
"model.load_state_dict(checkpoint)"
452452
]
453453
},
454+
{
455+
"cell_type": "markdown",
456+
"metadata": {},
457+
"source": [
458+
"基于梯度幅度的修剪标准"
459+
]
460+
},
454461
{
455462
"cell_type": "code",
456-
"execution_count": 52,
463+
"execution_count": null,
457464
"metadata": {},
458465
"outputs": [],
459466
"source": [
460-
"# 基于梯度幅度的修剪标准\n",
467+
"# 修剪整个模型的权重,传入整个模型\n",
461468
"def gradient_magnitude_pruning(model, percentile):\n",
462469
" for name, param in model.named_parameters():\n",
463470
" if 'weight' in name:\n",
@@ -467,10 +474,11 @@
467474
},
468475
{
469476
"cell_type": "code",
470-
"execution_count": 53,
477+
"execution_count": null,
471478
"metadata": {},
472479
"outputs": [],
473480
"source": [
481+
"# 修剪局部模型权重,传入某一层的权重\n",
474482
"@torch.no_grad()\n",
475483
"def gradient_magnitude_pruning(weight, gradient, percentile=0.5):\n",
476484
" num_elements = weight.numel()\n",
@@ -491,7 +499,7 @@
491499
},
492500
{
493501
"cell_type": "code",
494-
"execution_count": 54,
502+
"execution_count": null,
495503
"metadata": {},
496504
"outputs": [
497505
{
@@ -514,7 +522,7 @@
514522
}
515523
],
516524
"source": [
517-
"# 使用示例\n",
525+
"# 使用示例,这里以fc2层的权重为例\n",
518526
"percentile = 0.5\n",
519527
"gradient_magnitude_pruning(model.fc2.weight, gradients['fc2.weight'], percentile)"
520528
]

docs/notebook/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
numpy==2.1.1
1+
numpy==1.24.3
22
matplotlib==3.9.2
33
tqdm==4.66.5
44
jupyter==1.1.1
@@ -8,4 +8,4 @@ torchprofile==0.0.4
88
torchsummary==1.5.1
99
fast-pytorch-kmeans
1010
scipy
11-
datasets
11+
datasets

0 commit comments

Comments
 (0)