Skip to content

Commit ef1f26c

Browse files
authored
Merge pull request #67 from OmkarPathak/dev
Completed Quadtree implementation
2 parents 14e6b0f + ddc9f57 commit ef1f26c

File tree

2 files changed

+368
-134
lines changed

2 files changed

+368
-134
lines changed

pygorithm/data_structures/quadtree.py

+219-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
depth and bucket size.
77
"""
88
import inspect
9+
import math
10+
from collections import deque
911

1012
from pygorithm.geometry import (vector2, polygon2, rect2)
1113

@@ -25,7 +27,7 @@ def __init__(self, aabb):
2527
:param aabb: axis-aligned bounding box
2628
:type aabb: :class:`pygorithm.geometry.rect2.Rect2`
2729
"""
28-
pass
30+
self.aabb = aabb
2931

3032
def __repr__(self):
3133
"""
@@ -46,7 +48,7 @@ def __repr__(self):
4648
:returns: unambiguous representation of this quad tree entity
4749
:rtype: string
4850
"""
49-
pass
51+
return "quadtreeentity(aabb={})".format(repr(self.aabb))
5052

5153
def __str__(self):
5254
"""
@@ -67,7 +69,7 @@ def __str__(self):
6769
:returns: human readable representation of this entity
6870
:rtype: string
6971
"""
70-
pass
72+
return "entity(at {})".format(str(self.aabb))
7173

7274
class QuadTree(object):
7375
"""
@@ -129,7 +131,12 @@ def __init__(self, bucket_size, max_depth, location, depth = 0, entities = None)
129131
:param entities: the entities to initialize this quadtree with
130132
:type entities: list of :class:`.QuadTreeEntity` or None for empty list
131133
"""
132-
pass
134+
self.bucket_size = bucket_size
135+
self.max_depth = max_depth
136+
self.location = location
137+
self.depth = depth
138+
self.entities = entities if entities is not None else []
139+
self.children = None
133140

134141
def think(self, recursive = False):
135142
"""
@@ -145,7 +152,13 @@ def think(self, recursive = False):
145152
:param recursive: if `think(True)` should be called on :py:attr:`.children` (if there are any)
146153
:type recursive: bool
147154
"""
148-
pass
155+
if not self.children and self.depth < self.max_depth and len(self.entities) > self.bucket_size:
156+
self.split()
157+
158+
if recursive:
159+
if self.children:
160+
for child in self.children:
161+
child.think(True)
149162

150163
def split(self):
151164
"""
@@ -164,12 +177,43 @@ def split(self):
164177
165178
:raises ValueError: if :py:attr:`.children` is not empty
166179
"""
167-
pass
180+
if self.children:
181+
raise ValueError("cannot split twice")
182+
183+
_cls = type(self)
184+
def _cstr(r):
185+
return _cls(self.bucket_size, self.max_depth, r, self.depth + 1)
186+
187+
_halfwidth = self.location.width / 2
188+
_halfheight = self.location.height / 2
189+
_x = self.location.mincorner.x
190+
_y = self.location.mincorner.y
191+
192+
self.children = [
193+
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x, _y))),
194+
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x + _halfwidth, _y))),
195+
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x + _halfwidth, _y + _halfheight))),
196+
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x, _y + _halfheight))) ]
197+
198+
_newents = []
199+
for ent in self.entities:
200+
quad = self.get_quadrant(ent)
201+
202+
if quad < 0:
203+
_newents.append(ent)
204+
else:
205+
self.children[quad].entities.append(ent)
206+
self.entities = _newents
207+
208+
168209

