Compute Shader Tutorial#

Using the compute shader, you can use the GPU to perform calculations thousands of times faster than just by using the CPU.

In this example, we will simulate a star field using an ‘N-Body simulation’. Each star is effected by each other star’s gravity. For 1,000 stars, this means we have 1,000 x 1,000 = 1,000,000 million calculations to perform for each frame. The video has 65,000 stars, requiring 4.2 billion gravity force calculations per frame. On high-end hardware it can still run at 60 fps!

How does this work? There are three major parts to this program:

  • The Python code, this glues everything together.

  • The visualization shaders, which let us see the data.

  • The compute shader, which moves everything.

Visualization Shaders#

There are multiple visualization shaders, which operate in this order:

../../_images/shaders.svg

The Python program creates a shader storage buffer object (SSBO) of floating point numbers. This buffer has the x, y, z and radius of each star stored in in_vertex. It also stores the color in in_color.

The vertex shader doesn’t do much more than separate out the radius variable from the group of floats used to store position.

shaders/vertex_shader.glsl#
 1#version 330
 2
 3in vec4 in_vertex;
 4in vec4 in_color;
 5
 6out vec2 vertex_pos;
 7out float vertex_radius;
 8out vec4 vertex_color;
 9
10void main()
11{
12    vertex_pos = in_vertex.xy;
13    vertex_radius = in_vertex.w;
14    vertex_color = in_color;
15}

The geometry shader converts the single point (which we can’t render) to a square, which we can render. It changes the one point, to four points of a quad.

shaders/geometry_shader.glsl#
 1#version 330
 2
 3layout (points) in;
 4layout (triangle_strip, max_vertices = 4) out;
 5
 6// Use arcade's global projection UBO
 7uniform Projection {
 8    uniform mat4 matrix;
 9} proj;
10
11in vec2 vertex_pos[];
12in vec4 vertex_color[];
13in float vertex_radius[];
14
15out vec2 g_uv;
16out vec3 g_color;
17
18void main() {
19    vec2 center = vertex_pos[0];
20    vec2 hsize = vec2(vertex_radius[0]);
21
22    g_color = vertex_color[0].rgb;
23
24    gl_Position = proj.matrix * vec4(vec2(-hsize.x, hsize.y) + center, 0.0, 1.0);
25    g_uv = vec2(0, 1);
26    EmitVertex();
27
28    gl_Position = proj.matrix * vec4(vec2(-hsize.x, -hsize.y) + center, 0.0, 1.0);
29    g_uv = vec2(0, 0);
30    EmitVertex();
31
32    gl_Position = proj.matrix * vec4(vec2(hsize.x, hsize.y) + center, 0.0, 1.0);
33    g_uv = vec2(1, 1);
34    EmitVertex();
35
36    gl_Position = proj.matrix * vec4(vec2(hsize.x, -hsize.y) + center, 0.0, 1.0);
37    g_uv = vec2(1, 0);
38    EmitVertex();
39
40    EndPrimitive();
41}

The fragment shader runs for each pixel. It produces the soft glow effect of the star, and rounds off the quad into a circle.

shaders/fragment_shader.glsl#
 1#version 330
 2
 3in vec2 g_uv;
 4in vec3 g_color;
 5
 6out vec4 out_color;
 7
 8void main()
 9{
10    float l = length(vec2(0.5, 0.5) - g_uv.xy);
11    if ( l > 0.5)
12    {
13        discard;
14    }
15    float alpha;
16    if (l == 0.0)
17        alpha = 1.0;
18    else
19        alpha = min(1.0, .60-l * 2);
20
21    vec3 c = g_color.rgb;
22    // c.xy += v_uv.xy * 0.05;
23    // c.xy += v_pos.xy * 0.75;
24    out_color = vec4(c, alpha);
25}

Compute Shaders#

This program runs two buffers. We have an input buffer, with all our current data. We perform calculations on that data and write to the output buffer. We then swap those buffers for the next frame, where we use the output of the previous frame as the input to the next frame.

