Skip to content

Commit 2de1f56

Browse files
committed
Update 2.pruning_criteria.ipynb
1 parent 460dba8 commit 2de1f56

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

docs/notebook/ch03/2.pruning_criteria.ipynb

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,11 @@
389389
},
390390
{
391391
"cell_type": "code",
392-
"execution_count": 49,
392+
"execution_count": null,
393393
"metadata": {},
394394
"outputs": [],
395395
"source": [
396-
"# 定义一个LeNet网络\n",
396+
"# 为避免前面的操作影响后续结果,重新定义一个LeNet网络,和前面一致\n",
397397
"class LeNet(nn.Module):\n",
398398
" def __init__(self, num_classes=10):\n",
399399
" super(LeNet, self).__init__()\n",
@@ -451,13 +451,20 @@
451451
"model.load_state_dict(checkpoint)"
452452
]
453453
},
454+
{
455+
"cell_type": "markdown",
456+
"metadata": {},
457+
"source": [
458+
"基于梯度幅度的修剪标准"
459+
]
460+
},
454461
{
455462
"cell_type": "code",
456-
"execution_count": 52,
463+
"execution_count": null,
457464
"metadata": {},
458465
"outputs": [],
459466
"source": [
460-
"# 基于梯度幅度的修剪标准\n",
467+
"# 修剪整个模型的权重,传入整个模型\n",
461468
"def gradient_magnitude_pruning(model, percentile):\n",
462469
" for name, param in model.named_parameters():\n",
463470
" if 'weight' in name:\n",
@@ -467,10 +474,11 @@
467474
},
468475
{
469476
"cell_type": "code",
470-
"execution_count": 53,
477+
"execution_count": null,
471478
"metadata": {},
472479
"outputs": [],
473480
"source": [
481+
"# 修剪局部模型权重,传入某一层的权重\n",
474482
"@torch.no_grad()\n",
475483
"def gradient_magnitude_pruning(weight, gradient, percentile=0.5):\n",
476484
" num_elements = weight.numel()\n",
@@ -491,7 +499,7 @@
491499
},
492500
{
493501
"cell_type": "code",
494-
"execution_count": 54,
502+
"execution_count": null,
495503
"metadata": {},
496504
"outputs": [
497505
{
@@ -514,7 +522,7 @@
514522
}
515523
],
516524
"source": [
517-
"# 使用示例\n",
525+
"# 使用示例,这里以fc2层的权重为例\n",
518526
"percentile = 0.5\n",
519527
"gradient_magnitude_pruning(model.fc2.weight, gradients['fc2.weight'], percentile)"
520528
]

0 commit comments

Comments
 (0)