1// Copyright 2014 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package runner 6 7import ( 8 "encoding/binary" 9 "fmt" 10 "io" 11 "math" 12 "net" 13 "slices" 14 "time" 15) 16 17// opcodePacket signals a packet, encoded with a 32-bit length prefix, followed 18// by the payload. 19const opcodePacket = byte('P') 20 21// opcodeTimeout signals a read timeout, encoded by a 64-bit number of 22// nanoseconds. On receipt, the peer should reply with 23// opcodeTimeoutAck. opcodeTimeout may only be sent by the Go side. 24const opcodeTimeout = byte('T') 25 26// opcodeTimeoutAck acknowledges a read timeout. This opcode has no payload and 27// may only be sent by the C side. Timeout ACKs act as a synchronization point 28// at the timeout, to bracket one flight of messages from C. 29const opcodeTimeoutAck = byte('t') 30 31// opcodeMTU updates the shim's MTU, encoded as a 32-bit number of bytes. 32const opcodeMTU = byte('M') 33 34// opcodeExpectNextTimeout indicates that the shim should report a specified timeout 35// to the calling application. The timeout is encoded as in opcodeTimeout, but 36// MaxUint64 indicates there should be no timeout. 37const opcodeExpectNextTimeout = byte('E') 38 39type packetAdaptor struct { 40 net.Conn 41 debug *recordingConn 42} 43 44// newPacketAdaptor wraps a reliable streaming net.Conn into a reliable 45// packet-based net.Conn. The stream contains packets and control commands, 46// distinguished by a one byte opcode. 47func newPacketAdaptor(conn net.Conn) *packetAdaptor { 48 return &packetAdaptor{conn, nil} 49} 50 51func (p *packetAdaptor) log(message string, data []byte) { 52 if p.debug == nil { 53 return 54 } 55 56 p.debug.LogSpecial(message, data) 57} 58 59func (p *packetAdaptor) readOpcode() (byte, error) { 60 out := make([]byte, 1) 61 if _, err := io.ReadFull(p.Conn, out); err != nil { 62 return 0, err 63 } 64 return out[0], nil 65} 66 67func (p *packetAdaptor) readPacketBody() ([]byte, error) { 68 var length uint32 69 if err := binary.Read(p.Conn, binary.BigEndian, &length); err != nil { 70 return nil, err 71 } 72 out := make([]byte, length) 73 if _, err := io.ReadFull(p.Conn, out); err != nil { 74 return nil, err 75 } 76 return out, nil 77} 78 79func (p *packetAdaptor) Read(b []byte) (int, error) { 80 opcode, err := p.readOpcode() 81 if err != nil { 82 return 0, err 83 } 84 if opcode != opcodePacket { 85 return 0, fmt.Errorf("unexpected opcode '%d'", opcode) 86 } 87 out, err := p.readPacketBody() 88 if err != nil { 89 return 0, err 90 } 91 return copy(b, out), nil 92} 93 94func (p *packetAdaptor) Write(b []byte) (int, error) { 95 payload := make([]byte, 1+4+len(b)) 96 payload[0] = opcodePacket 97 binary.BigEndian.PutUint32(payload[1:5], uint32(len(b))) 98 copy(payload[5:], b) 99 if _, err := p.Conn.Write(payload); err != nil { 100 return 0, err 101 } 102 return len(b), nil 103} 104 105// SendReadTimeout instructs the peer to simulate a read timeout. It then waits 106// for acknowledgement of the timeout, buffering any packets received since 107// then. The packets are then returned. 108func (p *packetAdaptor) SendReadTimeout(d time.Duration) ([][]byte, error) { 109 p.log("Simulating read timeout: "+d.String(), nil) 110 111 payload := make([]byte, 1+8) 112 payload[0] = opcodeTimeout 113 binary.BigEndian.PutUint64(payload[1:], uint64(d.Nanoseconds())) 114 if _, err := p.Conn.Write(payload); err != nil { 115 return nil, err 116 } 117 118 var packets [][]byte 119 for { 120 opcode, err := p.readOpcode() 121 if err != nil { 122 return nil, err 123 } 124 switch opcode { 125 case opcodeTimeoutAck: 126 p.log("Received timeout ACK", nil) 127 // Done! Return the packets buffered and continue. 128 return packets, nil 129 case opcodePacket: 130 // Buffer the packet for the caller to process. 131 packet, err := p.readPacketBody() 132 if err != nil { 133 return nil, err 134 } 135 p.log("Simulating dropped packet", packet) 136 packets = append(packets, packet) 137 default: 138 return nil, fmt.Errorf("unexpected opcode '%d'", opcode) 139 } 140 } 141} 142 143// SetPeerMTU instructs the peer to set the MTU to the specified value. 144func (p *packetAdaptor) SetPeerMTU(mtu int) error { 145 p.log(fmt.Sprintf("Setting MTU to %d", mtu), nil) 146 147 payload := make([]byte, 1+4) 148 payload[0] = opcodeMTU 149 binary.BigEndian.PutUint32(payload[1:], uint32(mtu)) 150 _, err := p.Conn.Write(payload) 151 return err 152} 153 154// ExpectNextTimeout indicates the peer's next timeout should be d from now. 155func (p *packetAdaptor) ExpectNextTimeout(d time.Duration) error { 156 payload := make([]byte, 1+8) 157 payload[0] = opcodeExpectNextTimeout 158 binary.BigEndian.PutUint64(payload[1:], uint64(d.Nanoseconds())) 159 _, err := p.Conn.Write(payload) 160 return err 161} 162 163// ExpectNoNext indicates the peer should not have a next timeout. 164func (p *packetAdaptor) ExpectNoNextTimeout() error { 165 payload := make([]byte, 1+8) 166 payload[0] = opcodeExpectNextTimeout 167 binary.BigEndian.PutUint64(payload[1:], math.MaxUint64) 168 _, err := p.Conn.Write(payload) 169 return err 170} 171 172type replayAdaptor struct { 173 net.Conn 174 prevWrite []byte 175} 176 177// newReplayAdaptor wraps a packeted net.Conn. It transforms it into 178// one which, after writing a packet, always replays the previous 179// write. 180func newReplayAdaptor(conn net.Conn) net.Conn { 181 return &replayAdaptor{Conn: conn} 182} 183 184func (r *replayAdaptor) Write(b []byte) (int, error) { 185 n, err := r.Conn.Write(b) 186 187 // Replay the previous packet and save the current one to 188 // replay next. 189 if r.prevWrite != nil { 190 r.Conn.Write(r.prevWrite) 191 } 192 r.prevWrite = append(r.prevWrite[:0], b...) 193 194 return n, err 195} 196 197type damageAdaptor struct { 198 net.Conn 199 damage bool 200} 201 202// newDamageAdaptor wraps a packeted net.Conn. It transforms it into one which 203// optionally damages the final byte of every Write() call. 204func newDamageAdaptor(conn net.Conn) *damageAdaptor { 205 return &damageAdaptor{Conn: conn} 206} 207 208func (d *damageAdaptor) setDamage(damage bool) { 209 d.damage = damage 210} 211 212func (d *damageAdaptor) Write(b []byte) (int, error) { 213 if d.damage && len(b) > 0 { 214 b = slices.Clone(b) 215 b[len(b)-1]++ 216 } 217 return d.Conn.Write(b) 218} 219