shaders/compute_shader.glsl#
 1#version 430
 2
 3// Set up our compute groups
 4layout(local_size_x=COMPUTE_SIZE_X, local_size_y=COMPUTE_SIZE_Y) in;
 5
 6// Input uniforms go here if you need them.
 7// Some examples:
 8//uniform vec2 screen_size;
 9//uniform vec2 force;
10//uniform float frame_time;
11
12// Structure of the ball data
13struct Ball
14{
15    vec4 pos;
16    vec4 vel;
17    vec4 color;
18};
19
20// Input buffer
21layout(std430, binding=0) buffer balls_in
22{
23    Ball balls[];
24} In;
25
26// Output buffer
27layout(std430, binding=1) buffer balls_out
28{
29    Ball balls[];
30} Out;
31
32void main()
33{
34    int curBallIndex = int(gl_GlobalInvocationID);
35
36    Ball in_ball = In.balls[curBallIndex];
37
38    vec4 p = in_ball.pos.xyzw;
39    vec4 v = in_ball.vel.xyzw;
40
41    // Move the ball according to the current force
42    p.xy += v.xy;
43
44    // Calculate the new force based on all the other bodies
45    for (int i=0; i < In.balls.length(); i++) {
46        // If enabled, this will keep the star from calculating gravity on itself
47        // However, it does slow down the calcluations do do this check.
48        //  if (i == x)
49        //      continue;
50
51        // Calculate distance squared
52        float dist = distance(In.balls[i].pos.xyzw.xy, p.xy);
53        float distanceSquared = dist * dist;
54
55        // If stars get too close the fling into never-never land.
56        // So use a minimum distance
57        float minDistance = 0.02;
58        float gravityStrength = 0.3;
59        float simulationSpeed = 0.002;
60        float force = min(minDistance, gravityStrength / distanceSquared) * -simulationSpeed;
61
62        vec2 diff = p.xy - In.balls[i].pos.xyzw.xy;
63        // We should normalize this I think, but it doesn't work.
64        //  diff = normalize(diff);
65        vec2 delta_v = diff * force;
66        v.xy += delta_v;
67    }
68
69
70    Ball out_ball;
71    out_ball.pos.xyzw = p.xyzw;
72    out_ball.vel.xyzw = v.xyzw;
73
74    vec4 c = in_ball.color.xyzw;
75    out_ball.color.xyzw = c.xyzw;
76
77    Out.balls[curBallIndex] = out_ball;
78}

Python Program#

Read through the code here, I’ve tried hard to explain all the parts in the comments.

