1// Copyright 2023 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package runner
16
17import (
18	"context"
19	"encoding/binary"
20	"fmt"
21	"io"
22	"net"
23	"os"
24	"sync"
25	"time"
26)
27
28type shimDispatcher struct {
29	lock        sync.Mutex
30	nextShimID  uint64
31	listener    *net.TCPListener
32	shims       map[uint64]*shimListener
33	closedShims map[uint64]struct{}
34	err         error
35}
36
37func newShimDispatcher() (*shimDispatcher, error) {
38	listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback})
39	if err != nil {
40		listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}})
41	}
42
43	if err != nil {
44		return nil, err
45	}
46	d := &shimDispatcher{listener: listener, shims: make(map[uint64]*shimListener), closedShims: make(map[uint64]struct{})}
47	go d.acceptLoop()
48	return d, nil
49}
50
51func (d *shimDispatcher) NewShim() (*shimListener, error) {
52	d.lock.Lock()
53	defer d.lock.Unlock()
54	if d.err != nil {
55		return nil, d.err
56	}
57
58	l := &shimListener{dispatcher: d, shimID: d.nextShimID, connChan: make(chan net.Conn, 1)}
59	d.shims[l.shimID] = l
60	d.nextShimID++
61	return l, nil
62}
63
64func (d *shimDispatcher) unregisterShim(l *shimListener) {
65	d.lock.Lock()
66	delete(d.shims, l.shimID)
67	d.closedShims[l.shimID] = struct{}{}
68	d.lock.Unlock()
69}
70
71func (d *shimDispatcher) acceptLoop() {
72	for {
73		conn, err := d.listener.Accept()
74		if err != nil {
75			// Something went wrong. Shut down the listener.
76			d.closeWithError(err)
77			return
78		}
79
80		go func() {
81			if err := d.dispatch(conn); err != nil {
82				// To be robust against port scanners, etc., we log a warning
83				// but otherwise treat undispatchable connections as non-fatal.
84				fmt.Fprintf(os.Stderr, "Error dispatching connection: %s\n", err)
85				conn.Close()
86			}
87		}()
88	}
89}
90
91func (d *shimDispatcher) dispatch(conn net.Conn) error {
92	conn.SetReadDeadline(time.Now().Add(*idleTimeout))
93	var buf [8]byte
94	if _, err := io.ReadFull(conn, buf[:]); err != nil {
95		return err
96	}
97	conn.SetReadDeadline(time.Time{})
98
99	shimID := binary.LittleEndian.Uint64(buf[:])
100	d.lock.Lock()
101	shim, ok := d.shims[shimID]
102	_, closed := d.closedShims[shimID]
103	d.lock.Unlock()
104	if !ok {
105		// If the shim is known but already closed, just silently reject the
106		// connection. This may happen if runner fails the test at the shim's
107		// first connection, but the shim tries to make a second connection
108		// before it is killed.
109		if closed {
110			conn.Close()
111			return nil
112		}
113		return fmt.Errorf("shim ID %d not found", shimID)
114	}
115
116	shim.connChan <- conn
117	return nil
118}
119
120func (d *shimDispatcher) Close() error {
121	return d.closeWithError(net.ErrClosed)
122}
123
124func (d *shimDispatcher) closeWithError(err error) error {
125	closeErr := d.listener.Close()
126
127	d.lock.Lock()
128	shims := d.shims
129	d.shims = make(map[uint64]*shimListener)
130	d.err = err
131	d.lock.Unlock()
132
133	for _, shim := range shims {
134		shim.closeWithError(err)
135	}
136	return closeErr
137}
138
139type shimListener struct {
140	dispatcher *shimDispatcher
141	shimID     uint64
142	// connChan contains connections from the dispatcher. On fatal error, it is
143	// closed, with the error available in err.
144	connChan chan net.Conn
145	err      error
146	lock     sync.Mutex
147}
148
149func (l *shimListener) Port() int {
150	return l.dispatcher.listener.Addr().(*net.TCPAddr).Port
151}
152
153func (l *shimListener) IsIPv6() bool {
154	return len(l.dispatcher.listener.Addr().(*net.TCPAddr).IP) == net.IPv6len
155}
156
157func (l *shimListener) ShimID() uint64 {
158	return l.shimID
159}
160
161func (l *shimListener) Close() error {
162	l.dispatcher.unregisterShim(l)
163	l.closeWithError(net.ErrClosed)
164	return nil
165}
166
167func (l *shimListener) closeWithError(err error) {
168	// Multiple threads may close the listener at once, so protect closing with
169	// a lock.
170	l.lock.Lock()
171	if l.err == nil {
172		l.err = err
173		close(l.connChan)
174	}
175	l.lock.Unlock()
176}
177
178func (l *shimListener) Accept(deadline time.Time) (net.Conn, error) {
179	var timerChan <-chan time.Time
180	if !deadline.IsZero() {
181		remaining := time.Until(deadline)
182		if remaining < 0 {
183			return nil, context.DeadlineExceeded
184		}
185		timer := time.NewTimer(remaining)
186		defer timer.Stop()
187		timerChan = timer.C
188	}
189
190	select {
191	case <-timerChan:
192		return nil, context.DeadlineExceeded
193	case conn, ok := <-l.connChan:
194		if !ok {
195			return nil, l.err
196		}
197		return conn, nil
198	}
199}
200