1
1
import asyncio
2
2
import base64
3
3
import logging
4
- import os
5
- from asyncio import Future
4
+ import zipfile
5
+ from pathlib import Path
6
6
7
7
import numpy
8
- import yaml
8
+ import torch
9
+ from imageio import imread
9
10
10
- from typing import List , Optional , Tuple , Awaitable
11
-
12
- from tiktorch .types import SetDeviceReturnType , NDArray
13
- from tiktorch .server import TikTorchServer
14
- from tiktorch .rpc import Shutdown , RPCFuture
11
+ from tiktorch .server .reader import eval_model
15
12
16
13
logger = logging .getLogger (__name__ )
17
14
@@ -34,13 +31,32 @@ async def showDialog(self, *args, **kwargs) -> None:
34
31
35
32
class ImJoyPlugin :
36
33
def setup (self ) -> None :
37
- self .server = TikTorchServer ()
34
+ with zipfile .ZipFile ("/g/kreshuk/beuttenm/Desktop/unet2d.model.zip" , "r" ) as model_zip : # todo: configure path
35
+ self .exemplum = eval_model (
36
+ model_file = model_zip , devices = [f"cuda:{ i } " for i in range (torch .cuda .device_count ())] + ["cpu" ]
37
+ )
38
+
38
39
self .window = None
39
40
api .log ("initialized" )
40
41
41
42
async def run (self , ctx ) -> None :
42
- ctx .config .image_path = "/Users/fbeut/Downloads/chair.png"
43
- with open (ctx .config .image_path , "rb" ) as f :
43
+ image_path = Path ("/g/kreshuk/beuttenm/data/cremi/sneak.png" ) # todo: configure path
44
+ try :
45
+ await self .show_png (image_path )
46
+ except Exception as e :
47
+ logger .error (e )
48
+
49
+ assert image_path .exists ()
50
+ img = imread (str (image_path ))
51
+ assert img .shape [2 ] == 4
52
+ batch = img [None , :512 , :512 , 0 ] # cyx
53
+
54
+ prediction = self .exemplum .forward (batch )
55
+
56
+ self .show_numpy (prediction )
57
+
58
+ async def show_png (self , png_path : Path ):
59
+ with png_path .open ("rb" ) as f :
44
60
data = f .read ()
45
61
result = base64 .b64encode (data ).decode ("ascii" )
46
62
@@ -54,168 +70,8 @@ async def run(self, ctx) -> None:
54
70
self .window = await api .createWindow (data_plot )
55
71
print (f"Window created" )
56
72
57
- assert False
58
- # todo: remvoe this (set through ui)
59
- ctx .config .config_folder = "/repos/tiktorch/tests/data/CREMI_DUNet_pretrained_new"
60
- available_devices = self .server .get_available_devices ()
61
- api .log (f"available devices: { available_devices } " )
62
- self .config = ctx .config
63
- await self ._choose_devices (available_devices )
64
-
65
- async def _choose_devices (self , available_devices ) -> None :
66
- device_switch_template = {
67
- "type" : "switch" ,
68
- "label" : "Device" ,
69
- "model" : "status" ,
70
- "multi" : True ,
71
- "readonly" : False ,
72
- "featured" : False ,
73
- "disabled" : False ,
74
- "default" : False ,
75
- "textOn" : "Selected" ,
76
- "textOff" : "Not Selected" ,
77
- }
78
-
79
- def fill_template (update : dict ):
80
- ret = dict (device_switch_template )
81
- ret .update (update )
82
- return ret
83
-
84
- choose_devices_schema = {"fields" : [fill_template ({"model" : d [0 ], "label" : d [1 ]}) for d in available_devices ]}
85
- self .device_dialog = await api .showDialog (
86
- {
87
- "name" : "Select from available devices" ,
88
- "type" : "SchemaIO" ,
89
- "w" : 20 ,
90
- "h" : 3 * len (available_devices ),
91
- "data" : {
92
- "title" : f"Select devices for TikTorch server" ,
93
- "schema" : choose_devices_schema ,
94
- "model" : {},
95
- "callback" : self ._choose_devices_callback ,
96
- "show" : True ,
97
- "formOptions" : {"validateAfterLoad" : True , "validateAfterChanged" : True },
98
- "id" : 0 ,
99
- },
100
- }
101
- )
102
- # self.device_dialog.onClose(self._choose_devices_close_callback)
103
-
104
- # def _choose_devices_close_callback(self) -> None:
105
- # api.log("select device dialog closed")
106
- # self._chosen_devices = []
107
- @staticmethod
108
- async def _on_upload_change (model , schema , event ):
109
- api .log (str ((model , schema , event )))
110
-
111
- async def _choose_devices_callback (self , data ) -> None :
112
- api .log ("before chosen devices callback" )
113
- chosen_devices = [d for d , selected in data .items () if selected ]
114
- api .log (f"chosen devices callback: { chosen_devices } " )
115
- self .device_dialog .close ()
116
- self .server_devices = self ._load_model (chosen_devices )
117
- forward_schema = {
118
- "fields" : [
119
- {
120
- "type" : "upload" ,
121
- "label" : "Photo" ,
122
- "model" : "photo" ,
123
- "inputName" : "photo" ,
124
- "onChanged" : self ._on_upload_change ,
125
- },
126
- # {
127
- # "type": "switch",
128
- # "label": "image",
129
- # "model": "path",
130
- # "multi": True,
131
- # "readonly": False,
132
- # "featured": False,
133
- # "disabled": False,
134
- # "default": False,
135
- # "textOn": "Selected",
136
- # "textOff": "Not Selected",
137
- # },
138
- ]
139
- }
140
- self .data_dialog = await api .showDialog (
141
- {
142
- "name" : "Inference" ,
143
- "type" : "SchemaIO" ,
144
- "w" : 40 ,
145
- "h" : 15 ,
146
- "data" : {
147
- "title" : "Inference" ,
148
- "schema" : forward_schema ,
149
- "model" : {},
150
- "callback" : self ._new_user_input ,
151
- "show" : True ,
152
- "formOptions" : {"validateAfterLoad" : True , "validateAfterChanged" : True },
153
- "id" : 0 ,
154
- },
155
- }
156
- )
157
-
158
- def _load_model (self , chosen_devices ) -> Awaitable [SetDeviceReturnType ]:
159
- # todo: select individual files through gui
160
- # load config
161
- config_file_name = os .path .join (self .config .config_folder , "tiktorch_config.yml" )
162
- if not os .path .exists (config_file_name ):
163
- raise FileNotFoundError (f"Config file not found at: { config_file_name } ." )
164
-
165
- with open (config_file_name , "r" ) as f :
166
- tiktorch_config = yaml .load (f , Loader = yaml .SafeLoader )
167
-
168
- # Read model.py
169
- file_name = os .path .join (self .config .config_folder , "model.py" )
170
- if not os .path .exists (file_name ):
171
- raise FileNotFoundError (f"Model file not found at: { file_name } ." )
172
-
173
- with open (file_name , "rb" ) as f :
174
- binary_model_file = f .read ()
175
-
176
- # Read model and optimizer states if they exist
177
- binary_states = []
178
- for file_name in ["state.nn" , "optimizer.nn" ]:
179
- file_name = os .path .join (self .config .config_folder , file_name )
180
- if os .path .exists (file_name ):
181
- with open (file_name , "rb" ) as f :
182
- binary_states .append (f .read ())
183
- else :
184
- binary_states .append (b"" )
185
-
186
- return asyncio .wrap_future (
187
- self .server .load_model (tiktorch_config , binary_model_file , * binary_states , devices = chosen_devices ),
188
- loop = asyncio .get_event_loop (),
189
- )
190
-
191
- async def _new_user_input (self , data ):
192
- api .log (str (data ))
193
- # data_plot = {
194
- # 'name':'Plot charts: show png',
195
- # 'type':'imjoy/image',
196
- # 'w':12, 'h':15,
197
- # 'data':data}
198
- #
199
- # ## Check if window was defined
200
- # if self.window is None:
201
- # self.window = await api.createWindow(data_plot)
202
- # print(f'Window created')
203
-
204
- async def forward (
205
- self , data : numpy .ndarray , id_ : Optional [Tuple ] = None
206
- ) -> Awaitable [Tuple [numpy .ndarray , Optional [Tuple ]]]:
207
- await self .server_devices
208
- tikfut = self .server .forward (NDArray (data , id_ = id_ ))
209
- return asyncio .wrap_future (tikfut .map (lambda x : (x .as_numpy (), id_ )))
210
-
211
- async def exit (self ):
212
- api .log ("shutting down..." )
213
- try :
214
- self .server .shutdown ()
215
- except Shutdown :
216
- api .log ("shutdown successful" )
217
- else :
218
- api .log ("shutdown failed" )
73
+ def show_numpy (self , data : numpy .ndarray ):
74
+ print (data )
219
75
220
76
221
77
if __name__ == "__main__" :
@@ -236,4 +92,3 @@ class Ctx:
236
92
plugin = ImJoyPlugin ()
237
93
plugin .setup ()
238
94
loop .run_until_complete (plugin .run (ctx ))
239
- loop .run_until_complete (plugin .exit ())
0 commit comments