-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robodummy #57
base: main
Are you sure you want to change the base?
Robodummy #57
Changes from 11 commits
5b8c573
04e3fd9
73e4c47
2dec996
3d66cbc
7bf8f81
d157d9d
39fb936
52d0114
83316c2
8814e98
eb19bb3
8220b95
35af962
ae97dbd
3b8b7d7
84d61cb
28640d6
997c4eb
d355deb
a9cfecd
fd1eb97
cba7191
f265faf
fcce3f2
4daa134
6be7a82
f38692a
27c4f43
d1a1cdc
064c27c
4c6d9cb
3940c55
5da186d
f45d619
3db31e9
79a77da
6bd9d7e
055d08f
3eec9c8
8e907e3
c6ddc1d
617a4df
39a674d
9050f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,6 @@ tiktorch/.idea | |
tiktorch/__pycache/ | ||
/#wrapper.py# | ||
/.#wrapper.py# | ||
.py~ | ||
.py~ | ||
*.nn | ||
*.hdf | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import yaml | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete this file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also |
||
|
||
with open("tests/data/CREMI_DUNet_pretrained_new/robot_config.yml") as f: | ||
config_dict = yaml.load(f) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as f | ||
from sklearn.metrics import mean_squared_error | ||
import zipfile | ||
import h5py | ||
import z5py | ||
from z5py.converter import convert_from_h5 | ||
from scipy.ndimage import convolve | ||
from torch.autograd import Variable | ||
from collections import OrderedDict | ||
import yaml | ||
import logging | ||
from abc import ABC, abstractmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This stuff should be introduced at later stages of the project if at all. It infects code and spreads like a plague, without improving developer experience. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, can you please elaborate a bit on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the benefits of abs usage compared to |
||
from tiktorch.server import TikTorchServer | ||
from tiktorch.rpc import Client, Server, InprocConnConf | ||
from tiktorch.rpc_interface import INeuralNetworkAPI | ||
from tiktorch.types import NDArray, NDArrayBatch | ||
from tests.conftest import nn_sample | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should not be there |
||
from mr_robot.utils import * | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend avoiding |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe sort the import statements a little don't mix import... and from... too much There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
patch_size = 16 | ||
img_dim = 32 | ||
|
||
|
||
class MrRobot: | ||
def __init__(self, path_to_config_file, strategy): | ||
# start the server | ||
self.new_server = TikTorchServer() | ||
self.strategy =strategy | ||
|
||
with open(path_to_config_file, mode="r") as f: | ||
self.base_config = yaml.load(f) | ||
|
||
self.max_robo_iterations = self.base_config['max_robo_iterations'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. including |
||
self.counter = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
if that is what you intend |
||
self.logger = logging.getLogger(__name__) | ||
|
||
def load_data(self): | ||
self.f = z5py.File(self.base_config["cremi_data_dir"]) | ||
self.logger('data file loaded') | ||
|
||
|
||
def load_model(self): | ||
# load the model | ||
|
||
|
||
#with open(base_config['cremi_data_dir'], mode="rb") as f: | ||
# binary_state = f.read() | ||
|
||
archive = zipfile.ZipFile(self.base_config['cremi_dir']['path_to_zip'], 'r') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should run black on your code (this will convert |
||
model = archive.read(self.base_config['cremi_dir']['path_in_zip_to_model']) | ||
binary_state = archive.read(self.base_config['cremi_dir']['path_in_zip_to_state']) | ||
|
||
#cleaning dictionary before passing to tiktorch | ||
self.base_config.pop('cremi_dir') | ||
self.base_config.pop('cremi_data') | ||
self.base_config.pop('cremi_path_to_labelled') | ||
|
||
#with open("model.py", mode="rb") as f: | ||
# model_file = f.read() | ||
|
||
fut = self.new_server.load_model(base_config, model, binary_state, b"", ["cpu"]) | ||
self.logger.info("model loaded") | ||
|
||
def resume(self): | ||
self.new_server.resume_training() | ||
self.logger.info("training resumed") | ||
|
||
def predict(self): | ||
self.ip = self.f["volume"][0:1, 0:img_dim, 0:img_dim] | ||
# self.label = np.expand_dims(self.f['volumes/labels/neuron_ids'][0,0:img_dim,0:img_dim], axis=0) | ||
self.op = self.new_server.forward(self.ip) | ||
self.op = self.op.result().as_numpy() | ||
#self.logger.info("prediction run") | ||
|
||
def stop(self): | ||
if(self.counter > self.max_robo_iterations): | ||
return False | ||
else: | ||
self.counter+=1 | ||
return True | ||
|
||
def run(self): | ||
self.strategy.patch('MSE', self.op) | ||
while(self.stop()): | ||
idx = self.strategy.get_next_patch(self.op) | ||
self.add(idx) | ||
|
||
def add(self, idx): | ||
file = z5py.File(self.base_config["cremi_data"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to open this file every time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
labels = file["cremi_path_to_labelled"][0:1, 0:img_dim, 0:img_dim] | ||
|
||
new_ip = self.ip.as_numpy()[idx[0]:idx[1], idx[2]:idx[3], idx[4]:idx[5]].astype(float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should not hardcode that the data is 3 dimensional, use tuples to index instead |
||
new_label = labels[ idx[0]:idx[1], idx[2]:idx[3], idx[4]:idx[5] ].astype(float) | ||
# print(ip.dtype, label.dtype) | ||
self.new_server.update_training_data(NDArrayBatch([NDArray(new_ip)]), NDArrayBatch([new_label])) | ||
|
||
# annotate worst patch | ||
def dense_annotate(self, x, y, label, image): | ||
raise NotImplementedError | ||
|
||
def terminate(self): | ||
self.new_server.shutdown() | ||
|
||
|
||
class BaseStrategy(ABC): | ||
|
||
def __init__(self, path_to_config_file): | ||
with open(path_to_config_file, mode="r") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.base_config = yaml.load(f) | ||
#self.op = op | ||
self.logger = logging.getLogger(__name__) | ||
|
||
def loss(self,tile,label, loss_fn): | ||
label = label[0] | ||
tile = tile[0] | ||
result = mean_squared_error(label, tile) # CHECK THIS | ||
return result | ||
|
||
def base_patch(self, loss_fn, op): | ||
idx = tile_image(op.shape, patch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be great if you could add some doc strings to communicate what your methods (and classes) are for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also: no need to call this method |
||
file = z5py.File(self.base_config["cremi_data"]) | ||
labels = file["cremi_path_to_labelled"][0:1, 0:img_dim, 0:img_dim] | ||
|
||
self.patch_data = [] | ||
for i in range(len(idx)): | ||
curr_loss = self.loss( | ||
op[idx[i][0] : idx[i][1], idx[i][2] : idx[i][3], idx[i][4] : idx[i][5]], | ||
labels[idx[i][0] : idx[i][1], idx[i][2] : idx[i][3], idx[i][4] : idx[i][5]], | ||
loss_fn | ||
) | ||
|
||
self.patch_data.append((curr_loss,idx[i])) | ||
self.logger.info("loss for patch %d: %d" % (i,curr_loss) ) | ||
|
||
|
||
@abstractmethod | ||
def get_next_patch(self): | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
|
||
|
||
class Strategy1(BaseStrategy): | ||
|
||
def __init__(self, path_to_config_file): | ||
super().__init__(path_to_config_file) | ||
self.counter=-1 | ||
|
||
def patch(self, loss_fn, op): | ||
super().base_patch(loss_fn,op) | ||
self.patch_data.sort(reverse = True) | ||
|
||
def get_next_patch(self): | ||
self.counter+=1 | ||
return self.patch_data[self.counter][1] | ||
|
||
|
||
class Strategy2(BaseStrategy): | ||
def __init__(): | ||
super().__init__() | ||
|
||
|
||
class Strategy3(BaseStrategy): | ||
def __init__(): | ||
super().__init__() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#base config for robot | ||
max_robo_iterations: 10 | ||
model_class_name: DUNet2D | ||
model_init_kwargs: {in_channels: 1, out_channels: 1} | ||
training: { | ||
training_shape: [1, 32, 32], | ||
batch_size: 1, | ||
loss_criterion_config: {"method": "MSELoss"}, | ||
optimizer_config: {"method": "Adam"}, | ||
num_iterations_done: 1 | ||
} | ||
validation: {} | ||
dry_run: { | ||
"skip": True, | ||
"shrinkage": [0, 0, 0] | ||
} | ||
|
||
cremi_dir: { | ||
path_to_zip: D:/Machine Learning/tiktorch/tests/data/CREMI_DUNet_pretrained_new.zip, | ||
path_in_zip_to_model: CREMI_DUNet_pretrained_new/model.py, | ||
path_in_zip_to_state: CREMI_DUNet_pretrained_new/state.nn | ||
} | ||
|
||
cremi_data: D:/Machine Learning/tiktorch/mr_robot/train.n5 | ||
cremi_path_to_labelled: volumes/labels/neuron_ids |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
## utility functions for the robot ## | ||
# | ||
def summary(model, input_size, batch_size=-1, device="cuda"): | ||
def register_hook(module): | ||
def hook(module, input, output): | ||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||
module_idx = len(summary) | ||
|
||
m_key = "%s-%i" % (class_name, module_idx + 1) | ||
summary[m_key] = OrderedDict() | ||
summary[m_key]["input_shape"] = list(input[0].size()) | ||
summary[m_key]["input_shape"][0] = batch_size | ||
if isinstance(output, (list, tuple)): | ||
summary[m_key]["output_shape"] = [[-1] + list(o.size())[1:] for o in output] | ||
else: | ||
summary[m_key]["output_shape"] = list(output.size()) | ||
summary[m_key]["output_shape"][0] = batch_size | ||
|
||
params = 0 | ||
if hasattr(module, "weight") and hasattr(module.weight, "size"): | ||
params += torch.prod(torch.LongTensor(list(module.weight.size()))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems strange, it uses |
||
summary[m_key]["trainable"] = module.weight.requires_grad | ||
if hasattr(module, "bias") and hasattr(module.bias, "size"): | ||
params += torch.prod(torch.LongTensor(list(module.bias.size()))) | ||
summary[m_key]["nb_params"] = params | ||
|
||
if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model): | ||
hooks.append(module.register_forward_hook(hook)) | ||
|
||
device = device.lower() | ||
assert device in ["cuda", "cpu"], "Input device is not valid, please specify 'cuda' or 'cpu'" | ||
|
||
if device == "cuda" and torch.cuda.is_available(): | ||
dtype = torch.cuda.FloatTensor | ||
else: | ||
dtype = torch.FloatTensor | ||
|
||
# multiple inputs to the network | ||
if isinstance(input_size, tuple): | ||
input_size = [input_size] | ||
|
||
# batch_size of 2 for batchnorm | ||
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] | ||
# print(type(x[0])) | ||
|
||
# create properties | ||
summary = OrderedDict() | ||
hooks = [] | ||
|
||
# register hook | ||
model.apply(register_hook) | ||
|
||
# make a forward pass | ||
# print(x.shape) | ||
model(*x) | ||
|
||
# remove these hooks | ||
for h in hooks: | ||
h.remove() | ||
|
||
print("----------------------------------------------------------------") | ||
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") | ||
print(line_new) | ||
print("================================================================") | ||
total_params = 0 | ||
total_output = 0 | ||
trainable_params = 0 | ||
for layer in summary: | ||
# input_shape, output_shape, trainable, nb_params | ||
line_new = "{:>20} {:>25} {:>15}".format( | ||
layer, str(summary[layer]["output_shape"]), "{0:,}".format(summary[layer]["nb_params"]) | ||
) | ||
total_params += summary[layer]["nb_params"] | ||
total_output += np.prod(summary[layer]["output_shape"]) | ||
if "trainable" in summary[layer]: | ||
if summary[layer]["trainable"] == True: | ||
trainable_params += summary[layer]["nb_params"] | ||
print(line_new) | ||
|
||
# assume 4 bytes/number (float on cuda). | ||
total_input_size = abs(np.prod(input_size) * batch_size * 4.0 / (1024 ** 2.0)) | ||
total_output_size = abs(2.0 * total_output * 4.0 / (1024 ** 2.0)) # x2 for gradients | ||
total_params_size = abs(total_params.numpy() * 4.0 / (1024 ** 2.0)) | ||
total_size = total_params_size + total_output_size + total_input_size | ||
|
||
print("================================================================") | ||
print("Total params: {0:,}".format(total_params)) | ||
print("Trainable params: {0:,}".format(trainable_params)) | ||
print("Non-trainable params: {0:,}".format(total_params - trainable_params)) | ||
print("----------------------------------------------------------------") | ||
print("Input size (MB): %0.2f" % total_input_size) | ||
print("Forward/backward pass size (MB): %0.2f" % total_output_size) | ||
print("Params size (MB): %0.2f" % total_params_size) | ||
print("Estimated Total Size (MB): %0.2f" % total_size) | ||
print("----------------------------------------------------------------") | ||
|
||
|
||
# ref: https://github.com/constantinpape/vis_tools/blob/master/vis_tools/edges.py#L5 | ||
def make_edges3d(segmentation): | ||
FynnBe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" Make 3d edge volume from 3d segmentation | ||
""" | ||
# NOTE we add one here to make sure that we don't have zero in the segmentation | ||
gz = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(3, 1, 1)) | ||
gy = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 3, 1)) | ||
gx = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 1, 3)) | ||
return (gx ** 2 + gy ** 2 + gz ** 2) > 0 | ||
|
||
|
||
# create patches | ||
def tile_image(image_shape, tile_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems like a good candidate for proper tests. |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems to me that image tiling could nicely be implemented for n dimensions. Maybe have a look at https://github.com/ilastik/lazyflow/blob/dfbb450989d4f790f5b19170383b777fb88be0e8/lazyflow/roi.py#L473 for some inspiration |
||
tiles = [] | ||
(w, h) = image_shape[len(image_shape) - 2], image_shape[len(image_shape) - 1] | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = [wsi, wsi + tile_size, hsi, hsi + tile_size] | ||
tiles.append(img) | ||
|
||
if h % tile_size != 0: | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
img = [wsi, wsi + tile_size, h - tile_size, h] | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0: | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = [w - tile_size, w, hsi, hsi + tile_size] | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0 and h % tile_size != 0: | ||
img = [w - tile_size, w, h - tile_size, h] | ||
tiles.append(img) | ||
|
||
x = [] | ||
for i in range(len(image_shape) - 2): | ||
x.append([0, image_shape[i]]) | ||
|
||
for i in range(len(tiles)): | ||
tiles[i] = x + tiles[i] | ||
|
||
return tiles |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
from mr_robot.mr_robot import MrRobot | ||
from tiktorch.server import TikTorchServer | ||
from tiktorch.rpc import RPCFuture | ||
from tiktorch.types import SetDeviceReturnType | ||
|
||
import z5py | ||
|
||
def test_MrRobot(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test won't run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not? i have been using it to test the whole afternoon.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe some commits are not pushed to your repo then:
this is what i get |
||
robo = MrRobot() | ||
assert isinstance(robo,MrRobot) | ||
assert isinstance(robo.new_server, TikTorchServer) | ||
|
||
file = robo.load_data() | ||
assert isinstance(file, z5py.file.File) | ||
|
||
fut = robo.load_model() | ||
assert isinstance(fut, RPCFuture) | ||
|
||
robo.resume() | ||
op = robo.predict() | ||
|
||
assert op.shape == (1,1,32,32) | ||
assert isinstance(op, np.ndarray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to ignore
.nn
and.hdf
files (as there are none in the repo). Pls remove