YrXtals/shaders/fft.wgsl

55 lines
1.4 KiB
WebGPU Shading Language

// radix-2 Cooley-Tukey 1D complex FFT, one bit_reverse pass plus log2N butterfly passes over scratch.
struct Args {
n: u32,
log2_n: u32,
stride: u32,
inverse: u32,
};
@group(0) @binding(0) var<uniform> args: Args;
@group(0) @binding(1) var<storage, read_write> data: array<vec2<f32>>;
@group(0) @binding(2) var<storage, read_write> scratch: array<vec2<f32>>;
@compute @workgroup_size(64)
fn bit_reverse(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= args.n) { return; }
var rev: u32 = 0u;
var x = i;
for (var b: u32 = 0u; b < args.log2_n; b = b + 1u) {
rev = (rev << 1u) | (x & 1u);
x = x >> 1u;
}
scratch[rev] = data[i];
}
const PI: f32 = 3.14159265358979;
@compute @workgroup_size(64)
fn butterfly(@builtin(global_invocation_id) gid: vec3<u32>) {
let k = gid.x;
let n = args.n;
if (k >= n / 2u) { return; }
let half = args.stride;
let m = half * 2u;
let group = k / half;
let j = k % half;
let base = group * m;
let ia = base + j;
let ib = ia + half;
var sign: f32 = -1.0;
if (args.inverse != 0u) { sign = 1.0; }
let angle = sign * PI * f32(j) / f32(half);
let w = vec2<f32>(cos(angle), sin(angle));
let a = scratch[ia];
let b = scratch[ib];
let bw = vec2<f32>(b.x * w.x - b.y * w.y, b.x * w.y + b.y * w.x);
scratch[ia] = a + bw;
scratch[ib] = a - bw;
}