169210
def get_quadrant(self, entity):
170211
"""
171212
Calculate the quadrant that the specified entity belongs to.
172213
214+
Touching a line is considered overlapping a line. Touching is
215+
determined using :py:meth:`math.isclose`
216+
173217
Quadrants are:
174218
175219
- -1: None (it overlaps 2 or more quadrants)
@@ -189,7 +233,48 @@ def get_quadrant(self, entity):
189233
:returns: quadrant
190234
:rtype: int
191235
"""
192-
pass
236+
237+
_aabb = entity.aabb
238+
_halfwidth = self.location.width / 2
239+
_halfheight = self.location.height / 2
240+
_x = self.location.mincorner.x
241+
_y = self.location.mincorner.y
242+
243+
if math.isclose(_aabb.mincorner.x, _x + _halfwidth):
244+
return -1
245+
if math.isclose(_aabb.mincorner.x + _aabb.width, _x + _halfwidth):
246+
return -1
247+
if math.isclose(_aabb.mincorner.y, _y + _halfheight):
248+
return -1
249+
if math.isclose(_aabb.mincorner.y + _aabb.height, _y + _halfheight):
250+
return -1
251+
252+
_leftside_isleft = _aabb.mincorner.x < _x + _halfwidth
253+
_rightside_isleft = _aabb.mincorner.x + _aabb.width < _x + _halfwidth
254+
255+
if _leftside_isleft != _rightside_isleft:
256+
return -1
257+
258+
_topside_istop = _aabb.mincorner.y < _y + _halfheight
259+
_botside_istop = _aabb.mincorner.y + _aabb.height < _y + _halfheight
260+
261+
if _topside_istop != _botside_istop:
262+
return -1
263+
264+
_left = _leftside_isleft
265+
_top = _topside_istop
266+
267+
if _left:
268+
if _top:
269+
return 0
270+
else:
271+
return 3
272+
else:
273+
if _top:
274+
return 1
275+
else:
276+
return 2
277+
193278

194279
def insert_and_think(self, entity):
195280
"""
@@ -204,7 +289,14 @@ def insert_and_think(self, entity):
204289
:param entity: the entity to insert
205290
:type entity: :class:`.QuadTreeEntity`
206291
"""
207-
pass
292+
if not self.children and len(self.entities) == self.bucket_size and self.depth < self.max_depth:
293+
self.split()
294+
295+
quad = self.get_quadrant(entity) if self.children else -1
296+
if quad < 0:
297+
self.entities.append(entity)
298+
else:
299+
self.children[quad].insert_and_think(entity)
208300

209301
def retrieve_collidables(self, entity, predicate = None):
210302
"""
@@ -227,19 +319,71 @@ def retrieve_collidables(self, entity, predicate = None):
227319
:returns: potential collidables (never `None)
228320
:rtype: list of :class:`.QuadTreeEntity`
229321
"""
230-
pass
322+
result = list(filter(predicate, self.entities))
323+
quadrant = self.get_quadrant(entity) if self.children else -1
324+
325+
if quadrant >= 0:
326+
result.extend(self.children[quadrant].retrieve_collidables(entity, predicate))
327+
elif self.children:
328+
for child in self.children:
329+
touching, overlapping, alwaysNone = rect2.Rect2.find_intersection(entity.aabb, child.location, find_mtv=False)
330+
if touching or overlapping:
331+
result.extend(child.retrieve_collidables(entity, predicate))
332+
333+
return result
334+
335+
def _iter_helper(self, pred):
336+
"""
337+
Calls pred on each child and childs child, iteratively.
338+
339+
pred takes one positional argument (the child).
340+
341+
:param pred: function to call
342+
:type pred: `types.FunctionType`
343+
"""
344+
345+
_stack = deque()
346+
_stack.append(self)
231347

