Skip to content

Commit 1db5743

Browse files
authored
Merge pull request #377 from caspervdw/refine-python
FIX Major bugfix in python refinement
2 parents 45d0ceb + bbbebad commit 1db5743

File tree

5 files changed

+271
-435
lines changed

5 files changed

+271
-435
lines changed

doc/releases/v0.3.1.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Enhancements
1616
Bug Fixes
1717
~~~~~~~~~
1818

19+
- Bug in the python refinement code was solved: feature finding with `engine='python'` is now more accurate. (:issue:`#377`)
20+
1921
- Error in `subtract_drift` is solved (:issue:`#351`)
2022

2123
- Legends are disabled by default in plotting (:issue:`#357`)

trackpy/artificial.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,43 @@ def draw_spots(shape, positions, diameter, noise_level=0, bitdepth=8,
176176
else:
177177
raise ValueError('Bitdepth should be <= 32')
178178
np.random.seed(0)
179-
image = np.random.randint(0, noise_level + 1, shape).astype(internaldtype)
179+
image = np.zeros(shape, dtype=internaldtype)
180+
if noise_level > 0:
181+
image += np.random.randint(0, noise_level + 1,
182+
shape).astype(internaldtype)
180183
for pos in positions:
181184
draw_feature(image, pos, diameter, max_value=2**bitdepth - 1,
182185
feat_func=feat_func, ecc=ecc, **kwargs)
183186
return image.clip(0, 2**bitdepth - 1).astype(dtype)
187+
188+
189+
def draw_array(N, diameter, separation=None, ndim=2, **kwargs):
190+
""" Generates an image with an array of features. Each feature has a random
191+
offset of +- 0.5 pixel.
192+
193+
Parameters
194+
----------
195+
N : int
196+
the number of features
197+
diameter : number or tuple
198+
the sizes of the box that will be used per feature. The actual feature
199+
'size' is determined by feat_func and kwargs given to feat_func.
200+
separation : number or tuple
201+
the desired separation between features
202+
kwargs : see draw_spots
203+
204+
See also
205+
--------
206+
draw_spots
207+
"""
208+
diameter = validate_tuple(diameter, ndim)
209+
if separation is None:
210+
separation = tuple([d * 2 for d in diameter])
211+
margin = separation
212+
Nsqrt = int(N**(1/ndim) + 0.9999)
213+
pos = np.meshgrid(*[np.arange(0, s * Nsqrt, s) for s in separation],
214+
indexing='ij')
215+
pos = np.array([p.ravel() for p in pos], dtype=np.float).T[:N] + margin
216+
pos += (np.random.random(pos.shape) - 0.5) #randomize subpixel location
217+
shape = tuple(np.max(pos, axis=0).astype(np.int) + margin)
218+
return pos, draw_spots(shape, pos, diameter, **kwargs)

trackpy/feature.py

Lines changed: 40 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _safe_center_of_mass(x, radius, grids):
155155

156156

157157
def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
158-
engine='auto', shift_thresh=0.6, break_thresh=0.005,
158+
engine='auto', shift_thresh=0.6, break_thresh=None,
159159
characterize=True, walkthrough=False):
160160
"""Find the center of mass of a bright feature starting from an estimate.
161161
@@ -184,14 +184,16 @@ def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
184184
shift mask to neighboring pixel. The new mask will be used for any
185185
remaining iterations.
186186
break_thresh : float, optional
187-
Default: 0.005 (unit is pixels).
188-
When the subpixel refinement along all dimensions is less than this
189-
number, declare victory and stop refinement.
187+
Deprecated
190188
characterize : boolean, True by default
191189
Compute and return mass, size, eccentricity, signal.
192190
walkthrough : boolean, False by default
193191
Print the offset on each loop and display final neighborhood image.
194192
"""
193+
if break_thresh is not None:
194+
warnings.warn("break_threshold will be deprecated: shift_threshold is"
195+
"the only parameter that determines when to shift the"
196+
"mask.")
195197
# ensure that radius is tuple of integers, for direct calls to refine()
196198
radius = validate_tuple(radius, image.ndim)
197199
separation = validate_tuple(separation, image.ndim)
@@ -201,10 +203,13 @@ def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
201203
engine = 'numba'
202204
else:
203205
engine = 'python'
206+
207+
# In here, coord is an integer. Make a copy, will not modify inplace.
208+
coords = np.round(coords).astype(np.int)
209+
204210
if engine == 'python':
205-
coords = np.array(coords) # a copy, will not modify in place
206211
results = _refine(raw_image, image, radius, coords, max_iterations,
207-
shift_thresh, break_thresh, characterize, walkthrough)
212+
shift_thresh, characterize, walkthrough)
208213
elif engine == 'numba':
209214
if not NUMBA_AVAILABLE:
210215
warnings.warn("numba could not be imported. Without it, the "
@@ -217,7 +222,6 @@ def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
217222
if walkthrough:
218223
raise ValueError("walkthrough is not availabe in the numba engine")
219224
# Do some extra prep in pure Python that can't be done in numba.
220-
coords = np.array(coords, dtype=np.float64)
221225
N = coords.shape[0]
222226
mask = binary_mask(radius, image.ndim)
223227
if image.ndim == 3:
@@ -238,34 +242,33 @@ def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
238242
dtype=np.int16)
239243
_numba_refine_3D(np.asarray(raw_image), np.asarray(image),
240244
radius[0], radius[1], radius[2], coords, N,
241-
int(max_iterations), shift_thresh, break_thresh,
245+
int(max_iterations), shift_thresh,
242246
characterize,
243247
image.shape[0], image.shape[1], image.shape[2],
244248
maskZ, maskY, maskX, maskX.shape[0],
245249
r2_mask, z2_mask, y2_mask, x2_mask, results)
246250
elif not characterize:
247-
mask_coordsY, mask_coordsX = np.asarray(mask.nonzero(), dtype=np.int16)
251+
mask_coordsY, mask_coordsX = np.asarray(mask.nonzero(), np.int16)
248252
results = np.empty((N, 3), dtype=np.float64)
249-
_numba_refine_2D(np.asarray(raw_image), np.asarray(image),
250-
radius[0], radius[1], coords, N,
251-
int(max_iterations), shift_thresh, break_thresh,
253+
_numba_refine_2D(np.asarray(image), radius[0], radius[1], coords, N,
254+
int(max_iterations), shift_thresh,
252255
image.shape[0], image.shape[1],
253256
mask_coordsY, mask_coordsX, mask_coordsY.shape[0],
254257
results)
255258
elif radius[0] == radius[1]:
256-
mask_coordsY, mask_coordsX = np.asarray(mask.nonzero(), dtype=np.int16)
259+
mask_coordsY, mask_coordsX = np.asarray(mask.nonzero(), np.int16)
257260
results = np.empty((N, 7), dtype=np.float64)
258261
r2_mask = r_squared_mask(radius, image.ndim)[mask]
259262
cmask = cosmask(radius)[mask]
260263
smask = sinmask(radius)[mask]
261264
_numba_refine_2D_c(np.asarray(raw_image), np.asarray(image),
262265
radius[0], radius[1], coords, N,
263-
int(max_iterations), shift_thresh, break_thresh,
264-
image.shape[0], image.shape[1],
265-
mask_coordsY, mask_coordsX, mask_coordsY.shape[0],
266+
int(max_iterations), shift_thresh,
267+
image.shape[0], image.shape[1], mask_coordsY,
268+
mask_coordsX, mask_coordsY.shape[0],
266269
r2_mask, cmask, smask, results)
267270
else:
268-
mask_coordsY, mask_coordsX = np.asarray(mask.nonzero(), dtype=np.int16)
271+
mask_coordsY, mask_coordsX = np.asarray(mask.nonzero(), np.int16)
269272
results = np.empty((N, 8), dtype=np.float64)
270273
x2_masks = x_squared_masks(radius, image.ndim)
271274
y2_mask = image.ndim * x2_masks[0][mask]
@@ -275,8 +278,8 @@ def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
275278
_numba_refine_2D_c_a(np.asarray(raw_image), np.asarray(image),
276279
radius[0], radius[1], coords, N,
277280
int(max_iterations), shift_thresh,
278-
break_thresh, image.shape[0], image.shape[1],
279-
mask_coordsY, mask_coordsX, mask_coordsY.shape[0],
281+
image.shape[0], image.shape[1], mask_coordsY,
282+
mask_coordsX, mask_coordsY.shape[0],
280283
y2_mask, x2_mask, cmask, smask, results)
281284
else:
282285
raise ValueError("Available engines are 'python' and 'numba'")
@@ -312,13 +315,12 @@ def refine(raw_image, image, radius, coords, separation=0, max_iterations=10,
312315

313316
# (This is pure Python. A numba variant follows below.)
314317
def _refine(raw_image, image, radius, coords, max_iterations,
315-
shift_thresh, break_thresh, characterize, walkthrough):
316-
318+
shift_thresh, characterize, walkthrough):
319+
if not np.issubdtype(coords.dtype, np.int):
320+
raise ValueError('The coords array should be of integer datatype')
317321
ndim = image.ndim
318322
isotropic = np.all(radius[1:] == radius[:-1])
319-
mask = binary_mask(radius, ndim)
320-
slices = [[slice(c - rad, c + rad + 1) for c, rad in zip(coord, radius)]
321-
for coord in coords]
323+
mask = binary_mask(radius, ndim).astype(np.uint8)
322324

323325
# Declare arrays that we will fill iteratively through loop.
324326
N = coords.shape[0]
@@ -336,47 +338,25 @@ def _refine(raw_image, image, radius, coords, max_iterations,
336338
ogrid = np.ogrid[[slice(0, i) for i in mask.shape]] # for center of mass
337339
ogrid = [g.astype(float) for g in ogrid]
338340

339-
for feat in range(N):
340-
coord = coords[feat]
341-
342-
# Define the circular neighborhood of (x, y).
343-
rect = slices[feat]
344-
neighborhood = mask*image[rect]
345-
cm_n = _safe_center_of_mass(neighborhood, radius, ogrid)
346-
cm_i = cm_n - radius + coord # image coords
347-
allow_moves = True
341+
for feat, coord in enumerate(coords):
348342
for iteration in range(max_iterations):
343+
# Define the circular neighborhood of (x, y).
344+
rect = [slice(c - r, c + r + 1) for c, r in zip(coord, radius)]
345+
neighborhood = mask*image[rect]
346+
cm_n = _safe_center_of_mass(neighborhood, radius, ogrid)
347+
cm_i = cm_n - radius + coord # image coords
348+
349349
off_center = cm_n - radius
350350
logger.debug('off_center: %f', off_center)
351-
if np.all(np.abs(off_center) < break_thresh):
351+
if np.all(np.abs(off_center) < shift_thresh):
352352
break # Accurate enough.
353+
# If we're off by more than half a pixel in any direction, move..
354+
coord[off_center > shift_thresh] += 1
355+
coord[off_center < -shift_thresh] -= 1
356+
# Don't move outside the image!
357+
upper_bound = np.array(image.shape) - 1 - radius
358+
coord = np.clip(coord, radius, upper_bound).astype(int)
353359

354-
# If we're off by more than half a pixel in any direction, move.
355-
elif np.any(np.abs(off_center) > shift_thresh) & allow_moves:
356-
# In here, coord is an integer.
357-
new_coord = coord
358-
new_coord[off_center > shift_thresh] += 1
359-
new_coord[off_center < -shift_thresh] -= 1
360-
# Don't move outside the image!
361-
upper_bound = np.array(image.shape) - 1 - radius
362-
new_coord = np.clip(new_coord, radius, upper_bound).astype(int)
363-
# Update slice to shifted position.
364-
rect = [slice(c - rad, c + rad + 1)
365-
for c, rad in zip(new_coord, radius)]
366-
neighborhood = mask*image[rect]
367-
368-
# If we're off by less than half a pixel, interpolate.
369-
else:
370-
# Here, coord is a float. We are off the grid.
371-
neighborhood = ndimage.shift(neighborhood, -off_center,
372-
order=2, mode='constant', cval=0)
373-
new_coord = coord + off_center
374-
# Disallow any whole-pixels moves on future iterations.
375-
allow_moves = False
376-
377-
cm_n = _safe_center_of_mass(neighborhood, radius, ogrid) # neighborhood
378-
cm_i = cm_n - radius + new_coord # image coords
379-
coord = new_coord
380360
# matplotlib and ndimage have opposite conventions for xy <-> yx.
381361
final_coords[feat] = cm_i[..., ::-1]
382362

0 commit comments

Comments
 (0)