1// Copyright 2021 The BoringSSL Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//go:build ignore
16
17package main
18
19import (
20	"errors"
21	"flag"
22	"fmt"
23	"log"
24	"net"
25	"os"
26	"path"
27	"strings"
28
29	"golang.org/x/crypto/cryptobyte"
30	"golang.org/x/net/dns/dnsmessage"
31)
32
33const (
34	httpsType = 65 // RRTYPE for HTTPS records.
35
36	// SvcParamKey codepoints defined in draft-ietf-dnsop-svcb-https-06.
37	httpsKeyMandatory     = 0
38	httpsKeyALPN          = 1
39	httpsKeyNoDefaultALPN = 2
40	httpsKeyPort          = 3
41	httpsKeyIPV4Hint      = 4
42	httpsKeyECH           = 5
43	httpsKeyIPV6Hint      = 6
44)
45
46var (
47	name   = flag.String("name", "", "The name to look up in DNS. Required.")
48	server = flag.String("server", "8.8.8.8:53", "Comma-separated host and UDP port that defines the DNS server to query.")
49	outDir = flag.String("out-dir", "", "The directory where ECHConfigList values will be written. If unspecified, bytes are hexdumped to stdout.")
50)
51
52type httpsRecord struct {
53	priority   uint16
54	targetName string
55
56	// SvcParams:
57	mandatory     []uint16
58	alpn          []string
59	noDefaultALPN bool
60	hasPort       bool
61	port          uint16
62	ipv4hint      []net.IP
63	ech           []byte
64	ipv6hint      []net.IP
65	unknownParams map[uint16][]byte
66}
67
68// String pretty-prints |h| as a multi-line string with bullet points.
69func (h httpsRecord) String() string {
70	var b strings.Builder
71	fmt.Fprintf(&b, "HTTPS SvcPriority:%d TargetName:%q", h.priority, h.targetName)
72
73	if len(h.mandatory) != 0 {
74		fmt.Fprintf(&b, "\n  * mandatory: %v", h.mandatory)
75	}
76	if len(h.alpn) != 0 {
77		fmt.Fprintf(&b, "\n  * alpn: %q", h.alpn)
78	}
79	if h.noDefaultALPN {
80		fmt.Fprint(&b, "\n  * no-default-alpn")
81	}
82	if h.hasPort {
83		fmt.Fprintf(&b, "\n  * port: %d", h.port)
84	}
85	if len(h.ipv4hint) != 0 {
86		fmt.Fprintf(&b, "\n  * ipv4hint:")
87		for _, address := range h.ipv4hint {
88			fmt.Fprintf(&b, "\n    - %s", address)
89		}
90	}
91	if len(h.ech) != 0 {
92		fmt.Fprintf(&b, "\n  * ech: %x", h.ech)
93	}
94	if len(h.ipv6hint) != 0 {
95		fmt.Fprintf(&b, "\n  * ipv6hint:")
96		for _, address := range h.ipv6hint {
97			fmt.Fprintf(&b, "\n    - %s", address)
98		}
99	}
100	if len(h.unknownParams) != 0 {
101		fmt.Fprint(&b, "\n  * unknown SvcParams:")
102		for key, value := range h.unknownParams {
103			fmt.Fprintf(&b, "\n    - %d: %x", key, value)
104		}
105	}
106	return b.String()
107}
108
109// dnsQueryForHTTPS queries the DNS server over UDP for any HTTPS records
110// associated with |domain|. It scans the response's answers and returns all the
111// HTTPS records it finds. It returns an error if any connection steps fail.
112func dnsQueryForHTTPS(domain string) ([][]byte, error) {
113	udpAddr, err := net.ResolveUDPAddr("udp", *server)
114	if err != nil {
115		return nil, err
116	}
117	conn, err := net.DialUDP("udp", nil, udpAddr)
118	if err != nil {
119		return nil, fmt.Errorf("failed to dial: %s", err)
120	}
121	defer conn.Close()
122
123	// Domain name must be canonical or message packing will fail.
124	if domain[len(domain)-1] != '.' {
125		domain += "."
126	}
127	dnsName, err := dnsmessage.NewName(domain)
128	if err != nil {
129		return nil, fmt.Errorf("failed to create DNS name from %q: %s", domain, err)
130	}
131	question := dnsmessage.Question{
132		Name:  dnsName,
133		Type:  httpsType,
134		Class: dnsmessage.ClassINET,
135	}
136	msg := dnsmessage.Message{
137		Header: dnsmessage.Header{
138			RecursionDesired: true,
139		},
140		Questions: []dnsmessage.Question{question},
141	}
142	packedMsg, err := msg.Pack()
143	if err != nil {
144		return nil, fmt.Errorf("failed to pack msg: %s", err)
145	}
146
147	if _, err = conn.Write(packedMsg); err != nil {
148		return nil, fmt.Errorf("failed to send the DNS query: %s", err)
149	}
150
151	for {
152		response := make([]byte, 512)
153		n, err := conn.Read(response)
154		if err != nil {
155			return nil, fmt.Errorf("failed to read the DNS response: %s", err)
156		}
157		response = response[:n]
158
159		var p dnsmessage.Parser
160		header, err := p.Start(response)
161		if err != nil {
162			return nil, err
163		}
164		if !header.Response {
165			return nil, errors.New("received DNS message is not a response")
166		}
167		if header.RCode != dnsmessage.RCodeSuccess {
168			return nil, fmt.Errorf("response from DNS has non-success RCode: %s", header.RCode.String())
169		}
170		if header.ID != 0 {
171			return nil, errors.New("received a DNS response with the wrong ID")
172		}
173		if !header.RecursionAvailable {
174			return nil, errors.New("server does not support recursion")
175		}
176		// Verify that this response answers the question that we asked in the
177		// query. If the resolver encountered any CNAMEs, it's not guaranteed
178		// that the response will contain a question with the same QNAME as our
179		// query. However, RFC 8499 Section 4 indicates that in general use, the
180		// response's QNAME should match the query, so we will make that
181		// assumption.
182		q, err := p.Question()
183		if err != nil {
184			return nil, err
185		}
186		if q != question {
187			return nil, fmt.Errorf("response answers the wrong question: %v", q)
188		}
189		if q, err = p.Question(); err != dnsmessage.ErrSectionDone {
190			return nil, fmt.Errorf("response contains an unexpected question: %v", q)
191		}
192
193		var httpsRecords [][]byte
194		for {
195			h, err := p.AnswerHeader()
196			if err == dnsmessage.ErrSectionDone {
197				break
198			}
199			if err != nil {
200				return nil, err
201			}
202
203			switch h.Type {
204			case httpsType:
205				// This should continue to work when golang.org/x/net/dns/dnsmessage
206				// adds support for HTTPS records.
207				r, err := p.UnknownResource()
208				if err != nil {
209					return nil, err
210				}
211				httpsRecords = append(httpsRecords, r.Data)
212			default:
213				if _, err := p.UnknownResource(); err != nil {
214					return nil, err
215				}
216			}
217		}
218		return httpsRecords, nil
219	}
220}
221
222// parseHTTPSRecord parses an HTTPS record (draft-ietf-dnsop-svcb-https-06,
223// Section 2.2) from |raw|. If there are syntax errors, it returns an error.
224func parseHTTPSRecord(raw []byte) (httpsRecord, error) {
225	reader := cryptobyte.String(raw)
226
227	var priority uint16
228	if !reader.ReadUint16(&priority) {
229		return httpsRecord{}, errors.New("failed to parse HTTPS record priority")
230	}
231
232	// Read the TargetName.
233	var dottedDomain string
234	for {
235		var label cryptobyte.String
236		if !reader.ReadUint8LengthPrefixed(&label) {
237			return httpsRecord{}, errors.New("failed to parse HTTPS record TargetName")
238		}
239		if label.Empty() {
240			break
241		}
242		dottedDomain += string(label) + "."
243	}
244
245	if priority == 0 {
246		// TODO(dmcardle) Recursively follow AliasForm records.
247		return httpsRecord{}, fmt.Errorf("received an AliasForm HTTPS record with TargetName=%q", dottedDomain)
248	}
249
250	record := httpsRecord{
251		priority:      priority,
252		targetName:    dottedDomain,
253		unknownParams: make(map[uint16][]byte),
254	}
255
256	// Read the SvcParams.
257	var lastSvcParamKey uint16
258	for svcParamCount := 0; !reader.Empty(); svcParamCount++ {
259		var svcParamKey uint16
260		var svcParamValue cryptobyte.String
261		if !reader.ReadUint16(&svcParamKey) ||
262			!reader.ReadUint16LengthPrefixed(&svcParamValue) {
263			return httpsRecord{}, errors.New("failed to parse HTTPS record SvcParam")
264		}
265		if svcParamCount > 0 && svcParamKey <= lastSvcParamKey {
266			return httpsRecord{}, errors.New("malformed HTTPS record contains out-of-order SvcParamKey")
267		}
268		lastSvcParamKey = svcParamKey
269
270		switch svcParamKey {
271		case httpsKeyMandatory:
272			if svcParamValue.Empty() {
273				return httpsRecord{}, errors.New("malformed mandatory SvcParamValue")
274			}
275			var lastKey uint16
276			for !svcParamValue.Empty() {
277				// |httpsKeyMandatory| may not appear in the mandatory list.
278				// |httpsKeyMandatory| is zero, so checking against the initial
279				// value of |lastKey| handles ordering and the invalid code point.
280				var key uint16
281				if !svcParamValue.ReadUint16(&key) ||
282					key <= lastKey {
283					return httpsRecord{}, errors.New("malformed mandatory SvcParamValue")
284				}
285				lastKey = key
286				record.mandatory = append(record.mandatory, key)
287			}
288		case httpsKeyALPN:
289			if svcParamValue.Empty() {
290				return httpsRecord{}, errors.New("malformed alpn SvcParamValue")
291			}
292			for !svcParamValue.Empty() {
293				var alpn cryptobyte.String
294				if !svcParamValue.ReadUint8LengthPrefixed(&alpn) || alpn.Empty() {
295					return httpsRecord{}, errors.New("malformed alpn SvcParamValue")
296				}
297				record.alpn = append(record.alpn, string(alpn))
298			}
299		case httpsKeyNoDefaultALPN:
300			if !svcParamValue.Empty() {
301				return httpsRecord{}, errors.New("malformed no-default-alpn SvcParamValue")
302			}
303			record.noDefaultALPN = true
304		case httpsKeyPort:
305			if !svcParamValue.ReadUint16(&record.port) ||
306				!svcParamValue.Empty() {
307				return httpsRecord{}, errors.New("malformed port SvcParamValue")
308			}
309			record.hasPort = true
310		case httpsKeyIPV4Hint:
311			if svcParamValue.Empty() {
312				return httpsRecord{}, errors.New("malformed ipv4hint SvcParamValue")
313			}
314			for !svcParamValue.Empty() {
315				var address []byte
316				if !svcParamValue.ReadBytes(&address, 4) {
317					return httpsRecord{}, errors.New("malformed ipv4hint SvcParamValue")
318				}
319				record.ipv4hint = append(record.ipv4hint, address)
320			}
321		case httpsKeyECH:
322			if svcParamValue.Empty() {
323				return httpsRecord{}, errors.New("malformed ech SvcParamValue")
324			}
325			record.ech = svcParamValue
326		case httpsKeyIPV6Hint:
327			if svcParamValue.Empty() {
328				return httpsRecord{}, errors.New("malformed ipv6hint SvcParamValue")
329			}
330			for !svcParamValue.Empty() {
331				var address []byte
332				if !svcParamValue.ReadBytes(&address, 16) {
333					return httpsRecord{}, errors.New("malformed ipv6hint SvcParamValue")
334				}
335				record.ipv6hint = append(record.ipv6hint, address)
336			}
337		default:
338			record.unknownParams[svcParamKey] = svcParamValue
339		}
340	}
341	return record, nil
342}
343
344func main() {
345	flag.Parse()
346	log.SetFlags(log.Lshortfile | log.LstdFlags)
347
348	if len(*name) == 0 {
349		flag.Usage()
350		os.Exit(1)
351	}
352
353	httpsRecords, err := dnsQueryForHTTPS(*name)
354	if err != nil {
355		log.Printf("Error querying %q: %s\n", *name, err)
356		os.Exit(1)
357	}
358	if len(httpsRecords) == 0 {
359		log.Println("No HTTPS records found in DNS response.")
360		os.Exit(1)
361	}
362
363	if len(*outDir) > 0 {
364		if err = os.Mkdir(*outDir, 0755); err != nil && !os.IsExist(err) {
365			log.Printf("Failed to create out directory %q: %s\n", *outDir, err)
366			os.Exit(1)
367		}
368	}
369
370	var echConfigListCount int
371	for _, httpsRecord := range httpsRecords {
372		record, err := parseHTTPSRecord(httpsRecord)
373		if err != nil {
374			log.Printf("Failed to parse HTTPS record: %s", err)
375			os.Exit(1)
376		}
377		fmt.Printf("%s\n", record)
378		if len(*outDir) == 0 {
379			continue
380		}
381
382		outFile := path.Join(*outDir, fmt.Sprintf("ech-config-list-%d", echConfigListCount))
383		if err = os.WriteFile(outFile, record.ech, 0644); err != nil {
384			log.Printf("Failed to write file: %s\n", err)
385			os.Exit(1)
386		}
387		fmt.Printf("Wrote ECHConfigList to %q\n", outFile)
388		echConfigListCount++
389	}
390}
391