348+
while _stack:
349+
curr = _stack.pop()
350+
if curr.children:
351+
for child in curr.children:
352+
_stack.append(child)
353+
354+
pred(curr)
355+
232356
def find_entities_per_depth(self):
233357
"""
234358
Calculate the number of nodes and entities at each depth level in this
235359
quad tree. Only returns for depth levels at or equal to this node.
236360
237361
This is implemented iteratively. See :py:meth:`.__str__` for usage example.
238362
239-
:returns: dict of depth level to (number of nodes, number of entities)
240-
:rtype: dict int: (int, int)
363+
:returns: dict of depth level to number of entities
364+
:rtype: dict int: int
365+
"""
366+
367+
container = { 'result': {} }
368+
def handler(curr, container=container):
369+
container['result'][curr.depth] = container['result'].get(curr.depth, 0) + len(curr.entities)
370+
self._iter_helper(handler)
371+
372+
return container['result']
373+
374+
def find_nodes_per_depth(self):
375+
"""
376+
Calculate the number of nodes at each depth level.
377+
378+
This is implemented iteratively. See :py:meth:`.__str__` for usage example.
379+
380+
:returns: dict of depth level to number of nodes
381+
:rtype: dict int: int
241382
"""
242-
pass
383+
384+
nodes_per_depth = {}
385+
self._iter_helper(lambda curr, d=nodes_per_depth: d.update({ (curr.depth, d.get(curr.depth, 0) + 1) }))
386+
return nodes_per_depth
243387

244388
def sum_entities(self, entities_per_depth=None):
245389
"""
@@ -254,7 +398,15 @@ def sum_entities(self, entities_per_depth=None):
254398
:returns: number of entities in this and child nodes
255399
:rtype: int
256400
"""
257-
pass
401+
if entities_per_depth is not None:
402+
return sum(entities_per_depth.values())
403+
404+
container = { 'result': 0 }
405+
def handler(curr, container=container):
406+
container['result'] += len(curr.entities)
407+
self._iter_helper(handler)
408+
409+
return container['result']
258410

259411
def calculate_avg_ents_per_leaf(self):
260412
"""
@@ -270,7 +422,13 @@ def calculate_avg_ents_per_leaf(self):
270422
:returns: average number of entities at each leaf node
271423
:rtype: :class:`numbers.Number`
272424
"""
273-
pass
425+
container = { 'leafs': 0, 'total': 0 }
426+
def handler(curr, container=container):
427+
if not curr.children:
428+
container['leafs'] += 1
429+
container['total'] += len(curr.entities)
430+
self._iter_helper(handler)
431+
return container['total'] / container['leafs']
274432

275433
def calculate_weight_misplaced_ents(self, sum_entities=None):
276434
"""
@@ -293,11 +451,40 @@ def calculate_weight_misplaced_ents(self, sum_entities=None):
293451
:returns: weight of misplaced entities
294452
:rtype: :class:`numbers.Number`
295453
"""
296-
pass
297454

