Skip to content

Commit 6b33353

Browse files
committed
Fixed torch tensor loading.
1 parent 38da4dc commit 6b33353

12 files changed

Lines changed: 51 additions & 146 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ add_dependencies(TorchBridge ChAI)
216216
add_dependencies(TorchBridge bridge_objs)
217217
target_link_options(TorchBridge
218218
PRIVATE
219+
-M ${PROJECT_ROOT_DIR}/lib
219220
${BRIDGE_DIR}/include/bridge.h
220221
${BRIDGE_OBJECT_FILES}
221222
-L ${LIBTORCH_DIR}/lib
@@ -225,6 +226,7 @@ target_link_options(TorchBridge
225226
"-ltorch_global_deps"
226227
${LIBTORCH_LIBS_LINKER_ARGS}
227228
--ldflags "-Wl,-rpath,${LIBTORCH_DIR}/lib"
229+
${CHAI_LINKER_ARGS}
228230
)
229231

230232

bridge/lib/bridge.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ void store_tensor(at::Tensor &input, float32_t* dest) {
9494
std::memcpy(dest,data,bytes_size);
9595
}
9696

97-
bridge_tensor_t torch_to_bridge(at::Tensor &tensor) {
97+
bridge_tensor_t torch_to_bridge(at::Tensor tensor_) {
98+
at::Tensor tensor = tensor_.contiguous().to(torch::kCPU,torch::kFloat32,false,false);
9899
bridge_tensor_t result;
99100
result.created_by_c = true;
100101
result.was_freed = false;
@@ -170,7 +171,9 @@ extern "C" bridge_tensor_t load_tensor_from_file(const uint8_t* file_path) {
170171
std::string fp(reinterpret_cast<const char*>(file_path));
171172
std::vector<char> f = get_the_bytes(fp);
172173
torch::IValue x = torch::pickle_load(f);
173-
torch::Tensor t = x.toTensor();
174+
at::Tensor t = x.toTensor();
175+
std::cout << "Tensor loaded from file: " << t.sizes() << std::endl;
176+
std::cout.flush();
174177
return torch_to_bridge(t);
175178
}
176179

examples/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,3 @@
33
add_subdirectory(my_example)
44

55
add_subdirectory(torch_model_loading)
6-
7-
8-
add_subdirectory(split_loop)

examples/split_loop/CMakeLists.txt

Lines changed: 0 additions & 86 deletions
This file was deleted.

examples/split_loop/split_loop.chpl

Lines changed: 0 additions & 27 deletions
This file was deleted.

examples/torch_model_loading/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@ add_executable(TorchLoad
99
)
1010
add_dependencies(TorchLoad bridge)
1111
add_dependencies(TorchLoad ChAI)
12+
add_dependencies(TorchLoad bridge_objs)
1213
target_link_options(TorchLoad
1314
PRIVATE
14-
${CHAI_LINKER_ARGS}
15+
--main-module torch_load.chpl
16+
-M ${PROJECT_ROOT_DIR}/lib
17+
${CHAI_LINKER_ARGS}
1518
)
1619

1720
add_custom_command(TARGET TorchLoad POST_BUILD
1821
COMMAND ${CMAKE_COMMAND} -E copy_directory
1922
"${CMAKE_CURRENT_SOURCE_DIR}/models"
20-
"$<TARGET_FILE_DIR:TorchLoad>/models"
23+
"${CMAKE_BINARY_DIR}/TorchLoad_models"
2124
COMMENT "Copying model folder"
25+
)
26+
27+
set_target_properties(TorchLoad PROPERTIES
28+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
2229
)

examples/torch_model_loading/mktensor.ipynb

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 2,
5+
"execution_count": 1,
66
"id": "873dd3b8",
77
"metadata": {},
88
"outputs": [],
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 3,
15+
"execution_count": 2,
1616
"id": "a07c23ff",
1717
"metadata": {},
1818
"outputs": [],
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 4,
39+
"execution_count": 3,
4040
"id": "131adc46",
4141
"metadata": {},
4242
"outputs": [
@@ -57,39 +57,49 @@
5757
},
5858
{
5959
"cell_type": "code",
60-
"execution_count": 5,
60+
"execution_count": 11,
6161
"id": "d4aed442",
6262
"metadata": {},
6363
"outputs": [],
6464
"source": [
65-
"x = torch.arange(0,num_elt)\n",
66-
"x = x.reshape(f1,f2).to(torch.float32)"
65+
"x = torch.arange(0,(500 ** 2)*3)\n",
66+
"x = x.reshape(3,500,500).to(torch.float32)"
6767
]
6868
},
6969
{
7070
"cell_type": "code",
71-
"execution_count": 6,
72-
"id": "481d4709",
71+
"execution_count": 12,
72+
"id": "c45c09fc",
7373
"metadata": {},
74-
"outputs": [],
74+
"outputs": [
75+
{
76+
"data": {
77+
"text/plain": [
78+
"torch.Size([3, 500, 500])"
79+
]
80+
},
81+
"execution_count": 12,
82+
"metadata": {},
83+
"output_type": "execute_result"
84+
}
85+
],
7586
"source": [
76-
"torch.save(x, 'models/my_tensor.pt')"
87+
"x.shape"
7788
]
7889
},
7990
{
8091
"cell_type": "code",
81-
"execution_count": 7,
82-
"id": "1a17ec15",
92+
"execution_count": 13,
93+
"id": "481d4709",
8394
"metadata": {},
8495
"outputs": [],
8596
"source": [
86-
"m = {'a': x, 'b': x + 1}\n",
87-
"torch.save(m, 'models/my_tensor_dict.pt')"
97+
"torch.save(x, 'models/input.pt')"
8898
]
8999
},
90100
{
91101
"cell_type": "code",
92-
"execution_count": 8,
102+
"execution_count": 9,
93103
"id": "841d4dfe",
94104
"metadata": {},
95105
"outputs": [],
@@ -112,7 +122,7 @@
112122
},
113123
{
114124
"cell_type": "code",
115-
"execution_count": 9,
125+
"execution_count": 10,
116126
"id": "ca34dafc",
117127
"metadata": {},
118128
"outputs": [
@@ -122,7 +132,7 @@
122132
"True"
123133
]
124134
},
125-
"execution_count": 9,
135+
"execution_count": 10,
126136
"metadata": {},
127137
"output_type": "execute_result"
128138
}
2.86 MB
Binary file not shown.
781 KB
Binary file not shown.
1.53 MB
Binary file not shown.

0 commit comments

Comments
 (0)