1 /*
2  * Copyright (c) 2018 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 /** @file mqtt_transport_socket_tls.h
8  *
9  * @brief Internal functions to handle transport over TLS socket.
10  */
11 
12 #include <zephyr/logging/log.h>
13 LOG_MODULE_REGISTER(net_mqtt_sock_tls, CONFIG_MQTT_LOG_LEVEL);
14 
15 #include <errno.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/net/mqtt.h>
18 
19 #include "mqtt_os.h"
20 
mqtt_client_tls_connect(struct mqtt_client * client)21 int mqtt_client_tls_connect(struct mqtt_client *client)
22 {
23 	const struct sockaddr *broker = client->broker;
24 	struct mqtt_sec_config *tls_config = &client->transport.tls.config;
25 	int ret;
26 
27 	client->transport.tls.sock = zsock_socket(broker->sa_family,
28 						  SOCK_STREAM, IPPROTO_TLS_1_2);
29 	if (client->transport.tls.sock < 0) {
30 		return -errno;
31 	}
32 
33 	NET_DBG("Created socket %d", client->transport.tls.sock);
34 
35 	if (client->transport.if_name != NULL) {
36 		struct ifreq ifname = { 0 };
37 
38 		strncpy(ifname.ifr_name, client->transport.if_name,
39 			sizeof(ifname.ifr_name) - 1);
40 
41 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_SOCKET,
42 				       SO_BINDTODEVICE, &ifname,
43 				       sizeof(struct ifreq));
44 		if (ret < 0) {
45 			NET_ERR("Failed to bind ot interface %s error (%d)",
46 				ifname.ifr_name, -errno);
47 			goto error;
48 		}
49 
50 		NET_DBG("Bound to interface %s", ifname.ifr_name);
51 	}
52 
53 #if defined(CONFIG_SOCKS)
54 	if (client->transport.proxy.addrlen != 0) {
55 		ret = setsockopt(client->transport.tls.sock,
56 				 SOL_SOCKET, SO_SOCKS5,
57 				 &client->transport.proxy.addr,
58 				 client->transport.proxy.addrlen);
59 		if (ret < 0) {
60 			goto error;
61 		}
62 	}
63 #endif
64 	/* Set secure socket options. */
65 	ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS, TLS_PEER_VERIFY,
66 			       &tls_config->peer_verify,
67 			       sizeof(tls_config->peer_verify));
68 	if (ret < 0) {
69 		goto error;
70 	}
71 
72 	if (tls_config->cipher_list != NULL && tls_config->cipher_count > 0) {
73 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
74 				       TLS_CIPHERSUITE_LIST, tls_config->cipher_list,
75 				       sizeof(int) * tls_config->cipher_count);
76 		if (ret < 0) {
77 			goto error;
78 		}
79 	}
80 
81 	if (tls_config->sec_tag_list != NULL && tls_config->sec_tag_count > 0) {
82 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
83 				       TLS_SEC_TAG_LIST, tls_config->sec_tag_list,
84 				       sizeof(sec_tag_t) * tls_config->sec_tag_count);
85 		if (ret < 0) {
86 			goto error;
87 		}
88 	}
89 
90 #if defined(CONFIG_MQTT_LIB_TLS_USE_ALPN)
91 	if (tls_config->alpn_protocol_name_list != NULL &&
92 		tls_config->alpn_protocol_name_count > 0) {
93 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
94 				TLS_ALPN_LIST, tls_config->alpn_protocol_name_list,
95 				sizeof(const char *) * tls_config->alpn_protocol_name_count);
96 		if (ret < 0) {
97 			goto error;
98 		}
99 	}
100 
101 #endif
102 
103 	if (tls_config->hostname) {
104 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
105 				       TLS_HOSTNAME, tls_config->hostname,
106 				       strlen(tls_config->hostname) + 1);
107 		if (ret < 0) {
108 			goto error;
109 		}
110 	}
111 
112 	if (tls_config->cert_nocopy != TLS_CERT_NOCOPY_NONE) {
113 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
114 				       TLS_CERT_NOCOPY, &tls_config->cert_nocopy,
115 				       sizeof(tls_config->cert_nocopy));
116 		if (ret < 0) {
117 			goto error;
118 		}
119 	}
120 
121 	size_t peer_addr_size = sizeof(struct sockaddr_in6);
122 
123 	if (broker->sa_family == AF_INET) {
124 		peer_addr_size = sizeof(struct sockaddr_in);
125 	}
126 
127 	ret = zsock_connect(client->transport.tls.sock, client->broker,
128 			    peer_addr_size);
129 	if (ret < 0) {
130 		goto error;
131 	}
132 
133 	NET_DBG("Connect completed");
134 	return 0;
135 
136 error:
137 	(void) zsock_close(client->transport.tls.sock);
138 	return -errno;
139 }
140 
mqtt_client_tls_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)141 int mqtt_client_tls_write(struct mqtt_client *client, const uint8_t *data,
142 			  uint32_t datalen)
143 {
144 	uint32_t offset = 0U;
145 	int ret;
146 
147 	while (offset < datalen) {
148 		ret = zsock_send(client->transport.tls.sock, data + offset,
149 				 datalen - offset, 0);
150 		if (ret < 0) {
151 			return -errno;
152 		}
153 
154 		offset += ret;
155 	}
156 
157 	return 0;
158 }
159 
mqtt_client_tls_write_msg(struct mqtt_client * client,const struct msghdr * message)160 int mqtt_client_tls_write_msg(struct mqtt_client *client,
161 			      const struct msghdr *message)
162 {
163 	int ret, i;
164 	size_t offset = 0;
165 	size_t total_len = 0;
166 
167 	for (i = 0; i < message->msg_iovlen; i++) {
168 		total_len += message->msg_iov[i].iov_len;
169 	}
170 
171 	while (offset < total_len) {
172 		ret = zsock_sendmsg(client->transport.tls.sock, message, 0);
173 		if (ret < 0) {
174 			return -errno;
175 		}
176 
177 		offset += ret;
178 		if (offset >= total_len) {
179 			break;
180 		}
181 
182 		/* Update msghdr for the next iteration. */
183 		for (i = 0; i < message->msg_iovlen; i++) {
184 			if (ret < message->msg_iov[i].iov_len) {
185 				message->msg_iov[i].iov_len -= ret;
186 				message->msg_iov[i].iov_base =
187 					(uint8_t *)message->msg_iov[i].iov_base + ret;
188 				break;
189 			}
190 
191 			ret -= message->msg_iov[i].iov_len;
192 			message->msg_iov[i].iov_len = 0;
193 		}
194 	}
195 
196 	return 0;
197 }
198 
mqtt_client_tls_read(struct mqtt_client * client,uint8_t * data,uint32_t buflen,bool shall_block)199 int mqtt_client_tls_read(struct mqtt_client *client, uint8_t *data, uint32_t buflen,
200 			 bool shall_block)
201 {
202 	int flags = 0;
203 	int ret;
204 
205 	if (!shall_block) {
206 		flags |= ZSOCK_MSG_DONTWAIT;
207 	}
208 
209 	ret = zsock_recv(client->transport.tls.sock, data, buflen, flags);
210 	if (ret < 0) {
211 		return -errno;
212 	}
213 
214 	return ret;
215 }
216 
mqtt_client_tls_disconnect(struct mqtt_client * client)217 int mqtt_client_tls_disconnect(struct mqtt_client *client)
218 {
219 	int ret;
220 
221 	NET_INFO("Closing socket %d", client->transport.tls.sock);
222 	ret = zsock_close(client->transport.tls.sock);
223 	if (ret < 0) {
224 		return -errno;
225 	}
226 
227 	return 0;
228 }
229