1 /*
2 * Copyright (c) 2018 Nordic Semiconductor ASA
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 /** @file mqtt_transport_socket_tls.h
8 *
9 * @brief Internal functions to handle transport over TLS socket.
10 */
11
12 #include <zephyr/logging/log.h>
13 LOG_MODULE_REGISTER(net_mqtt_sock_tls, CONFIG_MQTT_LOG_LEVEL);
14
15 #include <errno.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/net/mqtt.h>
18
19 #include "mqtt_os.h"
20
mqtt_client_tls_connect(struct mqtt_client * client)21 int mqtt_client_tls_connect(struct mqtt_client *client)
22 {
23 const struct sockaddr *broker = client->broker;
24 struct mqtt_sec_config *tls_config = &client->transport.tls.config;
25 int ret;
26
27 client->transport.tls.sock = zsock_socket(broker->sa_family,
28 SOCK_STREAM, IPPROTO_TLS_1_2);
29 if (client->transport.tls.sock < 0) {
30 return -errno;
31 }
32
33 NET_DBG("Created socket %d", client->transport.tls.sock);
34
35 if (client->transport.if_name != NULL) {
36 struct ifreq ifname = { 0 };
37
38 strncpy(ifname.ifr_name, client->transport.if_name,
39 sizeof(ifname.ifr_name) - 1);
40
41 ret = zsock_setsockopt(client->transport.tls.sock, SOL_SOCKET,
42 SO_BINDTODEVICE, &ifname,
43 sizeof(struct ifreq));
44 if (ret < 0) {
45 NET_ERR("Failed to bind ot interface %s error (%d)",
46 ifname.ifr_name, -errno);
47 goto error;
48 }
49
50 NET_DBG("Bound to interface %s", ifname.ifr_name);
51 }
52
53 #if defined(CONFIG_SOCKS)
54 if (client->transport.proxy.addrlen != 0) {
55 ret = setsockopt(client->transport.tls.sock,
56 SOL_SOCKET, SO_SOCKS5,
57 &client->transport.proxy.addr,
58 client->transport.proxy.addrlen);
59 if (ret < 0) {
60 goto error;
61 }
62 }
63 #endif
64 /* Set secure socket options. */
65 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS, TLS_PEER_VERIFY,
66 &tls_config->peer_verify,
67 sizeof(tls_config->peer_verify));
68 if (ret < 0) {
69 goto error;
70 }
71
72 if (tls_config->cipher_list != NULL && tls_config->cipher_count > 0) {
73 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
74 TLS_CIPHERSUITE_LIST, tls_config->cipher_list,
75 sizeof(int) * tls_config->cipher_count);
76 if (ret < 0) {
77 goto error;
78 }
79 }
80
81 if (tls_config->sec_tag_list != NULL && tls_config->sec_tag_count > 0) {
82 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
83 TLS_SEC_TAG_LIST, tls_config->sec_tag_list,
84 sizeof(sec_tag_t) * tls_config->sec_tag_count);
85 if (ret < 0) {
86 goto error;
87 }
88 }
89
90 #if defined(CONFIG_MQTT_LIB_TLS_USE_ALPN)
91 if (tls_config->alpn_protocol_name_list != NULL &&
92 tls_config->alpn_protocol_name_count > 0) {
93 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
94 TLS_ALPN_LIST, tls_config->alpn_protocol_name_list,
95 sizeof(const char *) * tls_config->alpn_protocol_name_count);
96 if (ret < 0) {
97 goto error;
98 }
99 }
100
101 #endif
102
103 if (tls_config->hostname) {
104 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
105 TLS_HOSTNAME, tls_config->hostname,
106 strlen(tls_config->hostname) + 1);
107 if (ret < 0) {
108 goto error;
109 }
110 }
111
112 if (tls_config->cert_nocopy != TLS_CERT_NOCOPY_NONE) {
113 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
114 TLS_CERT_NOCOPY, &tls_config->cert_nocopy,
115 sizeof(tls_config->cert_nocopy));
116 if (ret < 0) {
117 goto error;
118 }
119 }
120
121 size_t peer_addr_size = sizeof(struct sockaddr_in6);
122
123 if (broker->sa_family == AF_INET) {
124 peer_addr_size = sizeof(struct sockaddr_in);
125 }
126
127 ret = zsock_connect(client->transport.tls.sock, client->broker,
128 peer_addr_size);
129 if (ret < 0) {
130 goto error;
131 }
132
133 NET_DBG("Connect completed");
134 return 0;
135
136 error:
137 (void) zsock_close(client->transport.tls.sock);
138 return -errno;
139 }
140
mqtt_client_tls_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)141 int mqtt_client_tls_write(struct mqtt_client *client, const uint8_t *data,
142 uint32_t datalen)
143 {
144 uint32_t offset = 0U;
145 int ret;
146
147 while (offset < datalen) {
148 ret = zsock_send(client->transport.tls.sock, data + offset,
149 datalen - offset, 0);
150 if (ret < 0) {
151 return -errno;
152 }
153
154 offset += ret;
155 }
156
157 return 0;
158 }
159
mqtt_client_tls_write_msg(struct mqtt_client * client,const struct msghdr * message)160 int mqtt_client_tls_write_msg(struct mqtt_client *client,
161 const struct msghdr *message)
162 {
163 int ret, i;
164 size_t offset = 0;
165 size_t total_len = 0;
166
167 for (i = 0; i < message->msg_iovlen; i++) {
168 total_len += message->msg_iov[i].iov_len;
169 }
170
171 while (offset < total_len) {
172 ret = zsock_sendmsg(client->transport.tls.sock, message, 0);
173 if (ret < 0) {
174 return -errno;
175 }
176
177 offset += ret;
178 if (offset >= total_len) {
179 break;
180 }
181
182 /* Update msghdr for the next iteration. */
183 for (i = 0; i < message->msg_iovlen; i++) {
184 if (ret < message->msg_iov[i].iov_len) {
185 message->msg_iov[i].iov_len -= ret;
186 message->msg_iov[i].iov_base =
187 (uint8_t *)message->msg_iov[i].iov_base + ret;
188 break;
189 }
190
191 ret -= message->msg_iov[i].iov_len;
192 message->msg_iov[i].iov_len = 0;
193 }
194 }
195
196 return 0;
197 }
198
mqtt_client_tls_read(struct mqtt_client * client,uint8_t * data,uint32_t buflen,bool shall_block)199 int mqtt_client_tls_read(struct mqtt_client *client, uint8_t *data, uint32_t buflen,
200 bool shall_block)
201 {
202 int flags = 0;
203 int ret;
204
205 if (!shall_block) {
206 flags |= ZSOCK_MSG_DONTWAIT;
207 }
208
209 ret = zsock_recv(client->transport.tls.sock, data, buflen, flags);
210 if (ret < 0) {
211 return -errno;
212 }
213
214 return ret;
215 }
216
mqtt_client_tls_disconnect(struct mqtt_client * client)217 int mqtt_client_tls_disconnect(struct mqtt_client *client)
218 {
219 int ret;
220
221 NET_INFO("Closing socket %d", client->transport.tls.sock);
222 ret = zsock_close(client->transport.tls.sock);
223 if (ret < 0) {
224 return -errno;
225 }
226
227 return 0;
228 }
229