Skip to content

Commit 5aba91a

Browse files
committed
Add pytorch file format loading functionality for ndarrays.
1 parent 34d4447 commit 5aba91a

10 files changed

Lines changed: 261 additions & 36 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ if(APPLE)
2929
set(CMAKE_C_COMPILER "/usr/bin/clang")
3030
set(CMAKE_CXX_COMPILER "/usr/bin/clang++")
3131
endif()
32-
set(CMAKE_CXX_STANDARD 17)
32+
set(CMAKE_CXX_STANDARD 23)
3333

3434

3535
include(LibTorchDL)

bridge/include/bridge.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ extern "C" {
1010
typedef float float32_t;
1111
typedef double float64_t;
1212
typedef char bool_t;
13+
typedef unsigned char uint8_t;
14+
typedef unsigned int uint32_t;
1315

1416
typedef struct bridge_tensor_t {
1517
float* data;
@@ -27,6 +29,10 @@ typedef struct nil_scalar_tensor_t {
2729
bool_t is_tensor;
2830
} nil_scalar_tensor_t;
2931

32+
float* unsafe(const float* arr);
33+
bridge_tensor_t load_tensor_from_file(const uint8_t* file_path);
34+
35+
3036
int baz(void);
3137

3238
void wrHello(void);
@@ -65,6 +71,7 @@ bridge_tensor_t max_pool2d(
6571
int dilation
6672
);
6773

74+
6875
// bridge_tensor_t conv2d(
6976
// bridge_tensor_t input,
7077
// bridge_tensor_t kernel,
@@ -73,7 +80,6 @@ bridge_tensor_t max_pool2d(
7380
// nil_scalar_tensor_t padding
7481
// );
7582

76-
float* unsafe(const float* arr);
7783

7884

7985
#ifdef __cplusplus

bridge/lib/bridge.cpp

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
// #include <torch/script.h>
55
// #include <Aten/ATen.h>
66
#include <iostream>
7+
#include <fstream>
8+
#include <string>
9+
#include <cstring>
10+
#include <sstream>
11+
#include <cstdlib>
712
#include <vector>
8-
913
#include <cstdint>
1014

1115

1216

13-
extern "C" float32_t* unsafe(const float32_t* arr) {
14-
return const_cast<float32_t*>(arr);
15-
}
1617

1718
int bridge_tensor_elements(bridge_tensor_t &bt) {
1819
int size = 1;
@@ -51,6 +52,62 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
5152
return torch::from_blob(bt.data, shape, torch::kFloat);
5253
}
5354

55+
56+
57+
58+
59+
60+
61+
62+
63+
64+
65+
66+
67+
extern "C" float32_t* unsafe(const float32_t* arr) {
68+
return const_cast<float32_t*>(arr);
69+
}
70+
71+
std::vector<char> get_the_bytes(std::string filename) {
72+
std::ifstream input(filename, std::ios::binary);
73+
std::vector<char> bytes((std::istreambuf_iterator<char>(input)),(std::istreambuf_iterator<char>()));
74+
input.close();
75+
return bytes;
76+
}
77+
78+
extern "C" bridge_tensor_t load_tensor_from_file(const uint8_t* file_path) {
79+
// // Load the tensor from a file
80+
// torch::Tensor tensor;
81+
// // torch::load(tensor,file_path);
82+
83+
// std::cout << "Tensor loaded from file: " << tensor.sizes() << std::endl;
84+
85+
// // Convert the tensor to a bridge_tensor_t
86+
87+
std::string fp(reinterpret_cast<const char*>(file_path));
88+
std::cout << "File path: " << fp << std::endl;
89+
90+
std::vector<char> f = get_the_bytes(fp);
91+
std::cout << "File size: " << f.size() << std::endl;
92+
93+
torch::IValue x = torch::pickle_load(f);
94+
// std::cout << "IValue loaded from file: " << x << std::endl;
95+
96+
torch::Tensor t = x.toTensor();
97+
std::cout << "Tensor loaded from IValue: " << t.sizes() << std::endl;
98+
std::cout << "Tensor sum: " << t.sum() << std::endl;
99+
100+
return torch_to_bridge(t);
101+
}
102+
103+
104+
105+
106+
107+
108+
109+
110+
54111
extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
55112
auto t = bridge_to_torch(arr);
56113
// Increment the tensor

examples/vgg/images/my_tensor.pt

1.16 KB
Binary file not shown.

examples/vgg/mktensor.ipynb

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "873dd3b8",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import torch"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 6,
16+
"id": "a07c23ff",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"def find_factors(num):\n",
21+
" \"\"\"\n",
22+
" Finds all factors of a given number.\n",
23+
"\n",
24+
" Args:\n",
25+
" num: An integer.\n",
26+
"\n",
27+
" Returns:\n",
28+
" A list of integers representing the factors of num.\n",
29+
" \"\"\"\n",
30+
" factors = []\n",
31+
" for i in range(1, num + 1):\n",
32+
" if num % i == 0:\n",
33+
" factors.append(i)\n",
34+
" return factors"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 19,
40+
"id": "131adc46",
41+
"metadata": {},
42+
"outputs": [
43+
{
44+
"name": "stdout",
45+
"output_type": "stream",
46+
"text": [
47+
"f1: 10000, f2: 5\n"
48+
]
49+
}
50+
],
51+
"source": [
52+
"num_elt = 50000\n",
53+
"f1 = find_factors(num_elt)[-4]\n",
54+
"f2 = num_elt // f1\n",
55+
"print(f\"f1: {f1}, f2: {f2}\")"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": 20,
61+
"id": "d4aed442",
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"x = torch.arange(0,num_elt)\n",
66+
"x = x.reshape(f1,f2).to(torch.float32)"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 21,
72+
"id": "481d4709",
73+
"metadata": {},
74+
"outputs": [],
75+
"source": [
76+
"torch.save(x, 'my_tensor.pt')"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"id": "1a17ec15",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": []
86+
}
87+
],
88+
"metadata": {
89+
"kernelspec": {
90+
"display_name": ".venv",
91+
"language": "python",
92+
"name": "python3"
93+
},
94+
"language_info": {
95+
"codemirror_mode": {
96+
"name": "ipython",
97+
"version": 3
98+
},
99+
"file_extension": ".py",
100+
"mimetype": "text/x-python",
101+
"name": "python",
102+
"nbconvert_exporter": "python",
103+
"pygments_lexer": "ipython3",
104+
"version": "3.12.9"
105+
}
106+
},
107+
"nbformat": 4,
108+
"nbformat_minor": 5
109+
}

examples/vgg/my_tensor.pt

196 KB
Binary file not shown.

examples/vgg/test.chpl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,29 +53,32 @@ proc run(model: shared VGG16(real(32)), file: string) {
5353
return (topPredictions.data, percentTopk);
5454
}
5555

56+
import Path;
57+
58+
5659
proc main(args: [] string) {
57-
writeln("Loading labels from ", labelFile);
58-
const labels = getLabels();
59-
writeln("Loaded ", labels.size, " labels.");
60+
writeln("Loading labels from ", labelFile);
61+
const labels = getLabels();
62+
writeln("Loaded ", labels.size, " labels.");
6063

61-
writeln("Constructing VGG16 model.");
62-
const vgg = new shared VGG16(real(32));
63-
writeln("Constructed VGG16 model.");
64+
writeln("Constructing VGG16 model.");
65+
const vgg = new shared VGG16(real(32));
66+
writeln("Constructed VGG16 model.");
6467

65-
writeln("Loading VGG16 model weights.");
66-
vgg.loadPyTorchDump(modelDir, false);
67-
writeln("Loaded VGG16 model.");
68+
writeln("Loading VGG16 model weights.");
69+
vgg.loadPyTorchDump(modelDir, false);
70+
writeln("Loaded VGG16 model.");
6871

6972

70-
var files = args[1..];
73+
var files = args[1..];
7174

72-
for f in files {
73-
var (topArr, percent) = run(vgg, f);
74-
writeln("For '", f, "' the top ", k, " predictions are: ");
75-
for i in 0..<k {
76-
writef(" %?: label=%?; confidence=%2.2r%%\n", i, labels[topArr[i]], percent[i]);
75+
for f in files {
76+
var (topArr, percent) = run(vgg, f);
77+
writeln("For '", f, "' the top ", k, " predictions are: ");
78+
for i in 0..<k {
79+
writef(" %?: label=%?; confidence=%2.2r%%\n", i, labels[topArr[i]], percent[i]);
80+
}
81+
writeln();
7782
}
78-
writeln();
79-
}
8083

8184
}

lib/Bridge.chpl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ module Bridge {
66
use Utilities.Standard;
77
use Allocators;
88

9+
use CTypes;
10+
911

1012
extern record bridge_tensor_t {
1113
var data: c_ptr(real(32));
@@ -31,6 +33,17 @@ module Bridge {
3133
}
3234
}
3335

36+
extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));
37+
38+
// extern proc load_tensor_from_file(file_path: c_ptrConst(u_char)): bridge_tensor_t; // Working
39+
40+
// extern proc load_tensor_from_file(const ref file_path: uint(8)): bridge_tensor_t;
41+
// extern proc load_tensor_from_file(file_path: c_ptrConst(c_uchar)): bridge_tensor_t;
42+
43+
// extern proc load_tensor_from_file(const file_path: c_ptr(uint(8))): bridge_tensor_t; // also working
44+
45+
extern proc load_tensor_from_file(const file_path: c_ptr(uint(8))): bridge_tensor_t;
46+
3447

3548
extern proc convolve2d(
3649
in input: bridge_tensor_t,
@@ -56,8 +69,6 @@ module Bridge {
5669
in dilation: int(32)): bridge_tensor_t;
5770

5871

59-
extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));
60-
6172

6273
proc getSizeArray(const ref arr: [] ?eltType): [] int(32) {
6374
var sizes: [0..<arr.rank] int(32);

lib/NDArray.chpl

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import ChapelArray;
44
import Math;
55
import Random;
66
import IO;
7+
import Path;
78

89
use Env;
910

@@ -2284,6 +2285,27 @@ proc ref ndarray.saveImage(imagePath: string) where rank == 3 {
22842285
Image.writeImage(imagePath,format=imgType,pixels=pixelData);
22852286
}
22862287

2288+
proc ref ndarray.loadChData(fr: IO.fileReader(?)) throws {
2289+
var r = fr.read(int);
2290+
if r != rank then
2291+
err("Error reading tensor: rank mismatch.", r , " != this." , rank);
2292+
var s = this.shape;
2293+
for i in 0..#rank do
2294+
s[i] = fr.read(int);
2295+
var d = util.domainFromShape((...s));
2296+
this._domain = d;
2297+
// for i in d do
2298+
// this.data[i] = fr.read(eltType);
2299+
fr.read(this.data);
2300+
}
2301+
2302+
proc type ndarray.loadPyTorchTensor(param rank: int,in filePath: string,type eltType = defaultEltType): ndarray(rank,eltType) {
2303+
use CTypes;
2304+
const fpPtr: c_ptr(uint(8)) = c_ptrTo(filePath);
2305+
var th = Bridge.load_tensor_from_file(fpPtr);
2306+
return ndarray.fromBridgeTensor(rank,th) : ndarray(rank,eltType);
2307+
}
2308+
22872309
// For printing.
22882310
proc ndarray.serialize(writer: IO.fileWriter(locking=false, IO.defaultSerializer),ref serializer: IO.defaultSerializer) throws {
22892311

@@ -2300,17 +2322,26 @@ proc ndarray.serialize(writer: IO.fileWriter(locking=false, IO.defaultSerializer
23002322
}
23012323

23022324
proc ref ndarray.read(fr: IO.fileReader(?)) throws {
2303-
var r = fr.read(int);
2304-
if r != rank then
2305-
err("Error reading tensor: rank mismatch.", r , " != this." , rank);
2306-
var s = this.shape;
2307-
for i in 0..#rank do
2308-
s[i] = fr.read(int);
2309-
var d = util.domainFromShape((...s));
2310-
this._domain = d;
2311-
// for i in d do
2312-
// this.data[i] = fr.read(eltType);
2313-
fr.read(this.data);
2325+
2326+
const file = fr.getFile();
2327+
const filePath: string = file.path;
2328+
const (_,fileName,fileExt) = util.splitPathParts(filePath);
2329+
2330+
select fileExt {
2331+
when "chdata" do
2332+
this.loadChData(fr);
2333+
when "png" do
2334+
this = ndarray.loadImage(filePath,eltType);
2335+
when "jpg" do
2336+
this = ndarray.loadImage(filePath,eltType);
2337+
when "jpeg" do
2338+
this = ndarray.loadImage(filePath,eltType);
2339+
when "bmp" do
2340+
this = ndarray.loadImage(filePath,eltType);
2341+
}
2342+
2343+
2344+
23142345
}
23152346

23162347

0 commit comments

Comments
 (0)