Skip to content

Commit 513f025

Browse files
committed
add imjoy plugin draft
1 parent 70a1949 commit 513f025

File tree

1 file changed

+29
-174
lines changed

1 file changed

+29
-174
lines changed

tiktorch/imjoy.py

+29-174
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
import asyncio
22
import base64
33
import logging
4-
import os
5-
from asyncio import Future
4+
import zipfile
5+
from pathlib import Path
66

77
import numpy
8-
import yaml
8+
import torch
9+
from imageio import imread
910

10-
from typing import List, Optional, Tuple, Awaitable
11-
12-
from tiktorch.types import SetDeviceReturnType, NDArray
13-
from tiktorch.server import TikTorchServer
14-
from tiktorch.rpc import Shutdown, RPCFuture
11+
from tiktorch.server.reader import eval_model
1512

1613
logger = logging.getLogger(__name__)
1714

@@ -34,13 +31,32 @@ async def showDialog(self, *args, **kwargs) -> None:
3431

3532
class ImJoyPlugin:
3633
def setup(self) -> None:
37-
self.server = TikTorchServer()
34+
with zipfile.ZipFile("/g/kreshuk/beuttenm/Desktop/unet2d.model.zip", "r") as model_zip: # todo: configure path
35+
self.exemplum = eval_model(
36+
model_file=model_zip, devices=[f"cuda:{i}" for i in range(torch.cuda.device_count())] + ["cpu"]
37+
)
38+
3839
self.window = None
3940
api.log("initialized")
4041

4142
async def run(self, ctx) -> None:
42-
ctx.config.image_path = "/Users/fbeut/Downloads/chair.png"
43-
with open(ctx.config.image_path, "rb") as f:
43+
image_path = Path("/g/kreshuk/beuttenm/data/cremi/sneak.png") # todo: configure path
44+
try:
45+
await self.show_png(image_path)
46+
except Exception as e:
47+
logger.error(e)
48+
49+
assert image_path.exists()
50+
img = imread(str(image_path))
51+
assert img.shape[2] == 4
52+
batch = img[None, :512, :512, 0] # cyx
53+
54+
prediction = self.exemplum.forward(batch)
55+
56+
self.show_numpy(prediction)
57+
58+
async def show_png(self, png_path: Path):
59+
with png_path.open("rb") as f:
4460
data = f.read()
4561
result = base64.b64encode(data).decode("ascii")
4662

@@ -54,168 +70,8 @@ async def run(self, ctx) -> None:
5470
self.window = await api.createWindow(data_plot)
5571
print(f"Window created")
5672

