from __future__ import annotations
from math import trunc
from typing import (
List,
Set,
Dict,
Generic,
)
from arcade.sprite.base import BasicSprite
from arcade.types import Point, IPoint, Rect
from arcade.sprite import SpriteType
[docs]
class SpatialHash(Generic[SpriteType]):
"""
Structure for fast collision checking with sprites.
See: https://www.gamedev.net/articles/programming/general-and-gameplay-programming/spatial-hashing-r2697/
:param cell_size: Size (width and height) of the cells in the spatial hash
"""
def __init__(self, cell_size: int) -> None:
# Sanity check the cell size
if cell_size <= 0:
raise ValueError("cell_size must be greater than 0")
if not isinstance(cell_size, int):
raise ValueError("cell_size must be an integer")
self.cell_size = cell_size
# Buckets of sprites per cell
self.contents: Dict[IPoint, Set[SpriteType]] = {}
# All the buckets a sprite is in.
# This is used to remove a sprite from the spatial hash.
self.buckets_for_sprite: Dict[SpriteType, List[Set[SpriteType]]] = {}
[docs]
def hash(self, point: IPoint) -> IPoint:
"""Convert world coordinates to cell coordinates"""
return (
point[0] // self.cell_size,
point[1] // self.cell_size,
)
[docs]
def reset(self):
"""
Clear all the sprites from the spatial hash.
"""
self.contents.clear()
self.buckets_for_sprite.clear()
[docs]
def add(self, sprite: SpriteType) -> None:
"""
Add a sprite to the spatial hash.
:param sprite: The sprite to add
"""
min_point = trunc(sprite.left), trunc(sprite.bottom)
max_point = trunc(sprite.right), trunc(sprite.top)
# hash the minimum and maximum points
min_point, max_point = self.hash(min_point), self.hash(max_point)
buckets: List[Set[SpriteType]] = []
# Iterate over the rectangular region adding the sprite to each cell
for i in range(min_point[0], max_point[0] + 1):
for j in range(min_point[1], max_point[1] + 1):
# Add sprite to the bucket
bucket = self.contents.setdefault((i, j), set())
bucket.add(sprite)
# Collect all the buckets we added to
buckets.append(bucket)
# Keep track of which buckets the sprite is in
self.buckets_for_sprite[sprite] = buckets
[docs]
def move(self, sprite: SpriteType) -> None:
"""
Shortcut to remove and re-add a sprite.
:param sprite: The sprite to move
"""
self.remove(sprite)
self.add(sprite)
[docs]
def remove(self, sprite: SpriteType) -> None:
"""
Remove a Sprite.
:param sprite: The sprite to remove
"""
# Remove the sprite from all the buckets it is in
for bucket in self.buckets_for_sprite[sprite]:
bucket.remove(sprite)
# Delete the sprite from the bucket tracker
del self.buckets_for_sprite[sprite]
[docs]
def get_sprites_near_sprite(self, sprite: BasicSprite) -> Set[SpriteType]:
"""
Get all the sprites that are in the same buckets as the given sprite.
:param sprite: The sprite to check
:return: A set of close-by sprites
"""
min_point = trunc(sprite.left), trunc(sprite.bottom)
max_point = trunc(sprite.right), trunc(sprite.top)
# hash the minimum and maximum points
min_point, max_point = self.hash(min_point), self.hash(max_point)
close_by_sprites: Set[SpriteType] = set()
# Iterate over the all the covered cells and collect the sprites
for i in range(min_point[0], max_point[0] + 1):
for j in range(min_point[1], max_point[1] + 1):
bucket = self.contents.setdefault((i, j), set())
close_by_sprites.update(bucket)
return close_by_sprites
[docs]
def get_sprites_near_point(self, point: Point) -> Set[SpriteType]:
"""
Return sprites in the same bucket as the given point.
:param point: The point to check
:return: A set of close-by sprites
"""
hash_point = self.hash((trunc(point[0]), trunc(point[1])))
# Return a copy of the set.
return set(self.contents.setdefault(hash_point, set()))
[docs]
def get_sprites_near_rect(self, rect: Rect) -> Set[SpriteType]:
"""
Return sprites in the same buckets as the given rectangle.
:param rect: The rectangle to check (left, right, bottom, top)
:return: A set of sprites in the rectangle
"""
left, right, bottom, top = rect
min_point = trunc(left), trunc(bottom)
max_point = trunc(right), trunc(top)
# hash the minimum and maximum points
min_point, max_point = self.hash(min_point), self.hash(max_point)
close_by_sprites: Set[SpriteType] = set()
# Iterate over the all the covered cells and collect the sprites
for i in range(min_point[0], max_point[0] + 1):
for j in range(min_point[1], max_point[1] + 1):
bucket = self.contents.setdefault((i, j), set())
close_by_sprites.update(bucket)
return close_by_sprites
@property
def count(self) -> int:
"""Return the number of sprites in the spatial hash"""
# NOTE: We should really implement __len__ but this means
# changing the truthiness of the class instance.
# if spatial_hash will be False if it is empty.
# For backwards compatibility, we'll keep it as a property.
return len(self.buckets_for_sprite)