1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"encoding/binary"
9	"fmt"
10	"io"
11	"math"
12	"net"
13	"slices"
14	"time"
15)
16
17// opcodePacket signals a packet, encoded with a 32-bit length prefix, followed
18// by the payload.
19const opcodePacket = byte('P')
20
21// opcodeTimeout signals a read timeout, encoded by a 64-bit number of
22// nanoseconds. On receipt, the peer should reply with
23// opcodeTimeoutAck. opcodeTimeout may only be sent by the Go side.
24const opcodeTimeout = byte('T')
25
26// opcodeTimeoutAck acknowledges a read timeout. This opcode has no payload and
27// may only be sent by the C side. Timeout ACKs act as a synchronization point
28// at the timeout, to bracket one flight of messages from C.
29const opcodeTimeoutAck = byte('t')
30
31// opcodeMTU updates the shim's MTU, encoded as a 32-bit number of bytes.
32const opcodeMTU = byte('M')
33
34// opcodeExpectNextTimeout indicates that the shim should report a specified timeout
35// to the calling application. The timeout is encoded as in opcodeTimeout, but
36// MaxUint64 indicates there should be no timeout.
37const opcodeExpectNextTimeout = byte('E')
38
39type packetAdaptor struct {
40	net.Conn
41	debug *recordingConn
42}
43
44// newPacketAdaptor wraps a reliable streaming net.Conn into a reliable
45// packet-based net.Conn. The stream contains packets and control commands,
46// distinguished by a one byte opcode.
47func newPacketAdaptor(conn net.Conn) *packetAdaptor {
48	return &packetAdaptor{conn, nil}
49}
50
51func (p *packetAdaptor) log(message string, data []byte) {
52	if p.debug == nil {
53		return
54	}
55
56	p.debug.LogSpecial(message, data)
57}
58
59func (p *packetAdaptor) readOpcode() (byte, error) {
60	out := make([]byte, 1)
61	if _, err := io.ReadFull(p.Conn, out); err != nil {
62		return 0, err
63	}
64	return out[0], nil
65}
66
67func (p *packetAdaptor) readPacketBody() ([]byte, error) {
68	var length uint32
69	if err := binary.Read(p.Conn, binary.BigEndian, &length); err != nil {
70		return nil, err
71	}
72	out := make([]byte, length)
73	if _, err := io.ReadFull(p.Conn, out); err != nil {
74		return nil, err
75	}
76	return out, nil
77}
78
79func (p *packetAdaptor) Read(b []byte) (int, error) {
80	opcode, err := p.readOpcode()
81	if err != nil {
82		return 0, err
83	}
84	if opcode != opcodePacket {
85		return 0, fmt.Errorf("unexpected opcode '%d'", opcode)
86	}
87	out, err := p.readPacketBody()
88	if err != nil {
89		return 0, err
90	}
91	return copy(b, out), nil
92}
93
94func (p *packetAdaptor) Write(b []byte) (int, error) {
95	payload := make([]byte, 1+4+len(b))
96	payload[0] = opcodePacket
97	binary.BigEndian.PutUint32(payload[1:5], uint32(len(b)))
98	copy(payload[5:], b)
99	if _, err := p.Conn.Write(payload); err != nil {
100		return 0, err
101	}
102	return len(b), nil
103}
104
105// SendReadTimeout instructs the peer to simulate a read timeout. It then waits
106// for acknowledgement of the timeout, buffering any packets received since
107// then. The packets are then returned.
108func (p *packetAdaptor) SendReadTimeout(d time.Duration) ([][]byte, error) {
109	p.log("Simulating read timeout: "+d.String(), nil)
110
111	payload := make([]byte, 1+8)
112	payload[0] = opcodeTimeout
113	binary.BigEndian.PutUint64(payload[1:], uint64(d.Nanoseconds()))
114	if _, err := p.Conn.Write(payload); err != nil {
115		return nil, err
116	}
117
118	var packets [][]byte
119	for {
120		opcode, err := p.readOpcode()
121		if err != nil {
122			return nil, err
123		}
124		switch opcode {
125		case opcodeTimeoutAck:
126			p.log("Received timeout ACK", nil)
127			// Done! Return the packets buffered and continue.
128			return packets, nil
129		case opcodePacket:
130			// Buffer the packet for the caller to process.
131			packet, err := p.readPacketBody()
132			if err != nil {
133				return nil, err
134			}
135			p.log("Simulating dropped packet", packet)
136			packets = append(packets, packet)
137		default:
138			return nil, fmt.Errorf("unexpected opcode '%d'", opcode)
139		}
140	}
141}
142
143// SetPeerMTU instructs the peer to set the MTU to the specified value.
144func (p *packetAdaptor) SetPeerMTU(mtu int) error {
145	p.log(fmt.Sprintf("Setting MTU to %d", mtu), nil)
146
147	payload := make([]byte, 1+4)
148	payload[0] = opcodeMTU
149	binary.BigEndian.PutUint32(payload[1:], uint32(mtu))
150	_, err := p.Conn.Write(payload)
151	return err
152}
153
154// ExpectNextTimeout indicates the peer's next timeout should be d from now.
155func (p *packetAdaptor) ExpectNextTimeout(d time.Duration) error {
156	payload := make([]byte, 1+8)
157	payload[0] = opcodeExpectNextTimeout
158	binary.BigEndian.PutUint64(payload[1:], uint64(d.Nanoseconds()))
159	_, err := p.Conn.Write(payload)
160	return err
161}
162
163// ExpectNoNext indicates the peer should not have a next timeout.
164func (p *packetAdaptor) ExpectNoNextTimeout() error {
165	payload := make([]byte, 1+8)
166	payload[0] = opcodeExpectNextTimeout
167	binary.BigEndian.PutUint64(payload[1:], math.MaxUint64)
168	_, err := p.Conn.Write(payload)
169	return err
170}
171
172type replayAdaptor struct {
173	net.Conn
174	prevWrite []byte
175}
176
177// newReplayAdaptor wraps a packeted net.Conn. It transforms it into
178// one which, after writing a packet, always replays the previous
179// write.
180func newReplayAdaptor(conn net.Conn) net.Conn {
181	return &replayAdaptor{Conn: conn}
182}
183
184func (r *replayAdaptor) Write(b []byte) (int, error) {
185	n, err := r.Conn.Write(b)
186
187	// Replay the previous packet and save the current one to
188	// replay next.
189	if r.prevWrite != nil {
190		r.Conn.Write(r.prevWrite)
191	}
192	r.prevWrite = append(r.prevWrite[:0], b...)
193
194	return n, err
195}
196
197type damageAdaptor struct {
198	net.Conn
199	damage bool
200}
201
202// newDamageAdaptor wraps a packeted net.Conn. It transforms it into one which
203// optionally damages the final byte of every Write() call.
204func newDamageAdaptor(conn net.Conn) *damageAdaptor {
205	return &damageAdaptor{Conn: conn}
206}
207
208func (d *damageAdaptor) setDamage(damage bool) {
209	d.damage = damage
210}
211
212func (d *damageAdaptor) Write(b []byte) (int, error) {
213	if d.damage && len(b) > 0 {
214		b = slices.Clone(b)
215		b[len(b)-1]++
216	}
217	return d.Conn.Write(b)
218}
219