57-
assert False
58-
# todo: remvoe this (set through ui)
59-
ctx.config.config_folder = "/repos/tiktorch/tests/data/CREMI_DUNet_pretrained_new"
60-
available_devices = self.server.get_available_devices()
61-
api.log(f"available devices: {available_devices}")
62-
self.config = ctx.config
63-
await self._choose_devices(available_devices)
64-
65-
async def _choose_devices(self, available_devices) -> None:
66-
device_switch_template = {
67-
"type": "switch",
68-
"label": "Device",
69-
"model": "status",
70-
"multi": True,
71-
"readonly": False,
72-
"featured": False,
73-
"disabled": False,
74-
"default": False,
75-
"textOn": "Selected",
76-
"textOff": "Not Selected",
77-
}
78-
79-
def fill_template(update: dict):
80-
ret = dict(device_switch_template)
81-
ret.update(update)
82-
return ret
83-
84-
choose_devices_schema = {"fields": [fill_template({"model": d[0], "label": d[1]}) for d in available_devices]}
85-
self.device_dialog = await api.showDialog(
86-
{
87-
"name": "Select from available devices",
88-
"type": "SchemaIO",
89-
"w": 20,
90-
"h": 3 * len(available_devices),
91-
"data": {
92-
"title": f"Select devices for TikTorch server",
93-
"schema": choose_devices_schema,
94-
"model": {},
95-
"callback": self._choose_devices_callback,
96-
"show": True,
97-
"formOptions": {"validateAfterLoad": True, "validateAfterChanged": True},
98-
"id": 0,
99-
},
100-
}
101-
)
102-
# self.device_dialog.onClose(self._choose_devices_close_callback)
103-
104-
# def _choose_devices_close_callback(self) -> None:
105-
# api.log("select device dialog closed")
106-
# self._chosen_devices = []
107-
@staticmethod
108-
async def _on_upload_change(model, schema, event):
109-
api.log(str((model, schema, event)))
110-
111-
async def _choose_devices_callback(self, data) -> None:
112-
api.log("before chosen devices callback")
113-
chosen_devices = [d for d, selected in data.items() if selected]
114-
api.log(f"chosen devices callback: {chosen_devices}")
115-
self.device_dialog.close()
116-
self.server_devices = self._load_model(chosen_devices)
117-
forward_schema = {
118-
"fields": [
119-
{
120-
"type": "upload",
121-
"label": "Photo",
122-
"model": "photo",
123-
"inputName": "photo",
124-
"onChanged": self._on_upload_change,
125-
},
126-
# {
127-
# "type": "switch",
128-
# "label": "image",
129-
# "model": "path",
130-
# "multi": True,
131-
# "readonly": False,
132-
# "featured": False,
133-
# "disabled": False,
134-
# "default": False,
135-
# "textOn": "Selected",
136-
# "textOff": "Not Selected",
137-
# },
138-
]
139-
}
140-
self.data_dialog = await api.showDialog(
141-
{
142-
"name": "Inference",
143-
"type": "SchemaIO",
144-
"w": 40,
145-
"h": 15,
146-
"data": {
147-
"title": "Inference",
148-
"schema": forward_schema,
149-
"model": {},
150-
"callback": self._new_user_input,
151-
"show": True,
152-
"formOptions": {"validateAfterLoad": True, "validateAfterChanged": True},
153-
"id": 0,
154-
},
155-
}
156-
)
157-
158-
def _load_model(self, chosen_devices) -> Awaitable[SetDeviceReturnType]:
159-
# todo: select individual files through gui
160-
# load config
161-
config_file_name = os.path.join(self.config.config_folder, "tiktorch_config.yml")
162-
if not os.path.exists(config_file_name):
163-
raise FileNotFoundError(f"Config file not found at: {config_file_name}.")
164-
165-
with open(config_file_name, "r") as f:
166-
tiktorch_config = yaml.load(f, Loader=yaml.SafeLoader)
167-
168-
# Read model.py
169-
file_name = os.path.join(self.config.config_folder, "model.py")
170-
if not os.path.exists(file_name):
171-
raise FileNotFoundError(f"Model file not found at: {file_name}.")
172-
173-
with open(file_name, "rb") as f:
174-
binary_model_file = f.read()
175-
176-
# Read model and optimizer states if they exist
177-
binary_states = []
178-
for file_name in ["state.nn", "optimizer.nn"]:
179-
file_name = os.path.join(self.config.config_folder, file_name)
180-
if os.path.exists(file_name):
181-
with open(file_name, "rb") as f:
182-
binary_states.append(f.read())
183-
else:
184-
binary_states.append(b"")
185-
186-
return asyncio.wrap_future(
187-
self.server.load_model(tiktorch_config, binary_model_file, *binary_states, devices=chosen_devices),
188-
loop=asyncio.get_event_loop(),
189-
)
190-
191-
async def _new_user_input(self, data):
192-
api.log(str(data))
193-
# data_plot = {
194-
# 'name':'Plot charts: show png',
195-
# 'type':'imjoy/image',
196-
# 'w':12, 'h':15,
197-
# 'data':data}
198-
#
199-
# ## Check if window was defined
200-
# if self.window is None:
201-
# self.window = await api.createWindow(data_plot)
202-
# print(f'Window created')
203-
204-
async def forward(
205-
self, data: numpy.ndarray, id_: Optional[Tuple] = None
206-
) -> Awaitable[Tuple[numpy.ndarray, Optional[Tuple]]]:
207-
await self.server_devices
208-
tikfut = self.server.forward(NDArray(data, id_=id_))
209-
return asyncio.wrap_future(tikfut.map(lambda x: (x.as_numpy(), id_)))
210-
211-
async def exit(self):
212-
api.log("shutting down...")
213-
try:
214-
self.server.shutdown()
215-
except Shutdown:
216-
api.log("shutdown successful")
217-
else:
218-
api.log("shutdown failed")
73+
def show_numpy(self, data: numpy.ndarray):
74+
print(data)
21975

22076

22177
if __name__ == "__main__":
@@ -236,4 +92,3 @@ class Ctx:
23692
plugin = ImJoyPlugin()
23793
plugin.setup()
23894
loop.run_until_complete(plugin.run(ctx))
239-
loop.run_until_complete(plugin.exit())

0 commit comments

Comments
 (0)