Compute Shader Tutorial#
Using the compute shader, you can use the GPU to perform calculations thousands of times faster than just by using the CPU.
In this example, we will simulate a star field using an ‘N-Body simulation’. Each star is effected by each other star’s gravity. For 1,000 stars, this means we have 1,000 x 1,000 = 1,000,000 million calculations to perform for each frame. The video has 65,000 stars, requiring 4.2 billion gravity force calculations per frame. On high-end hardware it can still run at 60 fps!
How does this work? There are three major parts to this program:
The Python code, this glues everything together.
The visualization shaders, which let us see the data.
The compute shader, which moves everything.
Visualization Shaders#
There are multiple visualization shaders, which operate in this order:
The Python program creates a shader storage buffer object (SSBO) of
floating point numbers. This buffer
has the x, y, z and radius of each star stored in in_vertex
. It also
stores the color in in_color
.
The vertex shader doesn’t do much more than separate out the radius variable from the group of floats used to store position.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | #version 330 in vec4 in_vertex; in vec4 in_color; out vec2 vertex_pos; out float vertex_radius; out vec4 vertex_color; void main() { vertex_pos = in_vertex.xy; vertex_radius = in_vertex.w; vertex_color = in_color; } |
The geometry shader converts the single point (which we can’t render) to a square, which we can render. It changes the one point, to four points of a quad.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | #version 330 layout (points) in; layout (triangle_strip, max_vertices = 4) out; // Use arcade's global projection UBO uniform Projection { uniform mat4 matrix; } proj; in vec2 vertex_pos[]; in vec4 vertex_color[]; in float vertex_radius[]; out vec2 g_uv; out vec3 g_color; void main() { vec2 center = vertex_pos[0]; vec2 hsize = vec2(vertex_radius[0]); g_color = vertex_color[0].rgb; gl_Position = proj.matrix * vec4(vec2(-hsize.x, hsize.y) + center, 0.0, 1.0); g_uv = vec2(0, 1); EmitVertex(); gl_Position = proj.matrix * vec4(vec2(-hsize.x, -hsize.y) + center, 0.0, 1.0); g_uv = vec2(0, 0); EmitVertex(); gl_Position = proj.matrix * vec4(vec2(hsize.x, hsize.y) + center, 0.0, 1.0); g_uv = vec2(1, 1); EmitVertex(); gl_Position = proj.matrix * vec4(vec2(hsize.x, -hsize.y) + center, 0.0, 1.0); g_uv = vec2(1, 0); EmitVertex(); EndPrimitive(); } |
The fragment shader runs for each pixel. It produces the soft glow effect of the star, and rounds off the quad into a circle.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | #version 330 in vec2 g_uv; in vec3 g_color; out vec4 out_color; void main() { float l = length(vec2(0.5, 0.5) - g_uv.xy); if ( l > 0.5) { discard; } float alpha; if (l == 0.0) alpha = 1.0; else alpha = min(1.0, .60-l * 2); vec3 c = g_color.rgb; // c.xy += v_uv.xy * 0.05; // c.xy += v_pos.xy * 0.75; out_color = vec4(c, alpha); } |
Compute Shaders#
This program runs two buffers. We have an input buffer, with all our current data. We perform calculations on that data and write to the output buffer. We then swap those buffers for the next frame, where we use the output of the previous frame as the input to the next frame.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #version 430 // Set up our compute groups layout(local_size_x=COMPUTE_SIZE_X, local_size_y=COMPUTE_SIZE_Y) in; // Input uniforms go here if you need them. // Some examples: //uniform vec2 screen_size; //uniform vec2 force; //uniform float frame_time; // Structure of the ball data struct Ball { vec4 pos; vec4 vel; vec4 color; }; // Input buffer layout(std430, binding=0) buffer balls_in { Ball balls[]; } In; // Output buffer layout(std430, binding=1) buffer balls_out { Ball balls[]; } Out; void main() { int curBallIndex = int(gl_GlobalInvocationID); Ball in_ball = In.balls[curBallIndex]; vec4 p = in_ball.pos.xyzw; vec4 v = in_ball.vel.xyzw; // Move the ball according to the current force p.xy += v.xy; // Calculate the new force based on all the other bodies for (int i=0; i < In.balls.length(); i++) { // If enabled, this will keep the star from calculating gravity on itself // However, it does slow down the calcluations do do this check. // if (i == x) // continue; // Calculate distance squared float dist = distance(In.balls[i].pos.xyzw.xy, p.xy); float distanceSquared = dist * dist; // If stars get too close the fling into never-never land. // So use a minimum distance float minDistance = 0.02; float gravityStrength = 0.3; float simulationSpeed = 0.002; float force = min(minDistance, gravityStrength / distanceSquared) * -simulationSpeed; vec2 diff = p.xy - In.balls[i].pos.xyzw.xy; // We should normalize this I think, but it doesn't work. // diff = normalize(diff); vec2 delta_v = diff * force; v.xy += delta_v; } Ball out_ball; out_ball.pos.xyzw = p.xyzw; out_ball.vel.xyzw = v.xyzw; vec4 c = in_ball.color.xyzw; out_ball.color.xyzw = c.xyzw; Out.balls[curBallIndex] = out_ball; } |
Python Program#
Read through the code here, I’ve tried hard to explain all the parts in the comments.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """ Compute shader with buffers """ import random from array import array import arcade from arcade.gl import BufferDescription # Window dimensions WINDOW_WIDTH = 2300 WINDOW_HEIGHT = 1300 # Size of performance graphs GRAPH_WIDTH = 200 GRAPH_HEIGHT = 120 GRAPH_MARGIN = 5 class MyWindow(arcade.Window): def __init__(self): # Call parent constructor # Ask for OpenGL 4.3 context, as we need that for compute shader support. super().__init__(WINDOW_WIDTH, WINDOW_HEIGHT, "Compute Shader", gl_version=(4, 3), resizable=True) self.center_window() # --- Class instance variables # Number of balls to move self.num_balls = 40000 # This has something to do with how we break the calculations up # and parallelize them. self.group_x = 256 self.group_y = 1 # --- Create buffers # Format of the buffer data. # 4f = position and size -> x, y, z, radius # 4x4 = Four floats used for calculating velocity. Not needed for visualization. # 4f = color -> rgba buffer_format = "4f 4x4 4f" # Generate the initial data that we will put in buffer 1. initial_data = self.gen_initial_data() # Create data buffers for the compute shader # We ping-pong render between these two buffers # ssbo = shader storage buffer object self.ssbo_1 = self.ctx.buffer(data=array('f', initial_data)) self.ssbo_2 = self.ctx.buffer(reserve=self.ssbo_1.size) # Attribute variable names for the vertex shader attributes = ["in_vertex", "in_color"] self.vao_1 = self.ctx.geometry( [BufferDescription(self.ssbo_1, buffer_format, attributes)], mode=self.ctx.POINTS, ) self.vao_2 = self.ctx.geometry( [BufferDescription(self.ssbo_2, buffer_format, attributes)], mode=self.ctx.POINTS, ) # --- Create shaders # Load in the shader source code file = open("shaders/compute_shader.glsl") compute_shader_source = file.read() file = open("shaders/vertex_shader.glsl") vertex_shader_source = file.read() file = open("shaders/fragment_shader.glsl") fragment_shader_source = file.read() file = open("shaders/geometry_shader.glsl") geometry_shader_source = file.read() # Create our compute shader. # Search/replace to set up our compute groups compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_X", str(self.group_x)) compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_Y", str(self.group_y)) self.compute_shader = self.ctx.compute_shader(source=compute_shader_source) # Program for visualizing the balls self.program = self.ctx.program( vertex_shader=vertex_shader_source, geometry_shader=geometry_shader_source, fragment_shader=fragment_shader_source, ) # --- Create FPS graph # Enable timings for the performance graph arcade.enable_timings() # Create a sprite list to put the performance graph into self.perf_graph_list = arcade.SpriteList() # Create the FPS performance graph graph = arcade.PerfGraph(GRAPH_WIDTH, GRAPH_HEIGHT, graph_data="FPS") graph.center_x = GRAPH_WIDTH / 2 graph.center_y = self.height - GRAPH_HEIGHT / 2 self.perf_graph_list.append(graph) def on_draw(self): # Clear the screen self.clear() # Enable blending so our alpha channel works self.ctx.enable(self.ctx.BLEND) # Bind buffers self.ssbo_1.bind_to_storage_buffer(binding=0) self.ssbo_2.bind_to_storage_buffer(binding=1) # Set input variables for compute shader # These are examples, although this example doesn't use them # self.compute_shader["screen_size"] = self.get_size() # self.compute_shader["force"] = force # self.compute_shader["frame_time"] = self.run_time # Run compute shader self.compute_shader.run(group_x=self.group_x, group_y=self.group_y) # Draw the balls self.vao_2.render(self.program) # Swap the buffers around (we are ping-ping rendering between two buffers) self.ssbo_1, self.ssbo_2 = self.ssbo_2, self.ssbo_1 # Swap what geometry we draw self.vao_1, self.vao_2 = self.vao_2, self.vao_1 # Draw the graphs self.perf_graph_list.draw() def gen_initial_data(self): for i in range(self.num_balls): # Position/radius yield random.randrange(0, self.width) yield random.randrange(0, self.height) yield 0.0 # z (padding) yield 6.0 # Velocity yield 0.0 yield 0.0 yield 0.0 # vz (padding) yield 0.0 # vw (padding) # Color yield 1.0 # r yield 1.0 # g yield 1.0 # b yield 1.0 # a app = MyWindow() arcade.run() |
An expanded version of this, with support for 3D, is available at: https://github.com/pvcraven/n-body