-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robodummy #57
base: main
Are you sure you want to change the base?
Robodummy #57
Changes from 1 commit
5b8c573
04e3fd9
73e4c47
2dec996
3d66cbc
7bf8f81
d157d9d
39fb936
52d0114
83316c2
8814e98
eb19bb3
8220b95
35af962
ae97dbd
3b8b7d7
84d61cb
28640d6
997c4eb
d355deb
a9cfecd
fd1eb97
cba7191
f265faf
fcce3f2
4daa134
6be7a82
f38692a
27c4f43
d1a1cdc
064c27c
4c6d9cb
3940c55
5da186d
f45d619
3db31e9
79a77da
6bd9d7e
055d08f
3eec9c8
8e907e3
c6ddc1d
617a4df
39a674d
9050f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,10 @@ | |
import torch.nn as nn | ||
import torch.nn.functional as f | ||
from sklearn.metrics import mean_squared_error | ||
from model import DUNet2D | ||
#from model import DUNet2D | ||
import h5py | ||
import z5py | ||
from z5py.converter import convert_from_h5 | ||
from scipy.ndimage import convolve | ||
from torch.autograd import Variable | ||
from collections import OrderedDict | ||
|
@@ -17,17 +19,21 @@ | |
from utils import * | ||
|
||
patch_size = 16 | ||
|
||
img_dim = 32 | ||
|
||
class MrRobot: | ||
def __init__(self): | ||
# start the server | ||
self.new_server = TikTorchServer() | ||
|
||
|
||
def load_data(self): | ||
with h5py.File("train.hdf", "r") as f: | ||
x = np.array(f.get("volumes/labels/neuron_ids")) | ||
y = np.array(f.get("volumes/raw")) | ||
self.f = z5py.File('train.n5') | ||
return self.f | ||
""" | ||
#with h5py.File("train.hdf", "r") as f: | ||
# x = np.array(f.get("volumes/labels/neuron_ids")) | ||
# y = np.array(f.get("volumes/raw")) | ||
|
||
self.labels = [] | ||
self.ip = [] | ||
|
@@ -40,6 +46,7 @@ def load_data(self): | |
self.ip = NDArray(np.asarray(self.ip)[:, :, 0:patch_size, 0:patch_size]) | ||
print("data loaded") | ||
return (ip, labels) | ||
""" | ||
|
||
def load_model(self): | ||
# load the model | ||
|
@@ -60,10 +67,12 @@ def resume(self): | |
print("training resumed") | ||
|
||
def predict(self): | ||
self.op = new_server.forward(self.ip) | ||
self.ip = np.expand_dims(self.f['volume'][0,0:img_dim, 0:img_dim], axis = 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, variable names need some polishing. They should be descriptive and have a clear scope. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of taking the first slice [0, ...] and then expanding the resulting array, you should simplify to take a slice right away: |
||
#self.label = np.expand_dims(self.f['volumes/labels/neuron_ids'][0,0:img_dim,0:img_dim], axis=0) | ||
self.op = self.new_server.forward(self.ip) | ||
self.op = op.result().as_numpy() | ||
print("prediction run") | ||
return (self.op, self.labels) | ||
return self.op | ||
|
||
def add(self, row, column): | ||
self.ip = self.ip.as_numpy()[ | ||
|
@@ -79,32 +88,38 @@ def add(self, row, column): | |
def dense_annotate(self, x, y, label, image): | ||
raise NotImplementedError | ||
|
||
def terminate(): | ||
new_server.shutdown() | ||
def terminate(self): | ||
self.new_server.shutdown() | ||
|
||
|
||
class BaseStrategy: | ||
def __init__(): | ||
raise NotImplementedError | ||
def __init__(self, file, op): | ||
self.f = file | ||
self.op = op | ||
|
||
# compute loss for a given patch | ||
def base_loss(self, patch, label): | ||
label = label[0][0] | ||
patch = patch[0][0] | ||
result = mean_squared_error(label, patch) # CHECK THIS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the criterion should be configurable |
||
return result | ||
|
||
|
||
class Strategy1(BaseStrategy): | ||
def __init__(self, op, labels): | ||
super().__init__() | ||
pred_idx = tile_image2D(op[0, 0].shape, 16) | ||
actual_idx = tile_image2D(labels[0, 0].shape, 16) | ||
w, h, self.row, self.column = 32, 32, -1, -1 | ||
def __init__(self, file, op): | ||
super().__init__(file,op) | ||
|
||
def run(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer the robot class to perform the 'run', not the strategy. The strategy should effectively implement a sampling strategy. I see this analog to the pytorch sampler. |
||
idx = tile_image(self.op.shape, patch_size) | ||
label = np.expand_dims(self.f['volumes/labels/neuron_ids'][0,0:img_dim,0:img_dim], axis=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same indexing as in predict method |
||
#idx = tile_image(label.shape, patch_size) | ||
w, h, self.row, self.column = img_dim, img_dim, -1, -1 | ||
error = 1e7 | ||
for i in range(len(pred_patches)): | ||
for i in range(len(idx)): | ||
# print(pred_patches[i].shape, actual_patches[i].shape) | ||
curr_loss = self.loss( | ||
op[0, 0, pred_idx[i][0] : pred_idx[i][1], pred_idx[i][2] : pred_idx[i][3]], | ||
labels[0, 0, actual_idx[i][0] : actual_idx[i][1], actual_idx[i][2] : actual_idx[i][3]], | ||
curr_loss = super().base_loss( | ||
self.op[idx[i][0]: idx[i][1], idx[i][2]:idx[i][3], idx[i][4] : idx[i][5], idx[i][6] : idx[i][7]], | ||
labels[idx[i][0]: idx[i][1], idx[i][2]:idx[i][3], idx[i][4] : idx[i][5], idx[i][6] : idx[i][7]], | ||
) | ||
print(curr_loss) | ||
if error > curr_loss: | ||
|
@@ -128,18 +143,20 @@ def __init__(): | |
|
||
|
||
if __name__ == "__main__": | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the following code should be inside of MrRobot. Currently you mirror parts of the tiktorch api in MrRobot (methods: resume, predict, add). This is fine for convenience, etc, but in it's core MrRobot should implement the way of running a 'user simulation' |
||
robo = MrRobot() | ||
robo.load_data() | ||
file = robo.load_data() | ||
robo.load_model() | ||
robo.resume() # resume training | ||
|
||
# run prediction | ||
op, label = robo.predict() | ||
op = robo.predict() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I think algorithm should be read as follows: # Step 1. Intialization
robo = MrRobot('/home/user/config.yaml') # Here robot loads all required data
robo.use_strategy(StrategyRandom())
# or even
robo = MrRobot('/home/user/config.yaml', StrategyRandom)
# Step 2. Start
robo.start() # Start tiktorch server
# Step 3. Prediction loop
while robo.should_stop():
robo.predict()
# def robo.predict
# 1. labels? = self.strategy.get_next_patch(<relevant data>)
# 2. self.update_training(labels, ...)
# Step 4. Termination
robo.terminate()
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I'd vote for
|
||
metric = Strategy1(op, label) | ||
metric = Strategy1(file, op) | ||
metric.run() | ||
row, column = metric.get_patch() | ||
robo.add(row, column) | ||
|
||
# shut down server | ||
robo.terminate() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe sort the import statements a little don't mix import... and from... too much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe https://github.com/thijsdezoete/sublime-text-isort-plugin/blob/master/README.md