1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <linux/ptrace.h>
4 #include <stddef.h>
5 #include <linux/bpf.h>
6 #include <bpf/bpf_helpers.h>
7 #include <bpf/bpf_tracing.h>
8 
9 char _license[] SEC("license") = "GPL";
10 
11 /* typically virtio scsi has max SGs of 6 */
12 #define VIRTIO_MAX_SGS	6
13 
14 /* Verifier will fail with SG_MAX = 128. The failure can be
15  * workarounded with a smaller SG_MAX, e.g. 10.
16  */
17 #define WORKAROUND
18 #ifdef WORKAROUND
19 #define SG_MAX		10
20 #else
21 /* typically virtio blk has max SEG of 128 */
22 #define SG_MAX		128
23 #endif
24 
25 #define SG_CHAIN	0x01UL
26 #define SG_END		0x02UL
27 
28 struct scatterlist {
29 	unsigned long   page_link;
30 	unsigned int    offset;
31 	unsigned int    length;
32 };
33 
34 #define sg_is_chain(sg)		((sg)->page_link & SG_CHAIN)
35 #define sg_is_last(sg)		((sg)->page_link & SG_END)
36 #define sg_chain_ptr(sg)	\
37 	((struct scatterlist *) ((sg)->page_link & ~(SG_CHAIN | SG_END)))
38 
__sg_next(struct scatterlist * sgp)39 static inline struct scatterlist *__sg_next(struct scatterlist *sgp)
40 {
41 	struct scatterlist sg;
42 
43 	bpf_probe_read_kernel(&sg, sizeof(sg), sgp);
44 	if (sg_is_last(&sg))
45 		return NULL;
46 
47 	sgp++;
48 
49 	bpf_probe_read_kernel(&sg, sizeof(sg), sgp);
50 	if (sg_is_chain(&sg))
51 		sgp = sg_chain_ptr(&sg);
52 
53 	return sgp;
54 }
55 
get_sgp(struct scatterlist ** sgs,int i)56 static inline struct scatterlist *get_sgp(struct scatterlist **sgs, int i)
57 {
58 	struct scatterlist *sgp;
59 
60 	bpf_probe_read_kernel(&sgp, sizeof(sgp), sgs + i);
61 	return sgp;
62 }
63 
64 int config = 0;
65 int result = 0;
66 
67 SEC("kprobe/virtqueue_add_sgs")
BPF_KPROBE(trace_virtqueue_add_sgs,void * unused,struct scatterlist ** sgs,unsigned int out_sgs,unsigned int in_sgs)68 int BPF_KPROBE(trace_virtqueue_add_sgs, void *unused, struct scatterlist **sgs,
69 	       unsigned int out_sgs, unsigned int in_sgs)
70 {
71 	struct scatterlist *sgp = NULL;
72 	__u64 length1 = 0, length2 = 0;
73 	unsigned int i, n, len;
74 
75 	if (config != 0)
76 		return 0;
77 
78 	for (i = 0; (i < VIRTIO_MAX_SGS) && (i < out_sgs); i++) {
79 		for (n = 0, sgp = get_sgp(sgs, i); sgp && (n < SG_MAX);
80 		     sgp = __sg_next(sgp)) {
81 			bpf_probe_read_kernel(&len, sizeof(len), &sgp->length);
82 			length1 += len;
83 			n++;
84 		}
85 	}
86 
87 	for (i = 0; (i < VIRTIO_MAX_SGS) && (i < in_sgs); i++) {
88 		for (n = 0, sgp = get_sgp(sgs, i); sgp && (n < SG_MAX);
89 		     sgp = __sg_next(sgp)) {
90 			bpf_probe_read_kernel(&len, sizeof(len), &sgp->length);
91 			length2 += len;
92 			n++;
93 		}
94 	}
95 
96 	config = 1;
97 	result = length2 - length1;
98 	return 0;
99 }
100