Compute Shader Tutorial#
Using the compute shader, you can use the GPU to perform calculations thousands of times faster than just by using the CPU.
In this example, we will simulate a star field using an ‘N-Body simulation’. Each star is effected by each other star’s gravity. For 1,000 stars, this means we have 1,000 x 1,000 = 1,000,000 million calculations to perform for each frame. The video has 65,000 stars, requiring 4.2 billion gravity force calculations per frame. On high-end hardware it can still run at 60 fps!
How does this work? There are three major parts to this program:
The Python code, this glues everything together.
The visualization shaders, which let us see the data.
The compute shader, which moves everything.
Visualization Shaders#
There are multiple visualization shaders, which operate in this order:
The Python program creates a shader storage buffer object (SSBO) of
floating point numbers. This buffer
has the x, y, z and radius of each star stored in in_vertex
. It also
stores the color in in_color
.
The vertex shader doesn’t do much more than separate out the radius variable from the group of floats used to store position.
1#version 330
2
3in vec4 in_vertex;
4in vec4 in_color;
5
6out vec2 vertex_pos;
7out float vertex_radius;
8out vec4 vertex_color;
9
10void main()
11{
12 vertex_pos = in_vertex.xy;
13 vertex_radius = in_vertex.w;
14 vertex_color = in_color;
15}
The geometry shader converts the single point (which we can’t render) to a square, which we can render. It changes the one point, to four points of a quad.
1#version 330
2
3layout (points) in;
4layout (triangle_strip, max_vertices = 4) out;
5
6// Use arcade's global projection UBO
7uniform Projection {
8 uniform mat4 matrix;
9} proj;
10
11in vec2 vertex_pos[];
12in vec4 vertex_color[];
13in float vertex_radius[];
14
15out vec2 g_uv;
16out vec3 g_color;
17
18void main() {
19 vec2 center = vertex_pos[0];
20 vec2 hsize = vec2(vertex_radius[0]);
21
22 g_color = vertex_color[0].rgb;
23
24 gl_Position = proj.matrix * vec4(vec2(-hsize.x, hsize.y) + center, 0.0, 1.0);
25 g_uv = vec2(0, 1);
26 EmitVertex();
27
28 gl_Position = proj.matrix * vec4(vec2(-hsize.x, -hsize.y) + center, 0.0, 1.0);
29 g_uv = vec2(0, 0);
30 EmitVertex();
31
32 gl_Position = proj.matrix * vec4(vec2(hsize.x, hsize.y) + center, 0.0, 1.0);
33 g_uv = vec2(1, 1);
34 EmitVertex();
35
36 gl_Position = proj.matrix * vec4(vec2(hsize.x, -hsize.y) + center, 0.0, 1.0);
37 g_uv = vec2(1, 0);
38 EmitVertex();
39
40 EndPrimitive();
41}
The fragment shader runs for each pixel. It produces the soft glow effect of the star, and rounds off the quad into a circle.
1#version 330
2
3in vec2 g_uv;
4in vec3 g_color;
5
6out vec4 out_color;
7
8void main()
9{
10 float l = length(vec2(0.5, 0.5) - g_uv.xy);
11 if ( l > 0.5)
12 {
13 discard;
14 }
15 float alpha;
16 if (l == 0.0)
17 alpha = 1.0;
18 else
19 alpha = min(1.0, .60-l * 2);
20
21 vec3 c = g_color.rgb;
22 // c.xy += v_uv.xy * 0.05;
23 // c.xy += v_pos.xy * 0.75;
24 out_color = vec4(c, alpha);
25}
Compute Shaders#
This program runs two buffers. We have an input buffer, with all our current data. We perform calculations on that data and write to the output buffer. We then swap those buffers for the next frame, where we use the output of the previous frame as the input to the next frame.
1#version 430
2
3// Set up our compute groups
4layout(local_size_x=COMPUTE_SIZE_X, local_size_y=COMPUTE_SIZE_Y) in;
5
6// Input uniforms go here if you need them.
7// Some examples:
8//uniform vec2 screen_size;
9//uniform vec2 force;
10//uniform float frame_time;
11
12// Structure of the ball data
13struct Ball
14{
15 vec4 pos;
16 vec4 vel;
17 vec4 color;
18};
19
20// Input buffer
21layout(std430, binding=0) buffer balls_in
22{
23 Ball balls[];
24} In;
25
26// Output buffer
27layout(std430, binding=1) buffer balls_out
28{
29 Ball balls[];
30} Out;
31
32void main()
33{
34 int curBallIndex = int(gl_GlobalInvocationID);
35
36 Ball in_ball = In.balls[curBallIndex];
37
38 vec4 p = in_ball.pos.xyzw;
39 vec4 v = in_ball.vel.xyzw;
40
41 // Move the ball according to the current force
42 p.xy += v.xy;
43
44 // Calculate the new force based on all the other bodies
45 for (int i=0; i < In.balls.length(); i++) {
46 // If enabled, this will keep the star from calculating gravity on itself
47 // However, it does slow down the calcluations do do this check.
48 // if (i == x)
49 // continue;
50
51 // Calculate distance squared
52 float dist = distance(In.balls[i].pos.xyzw.xy, p.xy);
53 float distanceSquared = dist * dist;
54
55 // If stars get too close the fling into never-never land.
56 // So use a minimum distance
57 float minDistance = 0.02;
58 float gravityStrength = 0.3;
59 float simulationSpeed = 0.002;
60 float force = min(minDistance, gravityStrength / distanceSquared) * -simulationSpeed;
61
62 vec2 diff = p.xy - In.balls[i].pos.xyzw.xy;
63 // We should normalize this I think, but it doesn't work.
64 // diff = normalize(diff);
65 vec2 delta_v = diff * force;
66 v.xy += delta_v;
67 }
68
69
70 Ball out_ball;
71 out_ball.pos.xyzw = p.xyzw;
72 out_ball.vel.xyzw = v.xyzw;
73
74 vec4 c = in_ball.color.xyzw;
75 out_ball.color.xyzw = c.xyzw;
76
77 Out.balls[curBallIndex] = out_ball;
78}
Python Program#
Read through the code here, I’ve tried hard to explain all the parts in the comments.
1"""
2Compute shader with buffers
3"""
4import random
5from array import array
6
7import arcade
8from arcade.gl import BufferDescription
9
10# Window dimensions
11WINDOW_WIDTH = 2300
12WINDOW_HEIGHT = 1300
13
14# Size of performance graphs
15GRAPH_WIDTH = 200
16GRAPH_HEIGHT = 120
17GRAPH_MARGIN = 5
18
19
20class MyWindow(arcade.Window):
21
22 def __init__(self):
23 # Call parent constructor
24 # Ask for OpenGL 4.3 context, as we need that for compute shader support.
25 super().__init__(WINDOW_WIDTH, WINDOW_HEIGHT,
26 "Compute Shader",
27 gl_version=(4, 3),
28 resizable=True)
29 self.center_window()
30
31 # --- Class instance variables
32
33 # Number of balls to move
34 self.num_balls = 40000
35
36 # This has something to do with how we break the calculations up
37 # and parallelize them.
38 self.group_x = 256
39 self.group_y = 1
40
41 # --- Create buffers
42
43 # Format of the buffer data.
44 # 4f = position and size -> x, y, z, radius
45 # 4x4 = Four floats used for calculating velocity. Not needed for visualization.
46 # 4f = color -> rgba
47 buffer_format = "4f 4x4 4f"
48 # Generate the initial data that we will put in buffer 1.
49 initial_data = self.gen_initial_data()
50
51 # Create data buffers for the compute shader
52 # We ping-pong render between these two buffers
53 # ssbo = shader storage buffer object
54 self.ssbo_1 = self.ctx.buffer(data=array('f', initial_data))
55 self.ssbo_2 = self.ctx.buffer(reserve=self.ssbo_1.size)
56
57 # Attribute variable names for the vertex shader
58 attributes = ["in_vertex", "in_color"]
59 self.vao_1 = self.ctx.geometry(
60 [BufferDescription(self.ssbo_1, buffer_format, attributes)],
61 mode=self.ctx.POINTS,
62 )
63 self.vao_2 = self.ctx.geometry(
64 [BufferDescription(self.ssbo_2, buffer_format, attributes)],
65 mode=self.ctx.POINTS,
66 )
67
68 # --- Create shaders
69
70 # Load in the shader source code
71 file = open("shaders/compute_shader.glsl")
72 compute_shader_source = file.read()
73 file = open("shaders/vertex_shader.glsl")
74 vertex_shader_source = file.read()
75 file = open("shaders/fragment_shader.glsl")
76 fragment_shader_source = file.read()
77 file = open("shaders/geometry_shader.glsl")
78 geometry_shader_source = file.read()
79
80 # Create our compute shader.
81 # Search/replace to set up our compute groups
82 compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_X",
83 str(self.group_x))
84 compute_shader_source = compute_shader_source.replace("COMPUTE_SIZE_Y",
85 str(self.group_y))
86 self.compute_shader = self.ctx.compute_shader(source=compute_shader_source)
87
88 # Program for visualizing the balls
89 self.program = self.ctx.program(
90 vertex_shader=vertex_shader_source,
91 geometry_shader=geometry_shader_source,
92 fragment_shader=fragment_shader_source,
93 )
94
95 # --- Create FPS graph
96
97 # Enable timings for the performance graph
98 arcade.enable_timings()
99
100 # Create a sprite list to put the performance graph into
101 self.perf_graph_list = arcade.SpriteList()
102
103 # Create the FPS performance graph
104 graph = arcade.PerfGraph(GRAPH_WIDTH, GRAPH_HEIGHT, graph_data="FPS")
105 graph.center_x = GRAPH_WIDTH / 2
106 graph.center_y = self.height - GRAPH_HEIGHT / 2
107 self.perf_graph_list.append(graph)
108
109 def on_draw(self):
110 # Clear the screen
111 self.clear()
112 # Enable blending so our alpha channel works
113 self.ctx.enable(self.ctx.BLEND)
114
115 # Bind buffers
116 self.ssbo_1.bind_to_storage_buffer(binding=0)
117 self.ssbo_2.bind_to_storage_buffer(binding=1)
118
119 # Set input variables for compute shader
120 # These are examples, although this example doesn't use them
121 # self.compute_shader["screen_size"] = self.get_size()
122 # self.compute_shader["force"] = force
123 # self.compute_shader["frame_time"] = self.run_time
124
125 # Run compute shader
126 self.compute_shader.run(group_x=self.group_x, group_y=self.group_y)
127
128 # Draw the balls
129 self.vao_2.render(self.program)
130
131 # Swap the buffers around (we are ping-ping rendering between two buffers)
132 self.ssbo_1, self.ssbo_2 = self.ssbo_2, self.ssbo_1
133 # Swap what geometry we draw
134 self.vao_1, self.vao_2 = self.vao_2, self.vao_1
135
136 # Draw the graphs
137 self.perf_graph_list.draw()
138
139 def gen_initial_data(self):
140 for i in range(self.num_balls):
141 # Position/radius
142 yield random.randrange(0, self.width)
143 yield random.randrange(0, self.height)
144 yield 0.0 # z (padding)
145 yield 6.0
146
147 # Velocity
148 yield 0.0
149 yield 0.0
150 yield 0.0 # vz (padding)
151 yield 0.0 # vw (padding)
152
153 # Color
154 yield 1.0 # r
155 yield 1.0 # g
156 yield 1.0 # b
157 yield 1.0 # a
158
159
160app = MyWindow()
161arcade.run()
An expanded version of this, with support for 3D, is available at: https://github.com/pvcraven/n-body