1// Copyright 2023 The BoringSSL Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package runner 16 17import ( 18 "context" 19 "encoding/binary" 20 "fmt" 21 "io" 22 "net" 23 "os" 24 "sync" 25 "time" 26) 27 28type shimDispatcher struct { 29 lock sync.Mutex 30 nextShimID uint64 31 listener *net.TCPListener 32 shims map[uint64]*shimListener 33 closedShims map[uint64]struct{} 34 err error 35} 36 37func newShimDispatcher() (*shimDispatcher, error) { 38 listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback}) 39 if err != nil { 40 listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}}) 41 } 42 43 if err != nil { 44 return nil, err 45 } 46 d := &shimDispatcher{listener: listener, shims: make(map[uint64]*shimListener), closedShims: make(map[uint64]struct{})} 47 go d.acceptLoop() 48 return d, nil 49} 50 51func (d *shimDispatcher) NewShim() (*shimListener, error) { 52 d.lock.Lock() 53 defer d.lock.Unlock() 54 if d.err != nil { 55 return nil, d.err 56 } 57 58 l := &shimListener{dispatcher: d, shimID: d.nextShimID, connChan: make(chan net.Conn, 1)} 59 d.shims[l.shimID] = l 60 d.nextShimID++ 61 return l, nil 62} 63 64func (d *shimDispatcher) unregisterShim(l *shimListener) { 65 d.lock.Lock() 66 delete(d.shims, l.shimID) 67 d.closedShims[l.shimID] = struct{}{} 68 d.lock.Unlock() 69} 70 71func (d *shimDispatcher) acceptLoop() { 72 for { 73 conn, err := d.listener.Accept() 74 if err != nil { 75 // Something went wrong. Shut down the listener. 76 d.closeWithError(err) 77 return 78 } 79 80 go func() { 81 if err := d.dispatch(conn); err != nil { 82 // To be robust against port scanners, etc., we log a warning 83 // but otherwise treat undispatchable connections as non-fatal. 84 fmt.Fprintf(os.Stderr, "Error dispatching connection: %s\n", err) 85 conn.Close() 86 } 87 }() 88 } 89} 90 91func (d *shimDispatcher) dispatch(conn net.Conn) error { 92 conn.SetReadDeadline(time.Now().Add(*idleTimeout)) 93 var buf [8]byte 94 if _, err := io.ReadFull(conn, buf[:]); err != nil { 95 return err 96 } 97 conn.SetReadDeadline(time.Time{}) 98 99 shimID := binary.LittleEndian.Uint64(buf[:]) 100 d.lock.Lock() 101 shim, ok := d.shims[shimID] 102 _, closed := d.closedShims[shimID] 103 d.lock.Unlock() 104 if !ok { 105 // If the shim is known but already closed, just silently reject the 106 // connection. This may happen if runner fails the test at the shim's 107 // first connection, but the shim tries to make a second connection 108 // before it is killed. 109 if closed { 110 conn.Close() 111 return nil 112 } 113 return fmt.Errorf("shim ID %d not found", shimID) 114 } 115 116 shim.connChan <- conn 117 return nil 118} 119 120func (d *shimDispatcher) Close() error { 121 return d.closeWithError(net.ErrClosed) 122} 123 124func (d *shimDispatcher) closeWithError(err error) error { 125 closeErr := d.listener.Close() 126 127 d.lock.Lock() 128 shims := d.shims 129 d.shims = make(map[uint64]*shimListener) 130 d.err = err 131 d.lock.Unlock() 132 133 for _, shim := range shims { 134 shim.closeWithError(err) 135 } 136 return closeErr 137} 138 139type shimListener struct { 140 dispatcher *shimDispatcher 141 shimID uint64 142 // connChan contains connections from the dispatcher. On fatal error, it is 143 // closed, with the error available in err. 144 connChan chan net.Conn 145 err error 146 lock sync.Mutex 147} 148 149func (l *shimListener) Port() int { 150 return l.dispatcher.listener.Addr().(*net.TCPAddr).Port 151} 152 153func (l *shimListener) IsIPv6() bool { 154 return len(l.dispatcher.listener.Addr().(*net.TCPAddr).IP) == net.IPv6len 155} 156 157func (l *shimListener) ShimID() uint64 { 158 return l.shimID 159} 160 161func (l *shimListener) Close() error { 162 l.dispatcher.unregisterShim(l) 163 l.closeWithError(net.ErrClosed) 164 return nil 165} 166 167func (l *shimListener) closeWithError(err error) { 168 // Multiple threads may close the listener at once, so protect closing with 169 // a lock. 170 l.lock.Lock() 171 if l.err == nil { 172 l.err = err 173 close(l.connChan) 174 } 175 l.lock.Unlock() 176} 177 178func (l *shimListener) Accept(deadline time.Time) (net.Conn, error) { 179 var timerChan <-chan time.Time 180 if !deadline.IsZero() { 181 remaining := time.Until(deadline) 182 if remaining < 0 { 183 return nil, context.DeadlineExceeded 184 } 185 timer := time.NewTimer(remaining) 186 defer timer.Stop() 187 timerChan = timer.C 188 } 189 190 select { 191 case <-timerChan: 192 return nil, context.DeadlineExceeded 193 case conn, ok := <-l.connChan: 194 if !ok { 195 return nil, l.err 196 } 197 return conn, nil 198 } 199} 200