Skip to content

Commit 18274de

Browse files
author
Anton Hosgood
committed
Update demo notebook
1 parent d728666 commit 18274de

File tree

2 files changed

+134
-23
lines changed

2 files changed

+134
-23
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ Standalone utility scripts.
5555
2. Create a **Python 3.13 or higher** virtual environment (optional but recommended):
5656

5757
```bash
58-
python3 --version # Ensure you have Python 3.13+ installed
59-
python3 -m venv .venv
58+
python3.13 -m venv .venv
6059
source .venv/bin/activate
6160
```
6261

notebooks/demo.ipynb

Lines changed: 133 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
{
22
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "markdown",
6+
"source": "In this demo notebook, we train a simple model from scratch to add colour to a black-and-white image.",
7+
"id": "1a2f9b3cabfd33ef"
8+
},
39
{
410
"cell_type": "code",
511
"id": "initial_id",
612
"metadata": {
713
"collapsed": true,
814
"ExecuteTime": {
9-
"end_time": "2025-05-16T17:01:44.790969Z",
10-
"start_time": "2025-05-16T17:01:43.494761Z"
15+
"end_time": "2025-05-20T14:46:02.736595Z",
16+
"start_time": "2025-05-20T14:46:01.432375Z"
1117
}
1218
},
1319
"source": [
@@ -21,21 +27,108 @@
2127
"outputs": [],
2228
"execution_count": 1
2329
},
30+
{
31+
"metadata": {},
32+
"cell_type": "markdown",
33+
"source": "Run the notebook as if it were in the project root.",
34+
"id": "55935725ab2aca96"
35+
},
36+
{
37+
"metadata": {
38+
"ExecuteTime": {
39+
"end_time": "2025-05-20T14:46:04.058777Z",
40+
"start_time": "2025-05-20T14:46:04.055123Z"
41+
}
42+
},
43+
"cell_type": "code",
44+
"source": "%cd ..",
45+
"id": "f479db5d9e61248a",
46+
"outputs": [
47+
{
48+
"name": "stdout",
49+
"output_type": "stream",
50+
"text": [
51+
"/Users/antonhosgood/image-colorizer\n"
52+
]
53+
}
54+
],
55+
"execution_count": 2
56+
},
57+
{
58+
"metadata": {},
59+
"cell_type": "markdown",
60+
"source": "Run the following cell to generate a sample dataset to train on.",
61+
"id": "1f147f68cf7943c4"
62+
},
63+
{
64+
"metadata": {},
65+
"cell_type": "code",
66+
"source": "!python3 -m scripts.generate_dataset data 128 128",
67+
"id": "dc10fd7b102cbc4a",
68+
"outputs": [],
69+
"execution_count": null
70+
},
2471
{
2572
"metadata": {
2673
"ExecuteTime": {
27-
"end_time": "2025-05-16T17:01:45.070538Z",
28-
"start_time": "2025-05-16T17:01:44.847741Z"
74+
"end_time": "2025-05-20T14:46:05.601372Z",
75+
"start_time": "2025-05-20T14:46:05.596731Z"
2976
}
3077
},
3178
"cell_type": "code",
3279
"source": [
33-
"base_dir = Path.cwd().parent\n",
80+
"# To build paths relative to project root\n",
81+
"base_dir = Path.cwd()\n",
3482
"\n",
83+
"# If loading a model checkpoint\n",
3584
"checkpoint_path = base_dir / \"checkpoints\" / \"unet\" / \"model_epoch_10.pth\"\n",
85+
"# Add your own mage to colourise\n",
86+
"input_image_path = base_dir / \"jetty.jpg\"\n",
3687
"\n",
3788
"device = get_device()\n",
38-
"\n",
89+
"print(f\"Using device: {device}\")"
90+
],
91+
"id": "cca55360ef40f85d",
92+
"outputs": [
93+
{
94+
"name": "stdout",
95+
"output_type": "stream",
96+
"text": [
97+
"Using device: mps\n"
98+
]
99+
}
100+
],
101+
"execution_count": 3
102+
},
103+
{
104+
"metadata": {},
105+
"cell_type": "markdown",
106+
"source": "Run the following cell to train a model from scratch (or skip it if you already have a pretrained model checkpoint).",
107+
"id": "f5150439b9764d18"
108+
},
109+
{
110+
"metadata": {},
111+
"cell_type": "code",
112+
"source": "!python3 -m src.train.train src/train/config.yaml",
113+
"id": "14c76e35011106b2",
114+
"outputs": [],
115+
"execution_count": null
116+
},
117+
{
118+
"metadata": {},
119+
"cell_type": "markdown",
120+
"source": "Load the trained model checkpoint.",
121+
"id": "db30850aeb846bc2"
122+
},
123+
{
124+
"metadata": {
125+
"ExecuteTime": {
126+
"end_time": "2025-05-20T14:46:09.248453Z",
127+
"start_time": "2025-05-20T14:46:09.009257Z"
128+
}
129+
},
130+
"cell_type": "code",
131+
"source": [
39132
"model = UNet().to(device)\n",
40133
"load_checkpoint(model, checkpoint_path, device)"
41134
],
@@ -45,25 +138,52 @@
45138
"name": "stdout",
46139
"output_type": "stream",
47140
"text": [
48-
"Loaded checkpoint from /Users/antonhosgood/hue-restorer/checkpoints/unet/model_epoch_10.pth\n"
141+
"Loaded checkpoint from /Users/antonhosgood/image-colorizer/checkpoints/unet/model_epoch_10.pth\n"
49142
]
50143
}
51144
],
52-
"execution_count": 2
145+
"execution_count": 4
53146
},
54147
{
55148
"metadata": {
56149
"ExecuteTime": {
57-
"end_time": "2025-05-16T17:02:02.655557Z",
58-
"start_time": "2025-05-16T17:02:02.393082Z"
150+
"end_time": "2025-05-20T14:46:10.126381Z",
151+
"start_time": "2025-05-20T14:46:10.121870Z"
59152
}
60153
},
61154
"cell_type": "code",
62155
"source": [
63-
"input_image_path = base_dir / \"jetty.jpg\"\n",
156+
"num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
64157
"\n",
65-
"inference(model, device, input_image_path, resize=(256, 512))"
158+
"print(f\"Number of trainable parameters: {num_trainable_params}\")"
66159
],
160+
"id": "ae9d1d24bd032427",
161+
"outputs": [
162+
{
163+
"name": "stdout",
164+
"output_type": "stream",
165+
"text": [
166+
"Number of trainable parameters: 31042499\n"
167+
]
168+
}
169+
],
170+
"execution_count": 5
171+
},
172+
{
173+
"metadata": {},
174+
"cell_type": "markdown",
175+
"source": "Run inference. If no output path is provided, a side-by-side comparison will be generated in the notebook.",
176+
"id": "42bb909ebdcdca05"
177+
},
178+
{
179+
"metadata": {
180+
"ExecuteTime": {
181+
"end_time": "2025-05-20T14:46:14.007293Z",
182+
"start_time": "2025-05-20T14:46:13.755714Z"
183+
}
184+
},
185+
"cell_type": "code",
186+
"source": "inference(model, device, input_image_path, resize=(256, 512))",
67187
"id": "d29f4ab39f99f5e1",
68188
"outputs": [
69189
{
@@ -77,15 +197,7 @@
77197
"output_type": "display_data"
78198
}
79199
],
80-
"execution_count": 3
81-
},
82-
{
83-
"metadata": {},
84-
"cell_type": "code",
85-
"outputs": [],
86-
"execution_count": null,
87-
"source": "",
88-
"id": "b17fa37143ceba14"
200+
"execution_count": 6
89201
}
90202
],
91203
"metadata": {

0 commit comments

Comments
 (0)