|
389 | 389 | }, |
390 | 390 | { |
391 | 391 | "cell_type": "code", |
392 | | - "execution_count": 49, |
| 392 | + "execution_count": null, |
393 | 393 | "metadata": {}, |
394 | 394 | "outputs": [], |
395 | 395 | "source": [ |
396 | | - "# 定义一个LeNet网络\n", |
| 396 | + "# 为避免前面的操作影响后续结果,重新定义一个LeNet网络,和前面一致\n", |
397 | 397 | "class LeNet(nn.Module):\n", |
398 | 398 | " def __init__(self, num_classes=10):\n", |
399 | 399 | " super(LeNet, self).__init__()\n", |
|
451 | 451 | "model.load_state_dict(checkpoint)" |
452 | 452 | ] |
453 | 453 | }, |
| 454 | + { |
| 455 | + "cell_type": "markdown", |
| 456 | + "metadata": {}, |
| 457 | + "source": [ |
| 458 | + "基于梯度幅度的修剪标准" |
| 459 | + ] |
| 460 | + }, |
454 | 461 | { |
455 | 462 | "cell_type": "code", |
456 | | - "execution_count": 52, |
| 463 | + "execution_count": null, |
457 | 464 | "metadata": {}, |
458 | 465 | "outputs": [], |
459 | 466 | "source": [ |
460 | | - "# 基于梯度幅度的修剪标准\n", |
| 467 | + "# 修剪整个模型的权重,传入整个模型\n", |
461 | 468 | "def gradient_magnitude_pruning(model, percentile):\n", |
462 | 469 | " for name, param in model.named_parameters():\n", |
463 | 470 | " if 'weight' in name:\n", |
|
467 | 474 | }, |
468 | 475 | { |
469 | 476 | "cell_type": "code", |
470 | | - "execution_count": 53, |
| 477 | + "execution_count": null, |
471 | 478 | "metadata": {}, |
472 | 479 | "outputs": [], |
473 | 480 | "source": [ |
| 481 | + "# 修剪局部模型权重,传入某一层的权重\n", |
474 | 482 | "@torch.no_grad()\n", |
475 | 483 | "def gradient_magnitude_pruning(weight, gradient, percentile=0.5):\n", |
476 | 484 | " num_elements = weight.numel()\n", |
|
491 | 499 | }, |
492 | 500 | { |
493 | 501 | "cell_type": "code", |
494 | | - "execution_count": 54, |
| 502 | + "execution_count": null, |
495 | 503 | "metadata": {}, |
496 | 504 | "outputs": [ |
497 | 505 | { |
|
514 | 522 | } |
515 | 523 | ], |
516 | 524 | "source": [ |
517 | | - "# 使用示例\n", |
| 525 | + "# 使用示例,这里以fc2层的权重为例\n", |
518 | 526 | "percentile = 0.5\n", |
519 | 527 | "gradient_magnitude_pruning(model.fc2.weight, gradients['fc2.weight'], percentile)" |
520 | 528 | ] |
|
0 commit comments