Source code for arcade.sprite_list.spatial_hash

from __future__ import annotations

from math import trunc
from typing import Generic

from arcade.sprite import SpriteType
from arcade.sprite.base import BasicSprite
from arcade.types import IPoint, Point
from arcade.types.rect import Rect


[docs] class SpatialHash(Generic[SpriteType]): """A data structure best for collision checks with non-moving sprites. It subdivides space into a grid of squares, each with sides of length :py:attr:`cell_size`. Moving a sprite from one place to another is the same as removing and adding it. Although moving a few can be okay, it can quickly add up and slow down a game. Args: cell_size: The width and height of each square in the grid. """ def __init__(self, cell_size: int) -> None: # Sanity check the cell size if not isinstance(cell_size, int): raise TypeError("cell_size must be an int (integer)") if cell_size <= 0: raise ValueError("cell_size must be greater than 0") self.cell_size: int = cell_size """How big each grid cell is on each side. .. warning:: Do not change this after creation! Since each cell is a square, they're used as both the width and height. """ # Buckets of sprites per cell self.contents: dict[IPoint, set[SpriteType]] = {} # All the buckets a sprite is in. # This is used to remove a sprite from the spatial hash. self.buckets_for_sprite: dict[SpriteType, list[set[SpriteType]]] = {}
[docs] def hash(self, point: IPoint) -> IPoint: """Convert world coordinates to cell coordinates""" return ( point[0] // self.cell_size, point[1] // self.cell_size, )
[docs] def reset(self): """Clear all the sprites from the spatial hash.""" self.contents.clear() self.buckets_for_sprite.clear()
[docs] def add(self, sprite: SpriteType) -> None: """ Add a sprite to the spatial hash. Args: sprite: The sprite to add """ min_point = trunc(sprite.left), trunc(sprite.bottom) max_point = trunc(sprite.right), trunc(sprite.top) # hash the minimum and maximum points min_point, max_point = self.hash(min_point), self.hash(max_point) buckets: list[set[SpriteType]] = [] # Iterate over the rectangular region adding the sprite to each cell for i in range(min_point[0], max_point[0] + 1): for j in range(min_point[1], max_point[1] + 1): # Add sprite to the bucket bucket = self.contents.setdefault((i, j), set()) bucket.add(sprite) # Collect all the buckets we added to buckets.append(bucket) # Keep track of which buckets the sprite is in self.buckets_for_sprite[sprite] = buckets
[docs] def move(self, sprite: SpriteType) -> None: """ Shortcut to remove and re-add a sprite. Args: sprite: The sprite to move """ self.remove(sprite) self.add(sprite)
[docs] def remove(self, sprite: SpriteType) -> None: """ Remove a Sprite. Args: sprite: The sprite to remove """ # Remove the sprite from all the buckets it is in for bucket in self.buckets_for_sprite[sprite]: bucket.remove(sprite) # Delete the sprite from the bucket tracker del self.buckets_for_sprite[sprite]
[docs] def get_sprites_near_sprite(self, sprite: BasicSprite) -> set[SpriteType]: """ Get all the sprites that are in the same buckets as the given sprite. Args: sprite: The sprite to check """ min_point = trunc(sprite.left), trunc(sprite.bottom) max_point = trunc(sprite.right), trunc(sprite.top) # hash the minimum and maximum points min_point, max_point = self.hash(min_point), self.hash(max_point) close_by_sprites: set[SpriteType] = set() # Iterate over the all the covered cells and collect the sprites for i in range(min_point[0], max_point[0] + 1): for j in range(min_point[1], max_point[1] + 1): bucket = self.contents.setdefault((i, j), set()) close_by_sprites.update(bucket) return close_by_sprites
[docs] def get_sprites_near_point(self, point: Point) -> set[SpriteType]: """ Return sprites in the same bucket as the given point. Args: point: The point to check """ hash_point = self.hash((trunc(point[0]), trunc(point[1]))) # Return a copy of the set. return set(self.contents.setdefault(hash_point, set()))
[docs] def get_sprites_near_rect(self, rect: Rect) -> set[SpriteType]: """ Return sprites in the same buckets as the given rectangle. Args: rect: The rectangle to check (left, right, bottom, top) """ left, right, bottom, top = rect.lrbt min_point = trunc(left), trunc(bottom) max_point = trunc(right), trunc(top) # hash the minimum and maximum points min_point, max_point = self.hash(min_point), self.hash(max_point) close_by_sprites: set[SpriteType] = set() # Iterate over the all the covered cells and collect the sprites for i in range(min_point[0], max_point[0] + 1): for j in range(min_point[1], max_point[1] + 1): bucket = self.contents.setdefault((i, j), set()) close_by_sprites.update(bucket) return close_by_sprites
@property def count(self) -> int: """Return the number of sprites in the spatial hash""" # NOTE: We should really implement __len__ but this means # changing the truthiness of the class instance. # if spatial_hash will be False if it is empty. # For backwards compatibility, we'll keep it as a property. return len(self.buckets_for_sprite)