455+
# this iteration requires more context than _iter_helper provides.
456+
# we must keep track of parents as well in order to correctly update
457+
# weights
458+
459+
nonleaf_to_max_child_depth_dict = {}
460+
461+
# stack will be (quadtree, list (of parents) or None)
462+
_stack = deque()
463+
_stack.append((self, None))
464+
while _stack:
465+
curr, parents = _stack.pop()
466+
if parents:
467+
for p in parents:
468+
nonleaf_to_max_child_depth_dict[p] = max(nonleaf_to_max_child_depth_dict.get(p, 0), curr.depth)
469+
470+
if curr.children:
471+
new_parents = list(parents) if parents else []
472+
new_parents.append(curr)
473+
for child in curr.children:
474+
_stack.append((child, new_parents))
475+
476+
_weight = 0
477+
for nonleaf, maxchilddepth in nonleaf_to_max_child_depth_dict.items():
478+
_weight += len(nonleaf.entities) * 4 * (maxchilddepth - nonleaf.depth)
479+
480+
_sum = self.sum_entities() if sum_entities is None else sum_entities
481+
return _weight / _sum
482+
298483
def __repr__(self):
299484
"""
300-
Create an unambiguous, recursive representation of this quad tree.
485+
Create an unambiguous representation of this quad tree.
486+
487+
This is implemented iteratively.
301488
302489
Example:
303490
@@ -308,19 +495,18 @@ def __repr__(self):
308495
309496
# create a tree with a up to 2 entities in a bucket that
310497
# can have a depth of up to 5.
311-
_tree = quadtree.QuadTree(2, 5, rect2.Rect2(100, 100))
498+
_tree = quadtree.QuadTree(1, 5, rect2.Rect2(100, 100))
312499
313500
# add a few entities to the tree
314501
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(5, 5))))
315502
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(95, 5))))
316503
317-
# prints quadtree(bucket_size=2, max_depth=5, location=rect2(width=100, height=100, mincorner=vector2(x=0, y=0)), depth=0, entities=[], children=[ quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=0, y=0)), depth=1, entities=[ quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=5, y=5))) ], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=50, y=0)), depth=1, entities=[ quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=95, y=5))) ], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=50, y=50)), depth=1, entities=[], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=0, y=50)), depth=1, entities=[], children=[]) ])
318-
print(repr(_tree))
504+
# prints quadtree(bucket_size=1, max_depth=5, location=rect2(width=100, height=100, mincorner=vector2(x=0, y=0)), depth=0, entities=[], children=[quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=0, y=0)), depth=1, entities=[quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=5, y=5)))], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=50.0, y=0)), depth=1, entities=[quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=95, y=5)))], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=50.0, y=50.0)), depth=1, entities=[], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=0, y=50.0)), depth=1, entities=[], children=None)])
319505
320506
:returns: unambiguous, recursive representation of this quad tree
321507
:rtype: string
322508
"""
323-
pass
509+
return "quadtree(bucket_size={}, max_depth={}, location={}, depth={}, entities={}, children={})".format(self.bucket_size, self.max_depth, repr(self.location), self.depth, self.entities, self.children)
324510

325511
def __str__(self):
326512
"""
@@ -347,12 +533,23 @@ def __str__(self):
347533
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(5, 5))))
348534
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(95, 5))))
349535
350-
# prints quadtree(at rect(100x100 at <0, 0>) with 0 entities here (2 in total); (nodes, entities) per depth: [ 0: (1, 0), 1: (4, 2) ] (max depth: 5), avg ent/leaf: 0.5 (target 2), misplaced weight = 0 (0 best, >1 bad))
536+
# prints quadtree(at rect(100x100 at <0, 0>) with 0 entities here (2 in total); (nodes, entities) per depth: [ 0: (1, 0), 1: (4, 2) ] (allowed max depth: 5, actual: 1), avg ent/leaf: 0.5 (target 1), misplaced weight 0.0 (0 best, >1 bad)
537+
print(_tree)
351538
352539
:returns: human-readable representation of this quad tree
353540
:rtype: string
354541
"""
355-
pass
542+
543+
nodes_per_depth = self.find_nodes_per_depth()
544+
_ents_per_depth = self.find_entities_per_depth()
545+
546+
_nodes_ents_per_depth_str = "[ {} ]".format(', '.join("{}: ({}, {})".format(dep, nodes_per_depth[dep], _ents_per_depth[dep]) for dep in nodes_per_depth.keys()))
547+
548+
_sum = self.sum_entities(entities_per_depth=_ents_per_depth)
549+
_max_depth = max(_ents_per_depth.keys())
550+
_avg_ent_leaf = self.calculate_avg_ents_per_leaf()
551+
_mispl_weight = self.calculate_weight_misplaced_ents(sum_entities=_sum)
552+
return "quadtree(at {} with {} entities here ({} in total); (nodes, entities) per depth: {} (allowed max depth: {}, actual: {}), avg ent/leaf: {} (target {}), misplaced weight {} (0 best, >1 bad)".format(self.location, len(self.entities), _sum, _nodes_ents_per_depth_str, self.max_depth, _max_depth, _avg_ent_leaf, self.bucket_size, _mispl_weight)
356553

357554
@staticmethod
358555
def get_code():

0 commit comments

Comments
 (0)