1 /*
2  * Copyright (C) 2018-2022 Intel Corporation.
3  * Copyright (C) 2014, 2015 IBM Corporation.
4  *
5  * SPDX-License-Identifier: BSD-3-Clause
6  */
7 
8 #include <sys/un.h>
9 #include <errno.h>
10 #include <sys/wait.h>
11 #include <signal.h>
12 #include <sys/socket.h>
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <unistd.h>
16 #include <strings.h>
17 #include <string.h>
18 #include <stdbool.h>
19 #include <sys/stat.h>
20 #include <fcntl.h>
21 
22 #include "vmmapi.h"
23 #include "tpm_internal.h"
24 #include "log.h"
25 
26 /* According to definition in TPM2 spec */
27 #define TPM_ORD_ContinueSelfTest	0x53
28 #define TPM_TAG_RSP_COMMAND		0xc4
29 #define TPM_FAIL		9
30 #define PTM_INIT_FLAG_DELETE_VOLATILE	(1 << 0)
31 
32 /* To align with definition in SWTPM */
33 typedef uint32_t ptm_res;
34 
35 struct ptm_est {
36 	union {
37 		struct {
38 			ptm_res tpm_result;
39 			unsigned char bit; /* TPM established bit */
40 		} resp; /* response */
41 	} u;
42 };
43 
44 struct ptm_reset_est {
45 	union {
46 		struct {
47 			uint8_t loc; /* locality to use */
48 		} req; /* request */
49 		struct {
50 			ptm_res tpm_result;
51 		} resp; /* response */
52 	} u;
53 };
54 
55 struct ptm_init {
56 	union {
57 		struct {
58 			uint32_t init_flags; /* see definitions below */
59 		} req; /* request */
60 		struct {
61 			ptm_res tpm_result;
62 		} resp; /* response */
63 	} u;
64 };
65 
66 struct ptm_loc {
67 	union {
68 		struct {
69 			uint8_t loc; /* locality to set */
70 		} req; /* request */
71 		struct {
72 			ptm_res tpm_result;
73 		} resp; /* response */
74 	} u;
75 };
76 
77 struct ptm_getconfig {
78 	union {
79 		struct {
80 			ptm_res tpm_result;
81 			uint32_t flags;
82 		} resp; /* response */
83 	} u;
84 };
85 
86 struct ptm_setbuffersize {
87 	union {
88 		struct {
89 			uint32_t buffersize; /* 0 to query for current buffer size */
90 		} req; /* request */
91 		struct {
92 			ptm_res tpm_result;
93 			uint32_t buffersize; /* buffer size in use */
94 			uint32_t minsize; /* min. supported buffer size */
95 			uint32_t maxsize; /* max. supported buffer size */
96 		} resp; /* response */
97 	} u;
98 };
99 
100 typedef struct ptm_est ptm_est;
101 typedef struct ptm_reset_est ptm_reset_est;
102 typedef struct ptm_loc ptm_loc;
103 typedef struct ptm_init ptm_init;
104 typedef struct ptm_getconfig ptm_getconfig;
105 typedef struct ptm_setbuffersize ptm_setbuffersize;
106 
107 /* For TPM CRB definition */
108 #pragma pack(push, 1)
109 typedef struct  {
110 	uint16_t  tag;
111 	uint32_t  length;
112 	uint32_t  ordinal;
113 } tpm_input_header;
114 
115 typedef struct  {
116 	uint16_t  tag;
117 	uint32_t  length;
118 	uint32_t  return_code;
119 } tpm_output_header;
120 #pragma pack(pop)
121 
122 /* This is the main data structure for tpm emulator,
123  * it will work with one SWTPM instance to
124  * provide TPM functionlity to User VM.
125  *
126  * ctrl_chan_fd: fd to communicate with SWTPM ctrl channel
127  * cmd_chan_fd: fd to communicate with SWTPM cmd channel
128  * cur_locty_number: to store the last set locality
129  * established_flag & established_flag_cached: used in
130  *    swtpm_get_tpm_established_flag, to store tpm establish flag.
131  */
132 typedef struct swtpm_context {
133 	int ctrl_chan_fd;
134 	int cmd_chan_fd;
135 	uint8_t cur_locty_number; /* last set locality */
136 	unsigned int established_flag:1;
137 	unsigned int established_flag_cached:1;
138 } swtpm_context;
139 
140 /* Align with definition in SWTPM */
141 enum {
142 	CMD_GET_CAPABILITY = 1,		/* 0x01 */
143 	CMD_INIT,			/* 0x02 */
144 	CMD_SHUTDOWN,			/* 0x03 */
145 	CMD_GET_TPMESTABLISHED,		/* 0x04 */
146 	CMD_SET_LOCALITY,		/* 0x05 */
147 	CMD_HASH_START,			/* 0x06 */
148 	CMD_HASH_DATA,			/* 0x07 */
149 	CMD_HASH_END,			/* 0x08 */
150 	CMD_CANCEL_TPM_CMD,		/* 0x09 */
151 	CMD_STORE_VOLATILE,		/* 0x0a */
152 	CMD_RESET_TPMESTABLISHED,	/* 0x0b */
153 	CMD_GET_STATEBLOB,		/* 0x0c */
154 	CMD_SET_STATEBLOB,		/* 0x0d */
155 	CMD_STOP,			/* 0x0e */
156 	CMD_GET_CONFIG,			/* 0x0f */
157 	CMD_SET_DATAFD,			/* 0x10 */
158 	CMD_SET_BUFFERSIZE,		/* 0x11 */
159 	CMD_GET_INFO,			/* 0x12 */
160 };
161 
162 static swtpm_context tpm_context;
163 
164 
tpm_cmd_get_tag(const void * b)165 static inline uint16_t tpm_cmd_get_tag(const void *b)
166 {
167 	return __builtin_bswap16(*(uint16_t*)(b));
168 }
169 
tpm_cmd_get_size(const void * b)170 static inline uint32_t tpm_cmd_get_size(const void *b)
171 {
172 	return __builtin_bswap32(*(uint32_t*)(b + 2));
173 }
174 
tpm_cmd_get_ordinal(const void * b)175 static inline uint32_t tpm_cmd_get_ordinal(const void *b)
176 {
177 	return __builtin_bswap32(*(uint32_t*)(b + 6));
178 }
179 
tpm_cmd_get_errcode(const void * b)180 static inline uint32_t tpm_cmd_get_errcode(const void *b)
181 {
182 	return __builtin_bswap32(*(uint32_t*)(b + 6));
183 }
184 
tpm_cmd_set_tag(void * b,uint16_t tag)185 static inline void tpm_cmd_set_tag(void *b, uint16_t tag)
186 {
187 	*(uint16_t*)(b) = __builtin_bswap16(tag);
188 }
189 
tpm_cmd_set_size(void * b,uint32_t size)190 static inline void tpm_cmd_set_size(void *b, uint32_t size)
191 {
192 	*(uint32_t*)(b + 2) = __builtin_bswap32(size);
193 }
194 
tpm_cmd_set_error(void * b,uint32_t error)195 static inline void tpm_cmd_set_error(void *b, uint32_t error)
196 {
197 	*(uint32_t*)(b + 6) = __builtin_bswap32(error);
198 }
199 
tpm_is_selftest(const uint8_t * in,uint32_t in_len)200 static bool tpm_is_selftest(const uint8_t *in, uint32_t in_len)
201 {
202 	if (in_len >= sizeof(tpm_input_header))
203 		return tpm_cmd_get_ordinal(in) == TPM_ORD_ContinueSelfTest;
204 
205 	return false;
206 }
207 
ctrl_chan_conn(const char * servername)208 static int ctrl_chan_conn(const char *servername)
209 {
210 	int clifd;
211 	struct sockaddr_un servaddr;
212 	int ret;
213 
214 	if (!servername) {
215 		pr_err("%s error, invalid input\n", __func__);
216 		return -1;
217 	}
218 
219 	if (strnlen(servername, sizeof(servaddr.sun_path)) == (sizeof(servaddr.sun_path))) {
220 		pr_err("%s error, length of servername is too long\n", __func__);
221 		return -1;
222 	}
223 
224 	clifd = socket(AF_UNIX, SOCK_STREAM, 0);
225 	if (clifd < 0) {
226 		pr_err("socket failed.\n");
227 		return -1;
228 	}
229 
230 	bzero(&servaddr, sizeof(servaddr));
231 	servaddr.sun_family = AF_UNIX;
232 
233 	strncpy(servaddr.sun_path, servername, sizeof(servaddr.sun_path));
234 
235 	ret = connect(clifd, (struct sockaddr *)&servaddr, sizeof(servaddr));
236 	if (ret < 0) {
237 		pr_err("connect failed.\n");
238 		close(clifd);
239 		return -1;
240 	}
241 
242 	return clifd;
243 }
244 
ctrl_chan_write(int ctrl_chan_fd,const uint8_t * buf,int len,int * pdatafd,int fd_num)245 static int ctrl_chan_write(int ctrl_chan_fd, const uint8_t *buf, int len,
246 			int *pdatafd, int fd_num)
247 {
248 	int ret;
249 	struct msghdr msg;
250 	struct iovec iov[1];
251 	union {
252 		struct cmsghdr cm;
253 		char control[CMSG_SPACE(sizeof(int))];
254 	} control_un;
255 	struct cmsghdr *pcmsg;
256 
257 	if (!buf || (!pdatafd && fd_num)) {
258 		pr_err("%s error, invalid input\n", __func__);
259 		return -1;
260 	}
261 
262 	msg.msg_name = NULL;
263 	msg.msg_namelen = 0;
264 	iov[0].iov_base = (void*)buf;
265 	iov[0].iov_len = len;
266 	msg.msg_iov = iov;
267 	msg.msg_iovlen = 1;
268 
269 	if (fd_num == 0) {
270 		if (pdatafd)
271 			return -1;
272 
273 		msg.msg_control = NULL;
274 		msg.msg_controllen = 0;
275 	} else if (fd_num == 1) {
276 		msg.msg_control = control_un.control;
277 		msg.msg_controllen = sizeof(control_un.control);
278 
279 		pcmsg = CMSG_FIRSTHDR(&msg);
280 		pcmsg->cmsg_len = CMSG_LEN(sizeof(int));
281 		pcmsg->cmsg_level = SOL_SOCKET;
282 		pcmsg->cmsg_type = SCM_RIGHTS;
283 		*((int *)CMSG_DATA(pcmsg)) = *pdatafd;
284 	} else {
285 		pr_err("fd_num failed.\n");
286 		return -1;
287 	}
288 
289 	do {
290 		ret = sendmsg(ctrl_chan_fd, &msg, 0);
291 	} while (ret < 0 && errno == EINTR);
292 
293 	if (ret < 0) {
294 		pr_err("Failed to send msg, reason: %s\n", strerror(errno));
295 	}
296 
297 	return ret;
298 }
299 
ctrl_chan_read(int ctrl_chan_fd,uint8_t * buf,int len)300 static int ctrl_chan_read(int ctrl_chan_fd, uint8_t *buf, int len)
301 {
302 	struct msghdr msg;
303 	struct iovec iov[1];
304 	int recvd = 0;
305 	int n;
306 
307 	if (!buf) {
308 		pr_err("%s error, invalid input\n", __func__);
309 		return -1;
310 	}
311 
312 	msg.msg_name = NULL;
313 	msg.msg_namelen = 0;
314 	iov[0].iov_base = buf;
315 	iov[0].iov_len = len;
316 	msg.msg_iov = iov;
317 	msg.msg_iovlen = 1;
318 	/* No need to recv fd */
319 	msg.msg_control = NULL;
320 	msg.msg_controllen = 0;
321 
322 	while (recvd < len) {
323 		if (0 == recvd)
324 			n = recvmsg(ctrl_chan_fd, &msg, 0);
325 		else
326 			n = read(ctrl_chan_fd, msg.msg_iov[0].iov_base + recvd, len - recvd);
327 		if (n <= 0)
328 			return n;
329 		recvd += n;
330 	}
331 
332 	return recvd;
333 }
334 
cmd_chan_write(int cmd_chan_fd,const uint8_t * buf,int len)335 static int cmd_chan_write(int cmd_chan_fd, const uint8_t *buf, int len)
336 {
337 	ssize_t	 nwritten = 0;
338 	int buffer_length = len;
339 
340 	if (!buf) {
341 		pr_err("%s error, invalid input\n", __func__);
342 		return -1;
343 	}
344 
345 	while (buffer_length > 0) {
346 		nwritten = write(cmd_chan_fd, buf, buffer_length);
347 		if (nwritten >= 0) {
348 			buffer_length -= nwritten;
349 			buf += nwritten;
350 		}
351 		else {
352 			pr_err("cmd_chan_write: Error, write() %d %s\n",
353 					  errno, strerror(errno));
354 			return -1;
355 		}
356 	}
357 
358 	return (len - buffer_length);
359 }
360 
cmd_chan_read(int cmd_chan_fd,uint8_t * buf,int len)361 static int cmd_chan_read(int cmd_chan_fd, uint8_t *buf, int len)
362 {
363 	ssize_t nread = 0;
364 	size_t nleft = len;
365 
366 	if (!buf) {
367 		pr_err("%s error, invalid input\n", __func__);
368 		return -1;
369 	}
370 
371 	while (nleft > 0) {
372 		nread = read(cmd_chan_fd, buf, nleft);
373 		if (nread > 0) {
374 			nleft -= nread;
375 			buf += nread;
376 		}
377 		else if (nread < 0) {/* error */
378 			pr_err("cmd_chan_read: Error, read() error %d %s\n",
379 				   errno, strerror(errno));
380 			return -1;
381 		}
382 		else if (nread == 0) {/* EOF */
383 			pr_err("cmd_chan_read: Error, read EOF, read %lu bytes\n",
384 				   (unsigned long)(len - nleft));
385 			return -1;
386 		}
387 	}
388 
389 	return (len - nleft);
390 }
391 
392 /*
393  * Send command to swtpm ctrl channel.
394  * Note: Both msg_len_in & msg_len_out are valid and needed.
395  * It has 2 cases as below:
396  * 1. msg_len_in is equal to msg_len_out, all are valid and bigger than 0.
397  * 2. msg_len_in is 0, while msg_len_out is bigger than 0.
398  * msg_len_out should always >0 because it need to return "ptm_res"
399  * as the return value(which to indicate pass or fail) at least.
400  */
swtpm_ctrlcmd(int ctrl_chan_fd,unsigned long cmd,void * msg,size_t msg_len_in,size_t msg_len_out,int * pdatafd,int fd_num)401 static int swtpm_ctrlcmd(int ctrl_chan_fd, unsigned long cmd, void *msg,
402 			size_t msg_len_in, size_t msg_len_out,
403 			int *pdatafd, int fd_num)
404 {
405 	uint32_t cmd_no = __builtin_bswap32(cmd);
406 	ssize_t n = sizeof(uint32_t) + msg_len_in;
407 	uint8_t *buf = NULL;
408 	int ret = -1;
409 	int send_num;
410 	int recv_num;
411 
412 	if (!msg || (!pdatafd && fd_num)) {
413 		pr_err("%s error, invalid input\n", __func__);
414 		return -1;
415 	}
416 
417 	buf = calloc(n, sizeof(char));
418 	if (!buf)
419 		return -1;
420 
421 	memcpy(buf, &cmd_no, sizeof(cmd_no));
422 	memcpy(buf + sizeof(cmd_no), msg, msg_len_in);
423 
424 	send_num = ctrl_chan_write(ctrl_chan_fd, buf, n, pdatafd, fd_num);
425 	if ((send_num <= 0) || (send_num != n) ) {
426 		pr_err("%s failed to write %d != %ld\n", __func__, send_num, n);
427 		goto end;
428 	}
429 
430 	if (msg_len_out != 0) {
431 		recv_num = ctrl_chan_read(ctrl_chan_fd, msg, msg_len_out);
432 		if ((recv_num <= 0) || (recv_num != msg_len_out)) {
433 			pr_err("%s failed to read %d != %ld\n", __func__, recv_num, msg_len_out);
434 			goto end;
435 		}
436 	}
437 
438 	ret = 0;
439 
440 end:
441 	free(buf);
442 	return ret;
443 }
444 
445 /*
446  * Send command to swtpm cmd channel.
447  * Note: out_len should be needed.
448  * Currently swtpm_cmdcmd will only be called by swtpm_handle_request
449  * to deliver the real tpm2 commands. And in crb_reg_write, we can
450  * find that out_len is set as (4096-0x80) which is the maximum size
451  * according to TPM2 spec. So inside function swtpm_cmdcmd,
452  * it need to Check" tpm_cmd_get_size(out)>out_len".
453  */
swtpm_cmdcmd(int cmd_chan_fd,const uint8_t * in,uint32_t in_len,uint8_t * out,uint32_t out_len,bool * selftest_done)454 static int swtpm_cmdcmd(int cmd_chan_fd,
455 			const uint8_t *in, uint32_t in_len,
456 			uint8_t *out, uint32_t out_len, bool *selftest_done)
457 {
458 	ssize_t ret;
459 	bool is_selftest = false;
460 
461 	if (!in || !out) {
462 		pr_err("%s error, invalid input\n", __func__);
463 		return -1;
464 	}
465 
466 	if (selftest_done) {
467 		*selftest_done = false;
468 		is_selftest = tpm_is_selftest(in, in_len);
469 	}
470 
471 	ret = cmd_chan_write(cmd_chan_fd, (uint8_t *)in, in_len);
472 	if ((ret == -1) || (ret != in_len)) {
473 		pr_err("%s failed to write %ld != %d\n", __func__, ret, in_len);
474 		return -1;
475 	}
476 
477 	ret = cmd_chan_read(cmd_chan_fd, (uint8_t *)out,
478 			  sizeof(tpm_output_header));
479 	if (ret == -1) {
480 		pr_err("%s failed to read size\n", __func__);
481 		return -1;
482 	}
483 
484 	if (tpm_cmd_get_size(out) > out_len) {
485 		pr_err("%s error, get out size is larger than out_len\n", __func__);
486 		return -1;
487 	}
488 
489 	ret = cmd_chan_read(cmd_chan_fd,
490 				(uint8_t *)out + sizeof(tpm_output_header),
491 				tpm_cmd_get_size(out) - sizeof(tpm_output_header));
492 	if (ret == -1) {
493 		pr_err("%s failed to read data\n", __func__);
494 		return -1;
495 	}
496 
497 	if (is_selftest) {
498 		*selftest_done = tpm_cmd_get_errcode(out) == 0;
499 	}
500 
501 	return 0;
502 }
503 
504 /*
505  * Create ctrl channel.
506  */
swtpm_ctrlchan_create(const char * arg_path)507 static int swtpm_ctrlchan_create(const char *arg_path)
508 {
509 	int connfd;
510 
511 	if (!arg_path) {
512 		pr_err("%s error, invalid input\n", __func__);
513 		return -1;
514 	}
515 
516 	connfd = ctrl_chan_conn(arg_path);
517 	if(connfd < 0)
518 	{
519 		pr_err("Error[%d] when connecting...", errno);
520 		return -1;
521 	}
522 
523 	tpm_context.ctrl_chan_fd = connfd;
524 
525 	return connfd;
526 }
527 
528 /*
529  * Create cmd channel.
530  */
swtpm_cmdchan_create(void)531 static int swtpm_cmdchan_create(void)
532 {
533 	ptm_res res = 0;
534 	int sv[2] = {-1, -1};
535 
536 	if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv) < 0)
537 	{
538 		pr_err("socketpair failed!\n");
539 		return -1;
540 	}
541 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_SET_DATAFD, &res, 0,
542 				 sizeof(res), &sv[1], 1) < 0 || res != 0) {
543 		pr_err("swtpm: Failed to send CMD_SET_DATAFD: %s", strerror(errno));
544 		goto err_exit;
545 	}
546 	tpm_context.cmd_chan_fd = sv[0];
547 	close(sv[1]);
548 
549 	return 0;
550 
551 err_exit:
552 	close(sv[0]);
553 	close(sv[1]);
554 	return -1;
555 }
556 
swtpm_start(ptm_init * init)557 static int swtpm_start(ptm_init *init)
558 {
559 	ptm_res res;
560 
561 	if (!init) {
562 		pr_err("%s error, invalid input\n", __func__);
563 		return -1;
564 	}
565 
566 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_INIT,
567 				init, sizeof(*init), sizeof(*init), NULL, 0) < 0) {
568 		pr_err("swtpm: could not send INIT: %s", strerror(errno));
569 		goto err_exit;
570 	}
571 
572 	res = __builtin_bswap32(init->u.resp.tpm_result);
573 	if (res) {
574 		pr_err("swtpm: TPM result for CMD_INIT: 0x%x", res);
575 		goto err_exit;
576 	}
577 
578 	return 0;
579 
580 err_exit:
581 	return -1;
582 }
583 
swtpm_stop(void)584 static int swtpm_stop(void)
585 {
586 	ptm_res res = 0;
587 
588 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_STOP, &res, 0, sizeof(res), NULL, 0) < 0) {
589 		pr_err("swtpm: Could not stop TPM: %s", strerror(errno));
590 		return -1;
591 	}
592 
593 	res = __builtin_bswap32(res);
594 	if (res) {
595 		pr_err("swtpm: TPM result for CMD_STOP: 0x%x", res);
596 		return -1;
597 	}
598 
599 	return 0;
600 }
601 
602 /* wanted_size: input, the buffer size which we want to setup.
603  * actual_size: output, the actual buffer size returned after setup.
604  *
605  * Note: To meet swtpm logic, swtpm_stop() must be called before
606  *    swtpm_set_buffer_size()
607  */
swtpm_set_buffer_size(size_t wanted_size,size_t * actual_size)608 static int swtpm_set_buffer_size(size_t wanted_size,
609 					size_t *actual_size)
610 {
611 	ptm_setbuffersize psbs;
612 
613 	if (wanted_size == 0) {
614 		pr_err("%s error, wanted_size is 0\n", __func__);
615 		return -1;
616 	}
617 
618 	psbs.u.req.buffersize = __builtin_bswap32(wanted_size);
619 
620 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_SET_BUFFERSIZE, &psbs,
621 			 sizeof(psbs.u.req), sizeof(psbs.u.resp), NULL, 0) < 0) {
622 		pr_err("swtpm: Could not set buffer size: %s", strerror(errno));
623 		return -1;
624 	}
625 
626 	psbs.u.resp.tpm_result = __builtin_bswap32(psbs.u.resp.tpm_result);
627 	if (psbs.u.resp.tpm_result != 0) {
628 		pr_err("swtpm: TPM result for set buffer size : 0x%x", psbs.u.resp.tpm_result);
629 		return -1;
630 	}
631 
632 	if (actual_size) {
633 		*actual_size = __builtin_bswap32(psbs.u.resp.buffersize);
634 	}
635 
636 	return 0;
637 }
638 
swtpm_startup_tpm(size_t buffersize,bool is_resume)639 static int swtpm_startup_tpm(size_t buffersize,
640 				bool is_resume)
641 {
642 	ptm_init init = {
643 		.u.req.init_flags = 0,
644 	};
645 
646 	if (swtpm_stop() < 0) {
647 		pr_err("swtpm_stop() failed!\n");
648 		return -1;
649 	}
650 
651 	if (buffersize != 0 &&
652 		swtpm_set_buffer_size(buffersize, NULL) < 0) {
653 		return -1;
654 	}
655 
656 	if (is_resume) {
657 		init.u.req.init_flags |= __builtin_bswap32(PTM_INIT_FLAG_DELETE_VOLATILE);
658 	}
659 
660 	return swtpm_start(&init);
661 }
662 
swtpm_shutdown(void)663 static void swtpm_shutdown(void)
664 {
665 	ptm_res res = 0;
666 
667 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_SHUTDOWN,
668 				&res, 0, sizeof(res), NULL, 0) < 0) {
669 		pr_err("swtpm: Could not cleanly shutdown the TPM: %s", strerror(errno));
670 	} else if (res != 0) {
671 		pr_err("swtpm: TPM result for sutdown: 0x%x", __builtin_bswap32(res));
672 	}
673 }
674 
swtpm_set_locality(uint8_t locty_number)675 static int swtpm_set_locality(uint8_t locty_number)
676 {
677 	ptm_loc loc;
678 
679 	if (tpm_context.cur_locty_number == locty_number)
680 		return 0;
681 
682 	loc.u.req.loc = locty_number;
683 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_SET_LOCALITY, &loc,
684 							 sizeof(loc), sizeof(loc), NULL, 0) < 0) {
685 		pr_err("swtpm: could not set locality : %s", strerror(errno));
686 		return -1;
687 	}
688 
689 	loc.u.resp.tpm_result = __builtin_bswap32(loc.u.resp.tpm_result);
690 	if (loc.u.resp.tpm_result != 0) {
691 		pr_err("swtpm: TPM result for set locality : 0x%x", loc.u.resp.tpm_result);
692 		return -1;
693 	}
694 
695 	tpm_context.cur_locty_number = locty_number;
696 
697 	return 0;
698 }
699 
swtpm_write_fatal_error_response(uint8_t * out,uint32_t out_len)700 static void swtpm_write_fatal_error_response(uint8_t *out, uint32_t out_len)
701 {
702 	if (!out) {
703 		pr_err("%s error, invalid input.\n", __func__);
704 		return;
705 	}
706 
707 	if (out_len >= sizeof(tpm_output_header)) {
708 		tpm_cmd_set_tag(out, TPM_TAG_RSP_COMMAND);
709 		tpm_cmd_set_size(out, sizeof(tpm_output_header));
710 		tpm_cmd_set_error(out, TPM_FAIL);
711 	}
712 }
713 
swtpm_cleanup(void)714 static void swtpm_cleanup(void)
715 {
716 	swtpm_shutdown();
717 	close(tpm_context.cmd_chan_fd);
718 	close(tpm_context.ctrl_chan_fd);
719 }
720 
swtpm_startup(size_t buffersize)721 int swtpm_startup(size_t buffersize)
722 {
723 	return swtpm_startup_tpm(buffersize, false);
724 }
725 
swtpm_handle_request(TPMCommBuffer * cmd)726 int swtpm_handle_request(TPMCommBuffer *cmd)
727 {
728 	if (!cmd) {
729 		pr_err("%s error, invalid input.\n", __func__);
730 		return -1;
731 	}
732 
733 	if (swtpm_set_locality(cmd->locty) < 0 ||
734 		swtpm_cmdcmd(tpm_context.cmd_chan_fd, cmd->in, cmd->in_len,
735 				cmd->out, cmd->out_len,
736 				&cmd->selftest_done) < 0) {
737 		swtpm_write_fatal_error_response(cmd->out, cmd->out_len);
738 		return -1;
739 	}
740 
741 	return 0;
742 }
743 
swtpm_get_tpm_established_flag(void)744 bool swtpm_get_tpm_established_flag(void)
745 {
746 	ptm_est est;
747 
748 	if (tpm_context.established_flag_cached) {
749 		return tpm_context.established_flag;
750 	}
751 
752 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_GET_TPMESTABLISHED, &est,
753 				0, sizeof(est), NULL, 0) < 0) {
754 		pr_err("swtpm: Could not get the TPM established flag: %s", strerror(errno));
755 		return false;
756 	}
757 
758 	tpm_context.established_flag_cached = 1;
759 	tpm_context.established_flag = (est.u.resp.bit != 0);
760 
761 	return tpm_context.established_flag;
762 }
763 
swtpm_cancel_cmd(void)764 void swtpm_cancel_cmd(void)
765 {
766 	ptm_res res = 0;
767 
768 	if (swtpm_ctrlcmd(tpm_context.ctrl_chan_fd, CMD_CANCEL_TPM_CMD, &res, 0,
769 				sizeof(res), NULL, 0) < 0) {
770 		pr_err("swtpm: Could not cancel command: %s", strerror(errno));
771 	} else if (res != 0) {
772 		pr_err("swtpm: Failed to cancel TPM: 0x%x", __builtin_bswap32(res));
773 	}
774 }
775 
init_tpm_emulator(const char * sock_path)776 int init_tpm_emulator(const char *sock_path)
777 {
778 	if (swtpm_ctrlchan_create(sock_path) < 0) {
779 		pr_err("error, failed to create ctrl channel.\n");
780 		return -1;
781 	}
782 
783 	if (swtpm_cmdchan_create() < 0) {
784 		pr_err("error, failed to create cmd channel.\n");
785 		return -1;
786 	}
787 
788 	return 0;
789 }
790 
deinit_tpm_emulator(void)791 void deinit_tpm_emulator(void)
792 {
793 	swtpm_cleanup();
794 }
795