1# MicroPython uasyncio module
2# MIT license; Copyright (c) 2019-2020 Damien P. George
3
4from . import core
5
6
7class Stream:
8    def __init__(self, s, e={}):
9        self.s = s
10        self.e = e
11        self.out_buf = b""
12
13    def get_extra_info(self, v):
14        return self.e[v]
15
16    async def __aenter__(self):
17        return self
18
19    async def __aexit__(self, exc_type, exc, tb):
20        await self.close()
21
22    def close(self):
23        pass
24
25    async def wait_closed(self):
26        # TODO yield?
27        self.s.close()
28
29    async def read(self, n):
30        yield core._io_queue.queue_read(self.s)
31        return self.s.read(n)
32
33    async def readinto(self, buf):
34        yield core._io_queue.queue_read(self.s)
35        return self.s.readinto(buf)
36
37    async def readexactly(self, n):
38        r = b""
39        while n:
40            yield core._io_queue.queue_read(self.s)
41            r2 = self.s.read(n)
42            if r2 is not None:
43                if not len(r2):
44                    raise EOFError
45                r += r2
46                n -= len(r2)
47        return r
48
49    async def readline(self):
50        l = b""
51        while True:
52            yield core._io_queue.queue_read(self.s)
53            l2 = self.s.readline()  # may do multiple reads but won't block
54            l += l2
55            if not l2 or l[-1] == 10:  # \n (check l in case l2 is str)
56                return l
57
58    def write(self, buf):
59        self.out_buf += buf
60
61    async def drain(self):
62        mv = memoryview(self.out_buf)
63        off = 0
64        while off < len(mv):
65            yield core._io_queue.queue_write(self.s)
66            ret = self.s.write(mv[off:])
67            if ret is not None:
68                off += ret
69        self.out_buf = b""
70
71
72# Stream can be used for both reading and writing to save code size
73StreamReader = Stream
74StreamWriter = Stream
75
76
77# Create a TCP stream connection to a remote host
78async def open_connection(host, port):
79    from uerrno import EINPROGRESS
80    import usocket as socket
81
82    ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0]  # TODO this is blocking!
83    s = socket.socket(ai[0], ai[1], ai[2])
84    s.setblocking(False)
85    ss = Stream(s)
86    try:
87        s.connect(ai[-1])
88    except OSError as er:
89        if er.errno != EINPROGRESS:
90            raise er
91    yield core._io_queue.queue_write(s)
92    return ss, ss
93
94
95# Class representing a TCP stream server, can be closed and used in "async with"
96class Server:
97    async def __aenter__(self):
98        return self
99
100    async def __aexit__(self, exc_type, exc, tb):
101        self.close()
102        await self.wait_closed()
103
104    def close(self):
105        self.task.cancel()
106
107    async def wait_closed(self):
108        await self.task
109
110    async def _serve(self, s, cb):
111        # Accept incoming connections
112        while True:
113            try:
114                yield core._io_queue.queue_read(s)
115            except core.CancelledError:
116                # Shutdown server
117                s.close()
118                return
119            try:
120                s2, addr = s.accept()
121            except:
122                # Ignore a failed accept
123                continue
124            s2.setblocking(False)
125            s2s = Stream(s2, {"peername": addr})
126            core.create_task(cb(s2s, s2s))
127
128
129# Helper function to start a TCP stream server, running as a new task
130# TODO could use an accept-callback on socket read activity instead of creating a task
131async def start_server(cb, host, port, backlog=5):
132    import usocket as socket
133
134    # Create and bind server socket.
135    host = socket.getaddrinfo(host, port)[0]  # TODO this is blocking!
136    s = socket.socket()
137    s.setblocking(False)
138    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
139    s.bind(host[-1])
140    s.listen(backlog)
141
142    # Create and return server object and task.
143    srv = Server()
144    srv.task = core.create_task(srv._serve(s, cb))
145    return srv
146
147
148################################################################################
149# Legacy uasyncio compatibility
150
151
152async def stream_awrite(self, buf, off=0, sz=-1):
153    if off != 0 or sz != -1:
154        buf = memoryview(buf)
155        if sz == -1:
156            sz = len(buf)
157        buf = buf[off : off + sz]
158    self.write(buf)
159    await self.drain()
160
161
162Stream.aclose = Stream.wait_closed
163Stream.awrite = stream_awrite
164Stream.awritestr = stream_awrite  # TODO explicitly convert to bytes?
165