Cord/crates/cord-shader/src/codegen_trig.rs

291 lines
10 KiB
Rust

use cord_trig::{NodeId, TrigGraph, TrigOp};
use std::fmt::Write;
/// Generate a complete WGSL raymarcher directly from a TrigGraph.
pub fn generate_wgsl_from_trig(graph: &TrigGraph) -> String {
let mut out = String::with_capacity(4096);
write_preamble(&mut out);
write_sdf_from_trig(&mut out, graph);
write_raymarcher(&mut out);
out
}
fn var_name(id: NodeId) -> String {
format!("v{id}")
}
fn write_sdf_from_trig(out: &mut String, graph: &TrigGraph) {
out.push_str("fn scene_sdf(p: vec3<f32>) -> f32 {\n");
for (i, op) in graph.nodes.iter().enumerate() {
let v = var_name(i as NodeId);
match op {
TrigOp::InputX => writeln!(out, " let {v} = p.x;").unwrap(),
TrigOp::InputY => writeln!(out, " let {v} = p.y;").unwrap(),
TrigOp::InputZ => writeln!(out, " let {v} = p.z;").unwrap(),
TrigOp::Const(c) => {
let f = *c as f32;
if f.is_nan() || f.is_infinite() {
writeln!(out, " let {v} = 0.0;").unwrap()
} else {
writeln!(out, " let {v} = {f:.8};").unwrap()
}
}
TrigOp::Add(a, b) => writeln!(out, " let {v} = {} + {};", var_name(*a), var_name(*b)).unwrap(),
TrigOp::Sub(a, b) => writeln!(out, " let {v} = {} - {};", var_name(*a), var_name(*b)).unwrap(),
TrigOp::Mul(a, b) => writeln!(out, " let {v} = {} * {};", var_name(*a), var_name(*b)).unwrap(),
TrigOp::Div(a, b) => writeln!(out, " let {v} = {} / {};", var_name(*a), var_name(*b)).unwrap(),
TrigOp::Neg(a) => writeln!(out, " let {v} = -{};", var_name(*a)).unwrap(),
TrigOp::Abs(a) => writeln!(out, " let {v} = abs({});", var_name(*a)).unwrap(),
TrigOp::Sin(a) => writeln!(out, " let {v} = sin({});", var_name(*a)).unwrap(),
TrigOp::Cos(a) => writeln!(out, " let {v} = cos({});", var_name(*a)).unwrap(),
TrigOp::Tan(a) => writeln!(out, " let {v} = tan({});", var_name(*a)).unwrap(),
TrigOp::Asin(a) => writeln!(out, " let {v} = asin({});", var_name(*a)).unwrap(),
TrigOp::Acos(a) => writeln!(out, " let {v} = acos({});", var_name(*a)).unwrap(),
TrigOp::Atan(a) => writeln!(out, " let {v} = atan({});", var_name(*a)).unwrap(),
TrigOp::Sinh(a) => writeln!(out, " let {v} = sinh({});", var_name(*a)).unwrap(),
TrigOp::Cosh(a) => writeln!(out, " let {v} = cosh({});", var_name(*a)).unwrap(),
TrigOp::Tanh(a) => writeln!(out, " let {v} = tanh({});", var_name(*a)).unwrap(),
TrigOp::Asinh(a) => writeln!(out, " let {v} = asinh({});", var_name(*a)).unwrap(),
TrigOp::Acosh(a) => writeln!(out, " let {v} = acosh({});", var_name(*a)).unwrap(),
TrigOp::Atanh(a) => writeln!(out, " let {v} = atanh({});", var_name(*a)).unwrap(),
TrigOp::Sqrt(a) => writeln!(out, " let {v} = sqrt({});", var_name(*a)).unwrap(),
TrigOp::Exp(a) => writeln!(out, " let {v} = exp({});", var_name(*a)).unwrap(),
TrigOp::Ln(a) => writeln!(out, " let {v} = log({});", var_name(*a)).unwrap(),
TrigOp::Hypot(a, b) => {
writeln!(out, " let {v} = sqrt({a_v} * {a_v} + {b_v} * {b_v});",
a_v = var_name(*a), b_v = var_name(*b)).unwrap()
}
TrigOp::Atan2(a, b) => {
writeln!(out, " let {v} = atan2({}, {});", var_name(*a), var_name(*b)).unwrap()
}
TrigOp::Min(a, b) => writeln!(out, " let {v} = min({}, {});", var_name(*a), var_name(*b)).unwrap(),
TrigOp::Max(a, b) => writeln!(out, " let {v} = max({}, {});", var_name(*a), var_name(*b)).unwrap(),
TrigOp::Clamp { val, lo, hi } => {
writeln!(out, " let {v} = clamp({}, {}, {});",
var_name(*val), var_name(*lo), var_name(*hi)).unwrap()
}
}
}
writeln!(out, " return {};", var_name(graph.output)).unwrap();
out.push_str("}\n\n");
}
fn write_preamble(out: &mut String) {
out.push_str(
r#"struct Uniforms {
resolution: vec2<f32>,
viewport_offset: vec2<f32>,
camera_pos: vec3<f32>,
time: f32,
camera_target: vec3<f32>,
fov: f32,
render_flags: vec4<f32>,
scene_scale: f32,
_pad: vec3<f32>,
};
@group(0) @binding(0) var<uniform> u: Uniforms;
"#);
}
fn write_raymarcher(out: &mut String) {
out.push_str(
r#"fn calc_normal(p: vec3<f32>) -> vec3<f32> {
let e = vec2<f32>(0.0005 * u.scene_scale, -0.0005 * u.scene_scale);
return normalize(
e.xyy * scene_sdf(p + e.xyy) +
e.yyx * scene_sdf(p + e.yyx) +
e.yxy * scene_sdf(p + e.yxy) +
e.xxx * scene_sdf(p + e.xxx)
);
}
fn soft_shadow(ro: vec3<f32>, rd: vec3<f32>, mint: f32, maxt: f32, k: f32) -> f32 {
let eps = 0.0002 * u.scene_scale;
let step_lo = 0.001 * u.scene_scale;
let step_hi = 0.5 * u.scene_scale;
var res = 1.0;
var t = mint;
var prev_d = 1e10;
for (var i = 0; i < 64; i++) {
let d = scene_sdf(ro + rd * t);
if d < eps { return 0.0; }
let y = d * d / (2.0 * prev_d);
let x = sqrt(max(d * d - y * y, 0.0));
res = min(res, k * x / max(t - y, 0.0001));
prev_d = d;
t += clamp(d, step_lo, step_hi);
if t > maxt { break; }
}
return clamp(res, 0.0, 1.0);
}
fn ao(p: vec3<f32>, n: vec3<f32>) -> f32 {
let s = u.scene_scale;
var occ = 0.0;
var w = 1.0;
for (var i = 0; i < 5; i++) {
let h = (0.01 + 0.12 * f32(i)) * s;
let d = scene_sdf(p + n * h);
occ += (h - d) * w;
w *= 0.7;
}
return clamp(1.0 - 1.5 * occ / s, 0.0, 1.0);
}
fn grid_aa(x: f32, line_w: f32) -> f32 {
let d = abs(fract(x) - 0.5);
let fw = fwidth(x);
return smoothstep(0.0, max(fw * 1.5, 0.001), d - line_w);
}
fn ground_plane(ro: vec3<f32>, rd: vec3<f32>) -> vec4<f32> {
if rd.z >= 0.0 { return vec4<f32>(0.0); }
let t = -ro.z / rd.z;
let max_ground = u.scene_scale * 50.0;
if t < 0.0 || t > max_ground { return vec4<f32>(0.0); }
let p = ro + rd * t;
let gs = max(u.scene_scale * 0.5, 1.0);
let gp = p.xy / gs;
// Minor grid (every cell)
let minor = grid_aa(gp.x, 0.02) * grid_aa(gp.y, 0.02);
// Major grid (every 5 cells)
let major = grid_aa(gp.x / 5.0, 0.04) * grid_aa(gp.y / 5.0, 0.04);
// Axis lines at world origin
let aw = gs * 0.08;
let afw_x = fwidth(p.x);
let afw_y = fwidth(p.y);
let ax = 1.0 - smoothstep(aw, aw + max(afw_y * 1.5, 0.001), abs(p.y));
let ay = 1.0 - smoothstep(aw, aw + max(afw_x * 1.5, 0.001), abs(p.x));
let fade_k = 0.3 / (u.scene_scale * u.scene_scale);
let fade = exp(-fade_k * t * t);
let base = vec3<f32>(0.22, 0.24, 0.28);
var col = mix(vec3<f32>(0.30, 0.32, 0.36), base, minor);
col = mix(vec3<f32>(0.38, 0.40, 0.44), col, major);
col = mix(vec3<f32>(0.55, 0.18, 0.18), col, ax * fade);
col = mix(vec3<f32>(0.18, 0.45, 0.18), col, ay * fade);
let sky_horizon = vec3<f32>(0.25, 0.35, 0.50);
col = mix(sky_horizon, col, fade);
let shad_max = u.scene_scale * 10.0;
let shad = soft_shadow(vec3<f32>(p.x, p.y, 0.001 * u.scene_scale), normalize(vec3<f32>(0.5, 0.8, 1.0)), 0.1 * u.scene_scale, shad_max, 8.0);
let shad_faded = mix(1.0, 0.5 + 0.5 * shad, fade);
return vec4<f32>(col * shad_faded, t);
}
fn get_camera_ray(uv: vec2<f32>) -> vec3<f32> {
let forward = normalize(u.camera_target - u.camera_pos);
let right = normalize(cross(forward, vec3<f32>(0.0, 0.0, 1.0)));
let up = cross(right, forward);
return normalize(forward * u.fov + right * uv.x - up * uv.y);
}
fn shade_ray(ro: vec3<f32>, rd: vec3<f32>) -> vec3<f32> {
let hit_eps = 0.0005 * u.scene_scale;
let max_t = u.scene_scale * 20.0;
var t = 0.0;
var min_d = 1e10;
var t_min = 0.0;
var hit = false;
for (var i = 0; i < 128; i++) {
let p = ro + rd * t;
let d = scene_sdf(p);
if d < min_d {
min_d = d;
t_min = t;
}
if d < hit_eps {
hit = true;
break;
}
t += d;
if t > max_t { break; }
}
var bg = vec3<f32>(0.0);
let gp = ground_plane(ro, rd);
if gp.w > 0.0 && u.render_flags.z > 0.5 {
bg = gp.xyz;
} else {
bg = mix(vec3<f32>(0.25, 0.35, 0.50), vec3<f32>(0.05, 0.12, 0.35), clamp(rd.z * 1.5, 0.0, 1.0));
}
bg = pow(bg, vec3<f32>(0.4545));
// SDF-derived coverage — analytical AA from the distance field
let pix = max(t_min, 0.001) / u.resolution.y;
var coverage: f32;
if hit {
coverage = 1.0;
} else {
coverage = 1.0 - smoothstep(0.0, pix * 2.0, min_d);
}
if coverage < 0.005 { return bg; }
let shade_t = select(t_min, t, hit);
let p = ro + rd * shade_t;
let n = calc_normal(p);
let light_dir = normalize(vec3<f32>(0.5, 0.8, 1.0));
let diff = max(dot(n, light_dir), 0.0);
let shad = mix(1.0, soft_shadow(p + n * 0.002 * u.scene_scale, light_dir, 0.02 * u.scene_scale, 15.0 * u.scene_scale, 8.0), u.render_flags.x);
let occ = mix(1.0, ao(p, n), u.render_flags.y);
let half_v = normalize(light_dir - rd);
let spec = pow(max(dot(n, half_v), 0.0), 48.0) * 0.5;
let sky_light = clamp(0.5 + 0.5 * n.z, 0.0, 1.0);
let bounce = clamp(0.5 - 0.5 * n.z, 0.0, 1.0);
let base = vec3<f32>(0.65, 0.67, 0.72);
var lin = vec3<f32>(0.0);
lin += diff * shad * vec3<f32>(1.0, 0.97, 0.9) * 1.5;
lin += sky_light * occ * vec3<f32>(0.30, 0.40, 0.60) * 0.7;
lin += bounce * occ * vec3<f32>(0.15, 0.12, 0.1) * 0.5;
var color = base * lin + spec * shad;
color = color / (color + vec3<f32>(1.0));
color = pow(color, vec3<f32>(0.4545));
return mix(bg, color, coverage);
}
@fragment
fn fs_main(@builtin(position) frag_coord: vec4<f32>) -> @location(0) vec4<f32> {
let ro = u.camera_pos;
let px = 1.0 / u.resolution.y;
let uv = (frag_coord.xy - u.viewport_offset - u.resolution * 0.5) * px;
let rd = get_camera_ray(uv);
return vec4<f32>(shade_ray(ro, rd), 1.0);
}
struct VsOutput {
@builtin(position) position: vec4<f32>,
};
@vertex
fn vs_main(@builtin(vertex_index) idx: u32) -> VsOutput {
var pos = array<vec2<f32>, 3>(
vec2<f32>(-1.0, -1.0),
vec2<f32>(3.0, -1.0),
vec2<f32>(-1.0, 3.0),
);
var out: VsOutput;
out.position = vec4<f32>(pos[idx], 0.0, 1.0);
return out;
}
"#);
}