@@ -20,9 +20,7 @@ def __init__(self):
2020 self .end1_adj = set () # multiple other nodes, connected or close
2121 self .end2_adj = set ()
2222 self .traversed = set () # for the simplification of the problem
23-
24- self .source = None # the neuron it belongs to
25- self .likelihood = 0 # the likelihood of this fragment belonging to this neuron
23+ self .source = {} # the neuron it belongs to, source id -> likelihood
2624
2725
2826class BaseNode :
@@ -49,13 +47,24 @@ def update(self, **kwargs):
4947
5048
5149class BaseCut :
52- def __init__ (self , swc : ListNeuron , soma : list [int ], verbose = False ):
50+ def __init__ (self , swc : ListNeuron , soma : list [int ], res , likelihood_thr = None , verbose = False ):
51+ """
52+
53+ :param swc: swc tree
54+ :param soma: list of soma id
55+ :param res: resolution in x, y, z
56+ :param likelihood_thr: the minimum likelihood allowed for a fragment to be attached to a neuron, left as None
57+ to attach it to just the biggest. When multiple sources share a common or big enough likelihood, all of them will be considered.
58+ :param verbose:
59+ """
5360 self ._verbose = verbose
5461 self ._swc = dict ([(t [0 ], t ) for t in swc ])
62+ self .res = np .array (res )
5563 self ._soma = soma
5664 self ._fragment : dict [int , BaseFragment ] = {}
5765 self ._fragment_trees : dict [int , dict [int , BaseNode ]] = {}
5866 self ._problem : pulp .LpProblem | None = None
67+ self ._likelihood_thr : float = likelihood_thr
5968
6069 @property
6170 def swc (self ):
@@ -77,36 +86,59 @@ def export_swc(self, partition=True):
7786 :return: an swc or a dict of swc
7887 """
7988 if not partition :
80- tree = [list (t ) for t in self ._swc .values ()]
89+ tree = dict ([(t [0 ], list (t )) for t in self ._swc .values ()])
90+ tag = dict (zip (self ._soma , range (len (self ._soma ))))
8191 for frag in self ._fragment .values ():
8292 for i in frag .nodes :
83- tree [i ][1 ] = frag .source
84- tree = [tuple (t ) for t in tree ]
93+ a = list (frag .source .values ())
94+ b = list (frag .source .keys ())
95+ a = np .argmax (a )
96+ tree [i ][1 ] = tag [b [a ]]
97+ tree = [tuple (t ) for t in tree .values ()]
8598 return tree
8699
87100 trees = dict ([(i , {(- 1 , 1 ): (1 , * self ._swc [i ][1 :6 ], - 1 )}) for i in self ._soma ])
88101 for frag_id , frag in self ._fragment .items ():
89- frag_node = self ._fragment_trees [frag .source ][frag_id ]
90- nodes = self ._fragment [frag_id ].nodes
91- if not frag_node .reverse :
92- nodes = nodes [::- 1 ]
93- par_frag_id = frag_node .parent
94- if par_frag_id == - 1 :
95- last_id = - 1 , 1
96- else :
97- par_frag_node = self ._fragment_trees [frag .source ][par_frag_id ]
98- par_nodes = self ._fragment [par_frag_id ].nodes
99- if par_frag_node .reverse :
100- last_id = par_frag_id , par_nodes [- 1 ]
102+ candid = []
103+ a = list (frag .source .values ())
104+ b = list (frag .source .keys ())
105+ if self ._likelihood_thr is None : # max only mode
106+ m = None
107+ for i in np .argsort (a )[::- 1 ]:
108+ if m is not None and m > a [i ]:
109+ break
110+ candid .append (b [i ])
111+ m = a [i ]
112+ else : # thresholding mode, bigger than this will all be considered
113+ for i in np .argsort (a )[::- 1 ]:
114+ if a [i ] < self ._likelihood_thr :
115+ break
116+ candid .append (b [i ])
117+
118+ # for each candid source, append the frag nodes
119+ for src in candid :
120+ frag_node = self ._fragment_trees [src ][frag_id ]
121+ nodes = self ._fragment [frag_id ].nodes
122+ if not frag_node .reverse :
123+ nodes = nodes [::- 1 ]
124+ par_frag_id = frag_node .parent
125+ if par_frag_id == - 1 :
126+ last_id = - 1 , 1
101127 else :
102- last_id = par_frag_id , par_nodes [0 ]
103- tree = trees [frag .source ]
104- for i in nodes :
105- n = list (self ._swc [i ])
106- n [6 ] = last_id
107- n [0 ] = len (tree ) + 1
108- tree [(frag_id , i )] = tuple (n )
109- last_id = frag_id , i
128+ par_frag_node = self ._fragment_trees [src ][par_frag_id ]
129+ par_nodes = self ._fragment [par_frag_id ].nodes
130+ if par_frag_node .reverse :
131+ last_id = par_frag_id , par_nodes [- 1 ]
132+ else :
133+ last_id = par_frag_id , par_nodes [0 ]
134+
135+ tree = trees [src ]
136+ for i in nodes :
137+ n = list (self ._swc [i ])
138+ n [6 ] = last_id
139+ n [0 ] = len (tree ) + 1
140+ tree [(frag_id , i )] = tuple (n )
141+ last_id = frag_id , i
110142
111143 for s , t in trees .items ():
112144 for k , v in t .items ():
@@ -126,17 +158,24 @@ def _linear_programming(self):
126158 # finding variables for fragment/soma pairs that require solving
127159 scores = {} # var_i_s, i: fragment id, s: soma id
128160 for i , frag in self ._fragment .items ():
129- scores [i ] = {}
130- for s in frag .traversed :
131- scores [i ][s ] = pulp .LpVariable (f'Score_{ i } _{ s } ' , 0 ) # non-negative
161+ if len (frag .traversed ) > 1 : # mixed sources
162+ scores [i ] = {}
163+ for s in frag .traversed :
164+ scores [i ][s ] = pulp .LpVariable (f'Score_{ i } _{ s } ' , 0 ) # non-negative
165+ elif len (frag .traversed ) == 1 :
166+ scores [i ] = {}
167+ for s in frag .traversed :
168+ scores [i ][s ] = pulp .LpVariable (f'Score_{ i } _{ s } ' , 1 , 1 ) # const
169+ else :
170+ pass
171+ # raise ValueError('')
132172
133173 # objective func: cost * score
134174 self ._problem += pulp .lpSum (
135175 pulp .lpSum (
136176 self ._fragment_trees [s ][i ].cost * score for s , score in frag_vars .items ()
137177 ) for i , frag_vars in scores .items ()
138178 ), "Global Penalty"
139-
140179 # constraints
141180 for i , frag_vars in scores .items ():
142181 self ._problem += (pulp .lpSum (score for score in frag_vars .values ()) == 1 ,
@@ -147,20 +186,17 @@ def _linear_programming(self):
147186 self ._problem += score <= scores [p ][s ], \
148187 f"Tree Topology Enforcement for Score_{ i } _{ s } "
149188
150- self ._problem .solve ()
189+ self ._problem .solve (pulp .PULP_CBC_CMD (msg = 0 ))
190+
191+ for frag in self ._fragment .values ():
192+ frag .source = dict .fromkeys (frag .traversed , 1 )
151193
152194 for variable in self ._problem .variables ():
153195 frag_id , src = variable .name .split ('_' )[1 :]
154196 frag_id , src = int (frag_id ), int (src )
155197 frag = self ._fragment [frag_id ]
156- if frag .source is None or frag .likelihood < variable .varValue :
157- frag .source = src
158- frag .likelihood = variable .varValue
159-
160- for frag in self ._fragment .values ():
161- if frag .source is None :
162- frag .source = list (frag .traversed )[0 ]
163- frag .likelihood = 1
198+ assert src in frag .source
199+ frag .source [src ] = variable .varValue
164200
165201 if self ._verbose :
166202 print ("Finished linear programming." )
0 commit comments