1 /* PSA firmware framework client API */
2 
3 /*
4  *  Copyright The Mbed TLS Contributors
5  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6  */
7 
8 #include <stdint.h>
9 #include <stdlib.h>
10 #include <stddef.h>
11 #include <assert.h>
12 #include <stdio.h>
13 #include <string.h>
14 #include <strings.h>
15 #include <inttypes.h>
16 #include <sys/types.h>
17 #include <sys/ipc.h>
18 #include <sys/msg.h>
19 
20 #include "client.h"
21 #include "common.h"
22 #include "error_ext.h"
23 #include "util.h"
24 
25 typedef struct internal_handle {
26     int server_qid;
27     int client_qid;
28     int internal_server_qid;
29     int valid;
30 } internal_handle_t;
31 
32 typedef struct vectors {
33     const psa_invec *in_vec;
34     size_t in_len;
35     psa_outvec *out_vec;
36     size_t out_len;
37 } vectors_t;
38 
39 /* Note that this implementation is functional and not secure */
40 int __psa_ff_client_security_state = NON_SECURE;
41 
42 /* Access to this global is not thread safe */
43 #define MAX_HANDLES 32
44 static internal_handle_t handles[MAX_HANDLES] = { { 0 } };
45 
get_next_free_handle()46 static int get_next_free_handle()
47 {
48     /* Never return handle 0 as it's a special null handle */
49     for (int i = 1; i < MAX_HANDLES; i++) {
50         if (handles[i].valid == 0) {
51             return i;
52         }
53     }
54     return -1;
55 }
56 
handle_is_valid(psa_handle_t handle)57 static int handle_is_valid(psa_handle_t handle)
58 {
59     if (handle > 0 && handle < MAX_HANDLES) {
60         if (handles[handle].valid == 1) {
61             return 1;
62         }
63     }
64     ERROR("ERROR: Invalid handle");
65     return 0;
66 }
67 
get_queue_info(char * path,int * cqid,int * sqid)68 static int get_queue_info(char *path, int *cqid, int *sqid)
69 {
70     key_t server_queue_key;
71     int rx_qid, server_qid;
72 
73     INFO("Attempting to contact a RoT service queue");
74 
75     if ((rx_qid = msgget(IPC_PRIVATE, 0660)) == -1) {
76         ERROR("msgget: rx_qid");
77         return -1;
78     }
79 
80     if ((server_queue_key = ftok(path, PROJECT_ID)) == -1) {
81         ERROR("ftok");
82         return -2;
83     }
84 
85     if ((server_qid = msgget(server_queue_key, 0)) == -1) {
86         ERROR("msgget: server_qid");
87         return -3;
88     }
89 
90     *cqid = rx_qid;
91     *sqid = server_qid;
92 
93     return 0;
94 }
95 
process_response(int rx_qid,vectors_t * vecs,int type,int * internal_server_qid)96 static psa_status_t process_response(int rx_qid, vectors_t *vecs, int type,
97                                      int *internal_server_qid)
98 {
99     struct message response, request;
100     psa_status_t ret = PSA_ERROR_CONNECTION_REFUSED;
101     size_t invec_seek[4] = { 0 };
102     size_t data_size;
103     psa_status_t invec, outvec; /* TODO: Should these be size_t ? */
104 
105     assert(internal_server_qid > 0);
106 
107     while (1) {
108         data_size = 0;
109         invec = 0;
110         outvec = 0;
111 
112         /* read response from server */
113         if (msgrcv(rx_qid, &response, sizeof(struct message_text), 0, 0) == -1) {
114             ERROR("   msgrcv failed");
115             return ret;
116         }
117 
118         /* process return message from server */
119         switch (response.message_type) {
120             case PSA_REPLY:
121                 memcpy(&ret, response.message_text.buf, sizeof(psa_status_t));
122                 INFO("   Message received from server: %d", ret);
123                 if (type == PSA_IPC_CONNECT && ret > 0) {
124                     *internal_server_qid = ret;
125                     INFO("   ASSSIGNED q ID %d", *internal_server_qid);
126                     ret = PSA_SUCCESS;
127                 }
128                 return ret;
129                 break;
130             case READ_REQUEST:
131                 /* read data request */
132                 request.message_type = READ_RESPONSE;
133 
134                 assert(vecs != 0);
135 
136                 memcpy(&invec, response.message_text.buf, sizeof(psa_status_t));
137                 memcpy(&data_size, response.message_text.buf+sizeof(size_t), sizeof(size_t));
138                 INFO("   Partition asked for %lu bytes from invec %d", data_size, invec);
139 
140                 /* need to add more checks here */
141                 assert(invec >= 0 && invec < PSA_MAX_IOVEC);
142 
143                 if (data_size > MAX_FRAGMENT_SIZE) {
144                     data_size = MAX_FRAGMENT_SIZE;
145                 }
146 
147                 /* send response */
148                 INFO("   invec_seek[invec] is %lu", invec_seek[invec]);
149                 INFO("   Reading from offset %p", vecs->in_vec[invec].base + invec_seek[invec]);
150                 memcpy(request.message_text.buf,
151                        (vecs->in_vec[invec].base + invec_seek[invec]),
152                        data_size);
153 
154                 /* update invec base TODO: check me */
155                 invec_seek[invec] = invec_seek[invec] + data_size;
156 
157                 INFO("   Sending message of type %li", request.message_type);
158                 INFO("       with content %s", request.message_text.buf);
159 
160                 if (msgsnd(*internal_server_qid, &request,
161                            sizeof(int) + sizeof(uint32_t) + data_size, 0) == -1) {
162                     ERROR("Internal error: failed to respond to read request");
163                 }
164                 break;
165             case WRITE_REQUEST:
166                 assert(vecs != 0);
167 
168                 request.message_type = WRITE_RESPONSE;
169 
170                 memcpy(&outvec, response.message_text.buf, sizeof(psa_status_t));
171                 memcpy(&data_size, response.message_text.buf + sizeof(size_t), sizeof(size_t));
172                 INFO("   Partition wants to write %lu bytes to outvec %d", data_size, outvec);
173 
174                 assert(outvec >= 0 && outvec < PSA_MAX_IOVEC);
175 
176                 /* copy memory into message and send back amount written */
177                 size_t sofar = vecs->out_vec[outvec].len;
178                 memcpy(vecs->out_vec[outvec].base + sofar,
179                        response.message_text.buf+(sizeof(size_t)*2), data_size);
180                 INFO("   Data size is %lu", data_size);
181                 vecs->out_vec[outvec].len += data_size;
182 
183                 INFO("   Sending message of type %li", request.message_type);
184 
185                 /* send response */
186                 if (msgsnd(*internal_server_qid, &request, sizeof(int) + data_size, 0) == -1) {
187                     ERROR("Internal error: failed to respond to write request");
188                 }
189                 break;
190             case SKIP_REQUEST:
191                 memcpy(&invec, response.message_text.buf, sizeof(psa_status_t));
192                 memcpy(&data_size, response.message_text.buf+sizeof(size_t), sizeof(size_t));
193                 INFO("   Partition asked to skip %lu bytes in invec %d", data_size, invec);
194                 assert(invec >= 0 && invec < PSA_MAX_IOVEC);
195                 /* update invec base TODO: check me */
196                 invec_seek[invec] = invec_seek[invec] + data_size;
197                 break;
198 
199             default:
200                 FATAL("   ERROR: unknown internal message type: %ld",
201                       response.message_type);
202         }
203     }
204 }
205 
send(int rx_qid,int server_qid,int * internal_server_qid,int32_t type,uint32_t minor_version,vectors_t * vecs)206 static psa_status_t send(int rx_qid, int server_qid, int *internal_server_qid,
207                          int32_t type, uint32_t minor_version, vectors_t *vecs)
208 {
209     psa_status_t ret = PSA_ERROR_CONNECTION_REFUSED;
210     size_t request_msg_size = (sizeof(int) + sizeof(long)); /* msg type plus queue id */
211     struct message request;
212     request.message_type = 1; /* TODO: change this */
213     request.message_text.psa_type = type;
214     vector_sizes_t vec_sizes;
215 
216     /* If the client is non-secure then set the NS bit */
217     if (__psa_ff_client_security_state != 0) {
218         request.message_type |= NON_SECURE;
219     }
220 
221     assert(request.message_type >= 0);
222 
223     INFO("SEND: Sending message of type %ld with psa_type %d", request.message_type, type);
224     INFO("     internal_server_qid = %i", *internal_server_qid);
225 
226     request.message_text.qid = rx_qid;
227 
228     if (type == PSA_IPC_CONNECT) {
229         memcpy(request.message_text.buf, &minor_version, sizeof(minor_version));
230         request_msg_size = request_msg_size + sizeof(minor_version);
231         INFO("   Request msg size is %lu", request_msg_size);
232     } else {
233         assert(internal_server_qid > 0);
234     }
235 
236     if (vecs != NULL && type >= PSA_IPC_CALL) {
237 
238         memset(&vec_sizes, 0, sizeof(vec_sizes));
239 
240         /* Copy invec sizes */
241         for (size_t i = 0; i < (vecs->in_len); i++) {
242             vec_sizes.invec_sizes[i] = vecs->in_vec[i].len;
243             INFO("   Client sending vector %lu: %lu", i, vec_sizes.invec_sizes[i]);
244         }
245 
246         /* Copy outvec sizes */
247         for (size_t i = 0; i < (vecs->out_len); i++) {
248             vec_sizes.outvec_sizes[i] = vecs->out_vec[i].len;
249 
250             /* Reset to 0 since we need to eventually fill in with bytes written */
251             vecs->out_vec[i].len = 0;
252         }
253 
254         memcpy(request.message_text.buf, &vec_sizes, sizeof(vec_sizes));
255         request_msg_size = request_msg_size + sizeof(vec_sizes);
256     }
257 
258     INFO("   Sending and then waiting");
259 
260     /* send message to server */
261     if (msgsnd(server_qid, &request, request_msg_size, 0) == -1) {
262         ERROR("   msgsnd failed");
263         return ret;
264     }
265 
266     return process_response(rx_qid, vecs, type, internal_server_qid);
267 }
268 
269 
psa_framework_version(void)270 uint32_t psa_framework_version(void)
271 {
272     return PSA_FRAMEWORK_VERSION;
273 }
274 
psa_connect(uint32_t sid,uint32_t minor_version)275 psa_handle_t psa_connect(uint32_t sid, uint32_t minor_version)
276 {
277     int idx;
278     psa_status_t ret;
279     char pathname[PATHNAMESIZE] = { 0 };
280 
281     idx = get_next_free_handle();
282 
283     /* if there's a free handle available */
284     if (idx >= 0) {
285         snprintf(pathname, PATHNAMESIZE - 1, TMP_FILE_BASE_PATH "psa_service_%u", sid);
286         INFO("Attempting to contact RoT service at %s", pathname);
287 
288         /* if communication is possible */
289         if (get_queue_info(pathname, &handles[idx].client_qid, &handles[idx].server_qid) >= 0) {
290 
291             ret = send(handles[idx].client_qid,
292                        handles[idx].server_qid,
293                        &handles[idx].internal_server_qid,
294                        PSA_IPC_CONNECT,
295                        minor_version,
296                        NULL);
297 
298             /* if connection accepted by RoT service */
299             if (ret >= 0) {
300                 handles[idx].valid = 1;
301                 return idx;
302             } else {
303                 ERROR("Server didn't like you");
304             }
305         } else {
306             ERROR("Couldn't contact RoT service. Does it exist?");
307 
308             if (__psa_ff_client_security_state == 0) {
309                 ERROR("Invalid SID");
310             }
311         }
312     }
313 
314     INFO("Couldn't obtain a free handle");
315     return PSA_ERROR_CONNECTION_REFUSED;
316 }
317 
psa_version(uint32_t sid)318 uint32_t psa_version(uint32_t sid)
319 {
320     int idx;
321     psa_status_t ret;
322     char pathname[PATHNAMESIZE] = { 0 };
323 
324     idx = get_next_free_handle();
325 
326     if (idx >= 0) {
327         snprintf(pathname, PATHNAMESIZE, TMP_FILE_BASE_PATH "psa_service_%u", sid);
328         if (get_queue_info(pathname, &handles[idx].client_qid, &handles[idx].server_qid) >= 0) {
329             ret = send(handles[idx].client_qid,
330                        handles[idx].server_qid,
331                        &handles[idx].internal_server_qid,
332                        VERSION_REQUEST,
333                        0,
334                        NULL);
335             INFO("psa_version: Recieved from server %d", ret);
336             if (ret > 0) {
337                 return ret;
338             }
339         }
340     }
341     ERROR("psa_version failed: does the service exist?");
342     return PSA_VERSION_NONE;
343 }
344 
psa_call(psa_handle_t handle,int32_t type,const psa_invec * in_vec,size_t in_len,psa_outvec * out_vec,size_t out_len)345 psa_status_t psa_call(psa_handle_t handle,
346                       int32_t type,
347                       const psa_invec *in_vec,
348                       size_t in_len,
349                       psa_outvec *out_vec,
350                       size_t out_len)
351 {
352     handle_is_valid(handle);
353 
354     if ((in_len + out_len) > PSA_MAX_IOVEC) {
355         ERROR("Too many iovecs: %lu + %lu", in_len, out_len);
356     }
357 
358     vectors_t vecs = { 0 };
359     vecs.in_vec = in_vec;
360     vecs.in_len = in_len;
361     vecs.out_vec = out_vec;
362     vecs.out_len = out_len;
363 
364     return send(handles[handle].client_qid,
365                 handles[handle].server_qid,
366                 &handles[handle].internal_server_qid,
367                 type,
368                 0,
369                 &vecs);
370 }
371 
psa_close(psa_handle_t handle)372 void psa_close(psa_handle_t handle)
373 {
374     handle_is_valid(handle);
375     if (send(handles[handle].client_qid, handles[handle].server_qid,
376              &handles[handle].internal_server_qid, PSA_IPC_DISCONNECT, 0, NULL)) {
377         ERROR("ERROR: Couldn't send disconnect msg");
378     } else {
379         if (msgctl(handles[handle].client_qid, IPC_RMID, NULL) != 0) {
380             ERROR("ERROR: Failed to delete msg queue");
381         }
382     }
383     INFO("Closing handle %u", handle);
384     handles[handle].valid = 0;
385 }
386