main.py#
  1"""
  2Compute shader with buffers
  3"""
  4import random
  5from array import array
  6
  7import arcade
  8from arcade.gl import BufferDescription
  9
 10# Window dimensions
 11WINDOW_WIDTH = 2300
 12WINDOW_HEIGHT = 1300
 13
 14# Size of performance graphs
 15GRAPH_WIDTH = 200
 16GRAPH_HEIGHT = 120
 17GRAPH_MARGIN = 5
 18
 19
 20class MyWindow(arcade.Window):
 21
 22    def __init__(self):
 23        # Call parent constructor
 24        # Ask for OpenGL 4.3 context, as we need that for compute shader support.
 25        super().__init__(WINDOW_WIDTH, WINDOW_HEIGHT,
 26                         "Compute Shader",
 27                         gl_version=(4, 3),
 28                         resizable=True)
 29        self.center_window()
 30
 31        # --- Class instance variables
 32
 33        # Number of balls to move
 34        self.num_balls = 40000
 35
 36        # This has something to do with how we break the calculations up
 37        # and parallelize them.
 38        self.group_x = 256
 39        self.group_y = 1
 40
 41        # --- Create buffers
 42
 43        # Format of the buffer data.
 44        # 4f = position and size -> x, y, z, radius
 45        # 4x4 = Four floats used for calculating velocity. Not needed for visualization.
 46        # 4f = color -> rgba
 47        buffer_format = "4f 4x4 4f"
 48        # Generate the initial data that we will put in buffer 1.
 49        initial_data = self.gen_initial_data()
 50
 51        # Create data buffers for the compute shader
 52        # We ping-pong render between these two buffers
 53        # ssbo = shader storage buffer object
 54        self.ssbo_1 = self.ctx.buffer(data=array('f', initial_data))
 55        self.ssbo_2 = self.ctx.buffer(reserve=self.ssbo_1.size)
 56
 57        # Attribute variable names for the vertex shader
 58        attributes = ["in_vertex", "in_color"]
 59        self.vao_1 = self.ctx.geometry(
 60            [BufferDescription(self.ssbo_1, buffer_format, attributes)],
 61            mode=self.ctx.POINTS,
 62        )
 63        self.vao_2 = self.ctx.geometry(
 64            [BufferDescription(self.ssbo_2, buffer_format, attributes)],
 65            mode=self.ctx.POINTS,
 66        )
 67
 68        # --- Create shaders
 69
 70        # Load in the shader source code
 71        file = open("shaders/compute_shader.glsl")
 72        compute_shader_source = file.read()
 73        file = open("shaders/vertex_shader.glsl")
 74        vertex_shader_source = file.read()
 75        file = open("shaders/fragment_shader.glsl")
 76        fragment_shader_source = file.read()
 77        file = open("shaders/geometry_shader.glsl")
 78        geometry_shader_source = file.read()
 79
 80        # Create our compute shader.
 81        # Search/replace to set up our compute groups
 82        compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_X",
 83                                                              str(self.group_x))
 84        compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_Y",
 85                                                              str(self.group_y))
 86        self.compute_shader = self.ctx.compute_shader(source=compute_shader_source)
 87
 88        # Program for visualizing the balls
 89        self.program = self.ctx.program(
 90            vertex_shader=vertex_shader_source,
 91            geometry_shader=geometry_shader_source,
 92            fragment_shader=fragment_shader_source,
 93        )
 94
 95        # --- Create FPS graph
 96
 97        # Enable timings for the performance graph
 98        arcade.enable_timings()
 99
100        # Create a sprite list to put the performance graph into
101        self.perf_graph_list = arcade.SpriteList()
102
103        # Create the FPS performance graph
104        graph = arcade.PerfGraph(GRAPH_WIDTH, GRAPH_HEIGHT, graph_data="FPS")
105        graph.center_x = GRAPH_WIDTH / 2
106        graph.center_y = self.height - GRAPH_HEIGHT / 2
107        self.perf_graph_list.append(graph)
108
109    def on_draw(self):
110        # Clear the screen
111        self.clear()
112        # Enable blending so our alpha channel works
113        self.ctx.enable(self.ctx.BLEND)
114
115        # Bind buffers
116        self.ssbo_1.bind_to_storage_buffer(binding=0)
117        self.ssbo_2.bind_to_storage_buffer(binding=1)
118
119        # Set input variables for compute shader
120        # These are examples, although this example doesn't use them
121        # self.compute_shader["screen_size"] = self.get_size()
122        # self.compute_shader["force"] = force
123        # self.compute_shader["frame_time"] = self.run_time
124
125        # Run compute shader
126        self.compute_shader.run(group_x=self.group_x, group_y=self.group_y)
127
128        # Draw the balls
129        self.vao_2.render(self.program)
130
131        # Swap the buffers around (we are ping-ping rendering between two buffers)
132        self.ssbo_1, self.ssbo_2 = self.ssbo_2, self.ssbo_1
133        # Swap what geometry we draw
134        self.vao_1, self.vao_2 = self.vao_2, self.vao_1
135
136        # Draw the graphs
137        self.perf_graph_list.draw()
138
139    def gen_initial_data(self):
140        for i in range(self.num_balls):
141            # Position/radius
142            yield random.randrange(0, self.width)
143            yield random.randrange(0, self.height)
144            yield 0.0  # z (padding)
145            yield 6.0
146
147            # Velocity
148            yield 0.0
149            yield 0.0
150            yield 0.0  # vz (padding)
151            yield 0.0  # vw (padding)
152
153            # Color
154            yield 1.0  # r
155            yield 1.0  # g
156            yield 1.0  # b
157            yield 1.0  # a
158
159
160app = MyWindow()
161arcade.run()

An expanded version of this, with support for 3D, is available at: https://github.com/pvcraven/n-body