1# A simple ray tracer
2# MIT license; Copyright (c) 2019 Damien P. George
3
4INF = 1e30
5EPS = 1e-6
6
7
8class Vec:
9    def __init__(self, x, y, z):
10        self.x, self.y, self.z = x, y, z
11
12    def __neg__(self):
13        return Vec(-self.x, -self.y, -self.z)
14
15    def __add__(self, rhs):
16        return Vec(self.x + rhs.x, self.y + rhs.y, self.z + rhs.z)
17
18    def __sub__(self, rhs):
19        return Vec(self.x - rhs.x, self.y - rhs.y, self.z - rhs.z)
20
21    def __mul__(self, rhs):
22        return Vec(self.x * rhs, self.y * rhs, self.z * rhs)
23
24    def length(self):
25        return (self.x ** 2 + self.y ** 2 + self.z ** 2) ** 0.5
26
27    def normalise(self):
28        l = self.length()
29        return Vec(self.x / l, self.y / l, self.z / l)
30
31    def dot(self, rhs):
32        return self.x * rhs.x + self.y * rhs.y + self.z * rhs.z
33
34
35RGB = Vec
36
37
38class Ray:
39    def __init__(self, p, d):
40        self.p, self.d = p, d
41
42
43class View:
44    def __init__(self, width, height, depth, pos, xdir, ydir, zdir):
45        self.width = width
46        self.height = height
47        self.depth = depth
48        self.pos = pos
49        self.xdir = xdir
50        self.ydir = ydir
51        self.zdir = zdir
52
53    def calc_dir(self, dx, dy):
54        return (self.xdir * dx + self.ydir * dy + self.zdir * self.depth).normalise()
55
56
57class Light:
58    def __init__(self, pos, colour, casts_shadows):
59        self.pos = pos
60        self.colour = colour
61        self.casts_shadows = casts_shadows
62
63
64class Surface:
65    def __init__(self, diffuse, specular, spec_idx, reflect, transp, colour):
66        self.diffuse = diffuse
67        self.specular = specular
68        self.spec_idx = spec_idx
69        self.reflect = reflect
70        self.transp = transp
71        self.colour = colour
72
73    @staticmethod
74    def dull(colour):
75        return Surface(0.7, 0.0, 1, 0.0, 0.0, colour * 0.6)
76
77    @staticmethod
78    def shiny(colour):
79        return Surface(0.2, 0.9, 32, 0.8, 0.0, colour * 0.3)
80
81    @staticmethod
82    def transparent(colour):
83        return Surface(0.2, 0.9, 32, 0.0, 0.8, colour * 0.3)
84
85
86class Sphere:
87    def __init__(self, surface, centre, radius):
88        self.surface = surface
89        self.centre = centre
90        self.radsq = radius ** 2
91
92    def intersect(self, ray):
93        v = self.centre - ray.p
94        b = v.dot(ray.d)
95        det = b ** 2 - v.dot(v) + self.radsq
96        if det > 0:
97            det **= 0.5
98            t1 = b - det
99            if t1 > EPS:
100                return t1
101            t2 = b + det
102            if t2 > EPS:
103                return t2
104        return INF
105
106    def surface_at(self, v):
107        return self.surface, (v - self.centre).normalise()
108
109
110class Plane:
111    def __init__(self, surface, centre, normal):
112        self.surface = surface
113        self.normal = normal.normalise()
114        self.cdotn = centre.dot(normal)
115
116    def intersect(self, ray):
117        ddotn = ray.d.dot(self.normal)
118        if abs(ddotn) > EPS:
119            t = (self.cdotn - ray.p.dot(self.normal)) / ddotn
120            if t > 0:
121                return t
122        return INF
123
124    def surface_at(self, p):
125        return self.surface, self.normal
126
127
128class Scene:
129    def __init__(self, ambient, light, objs):
130        self.ambient = ambient
131        self.light = light
132        self.objs = objs
133
134
135def trace_scene(canvas, view, scene, max_depth):
136    for v in range(canvas.height):
137        y = (-v + 0.5 * (canvas.height - 1)) * view.height / canvas.height
138        for u in range(canvas.width):
139            x = (u - 0.5 * (canvas.width - 1)) * view.width / canvas.width
140            ray = Ray(view.pos, view.calc_dir(x, y))
141            c = trace_ray(scene, ray, max_depth)
142            canvas.put_pix(u, v, c)
143
144
145def trace_ray(scene, ray, depth):
146    # Find closest intersecting object
147    hit_t = INF
148    hit_obj = None
149    for obj in scene.objs:
150        t = obj.intersect(ray)
151        if t < hit_t:
152            hit_t = t
153            hit_obj = obj
154
155    # Check if any objects hit
156    if hit_obj is None:
157        return RGB(0, 0, 0)
158
159    # Compute location of ray intersection
160    point = ray.p + ray.d * hit_t
161    surf, surf_norm = hit_obj.surface_at(point)
162    if ray.d.dot(surf_norm) > 0:
163        surf_norm = -surf_norm
164
165    # Compute reflected ray
166    reflected = ray.d - surf_norm * (surf_norm.dot(ray.d) * 2)
167
168    # Ambient light
169    col = surf.colour * scene.ambient
170
171    # Diffuse, specular and shadow from light source
172    light_vec = scene.light.pos - point
173    light_dist = light_vec.length()
174    light_vec = light_vec.normalise()
175    ndotl = surf_norm.dot(light_vec)
176    ldotv = light_vec.dot(reflected)
177    if ndotl > 0 or ldotv > 0:
178        light_ray = Ray(point + light_vec * EPS, light_vec)
179        light_col = trace_to_light(scene, light_ray, light_dist)
180        if ndotl > 0:
181            col += light_col * surf.diffuse * ndotl
182        if ldotv > 0:
183            col += light_col * surf.specular * ldotv ** surf.spec_idx
184
185    # Reflections
186    if depth > 0 and surf.reflect > 0:
187        col += trace_ray(scene, Ray(point + reflected * EPS, reflected), depth - 1) * surf.reflect
188
189    # Transparency
190    if depth > 0 and surf.transp > 0:
191        col += trace_ray(scene, Ray(point + ray.d * EPS, ray.d), depth - 1) * surf.transp
192
193    return col
194
195
196def trace_to_light(scene, ray, light_dist):
197    col = scene.light.colour
198    for obj in scene.objs:
199        t = obj.intersect(ray)
200        if t < light_dist:
201            col *= obj.surface.transp
202    return col
203
204
205class Canvas:
206    def __init__(self, width, height):
207        self.width = width
208        self.height = height
209        self.data = bytearray(3 * width * height)
210
211    def put_pix(self, x, y, c):
212        off = 3 * (y * self.width + x)
213        self.data[off] = min(255, max(0, int(255 * c.x)))
214        self.data[off + 1] = min(255, max(0, int(255 * c.y)))
215        self.data[off + 2] = min(255, max(0, int(255 * c.z)))
216
217    def write_ppm(self, filename):
218        with open(filename, "wb") as f:
219            f.write(bytes("P6 %d %d 255\n" % (self.width, self.height), "ascii"))
220            f.write(self.data)
221
222
223def main(w, h, d):
224    canvas = Canvas(w, h)
225    view = View(32, 32, 64, Vec(0, 0, 50), Vec(1, 0, 0), Vec(0, 1, 0), Vec(0, 0, -1))
226    scene = Scene(
227        0.5,
228        Light(Vec(0, 8, 0), RGB(1, 1, 1), True),
229        [
230            Plane(Surface.dull(RGB(1, 0, 0)), Vec(-10, 0, 0), Vec(1, 0, 0)),
231            Plane(Surface.dull(RGB(0, 1, 0)), Vec(10, 0, 0), Vec(-1, 0, 0)),
232            Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, 0, -10), Vec(0, 0, 1)),
233            Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, -10, 0), Vec(0, 1, 0)),
234            Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, 10, 0), Vec(0, -1, 0)),
235            Sphere(Surface.shiny(RGB(1, 1, 1)), Vec(-5, -4, 3), 4),
236            Sphere(Surface.dull(RGB(0, 0, 1)), Vec(4, -5, 0), 4),
237            Sphere(Surface.transparent(RGB(0.2, 0.2, 0.2)), Vec(6, -1, 8), 4),
238        ],
239    )
240    trace_scene(canvas, view, scene, d)
241    return canvas
242
243
244# For testing
245# main(256, 256, 4).write_ppm('rt.ppm')
246
247###########################################################################
248# Benchmark interface
249
250bm_params = {
251    (100, 100): (5, 5, 2),
252    (1000, 100): (18, 18, 3),
253    (5000, 100): (40, 40, 3),
254}
255
256
257def bm_setup(params):
258    return lambda: main(*params), lambda: (params[0] * params[1] * params[2], None)
259