Skip to content

Commit 955ff2e

Browse files
committed
Fix ndarray reading from chdata float16.
1 parent b0bac60 commit 955ff2e

25 files changed

Lines changed: 624 additions & 146 deletions

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,5 @@ add_custom_command(
292292

293293
# ./vgg images/frog.jpg
294294

295-
add_subdirectory(examples)
295+
add_subdirectory(examples)
296+
add_subdirectory("test")

bridge/lib/bridge.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,6 @@ extern "C" bridge_tensor_t load_tensor_from_file(const uint8_t* file_path) {
9292
}
9393

9494
extern "C" bridge_tensor_t load_tensor_dict_from_file(const uint8_t* file_path,const uint8_t* tensor_key) {
95-
96-
97-
std::cout << "Loading tensor from file: " << file_path << std::endl;
98-
std::cout << "Tensor key: " << tensor_key << std::endl;
99-
100-
std::cout.flush();
101-
10295
std::string fp(reinterpret_cast<const char*>(file_path));
10396
std::string tk(reinterpret_cast<const char*>(tensor_key));
10497

examples/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11

22

3-
add_subdirectory(my_example)
3+
add_subdirectory(my_example)
4+
5+
add_subdirectory(torch_model_loading)

examples/my_example/CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11

22

3-
4-
5-
63
add_executable(MyExample
74
${PROJECT_ROOT_DIR}/examples/my_example/my_example.chpl
85
${CHAI_LIB_FILES}
@@ -12,4 +9,12 @@ add_dependencies(MyExample ChAI)
129
target_link_options(MyExample
1310
PRIVATE
1411
${CHAI_LINKER_ARGS}
15-
)
12+
)
13+
14+
15+
# add_custom_command(TARGET MyMyExampleApp POST_BUILD
16+
# COMMAND ${CMAKE_COMMAND} -E copy_directory
17+
# "${CMAKE_CURRENT_SOURCE_DIR}/resources"
18+
# "$<TARGET_FILE_DIR:MyExample>/resources"
19+
# COMMENT "Copying runtime resources"
20+
# )
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
3+
4+
5+
6+
add_executable(TorchLoad
7+
${CMAKE_CURRENT_SOURCE_DIR}/torch_load.chpl
8+
${CHAI_LIB_FILES}
9+
)
10+
add_dependencies(TorchLoad bridge)
11+
add_dependencies(TorchLoad ChAI)
12+
target_link_options(TorchLoad
13+
PRIVATE
14+
${CHAI_LINKER_ARGS}
15+
)
16+
17+
add_custom_command(TARGET TorchLoad POST_BUILD
18+
COMMAND ${CMAKE_COMMAND} -E copy_directory
19+
"${CMAKE_CURRENT_SOURCE_DIR}/models"
20+
"$<TARGET_FILE_DIR:TorchLoad>/models"
21+
COMMENT "Copying model folder"
22+
)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "873dd3b8",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 2,
16+
"id": "a07c23ff",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"def find_factors(num):\n",
21+
" \"\"\"\n",
22+
" Finds all factors of a given number.\n",
23+
"\n",
24+
" Args:\n",
25+
" num: An integer.\n",
26+
"\n",
27+
" Returns:\n",
28+
" A list of integers representing the factors of num.\n",
29+
" \"\"\"\n",
30+
" factors = []\n",
31+
" for i in range(1, num + 1):\n",
32+
" if num % i == 0:\n",
33+
" factors.append(i)\n",
34+
" return factors"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 3,
40+
"id": "131adc46",
41+
"metadata": {},
42+
"outputs": [
43+
{
44+
"name": "stdout",
45+
"output_type": "stream",
46+
"text": [
47+
"f1: 10000, f2: 5\n"
48+
]
49+
}
50+
],
51+
"source": [
52+
"num_elt = 50000\n",
53+
"f1 = find_factors(num_elt)[-4]\n",
54+
"f2 = num_elt // f1\n",
55+
"print(f\"f1: {f1}, f2: {f2}\")"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": 4,
61+
"id": "d4aed442",
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"x = torch.arange(0,num_elt)\n",
66+
"x = x.reshape(f1,f2).to(torch.float32)"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 6,
72+
"id": "481d4709",
73+
"metadata": {},
74+
"outputs": [],
75+
"source": [
76+
"torch.save(x, 'models/my_tensor.pt')"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": 7,
82+
"id": "1a17ec15",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": [
86+
"m = {'a': x, 'b': x + 1}\n",
87+
"torch.save(m, 'models/my_tensor_dict.pt')"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 8,
93+
"id": "841d4dfe",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"class MyModule(torch.nn.Module):\n",
98+
" def __init__(self):\n",
99+
" super(MyModule, self).__init__()\n",
100+
" self.conv1 = torch.nn.Conv2d(in_channels=3,\n",
101+
" out_channels=16,\n",
102+
" kernel_size=3,\n",
103+
" stride=2)\n",
104+
" self.conv2 = torch.nn.Conv2d(in_channels=16,\n",
105+
" out_channels=32,\n",
106+
" kernel_size=3,\n",
107+
" stride=2)\n",
108+
"\n",
109+
" def forward(self, x):\n",
110+
" return self.linear(x)"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"id": "ca34dafc",
117+
"metadata": {},
118+
"outputs": [],
119+
"source": []
120+
}
121+
],
122+
"metadata": {
123+
"kernelspec": {
124+
"display_name": ".venv",
125+
"language": "python",
126+
"name": "python3"
127+
},
128+
"language_info": {
129+
"codemirror_mode": {
130+
"name": "ipython",
131+
"version": 3
132+
},
133+
"file_extension": ".py",
134+
"mimetype": "text/x-python",
135+
"name": "python",
136+
"nbconvert_exporter": "python",
137+
"pygments_lexer": "ipython3",
138+
"version": "3.12.9"
139+
}
140+
},
141+
"nbformat": 4,
142+
"nbformat_minor": 5
143+
}
196 KB
Binary file not shown.
392 KB
Binary file not shown.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
use Tensor;
2+
3+
4+
proc main() {
5+
const a = ndarray.loadPyTorchTensorDictWithKey(2,"models/my_tensor_dict.pt","a");
6+
const b = ndarray.loadPyTorchTensorDictWithKey(2,"models/my_tensor_dict.pt","b");
7+
writeln("a sum: ", a.sum());
8+
writeln("b sum: ", b.sum());
9+
}

examples/vgg/dump_weights.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
import vgg
77

88

9-
# from vgg import vgg16
9+
from vgg import vgg16
1010

11-
# model = vgg16(pretrained=True)
11+
model = vgg16(pretrained=True)
1212

13-
# model = model.to(torch.float16)
13+
model = model.to(torch.float16)
1414

15-
# # Add the scripts directory to the sys.path
16-
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
15+
# Add the scripts directory to the sys.path
16+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
1717

18-
# import chai
18+
import chai
1919

20-
# os.makedirs('models/vgg16', exist_ok=True)
21-
# model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True)
20+
os.makedirs('models/vgg16', exist_ok=True)
21+
model.chai_dump('models/vgg16','vgg16', with_json=False, verbose=True)
2222

2323
# # print(model.state_dict().keys())
2424
# # print([(n,w.dtype) for (n,w) in model.state_dict().items()])
@@ -27,10 +27,11 @@
2727
for vggxx in vgg.model_urls.keys():
2828
print(vggxx)
2929
model = vgg.__dict__[vggxx](pretrained=True)
30-
model = model.to(torch.float32)
30+
model = model.to(torch.float16)
3131
model.eval()
3232

3333
# # Add the scripts directory to the sys.path
34+
3435
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'scripts')))
3536

