|
139 | 139 | }, |
140 | 140 | { |
141 | 141 | "cell_type": "code", |
142 | | - "execution_count": null, |
| 142 | + "execution_count": 22, |
143 | 143 | "id": "51739d61", |
144 | 144 | "metadata": {}, |
145 | 145 | "outputs": [], |
| 146 | + "source": [ |
| 147 | + "# sm = torch.jit.script(style_model.to(torch.float32))\n", |
| 148 | + "# sm.save(f\"models/{model_name}_float32.pt\")\n", |
| 149 | + "\n", |
| 150 | + "# sm = torch.jit.script(style_model.to(torch.float16))\n", |
| 151 | + "# sm.save(f\"models/{model_name}_float16.pt\")\n" |
| 152 | + ] |
| 153 | + }, |
| 154 | + { |
| 155 | + "cell_type": "code", |
| 156 | + "execution_count": null, |
| 157 | + "id": "b0173e2e", |
| 158 | + "metadata": {}, |
| 159 | + "outputs": [], |
| 160 | + "source": [ |
| 161 | + "# torch::Tensor sobel_dx = torch::tensor({{-1, 0, 1},\n", |
| 162 | + "# {-2, 0, 2},\n", |
| 163 | + "# {-1, 0, 1}}).to(torch::kFloat32);\n", |
| 164 | + "# torch::Tensor sobel_dy = torch::tensor({{-1, -2, -1},\n", |
| 165 | + "# {0, 0, 0},\n", |
| 166 | + "# {1, 2, 1}}).to(torch::kFloat32);\n", |
| 167 | + "\n", |
| 168 | + "# torch::Tensor sobel_kernel = torch::cat({sobel_dx, sobel_dy}, 0).unsqueeze(0).unsqueeze(0);\n", |
| 169 | + "\n" |
| 170 | + ] |
| 171 | + }, |
| 172 | + { |
| 173 | + "cell_type": "code", |
| 174 | + "execution_count": 32, |
| 175 | + "id": "c09f3a28", |
| 176 | + "metadata": {}, |
| 177 | + "outputs": [ |
| 178 | + { |
| 179 | + "ename": "RuntimeError", |
| 180 | + "evalue": "Given groups=1, weight of size [1, 1, 6, 3], expected input[1, 3, 1428, 1904] to have 1 channels, but got 3 channels instead", |
| 181 | + "output_type": "error", |
| 182 | + "traceback": [ |
| 183 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 184 | + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", |
| 185 | + "Cell \u001b[0;32mIn[32], line 20\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msobel_cnn(x)\n\u001b[1;32m 19\u001b[0m sobel \u001b[38;5;241m=\u001b[39m Sobel()\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mfloat16)\n\u001b[0;32m---> 20\u001b[0m \u001b[43msobel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1428\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1904\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat16\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m sm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mscript(sobel)\n\u001b[1;32m 22\u001b[0m sm\u001b[38;5;241m.\u001b[39msave(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/sobel.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", |
| 186 | + "File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", |
| 187 | + "File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", |
| 188 | + "Cell \u001b[0;32mIn[32], line 17\u001b[0m, in \u001b[0;36mSobel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msobel_cnn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", |
| 189 | + "File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", |
| 190 | + "File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", |
| 191 | + "File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/nn/modules/conv.py:554\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 554\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", |
| 192 | + "File \u001b[0;32m~/.venv/lib/python3.12/site-packages/torch/nn/modules/conv.py:549\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 538\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(\n\u001b[1;32m 539\u001b[0m F\u001b[38;5;241m.\u001b[39mpad(\n\u001b[1;32m 540\u001b[0m \u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 547\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups,\n\u001b[1;32m 548\u001b[0m )\n\u001b[0;32m--> 549\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 550\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\n\u001b[1;32m 551\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", |
| 193 | + "\u001b[0;31mRuntimeError\u001b[0m: Given groups=1, weight of size [1, 1, 6, 3], expected input[1, 3, 1428, 1904] to have 1 channels, but got 3 channels instead" |
| 194 | + ] |
| 195 | + } |
| 196 | + ], |
| 197 | + "source": [ |
| 198 | + "class Sobel(torch.nn.Module):\n", |
| 199 | + " def __init__(self):\n", |
| 200 | + " super(Sobel, self).__init__()\n", |
| 201 | + " sobel_dx = torch.tensor([[-1, 0, 1],\n", |
| 202 | + " [-2, 0, 2],\n", |
| 203 | + " [-1, 0, 1]]).to(torch.float16)\n", |
| 204 | + " sobel_dy = torch.tensor([[-1, -2, -1],\n", |
| 205 | + " [0, 0, 0],\n", |
| 206 | + " [1, 2, 1]]).to(torch.float16)\n", |
| 207 | + " sobel_kernel = torch.cat((sobel_dx, sobel_dy), 0).unsqueeze(0).unsqueeze(0)\n", |
| 208 | + " sobel_kernel = sobel_kernel.to(torch.float16)\n", |
| 209 | + " self.sobel_kernel = torch.nn.Parameter(sobel_kernel, requires_grad=False)\n", |
| 210 | + " self.sobel_cnn = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False)\n", |
| 211 | + " self.sobel_cnn.weight = torch.nn.Parameter(sobel_kernel, requires_grad=False)\n", |
| 212 | + "\n", |
| 213 | + " def forward(self, x):\n", |
| 214 | + " return self.sobel_cnn(x)\n", |
| 215 | + "\n", |
| 216 | + "sobel = Sobel().to(torch.float16)\n", |
| 217 | + "sobel(torch.randn(3, 1428, 1904).to(torch.float16))\n", |
| 218 | + "sm = torch.jit.script(sobel)\n", |
| 219 | + "sm.save(\"models/sobel.pt\")\n" |
| 220 | + ] |
| 221 | + }, |
| 222 | + { |
| 223 | + "cell_type": "code", |
| 224 | + "execution_count": 57, |
| 225 | + "id": "3507a1fb", |
| 226 | + "metadata": {}, |
| 227 | + "outputs": [ |
| 228 | + { |
| 229 | + "name": "stdout", |
| 230 | + "output_type": "stream", |
| 231 | + "text": [ |
| 232 | + "torch.Size([3, 3])\n", |
| 233 | + "torch.Size([3, 3])\n", |
| 234 | + "torch.Size([2, 3, 3])\n" |
| 235 | + ] |
| 236 | + }, |
| 237 | + { |
| 238 | + "ename": "RuntimeError", |
| 239 | + "evalue": "expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1]", |
| 240 | + "output_type": "error", |
| 241 | + "traceback": [ |
| 242 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 243 | + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", |
| 244 | + "Cell \u001b[0;32mIn[57], line 18\u001b[0m\n\u001b[1;32m 11\u001b[0m X \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m3\u001b[39m, \u001b[38;5;241m1428\u001b[39m, \u001b[38;5;241m1904\u001b[39m)\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mfloat16)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# sobel_cnn = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False)\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# sobel_cnn.weight = torch.nn.Parameter(sobel_kernel, requires_grad=False)\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# sobel_cnn = sobel_cnn.to(torch.float16)\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# sobel_cnn(X)\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m Y \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msobel_kernel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstride\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", |
| 245 | + "\u001b[0;31mRuntimeError\u001b[0m: expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1]" |
| 246 | + ] |
| 247 | + } |
| 248 | + ], |
| 249 | + "source": [ |
| 250 | + "sobel_dx = torch.tensor([[-1, 0, 1],\n", |
| 251 | + " [-2, 0, 2],\n", |
| 252 | + " [-1, 0, 1]]).to(torch.float16)\n", |
| 253 | + "print(sobel_dx.shape)\n", |
| 254 | + "sobel_dy = torch.tensor([[-1, -2, -1],\n", |
| 255 | + " [0, 0, 0],\n", |
| 256 | + " [1, 2, 1]]).to(torch.float16)\n", |
| 257 | + "print(sobel_dy.shape)\n", |
| 258 | + "sobel_kernel = torch.cat([sobel_dx.unsqueeze(0), sobel_dy.unsqueeze(0)], 0)\n", |
| 259 | + "print(sobel_kernel.shape)\n", |
| 260 | + "X = torch.randn(1,3, 1428, 1904).to(torch.float16)\n", |
| 261 | + "\n", |
| 262 | + "# sobel_cnn = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False)\n", |
| 263 | + "# sobel_cnn.weight = torch.nn.Parameter(sobel_kernel, requires_grad=False)\n", |
| 264 | + "# sobel_cnn = sobel_cnn.to(torch.float16)\n", |
| 265 | + "# sobel_cnn(X)\n", |
| 266 | + "\n", |
| 267 | + "Y = torch.nn.functional.conv2d(X, sobel_kernel, stride=1, padding=1)" |
| 268 | + ] |
| 269 | + }, |
| 270 | + { |
| 271 | + "cell_type": "code", |
| 272 | + "execution_count": null, |
| 273 | + "id": "59d85573", |
| 274 | + "metadata": {}, |
| 275 | + "outputs": [], |
146 | 276 | "source": [] |
147 | 277 | } |
148 | 278 | ], |
|
0 commit comments