1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/dma-map-ops.h>
5 #include <linux/mm.h>
6 #include <linux/nospec.h>
7 #include <linux/io_uring.h>
8 #include <linux/netdevice.h>
9 #include <linux/rtnetlink.h>
10 #include <linux/skbuff_ref.h>
11
12 #include <net/page_pool/helpers.h>
13 #include <net/page_pool/memory_provider.h>
14 #include <net/netlink.h>
15 #include <net/netdev_rx_queue.h>
16 #include <net/tcp.h>
17 #include <net/rps.h>
18
19 #include <trace/events/page_pool.h>
20
21 #include <uapi/linux/io_uring.h>
22
23 #include "io_uring.h"
24 #include "kbuf.h"
25 #include "memmap.h"
26 #include "zcrx.h"
27 #include "rsrc.h"
28
29 #define IO_DMA_ATTR (DMA_ATTR_SKIP_CPU_SYNC | DMA_ATTR_WEAK_ORDERING)
30
io_pp_to_ifq(struct page_pool * pp)31 static inline struct io_zcrx_ifq *io_pp_to_ifq(struct page_pool *pp)
32 {
33 return pp->mp_priv;
34 }
35
io_zcrx_iov_to_area(const struct net_iov * niov)36 static inline struct io_zcrx_area *io_zcrx_iov_to_area(const struct net_iov *niov)
37 {
38 struct net_iov_area *owner = net_iov_owner(niov);
39
40 return container_of(owner, struct io_zcrx_area, nia);
41 }
42
io_zcrx_iov_page(const struct net_iov * niov)43 static inline struct page *io_zcrx_iov_page(const struct net_iov *niov)
44 {
45 struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
46
47 lockdep_assert(!area->mem.is_dmabuf);
48
49 return area->mem.pages[net_iov_idx(niov)];
50 }
51
io_populate_area_dma(struct io_zcrx_ifq * ifq,struct io_zcrx_area * area,struct sg_table * sgt,unsigned long off)52 static int io_populate_area_dma(struct io_zcrx_ifq *ifq,
53 struct io_zcrx_area *area,
54 struct sg_table *sgt, unsigned long off)
55 {
56 struct scatterlist *sg;
57 unsigned i, niov_idx = 0;
58
59 for_each_sgtable_dma_sg(sgt, sg, i) {
60 dma_addr_t dma = sg_dma_address(sg);
61 unsigned long sg_len = sg_dma_len(sg);
62 unsigned long sg_off = min(sg_len, off);
63
64 off -= sg_off;
65 sg_len -= sg_off;
66 dma += sg_off;
67
68 while (sg_len && niov_idx < area->nia.num_niovs) {
69 struct net_iov *niov = &area->nia.niovs[niov_idx];
70
71 if (net_mp_niov_set_dma_addr(niov, dma))
72 return -EFAULT;
73 sg_len -= PAGE_SIZE;
74 dma += PAGE_SIZE;
75 niov_idx++;
76 }
77 }
78 return 0;
79 }
80
io_release_dmabuf(struct io_zcrx_mem * mem)81 static void io_release_dmabuf(struct io_zcrx_mem *mem)
82 {
83 if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
84 return;
85
86 if (mem->sgt)
87 dma_buf_unmap_attachment_unlocked(mem->attach, mem->sgt,
88 DMA_FROM_DEVICE);
89 if (mem->attach)
90 dma_buf_detach(mem->dmabuf, mem->attach);
91 if (mem->dmabuf)
92 dma_buf_put(mem->dmabuf);
93
94 mem->sgt = NULL;
95 mem->attach = NULL;
96 mem->dmabuf = NULL;
97 }
98
io_import_dmabuf(struct io_zcrx_ifq * ifq,struct io_zcrx_mem * mem,struct io_uring_zcrx_area_reg * area_reg)99 static int io_import_dmabuf(struct io_zcrx_ifq *ifq,
100 struct io_zcrx_mem *mem,
101 struct io_uring_zcrx_area_reg *area_reg)
102 {
103 unsigned long off = (unsigned long)area_reg->addr;
104 unsigned long len = (unsigned long)area_reg->len;
105 unsigned long total_size = 0;
106 struct scatterlist *sg;
107 int dmabuf_fd = area_reg->dmabuf_fd;
108 int i, ret;
109
110 if (off)
111 return -EINVAL;
112 if (WARN_ON_ONCE(!ifq->dev))
113 return -EFAULT;
114 if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
115 return -EINVAL;
116
117 mem->is_dmabuf = true;
118 mem->dmabuf = dma_buf_get(dmabuf_fd);
119 if (IS_ERR(mem->dmabuf)) {
120 ret = PTR_ERR(mem->dmabuf);
121 mem->dmabuf = NULL;
122 goto err;
123 }
124
125 mem->attach = dma_buf_attach(mem->dmabuf, ifq->dev);
126 if (IS_ERR(mem->attach)) {
127 ret = PTR_ERR(mem->attach);
128 mem->attach = NULL;
129 goto err;
130 }
131
132 mem->sgt = dma_buf_map_attachment_unlocked(mem->attach, DMA_FROM_DEVICE);
133 if (IS_ERR(mem->sgt)) {
134 ret = PTR_ERR(mem->sgt);
135 mem->sgt = NULL;
136 goto err;
137 }
138
139 for_each_sgtable_dma_sg(mem->sgt, sg, i)
140 total_size += sg_dma_len(sg);
141
142 if (total_size != len) {
143 ret = -EINVAL;
144 goto err;
145 }
146
147 mem->dmabuf_offset = off;
148 mem->size = len;
149 return 0;
150 err:
151 io_release_dmabuf(mem);
152 return ret;
153 }
154
io_zcrx_map_area_dmabuf(struct io_zcrx_ifq * ifq,struct io_zcrx_area * area)155 static int io_zcrx_map_area_dmabuf(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
156 {
157 if (!IS_ENABLED(CONFIG_DMA_SHARED_BUFFER))
158 return -EINVAL;
159 return io_populate_area_dma(ifq, area, area->mem.sgt,
160 area->mem.dmabuf_offset);
161 }
162
io_count_account_pages(struct page ** pages,unsigned nr_pages)163 static unsigned long io_count_account_pages(struct page **pages, unsigned nr_pages)
164 {
165 struct folio *last_folio = NULL;
166 unsigned long res = 0;
167 int i;
168
169 for (i = 0; i < nr_pages; i++) {
170 struct folio *folio = page_folio(pages[i]);
171
172 if (folio == last_folio)
173 continue;
174 last_folio = folio;
175 res += 1UL << folio_order(folio);
176 }
177 return res;
178 }
179
io_import_umem(struct io_zcrx_ifq * ifq,struct io_zcrx_mem * mem,struct io_uring_zcrx_area_reg * area_reg)180 static int io_import_umem(struct io_zcrx_ifq *ifq,
181 struct io_zcrx_mem *mem,
182 struct io_uring_zcrx_area_reg *area_reg)
183 {
184 struct page **pages;
185 int nr_pages, ret;
186
187 if (area_reg->dmabuf_fd)
188 return -EINVAL;
189 if (!area_reg->addr)
190 return -EFAULT;
191 pages = io_pin_pages((unsigned long)area_reg->addr, area_reg->len,
192 &nr_pages);
193 if (IS_ERR(pages))
194 return PTR_ERR(pages);
195
196 ret = sg_alloc_table_from_pages(&mem->page_sg_table, pages, nr_pages,
197 0, nr_pages << PAGE_SHIFT,
198 GFP_KERNEL_ACCOUNT);
199 if (ret) {
200 unpin_user_pages(pages, nr_pages);
201 return ret;
202 }
203
204 mem->account_pages = io_count_account_pages(pages, nr_pages);
205 ret = io_account_mem(ifq->ctx, mem->account_pages);
206 if (ret < 0)
207 mem->account_pages = 0;
208
209 mem->pages = pages;
210 mem->nr_folios = nr_pages;
211 mem->size = area_reg->len;
212 return ret;
213 }
214
io_release_area_mem(struct io_zcrx_mem * mem)215 static void io_release_area_mem(struct io_zcrx_mem *mem)
216 {
217 if (mem->is_dmabuf) {
218 io_release_dmabuf(mem);
219 return;
220 }
221 if (mem->pages) {
222 unpin_user_pages(mem->pages, mem->nr_folios);
223 sg_free_table(&mem->page_sg_table);
224 kvfree(mem->pages);
225 }
226 }
227
io_import_area(struct io_zcrx_ifq * ifq,struct io_zcrx_mem * mem,struct io_uring_zcrx_area_reg * area_reg)228 static int io_import_area(struct io_zcrx_ifq *ifq,
229 struct io_zcrx_mem *mem,
230 struct io_uring_zcrx_area_reg *area_reg)
231 {
232 int ret;
233
234 ret = io_validate_user_buf_range(area_reg->addr, area_reg->len);
235 if (ret)
236 return ret;
237 if (area_reg->addr & ~PAGE_MASK || area_reg->len & ~PAGE_MASK)
238 return -EINVAL;
239
240 if (area_reg->flags & IORING_ZCRX_AREA_DMABUF)
241 return io_import_dmabuf(ifq, mem, area_reg);
242 return io_import_umem(ifq, mem, area_reg);
243 }
244
io_zcrx_unmap_area(struct io_zcrx_ifq * ifq,struct io_zcrx_area * area)245 static void io_zcrx_unmap_area(struct io_zcrx_ifq *ifq,
246 struct io_zcrx_area *area)
247 {
248 int i;
249
250 guard(mutex)(&ifq->dma_lock);
251 if (!area->is_mapped)
252 return;
253 area->is_mapped = false;
254
255 for (i = 0; i < area->nia.num_niovs; i++)
256 net_mp_niov_set_dma_addr(&area->nia.niovs[i], 0);
257
258 if (area->mem.is_dmabuf) {
259 io_release_dmabuf(&area->mem);
260 } else {
261 dma_unmap_sgtable(ifq->dev, &area->mem.page_sg_table,
262 DMA_FROM_DEVICE, IO_DMA_ATTR);
263 }
264 }
265
io_zcrx_map_area_umem(struct io_zcrx_ifq * ifq,struct io_zcrx_area * area)266 static unsigned io_zcrx_map_area_umem(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
267 {
268 int ret;
269
270 ret = dma_map_sgtable(ifq->dev, &area->mem.page_sg_table,
271 DMA_FROM_DEVICE, IO_DMA_ATTR);
272 if (ret < 0)
273 return ret;
274 return io_populate_area_dma(ifq, area, &area->mem.page_sg_table, 0);
275 }
276
io_zcrx_map_area(struct io_zcrx_ifq * ifq,struct io_zcrx_area * area)277 static int io_zcrx_map_area(struct io_zcrx_ifq *ifq, struct io_zcrx_area *area)
278 {
279 int ret;
280
281 guard(mutex)(&ifq->dma_lock);
282 if (area->is_mapped)
283 return 0;
284
285 if (area->mem.is_dmabuf)
286 ret = io_zcrx_map_area_dmabuf(ifq, area);
287 else
288 ret = io_zcrx_map_area_umem(ifq, area);
289
290 if (ret == 0)
291 area->is_mapped = true;
292 return ret;
293 }
294
io_zcrx_sync_for_device(const struct page_pool * pool,struct net_iov * niov)295 static void io_zcrx_sync_for_device(const struct page_pool *pool,
296 struct net_iov *niov)
297 {
298 #if defined(CONFIG_HAS_DMA) && defined(CONFIG_DMA_NEED_SYNC)
299 dma_addr_t dma_addr;
300
301 if (!dma_dev_need_sync(pool->p.dev))
302 return;
303
304 dma_addr = page_pool_get_dma_addr_netmem(net_iov_to_netmem(niov));
305 __dma_sync_single_for_device(pool->p.dev, dma_addr + pool->p.offset,
306 PAGE_SIZE, pool->p.dma_dir);
307 #endif
308 }
309
310 #define IO_RQ_MAX_ENTRIES 32768
311
312 #define IO_SKBS_PER_CALL_LIMIT 20
313
314 struct io_zcrx_args {
315 struct io_kiocb *req;
316 struct io_zcrx_ifq *ifq;
317 struct socket *sock;
318 unsigned nr_skbs;
319 };
320
321 static const struct memory_provider_ops io_uring_pp_zc_ops;
322
io_get_user_counter(struct net_iov * niov)323 static inline atomic_t *io_get_user_counter(struct net_iov *niov)
324 {
325 struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
326
327 return &area->user_refs[net_iov_idx(niov)];
328 }
329
io_zcrx_put_niov_uref(struct net_iov * niov)330 static bool io_zcrx_put_niov_uref(struct net_iov *niov)
331 {
332 atomic_t *uref = io_get_user_counter(niov);
333
334 if (unlikely(!atomic_read(uref)))
335 return false;
336 atomic_dec(uref);
337 return true;
338 }
339
io_zcrx_get_niov_uref(struct net_iov * niov)340 static void io_zcrx_get_niov_uref(struct net_iov *niov)
341 {
342 atomic_inc(io_get_user_counter(niov));
343 }
344
io_allocate_rbuf_ring(struct io_zcrx_ifq * ifq,struct io_uring_zcrx_ifq_reg * reg,struct io_uring_region_desc * rd,u32 id)345 static int io_allocate_rbuf_ring(struct io_zcrx_ifq *ifq,
346 struct io_uring_zcrx_ifq_reg *reg,
347 struct io_uring_region_desc *rd,
348 u32 id)
349 {
350 u64 mmap_offset;
351 size_t off, size;
352 void *ptr;
353 int ret;
354
355 off = sizeof(struct io_uring);
356 size = off + sizeof(struct io_uring_zcrx_rqe) * reg->rq_entries;
357 if (size > rd->size)
358 return -EINVAL;
359
360 mmap_offset = IORING_MAP_OFF_ZCRX_REGION;
361 mmap_offset += id << IORING_OFF_PBUF_SHIFT;
362
363 ret = io_create_region(ifq->ctx, &ifq->region, rd, mmap_offset);
364 if (ret < 0)
365 return ret;
366
367 ptr = io_region_get_ptr(&ifq->region);
368 ifq->rq_ring = (struct io_uring *)ptr;
369 ifq->rqes = (struct io_uring_zcrx_rqe *)(ptr + off);
370 return 0;
371 }
372
io_free_rbuf_ring(struct io_zcrx_ifq * ifq)373 static void io_free_rbuf_ring(struct io_zcrx_ifq *ifq)
374 {
375 io_free_region(ifq->ctx, &ifq->region);
376 ifq->rq_ring = NULL;
377 ifq->rqes = NULL;
378 }
379
io_zcrx_free_area(struct io_zcrx_area * area)380 static void io_zcrx_free_area(struct io_zcrx_area *area)
381 {
382 io_zcrx_unmap_area(area->ifq, area);
383 io_release_area_mem(&area->mem);
384
385 if (area->mem.account_pages)
386 io_unaccount_mem(area->ifq->ctx, area->mem.account_pages);
387
388 kvfree(area->freelist);
389 kvfree(area->nia.niovs);
390 kvfree(area->user_refs);
391 kfree(area);
392 }
393
394 #define IO_ZCRX_AREA_SUPPORTED_FLAGS (IORING_ZCRX_AREA_DMABUF)
395
io_zcrx_create_area(struct io_zcrx_ifq * ifq,struct io_zcrx_area ** res,struct io_uring_zcrx_area_reg * area_reg)396 static int io_zcrx_create_area(struct io_zcrx_ifq *ifq,
397 struct io_zcrx_area **res,
398 struct io_uring_zcrx_area_reg *area_reg)
399 {
400 struct io_zcrx_area *area;
401 unsigned nr_iovs;
402 int i, ret;
403
404 if (area_reg->flags & ~IO_ZCRX_AREA_SUPPORTED_FLAGS)
405 return -EINVAL;
406 if (area_reg->rq_area_token)
407 return -EINVAL;
408 if (area_reg->__resv2[0] || area_reg->__resv2[1])
409 return -EINVAL;
410
411 ret = -ENOMEM;
412 area = kzalloc(sizeof(*area), GFP_KERNEL);
413 if (!area)
414 goto err;
415 area->ifq = ifq;
416
417 ret = io_import_area(ifq, &area->mem, area_reg);
418 if (ret)
419 goto err;
420
421 nr_iovs = area->mem.size >> PAGE_SHIFT;
422 area->nia.num_niovs = nr_iovs;
423
424 ret = -ENOMEM;
425 area->nia.niovs = kvmalloc_array(nr_iovs, sizeof(area->nia.niovs[0]),
426 GFP_KERNEL | __GFP_ZERO);
427 if (!area->nia.niovs)
428 goto err;
429
430 area->freelist = kvmalloc_array(nr_iovs, sizeof(area->freelist[0]),
431 GFP_KERNEL | __GFP_ZERO);
432 if (!area->freelist)
433 goto err;
434
435 area->user_refs = kvmalloc_array(nr_iovs, sizeof(area->user_refs[0]),
436 GFP_KERNEL | __GFP_ZERO);
437 if (!area->user_refs)
438 goto err;
439
440 for (i = 0; i < nr_iovs; i++) {
441 struct net_iov *niov = &area->nia.niovs[i];
442
443 niov->owner = &area->nia;
444 area->freelist[i] = i;
445 atomic_set(&area->user_refs[i], 0);
446 niov->type = NET_IOV_IOURING;
447 }
448
449 area->free_count = nr_iovs;
450 /* we're only supporting one area per ifq for now */
451 area->area_id = 0;
452 area_reg->rq_area_token = (u64)area->area_id << IORING_ZCRX_AREA_SHIFT;
453 spin_lock_init(&area->freelist_lock);
454 *res = area;
455 return 0;
456 err:
457 if (area)
458 io_zcrx_free_area(area);
459 return ret;
460 }
461
io_zcrx_ifq_alloc(struct io_ring_ctx * ctx)462 static struct io_zcrx_ifq *io_zcrx_ifq_alloc(struct io_ring_ctx *ctx)
463 {
464 struct io_zcrx_ifq *ifq;
465
466 ifq = kzalloc(sizeof(*ifq), GFP_KERNEL);
467 if (!ifq)
468 return NULL;
469
470 ifq->if_rxq = -1;
471 ifq->ctx = ctx;
472 spin_lock_init(&ifq->lock);
473 spin_lock_init(&ifq->rq_lock);
474 mutex_init(&ifq->dma_lock);
475 return ifq;
476 }
477
io_zcrx_drop_netdev(struct io_zcrx_ifq * ifq)478 static void io_zcrx_drop_netdev(struct io_zcrx_ifq *ifq)
479 {
480 spin_lock(&ifq->lock);
481 if (ifq->netdev) {
482 netdev_put(ifq->netdev, &ifq->netdev_tracker);
483 ifq->netdev = NULL;
484 }
485 spin_unlock(&ifq->lock);
486 }
487
io_close_queue(struct io_zcrx_ifq * ifq)488 static void io_close_queue(struct io_zcrx_ifq *ifq)
489 {
490 struct net_device *netdev;
491 netdevice_tracker netdev_tracker;
492 struct pp_memory_provider_params p = {
493 .mp_ops = &io_uring_pp_zc_ops,
494 .mp_priv = ifq,
495 };
496
497 if (ifq->if_rxq == -1)
498 return;
499
500 spin_lock(&ifq->lock);
501 netdev = ifq->netdev;
502 netdev_tracker = ifq->netdev_tracker;
503 ifq->netdev = NULL;
504 spin_unlock(&ifq->lock);
505
506 if (netdev) {
507 net_mp_close_rxq(netdev, ifq->if_rxq, &p);
508 netdev_put(netdev, &netdev_tracker);
509 }
510 ifq->if_rxq = -1;
511 }
512
io_zcrx_ifq_free(struct io_zcrx_ifq * ifq)513 static void io_zcrx_ifq_free(struct io_zcrx_ifq *ifq)
514 {
515 io_close_queue(ifq);
516 io_zcrx_drop_netdev(ifq);
517
518 if (ifq->area)
519 io_zcrx_free_area(ifq->area);
520 if (ifq->dev)
521 put_device(ifq->dev);
522
523 io_free_rbuf_ring(ifq);
524 mutex_destroy(&ifq->dma_lock);
525 kfree(ifq);
526 }
527
io_zcrx_get_region(struct io_ring_ctx * ctx,unsigned int id)528 struct io_mapped_region *io_zcrx_get_region(struct io_ring_ctx *ctx,
529 unsigned int id)
530 {
531 struct io_zcrx_ifq *ifq = xa_load(&ctx->zcrx_ctxs, id);
532
533 lockdep_assert_held(&ctx->mmap_lock);
534
535 return ifq ? &ifq->region : NULL;
536 }
537
io_register_zcrx_ifq(struct io_ring_ctx * ctx,struct io_uring_zcrx_ifq_reg __user * arg)538 int io_register_zcrx_ifq(struct io_ring_ctx *ctx,
539 struct io_uring_zcrx_ifq_reg __user *arg)
540 {
541 struct pp_memory_provider_params mp_param = {};
542 struct io_uring_zcrx_area_reg area;
543 struct io_uring_zcrx_ifq_reg reg;
544 struct io_uring_region_desc rd;
545 struct io_zcrx_ifq *ifq;
546 int ret;
547 u32 id;
548
549 /*
550 * 1. Interface queue allocation.
551 * 2. It can observe data destined for sockets of other tasks.
552 */
553 if (!capable(CAP_NET_ADMIN))
554 return -EPERM;
555
556 /* mandatory io_uring features for zc rx */
557 if (!(ctx->flags & IORING_SETUP_DEFER_TASKRUN &&
558 ctx->flags & IORING_SETUP_CQE32))
559 return -EINVAL;
560 if (copy_from_user(®, arg, sizeof(reg)))
561 return -EFAULT;
562 if (copy_from_user(&rd, u64_to_user_ptr(reg.region_ptr), sizeof(rd)))
563 return -EFAULT;
564 if (memchr_inv(®.__resv, 0, sizeof(reg.__resv)) ||
565 reg.__resv2 || reg.zcrx_id)
566 return -EINVAL;
567 if (reg.if_rxq == -1 || !reg.rq_entries || reg.flags)
568 return -EINVAL;
569 if (reg.rq_entries > IO_RQ_MAX_ENTRIES) {
570 if (!(ctx->flags & IORING_SETUP_CLAMP))
571 return -EINVAL;
572 reg.rq_entries = IO_RQ_MAX_ENTRIES;
573 }
574 reg.rq_entries = roundup_pow_of_two(reg.rq_entries);
575
576 if (copy_from_user(&area, u64_to_user_ptr(reg.area_ptr), sizeof(area)))
577 return -EFAULT;
578
579 ifq = io_zcrx_ifq_alloc(ctx);
580 if (!ifq)
581 return -ENOMEM;
582 ifq->rq_entries = reg.rq_entries;
583
584 scoped_guard(mutex, &ctx->mmap_lock) {
585 /* preallocate id */
586 ret = xa_alloc(&ctx->zcrx_ctxs, &id, NULL, xa_limit_31b, GFP_KERNEL);
587 if (ret)
588 goto ifq_free;
589 }
590
591 ret = io_allocate_rbuf_ring(ifq, ®, &rd, id);
592 if (ret)
593 goto err;
594
595 ifq->netdev = netdev_get_by_index(current->nsproxy->net_ns, reg.if_idx,
596 &ifq->netdev_tracker, GFP_KERNEL);
597 if (!ifq->netdev) {
598 ret = -ENODEV;
599 goto err;
600 }
601
602 ifq->dev = ifq->netdev->dev.parent;
603 if (!ifq->dev) {
604 ret = -EOPNOTSUPP;
605 goto err;
606 }
607 get_device(ifq->dev);
608
609 ret = io_zcrx_create_area(ifq, &ifq->area, &area);
610 if (ret)
611 goto err;
612
613 mp_param.mp_ops = &io_uring_pp_zc_ops;
614 mp_param.mp_priv = ifq;
615 ret = net_mp_open_rxq(ifq->netdev, reg.if_rxq, &mp_param);
616 if (ret)
617 goto err;
618 ifq->if_rxq = reg.if_rxq;
619
620 reg.offsets.rqes = sizeof(struct io_uring);
621 reg.offsets.head = offsetof(struct io_uring, head);
622 reg.offsets.tail = offsetof(struct io_uring, tail);
623 reg.zcrx_id = id;
624
625 scoped_guard(mutex, &ctx->mmap_lock) {
626 /* publish ifq */
627 ret = -ENOMEM;
628 if (xa_store(&ctx->zcrx_ctxs, id, ifq, GFP_KERNEL))
629 goto err;
630 }
631
632 if (copy_to_user(arg, ®, sizeof(reg)) ||
633 copy_to_user(u64_to_user_ptr(reg.region_ptr), &rd, sizeof(rd)) ||
634 copy_to_user(u64_to_user_ptr(reg.area_ptr), &area, sizeof(area))) {
635 ret = -EFAULT;
636 goto err;
637 }
638 return 0;
639 err:
640 scoped_guard(mutex, &ctx->mmap_lock)
641 xa_erase(&ctx->zcrx_ctxs, id);
642 ifq_free:
643 io_zcrx_ifq_free(ifq);
644 return ret;
645 }
646
io_unregister_zcrx_ifqs(struct io_ring_ctx * ctx)647 void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx)
648 {
649 struct io_zcrx_ifq *ifq;
650
651 lockdep_assert_held(&ctx->uring_lock);
652
653 while (1) {
654 scoped_guard(mutex, &ctx->mmap_lock) {
655 unsigned long id = 0;
656
657 ifq = xa_find(&ctx->zcrx_ctxs, &id, ULONG_MAX, XA_PRESENT);
658 if (ifq)
659 xa_erase(&ctx->zcrx_ctxs, id);
660 }
661 if (!ifq)
662 break;
663 io_zcrx_ifq_free(ifq);
664 }
665
666 xa_destroy(&ctx->zcrx_ctxs);
667 }
668
__io_zcrx_get_free_niov(struct io_zcrx_area * area)669 static struct net_iov *__io_zcrx_get_free_niov(struct io_zcrx_area *area)
670 {
671 unsigned niov_idx;
672
673 lockdep_assert_held(&area->freelist_lock);
674
675 niov_idx = area->freelist[--area->free_count];
676 return &area->nia.niovs[niov_idx];
677 }
678
io_zcrx_return_niov_freelist(struct net_iov * niov)679 static void io_zcrx_return_niov_freelist(struct net_iov *niov)
680 {
681 struct io_zcrx_area *area = io_zcrx_iov_to_area(niov);
682
683 spin_lock_bh(&area->freelist_lock);
684 area->freelist[area->free_count++] = net_iov_idx(niov);
685 spin_unlock_bh(&area->freelist_lock);
686 }
687
io_zcrx_return_niov(struct net_iov * niov)688 static void io_zcrx_return_niov(struct net_iov *niov)
689 {
690 netmem_ref netmem = net_iov_to_netmem(niov);
691
692 if (!niov->pp) {
693 /* copy fallback allocated niovs */
694 io_zcrx_return_niov_freelist(niov);
695 return;
696 }
697 page_pool_put_unrefed_netmem(niov->pp, netmem, -1, false);
698 }
699
io_zcrx_scrub(struct io_zcrx_ifq * ifq)700 static void io_zcrx_scrub(struct io_zcrx_ifq *ifq)
701 {
702 struct io_zcrx_area *area = ifq->area;
703 int i;
704
705 if (!area)
706 return;
707
708 /* Reclaim back all buffers given to the user space. */
709 for (i = 0; i < area->nia.num_niovs; i++) {
710 struct net_iov *niov = &area->nia.niovs[i];
711 int nr;
712
713 if (!atomic_read(io_get_user_counter(niov)))
714 continue;
715 nr = atomic_xchg(io_get_user_counter(niov), 0);
716 if (nr && !page_pool_unref_netmem(net_iov_to_netmem(niov), nr))
717 io_zcrx_return_niov(niov);
718 }
719 }
720
io_shutdown_zcrx_ifqs(struct io_ring_ctx * ctx)721 void io_shutdown_zcrx_ifqs(struct io_ring_ctx *ctx)
722 {
723 struct io_zcrx_ifq *ifq;
724 unsigned long index;
725
726 lockdep_assert_held(&ctx->uring_lock);
727
728 xa_for_each(&ctx->zcrx_ctxs, index, ifq) {
729 io_zcrx_scrub(ifq);
730 io_close_queue(ifq);
731 }
732 }
733
io_zcrx_rqring_entries(struct io_zcrx_ifq * ifq)734 static inline u32 io_zcrx_rqring_entries(struct io_zcrx_ifq *ifq)
735 {
736 u32 entries;
737
738 entries = smp_load_acquire(&ifq->rq_ring->tail) - ifq->cached_rq_head;
739 return min(entries, ifq->rq_entries);
740 }
741
io_zcrx_get_rqe(struct io_zcrx_ifq * ifq,unsigned mask)742 static struct io_uring_zcrx_rqe *io_zcrx_get_rqe(struct io_zcrx_ifq *ifq,
743 unsigned mask)
744 {
745 unsigned int idx = ifq->cached_rq_head++ & mask;
746
747 return &ifq->rqes[idx];
748 }
749
io_zcrx_ring_refill(struct page_pool * pp,struct io_zcrx_ifq * ifq)750 static void io_zcrx_ring_refill(struct page_pool *pp,
751 struct io_zcrx_ifq *ifq)
752 {
753 unsigned int mask = ifq->rq_entries - 1;
754 unsigned int entries;
755 netmem_ref netmem;
756
757 spin_lock_bh(&ifq->rq_lock);
758
759 entries = io_zcrx_rqring_entries(ifq);
760 entries = min_t(unsigned, entries, PP_ALLOC_CACHE_REFILL - pp->alloc.count);
761 if (unlikely(!entries)) {
762 spin_unlock_bh(&ifq->rq_lock);
763 return;
764 }
765
766 do {
767 struct io_uring_zcrx_rqe *rqe = io_zcrx_get_rqe(ifq, mask);
768 struct io_zcrx_area *area;
769 struct net_iov *niov;
770 unsigned niov_idx, area_idx;
771
772 area_idx = rqe->off >> IORING_ZCRX_AREA_SHIFT;
773 niov_idx = (rqe->off & ~IORING_ZCRX_AREA_MASK) >> PAGE_SHIFT;
774
775 if (unlikely(rqe->__pad || area_idx))
776 continue;
777 area = ifq->area;
778
779 if (unlikely(niov_idx >= area->nia.num_niovs))
780 continue;
781 niov_idx = array_index_nospec(niov_idx, area->nia.num_niovs);
782
783 niov = &area->nia.niovs[niov_idx];
784 if (!io_zcrx_put_niov_uref(niov))
785 continue;
786
787 netmem = net_iov_to_netmem(niov);
788 if (page_pool_unref_netmem(netmem, 1) != 0)
789 continue;
790
791 if (unlikely(niov->pp != pp)) {
792 io_zcrx_return_niov(niov);
793 continue;
794 }
795
796 io_zcrx_sync_for_device(pp, niov);
797 net_mp_netmem_place_in_cache(pp, netmem);
798 } while (--entries);
799
800 smp_store_release(&ifq->rq_ring->head, ifq->cached_rq_head);
801 spin_unlock_bh(&ifq->rq_lock);
802 }
803
io_zcrx_refill_slow(struct page_pool * pp,struct io_zcrx_ifq * ifq)804 static void io_zcrx_refill_slow(struct page_pool *pp, struct io_zcrx_ifq *ifq)
805 {
806 struct io_zcrx_area *area = ifq->area;
807
808 spin_lock_bh(&area->freelist_lock);
809 while (area->free_count && pp->alloc.count < PP_ALLOC_CACHE_REFILL) {
810 struct net_iov *niov = __io_zcrx_get_free_niov(area);
811 netmem_ref netmem = net_iov_to_netmem(niov);
812
813 net_mp_niov_set_page_pool(pp, niov);
814 io_zcrx_sync_for_device(pp, niov);
815 net_mp_netmem_place_in_cache(pp, netmem);
816 }
817 spin_unlock_bh(&area->freelist_lock);
818 }
819
io_pp_zc_alloc_netmems(struct page_pool * pp,gfp_t gfp)820 static netmem_ref io_pp_zc_alloc_netmems(struct page_pool *pp, gfp_t gfp)
821 {
822 struct io_zcrx_ifq *ifq = io_pp_to_ifq(pp);
823
824 /* pp should already be ensuring that */
825 if (unlikely(pp->alloc.count))
826 goto out_return;
827
828 io_zcrx_ring_refill(pp, ifq);
829 if (likely(pp->alloc.count))
830 goto out_return;
831
832 io_zcrx_refill_slow(pp, ifq);
833 if (!pp->alloc.count)
834 return 0;
835 out_return:
836 return pp->alloc.cache[--pp->alloc.count];
837 }
838
io_pp_zc_release_netmem(struct page_pool * pp,netmem_ref netmem)839 static bool io_pp_zc_release_netmem(struct page_pool *pp, netmem_ref netmem)
840 {
841 struct net_iov *niov;
842
843 if (WARN_ON_ONCE(!netmem_is_net_iov(netmem)))
844 return false;
845
846 niov = netmem_to_net_iov(netmem);
847 net_mp_niov_clear_page_pool(niov);
848 io_zcrx_return_niov_freelist(niov);
849 return false;
850 }
851
io_pp_zc_init(struct page_pool * pp)852 static int io_pp_zc_init(struct page_pool *pp)
853 {
854 struct io_zcrx_ifq *ifq = io_pp_to_ifq(pp);
855 int ret;
856
857 if (WARN_ON_ONCE(!ifq))
858 return -EINVAL;
859 if (WARN_ON_ONCE(ifq->dev != pp->p.dev))
860 return -EINVAL;
861 if (WARN_ON_ONCE(!pp->dma_map))
862 return -EOPNOTSUPP;
863 if (pp->p.order != 0)
864 return -EOPNOTSUPP;
865 if (pp->p.dma_dir != DMA_FROM_DEVICE)
866 return -EOPNOTSUPP;
867
868 ret = io_zcrx_map_area(ifq, ifq->area);
869 if (ret)
870 return ret;
871
872 percpu_ref_get(&ifq->ctx->refs);
873 return 0;
874 }
875
io_pp_zc_destroy(struct page_pool * pp)876 static void io_pp_zc_destroy(struct page_pool *pp)
877 {
878 struct io_zcrx_ifq *ifq = io_pp_to_ifq(pp);
879
880 percpu_ref_put(&ifq->ctx->refs);
881 }
882
io_pp_nl_fill(void * mp_priv,struct sk_buff * rsp,struct netdev_rx_queue * rxq)883 static int io_pp_nl_fill(void *mp_priv, struct sk_buff *rsp,
884 struct netdev_rx_queue *rxq)
885 {
886 struct nlattr *nest;
887 int type;
888
889 type = rxq ? NETDEV_A_QUEUE_IO_URING : NETDEV_A_PAGE_POOL_IO_URING;
890 nest = nla_nest_start(rsp, type);
891 if (!nest)
892 return -EMSGSIZE;
893 nla_nest_end(rsp, nest);
894
895 return 0;
896 }
897
io_pp_uninstall(void * mp_priv,struct netdev_rx_queue * rxq)898 static void io_pp_uninstall(void *mp_priv, struct netdev_rx_queue *rxq)
899 {
900 struct pp_memory_provider_params *p = &rxq->mp_params;
901 struct io_zcrx_ifq *ifq = mp_priv;
902
903 io_zcrx_drop_netdev(ifq);
904 if (ifq->area)
905 io_zcrx_unmap_area(ifq, ifq->area);
906
907 p->mp_ops = NULL;
908 p->mp_priv = NULL;
909 }
910
911 static const struct memory_provider_ops io_uring_pp_zc_ops = {
912 .alloc_netmems = io_pp_zc_alloc_netmems,
913 .release_netmem = io_pp_zc_release_netmem,
914 .init = io_pp_zc_init,
915 .destroy = io_pp_zc_destroy,
916 .nl_fill = io_pp_nl_fill,
917 .uninstall = io_pp_uninstall,
918 };
919
io_zcrx_queue_cqe(struct io_kiocb * req,struct net_iov * niov,struct io_zcrx_ifq * ifq,int off,int len)920 static bool io_zcrx_queue_cqe(struct io_kiocb *req, struct net_iov *niov,
921 struct io_zcrx_ifq *ifq, int off, int len)
922 {
923 struct io_uring_zcrx_cqe *rcqe;
924 struct io_zcrx_area *area;
925 struct io_uring_cqe *cqe;
926 u64 offset;
927
928 if (!io_defer_get_uncommited_cqe(req->ctx, &cqe))
929 return false;
930
931 cqe->user_data = req->cqe.user_data;
932 cqe->res = len;
933 cqe->flags = IORING_CQE_F_MORE;
934
935 area = io_zcrx_iov_to_area(niov);
936 offset = off + (net_iov_idx(niov) << PAGE_SHIFT);
937 rcqe = (struct io_uring_zcrx_cqe *)(cqe + 1);
938 rcqe->off = offset + ((u64)area->area_id << IORING_ZCRX_AREA_SHIFT);
939 rcqe->__pad = 0;
940 return true;
941 }
942
io_zcrx_alloc_fallback(struct io_zcrx_area * area)943 static struct net_iov *io_zcrx_alloc_fallback(struct io_zcrx_area *area)
944 {
945 struct net_iov *niov = NULL;
946
947 spin_lock_bh(&area->freelist_lock);
948 if (area->free_count)
949 niov = __io_zcrx_get_free_niov(area);
950 spin_unlock_bh(&area->freelist_lock);
951
952 if (niov)
953 page_pool_fragment_netmem(net_iov_to_netmem(niov), 1);
954 return niov;
955 }
956
957 struct io_copy_cache {
958 struct page *page;
959 unsigned long offset;
960 size_t size;
961 };
962
io_copy_page(struct io_copy_cache * cc,struct page * src_page,unsigned int src_offset,size_t len)963 static ssize_t io_copy_page(struct io_copy_cache *cc, struct page *src_page,
964 unsigned int src_offset, size_t len)
965 {
966 size_t copied = 0;
967
968 len = min(len, cc->size);
969
970 while (len) {
971 void *src_addr, *dst_addr;
972 struct page *dst_page = cc->page;
973 unsigned dst_offset = cc->offset;
974 size_t n = len;
975
976 if (folio_test_partial_kmap(page_folio(dst_page)) ||
977 folio_test_partial_kmap(page_folio(src_page))) {
978 dst_page = nth_page(dst_page, dst_offset / PAGE_SIZE);
979 dst_offset = offset_in_page(dst_offset);
980 src_page = nth_page(src_page, src_offset / PAGE_SIZE);
981 src_offset = offset_in_page(src_offset);
982 n = min(PAGE_SIZE - src_offset, PAGE_SIZE - dst_offset);
983 n = min(n, len);
984 }
985
986 dst_addr = kmap_local_page(dst_page) + dst_offset;
987 src_addr = kmap_local_page(src_page) + src_offset;
988
989 memcpy(dst_addr, src_addr, n);
990
991 kunmap_local(src_addr);
992 kunmap_local(dst_addr);
993
994 cc->size -= n;
995 cc->offset += n;
996 len -= n;
997 copied += n;
998 }
999 return copied;
1000 }
1001
io_zcrx_copy_chunk(struct io_kiocb * req,struct io_zcrx_ifq * ifq,struct page * src_page,unsigned int src_offset,size_t len)1002 static ssize_t io_zcrx_copy_chunk(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1003 struct page *src_page, unsigned int src_offset,
1004 size_t len)
1005 {
1006 struct io_zcrx_area *area = ifq->area;
1007 size_t copied = 0;
1008 int ret = 0;
1009
1010 if (area->mem.is_dmabuf)
1011 return -EFAULT;
1012
1013 while (len) {
1014 struct io_copy_cache cc;
1015 struct net_iov *niov;
1016 size_t n;
1017
1018 niov = io_zcrx_alloc_fallback(area);
1019 if (!niov) {
1020 ret = -ENOMEM;
1021 break;
1022 }
1023
1024 cc.page = io_zcrx_iov_page(niov);
1025 cc.offset = 0;
1026 cc.size = PAGE_SIZE;
1027
1028 n = io_copy_page(&cc, src_page, src_offset, len);
1029
1030 if (!io_zcrx_queue_cqe(req, niov, ifq, 0, n)) {
1031 io_zcrx_return_niov(niov);
1032 ret = -ENOSPC;
1033 break;
1034 }
1035
1036 io_zcrx_get_niov_uref(niov);
1037 src_offset += n;
1038 len -= n;
1039 copied += n;
1040 }
1041
1042 return copied ? copied : ret;
1043 }
1044
io_zcrx_copy_frag(struct io_kiocb * req,struct io_zcrx_ifq * ifq,const skb_frag_t * frag,int off,int len)1045 static int io_zcrx_copy_frag(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1046 const skb_frag_t *frag, int off, int len)
1047 {
1048 struct page *page = skb_frag_page(frag);
1049
1050 return io_zcrx_copy_chunk(req, ifq, page, off + skb_frag_off(frag), len);
1051 }
1052
io_zcrx_recv_frag(struct io_kiocb * req,struct io_zcrx_ifq * ifq,const skb_frag_t * frag,int off,int len)1053 static int io_zcrx_recv_frag(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1054 const skb_frag_t *frag, int off, int len)
1055 {
1056 struct net_iov *niov;
1057
1058 if (unlikely(!skb_frag_is_net_iov(frag)))
1059 return io_zcrx_copy_frag(req, ifq, frag, off, len);
1060
1061 niov = netmem_to_net_iov(frag->netmem);
1062 if (!niov->pp || niov->pp->mp_ops != &io_uring_pp_zc_ops ||
1063 io_pp_to_ifq(niov->pp) != ifq)
1064 return -EFAULT;
1065
1066 if (!io_zcrx_queue_cqe(req, niov, ifq, off + skb_frag_off(frag), len))
1067 return -ENOSPC;
1068
1069 /*
1070 * Prevent it from being recycled while user is accessing it.
1071 * It has to be done before grabbing a user reference.
1072 */
1073 page_pool_ref_netmem(net_iov_to_netmem(niov));
1074 io_zcrx_get_niov_uref(niov);
1075 return len;
1076 }
1077
1078 static int
io_zcrx_recv_skb(read_descriptor_t * desc,struct sk_buff * skb,unsigned int offset,size_t len)1079 io_zcrx_recv_skb(read_descriptor_t *desc, struct sk_buff *skb,
1080 unsigned int offset, size_t len)
1081 {
1082 struct io_zcrx_args *args = desc->arg.data;
1083 struct io_zcrx_ifq *ifq = args->ifq;
1084 struct io_kiocb *req = args->req;
1085 struct sk_buff *frag_iter;
1086 unsigned start, start_off = offset;
1087 int i, copy, end, off;
1088 int ret = 0;
1089
1090 len = min_t(size_t, len, desc->count);
1091 /*
1092 * __tcp_read_sock() always calls io_zcrx_recv_skb one last time, even
1093 * if desc->count is already 0. This is caused by the if (offset + 1 !=
1094 * skb->len) check. Return early in this case to break out of
1095 * __tcp_read_sock().
1096 */
1097 if (!len)
1098 return 0;
1099 if (unlikely(args->nr_skbs++ > IO_SKBS_PER_CALL_LIMIT))
1100 return -EAGAIN;
1101
1102 if (unlikely(offset < skb_headlen(skb))) {
1103 ssize_t copied;
1104 size_t to_copy;
1105
1106 to_copy = min_t(size_t, skb_headlen(skb) - offset, len);
1107 copied = io_zcrx_copy_chunk(req, ifq, virt_to_page(skb->data),
1108 offset_in_page(skb->data) + offset,
1109 to_copy);
1110 if (copied < 0) {
1111 ret = copied;
1112 goto out;
1113 }
1114 offset += copied;
1115 len -= copied;
1116 if (!len)
1117 goto out;
1118 if (offset != skb_headlen(skb))
1119 goto out;
1120 }
1121
1122 start = skb_headlen(skb);
1123
1124 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
1125 const skb_frag_t *frag;
1126
1127 if (WARN_ON(start > offset + len))
1128 return -EFAULT;
1129
1130 frag = &skb_shinfo(skb)->frags[i];
1131 end = start + skb_frag_size(frag);
1132
1133 if (offset < end) {
1134 copy = end - offset;
1135 if (copy > len)
1136 copy = len;
1137
1138 off = offset - start;
1139 ret = io_zcrx_recv_frag(req, ifq, frag, off, copy);
1140 if (ret < 0)
1141 goto out;
1142
1143 offset += ret;
1144 len -= ret;
1145 if (len == 0 || ret != copy)
1146 goto out;
1147 }
1148 start = end;
1149 }
1150
1151 skb_walk_frags(skb, frag_iter) {
1152 if (WARN_ON(start > offset + len))
1153 return -EFAULT;
1154
1155 end = start + frag_iter->len;
1156 if (offset < end) {
1157 copy = end - offset;
1158 if (copy > len)
1159 copy = len;
1160
1161 off = offset - start;
1162 ret = io_zcrx_recv_skb(desc, frag_iter, off, copy);
1163 if (ret < 0)
1164 goto out;
1165
1166 offset += ret;
1167 len -= ret;
1168 if (len == 0 || ret != copy)
1169 goto out;
1170 }
1171 start = end;
1172 }
1173
1174 out:
1175 if (offset == start_off)
1176 return ret;
1177 desc->count -= (offset - start_off);
1178 return offset - start_off;
1179 }
1180
io_zcrx_tcp_recvmsg(struct io_kiocb * req,struct io_zcrx_ifq * ifq,struct sock * sk,int flags,unsigned issue_flags,unsigned int * outlen)1181 static int io_zcrx_tcp_recvmsg(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1182 struct sock *sk, int flags,
1183 unsigned issue_flags, unsigned int *outlen)
1184 {
1185 unsigned int len = *outlen;
1186 struct io_zcrx_args args = {
1187 .req = req,
1188 .ifq = ifq,
1189 .sock = sk->sk_socket,
1190 };
1191 read_descriptor_t rd_desc = {
1192 .count = len ? len : UINT_MAX,
1193 .arg.data = &args,
1194 };
1195 int ret;
1196
1197 lock_sock(sk);
1198 ret = tcp_read_sock(sk, &rd_desc, io_zcrx_recv_skb);
1199 if (len && ret > 0)
1200 *outlen = len - ret;
1201 if (ret <= 0) {
1202 if (ret < 0 || sock_flag(sk, SOCK_DONE))
1203 goto out;
1204 if (sk->sk_err)
1205 ret = sock_error(sk);
1206 else if (sk->sk_shutdown & RCV_SHUTDOWN)
1207 goto out;
1208 else if (sk->sk_state == TCP_CLOSE)
1209 ret = -ENOTCONN;
1210 else
1211 ret = -EAGAIN;
1212 } else if (unlikely(args.nr_skbs > IO_SKBS_PER_CALL_LIMIT) &&
1213 (issue_flags & IO_URING_F_MULTISHOT)) {
1214 ret = IOU_REQUEUE;
1215 } else if (sock_flag(sk, SOCK_DONE)) {
1216 /* Make it to retry until it finally gets 0. */
1217 if (issue_flags & IO_URING_F_MULTISHOT)
1218 ret = IOU_REQUEUE;
1219 else
1220 ret = -EAGAIN;
1221 }
1222 out:
1223 release_sock(sk);
1224 return ret;
1225 }
1226
io_zcrx_recv(struct io_kiocb * req,struct io_zcrx_ifq * ifq,struct socket * sock,unsigned int flags,unsigned issue_flags,unsigned int * len)1227 int io_zcrx_recv(struct io_kiocb *req, struct io_zcrx_ifq *ifq,
1228 struct socket *sock, unsigned int flags,
1229 unsigned issue_flags, unsigned int *len)
1230 {
1231 struct sock *sk = sock->sk;
1232 const struct proto *prot = READ_ONCE(sk->sk_prot);
1233
1234 if (prot->recvmsg != tcp_recvmsg)
1235 return -EPROTONOSUPPORT;
1236
1237 sock_rps_record_flow(sk);
1238 return io_zcrx_tcp_recvmsg(req, ifq, sk, flags, issue_flags, len);
1239 }
1240