1# A simple ray tracer 2# MIT license; Copyright (c) 2019 Damien P. George 3 4INF = 1e30 5EPS = 1e-6 6 7 8class Vec: 9 def __init__(self, x, y, z): 10 self.x, self.y, self.z = x, y, z 11 12 def __neg__(self): 13 return Vec(-self.x, -self.y, -self.z) 14 15 def __add__(self, rhs): 16 return Vec(self.x + rhs.x, self.y + rhs.y, self.z + rhs.z) 17 18 def __sub__(self, rhs): 19 return Vec(self.x - rhs.x, self.y - rhs.y, self.z - rhs.z) 20 21 def __mul__(self, rhs): 22 return Vec(self.x * rhs, self.y * rhs, self.z * rhs) 23 24 def length(self): 25 return (self.x ** 2 + self.y ** 2 + self.z ** 2) ** 0.5 26 27 def normalise(self): 28 l = self.length() 29 return Vec(self.x / l, self.y / l, self.z / l) 30 31 def dot(self, rhs): 32 return self.x * rhs.x + self.y * rhs.y + self.z * rhs.z 33 34 35RGB = Vec 36 37 38class Ray: 39 def __init__(self, p, d): 40 self.p, self.d = p, d 41 42 43class View: 44 def __init__(self, width, height, depth, pos, xdir, ydir, zdir): 45 self.width = width 46 self.height = height 47 self.depth = depth 48 self.pos = pos 49 self.xdir = xdir 50 self.ydir = ydir 51 self.zdir = zdir 52 53 def calc_dir(self, dx, dy): 54 return (self.xdir * dx + self.ydir * dy + self.zdir * self.depth).normalise() 55 56 57class Light: 58 def __init__(self, pos, colour, casts_shadows): 59 self.pos = pos 60 self.colour = colour 61 self.casts_shadows = casts_shadows 62 63 64class Surface: 65 def __init__(self, diffuse, specular, spec_idx, reflect, transp, colour): 66 self.diffuse = diffuse 67 self.specular = specular 68 self.spec_idx = spec_idx 69 self.reflect = reflect 70 self.transp = transp 71 self.colour = colour 72 73 @staticmethod 74 def dull(colour): 75 return Surface(0.7, 0.0, 1, 0.0, 0.0, colour * 0.6) 76 77 @staticmethod 78 def shiny(colour): 79 return Surface(0.2, 0.9, 32, 0.8, 0.0, colour * 0.3) 80 81 @staticmethod 82 def transparent(colour): 83 return Surface(0.2, 0.9, 32, 0.0, 0.8, colour * 0.3) 84 85 86class Sphere: 87 def __init__(self, surface, centre, radius): 88 self.surface = surface 89 self.centre = centre 90 self.radsq = radius ** 2 91 92 def intersect(self, ray): 93 v = self.centre - ray.p 94 b = v.dot(ray.d) 95 det = b ** 2 - v.dot(v) + self.radsq 96 if det > 0: 97 det **= 0.5 98 t1 = b - det 99 if t1 > EPS: 100 return t1 101 t2 = b + det 102 if t2 > EPS: 103 return t2 104 return INF 105 106 def surface_at(self, v): 107 return self.surface, (v - self.centre).normalise() 108 109 110class Plane: 111 def __init__(self, surface, centre, normal): 112 self.surface = surface 113 self.normal = normal.normalise() 114 self.cdotn = centre.dot(normal) 115 116 def intersect(self, ray): 117 ddotn = ray.d.dot(self.normal) 118 if abs(ddotn) > EPS: 119 t = (self.cdotn - ray.p.dot(self.normal)) / ddotn 120 if t > 0: 121 return t 122 return INF 123 124 def surface_at(self, p): 125 return self.surface, self.normal 126 127 128class Scene: 129 def __init__(self, ambient, light, objs): 130 self.ambient = ambient 131 self.light = light 132 self.objs = objs 133 134 135def trace_scene(canvas, view, scene, max_depth): 136 for v in range(canvas.height): 137 y = (-v + 0.5 * (canvas.height - 1)) * view.height / canvas.height 138 for u in range(canvas.width): 139 x = (u - 0.5 * (canvas.width - 1)) * view.width / canvas.width 140 ray = Ray(view.pos, view.calc_dir(x, y)) 141 c = trace_ray(scene, ray, max_depth) 142 canvas.put_pix(u, v, c) 143 144 145def trace_ray(scene, ray, depth): 146 # Find closest intersecting object 147 hit_t = INF 148 hit_obj = None 149 for obj in scene.objs: 150 t = obj.intersect(ray) 151 if t < hit_t: 152 hit_t = t 153 hit_obj = obj 154 155 # Check if any objects hit 156 if hit_obj is None: 157 return RGB(0, 0, 0) 158 159 # Compute location of ray intersection 160 point = ray.p + ray.d * hit_t 161 surf, surf_norm = hit_obj.surface_at(point) 162 if ray.d.dot(surf_norm) > 0: 163 surf_norm = -surf_norm 164 165 # Compute reflected ray 166 reflected = ray.d - surf_norm * (surf_norm.dot(ray.d) * 2) 167 168 # Ambient light 169 col = surf.colour * scene.ambient 170 171 # Diffuse, specular and shadow from light source 172 light_vec = scene.light.pos - point 173 light_dist = light_vec.length() 174 light_vec = light_vec.normalise() 175 ndotl = surf_norm.dot(light_vec) 176 ldotv = light_vec.dot(reflected) 177 if ndotl > 0 or ldotv > 0: 178 light_ray = Ray(point + light_vec * EPS, light_vec) 179 light_col = trace_to_light(scene, light_ray, light_dist) 180 if ndotl > 0: 181 col += light_col * surf.diffuse * ndotl 182 if ldotv > 0: 183 col += light_col * surf.specular * ldotv ** surf.spec_idx 184 185 # Reflections 186 if depth > 0 and surf.reflect > 0: 187 col += trace_ray(scene, Ray(point + reflected * EPS, reflected), depth - 1) * surf.reflect 188 189 # Transparency 190 if depth > 0 and surf.transp > 0: 191 col += trace_ray(scene, Ray(point + ray.d * EPS, ray.d), depth - 1) * surf.transp 192 193 return col 194 195 196def trace_to_light(scene, ray, light_dist): 197 col = scene.light.colour 198 for obj in scene.objs: 199 t = obj.intersect(ray) 200 if t < light_dist: 201 col *= obj.surface.transp 202 return col 203 204 205class Canvas: 206 def __init__(self, width, height): 207 self.width = width 208 self.height = height 209 self.data = bytearray(3 * width * height) 210 211 def put_pix(self, x, y, c): 212 off = 3 * (y * self.width + x) 213 self.data[off] = min(255, max(0, int(255 * c.x))) 214 self.data[off + 1] = min(255, max(0, int(255 * c.y))) 215 self.data[off + 2] = min(255, max(0, int(255 * c.z))) 216 217 def write_ppm(self, filename): 218 with open(filename, "wb") as f: 219 f.write(bytes("P6 %d %d 255\n" % (self.width, self.height), "ascii")) 220 f.write(self.data) 221 222 223def main(w, h, d): 224 canvas = Canvas(w, h) 225 view = View(32, 32, 64, Vec(0, 0, 50), Vec(1, 0, 0), Vec(0, 1, 0), Vec(0, 0, -1)) 226 scene = Scene( 227 0.5, 228 Light(Vec(0, 8, 0), RGB(1, 1, 1), True), 229 [ 230 Plane(Surface.dull(RGB(1, 0, 0)), Vec(-10, 0, 0), Vec(1, 0, 0)), 231 Plane(Surface.dull(RGB(0, 1, 0)), Vec(10, 0, 0), Vec(-1, 0, 0)), 232 Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, 0, -10), Vec(0, 0, 1)), 233 Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, -10, 0), Vec(0, 1, 0)), 234 Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, 10, 0), Vec(0, -1, 0)), 235 Sphere(Surface.shiny(RGB(1, 1, 1)), Vec(-5, -4, 3), 4), 236 Sphere(Surface.dull(RGB(0, 0, 1)), Vec(4, -5, 0), 4), 237 Sphere(Surface.transparent(RGB(0.2, 0.2, 0.2)), Vec(6, -1, 8), 4), 238 ], 239 ) 240 trace_scene(canvas, view, scene, d) 241 return canvas 242 243 244# For testing 245# main(256, 256, 4).write_ppm('rt.ppm') 246 247########################################################################### 248# Benchmark interface 249 250bm_params = { 251 (100, 100): (5, 5, 2), 252 (1000, 100): (18, 18, 3), 253 (5000, 100): (40, 40, 3), 254} 255 256 257def bm_setup(params): 258 return lambda: main(*params), lambda: (params[0] * params[1] * params[2], None) 259