3637
# import chai
@@ -43,25 +44,25 @@
4344
# traced_model = torch.jit.trace(model, faux_input)
4445
# traced_model.save(f'models/traced_{vggxx}.pt')
4546

46-
print(f'models/{vggxx}.pt')
47-
sd = model.state_dict()
48-
torch.save(sd, f'models/{vggxx}.pt')
47+
# print(f'models/{vggxx}.pt')
48+
# sd = model.state_dict()
49+
# torch.save(sd, f'models/{vggxx}.pt')
4950

50-
# print(f'models/traced_{vggxx}.pt')
51-
# faux_input = torch.randn(3,720,1280).to(torch.float16)
52-
# traced_model = torch.jit.trace(model, faux_input)
53-
# torch.save(traced_model,f'models/traced_{vggxx}.pt')
51+
# # print(f'models/traced_{vggxx}.pt')
52+
# # faux_input = torch.randn(3,720,1280).to(torch.float16)
53+
# # traced_model = torch.jit.trace(model, faux_input)
54+
# # torch.save(traced_model,f'models/traced_{vggxx}.pt')
5455

55-
print(f'models/script_{vggxx}.pt')
56-
script_model = torch.jit.script(model)
57-
script_model.save(f'models/script_{vggxx}.pt')
56+
# print(f'models/script_{vggxx}.pt')
57+
# script_model = torch.jit.script(model)
58+
# script_model.save(f'models/script_{vggxx}.pt')
5859

59-
test = torch.rand(1,3,720,1280).to(torch.float32)
60+
# test = torch.rand(1,3,720,1280).to(torch.float32)
6061
# print(model(test).shape)
6162

6263
model.eval()
6364
print(f'models/trace_{vggxx}.pt')
64-
faux_input = torch.rand(1,3,720,1280).to(torch.float32)
65+
faux_input = torch.rand(1,3,720,1280).to(torch.float16)
6566
print(model(faux_input).shape)
6667
trace_model = torch.jit.trace(model,faux_input)
6768
trace_model.save(f'models/trace_{vggxx}.pt')

0 commit comments

Comments
 (0)