1 #include <stdint.h>
2 #include <string.h>
3 
twobyte_strstr(const unsigned char * h,const unsigned char * n)4 static char* twobyte_strstr(const unsigned char* h, const unsigned char* n) {
5     uint16_t nw = n[0] << 8 | n[1], hw = h[0] << 8 | h[1];
6     for (h++; *h && hw != nw; hw = hw << 8 | *++h)
7         ;
8     return *h ? (char*)h - 1 : 0;
9 }
10 
threebyte_strstr(const unsigned char * h,const unsigned char * n)11 static char* threebyte_strstr(const unsigned char* h, const unsigned char* n) {
12     uint32_t nw = n[0] << 24 | n[1] << 16 | n[2] << 8;
13     uint32_t hw = h[0] << 24 | h[1] << 16 | h[2] << 8;
14     for (h += 2; *h && hw != nw; hw = (hw | *++h) << 8)
15         ;
16     return *h ? (char*)h - 2 : 0;
17 }
18 
fourbyte_strstr(const unsigned char * h,const unsigned char * n)19 static char* fourbyte_strstr(const unsigned char* h, const unsigned char* n) {
20     uint32_t nw = n[0] << 24 | n[1] << 16 | n[2] << 8 | n[3];
21     uint32_t hw = h[0] << 24 | h[1] << 16 | h[2] << 8 | h[3];
22     for (h += 3; *h && hw != nw; hw = hw << 8 | *++h)
23         ;
24     return *h ? (char*)h - 3 : 0;
25 }
26 
27 #define MAX(a, b) ((a) > (b) ? (a) : (b))
28 #define MIN(a, b) ((a) < (b) ? (a) : (b))
29 
30 #define BITOP(a, b, op) \
31     ((a)[(size_t)(b) / (8 * sizeof *(a))] op(size_t) 1 << ((size_t)(b) % (8 * sizeof *(a))))
32 
twoway_strstr(const unsigned char * h,const unsigned char * n)33 static char* twoway_strstr(const unsigned char* h, const unsigned char* n) {
34     const unsigned char* z;
35     size_t l, ip, jp, k, p, ms, p0, mem, mem0;
36     size_t byteset[32 / sizeof(size_t)] = {};
37     size_t shift[256];
38 
39     /* Computing length of needle and fill shift table */
40     for (l = 0; n[l] && h[l]; l++)
41         BITOP(byteset, n[l], |=)
42     , shift[n[l]] = l + 1;
43     if (n[l])
44         return 0; /* hit the end of h */
45 
46     /* Compute maximal suffix */
47     ip = -1;
48     jp = 0;
49     k = p = 1;
50     while (jp + k < l) {
51         if (n[ip + k] == n[jp + k]) {
52             if (k == p) {
53                 jp += p;
54                 k = 1;
55             } else
56                 k++;
57         } else if (n[ip + k] > n[jp + k]) {
58             jp += k;
59             k = 1;
60             p = jp - ip;
61         } else {
62             ip = jp++;
63             k = p = 1;
64         }
65     }
66     ms = ip;
67     p0 = p;
68 
69     /* And with the opposite comparison */
70     ip = -1;
71     jp = 0;
72     k = p = 1;
73     while (jp + k < l) {
74         if (n[ip + k] == n[jp + k]) {
75             if (k == p) {
76                 jp += p;
77                 k = 1;
78             } else
79                 k++;
80         } else if (n[ip + k] < n[jp + k]) {
81             jp += k;
82             k = 1;
83             p = jp - ip;
84         } else {
85             ip = jp++;
86             k = p = 1;
87         }
88     }
89     if (ip + 1 > ms + 1)
90         ms = ip;
91     else
92         p = p0;
93 
94     /* Periodic needle? */
95     if (memcmp(n, n + p, ms + 1)) {
96         mem0 = 0;
97         p = MAX(ms, l - ms - 1) + 1;
98     } else
99         mem0 = l - p;
100     mem = 0;
101 
102     /* Initialize incremental end-of-haystack pointer */
103     z = h;
104 
105     /* Search loop */
106     for (;;) {
107         /* Update incremental end-of-haystack pointer */
108         if (z - h < l) {
109             /* Fast estimate for MIN(l,63) */
110             size_t grow = l | 63;
111             const unsigned char* z2 = memchr(z, 0, grow);
112             if (z2) {
113                 z = z2;
114                 if (z - h < l)
115                     return 0;
116             } else
117                 z += grow;
118         }
119 
120         /* Check last byte first; advance by shift on mismatch */
121         if (BITOP(byteset, h[l - 1], &)) {
122             k = l - shift[h[l - 1]];
123             // printf("adv by %zu (on %c) at [%s] (%zu;l=%zu)\n", k, h[l-1], h, shift[h[l-1]], l);
124             if (k) {
125                 if (mem0 && mem && k < p)
126                     k = l - p;
127                 h += k;
128                 mem = 0;
129                 continue;
130             }
131         } else {
132             h += l;
133             mem = 0;
134             continue;
135         }
136 
137         /* Compare right half */
138         for (k = MAX(ms + 1, mem); n[k] && n[k] == h[k]; k++)
139             ;
140         if (n[k]) {
141             h += k - ms;
142             mem = 0;
143             continue;
144         }
145         /* Compare left half */
146         for (k = ms + 1; k > mem && n[k - 1] == h[k - 1]; k--)
147             ;
148         if (k <= mem)
149             return (char*)h;
150         h += p;
151         mem = mem0;
152     }
153 }
154 
strstr(const char * h,const char * n)155 char* strstr(const char* h, const char* n) {
156     /* Return immediately on empty needle */
157     if (!n[0])
158         return (char*)h;
159 
160     /* Use faster algorithms for short needles */
161     h = strchr(h, *n);
162     if (!h || !n[1])
163         return (char*)h;
164     if (!h[1])
165         return 0;
166     if (!n[2])
167         return twobyte_strstr((void*)h, (void*)n);
168     if (!h[2])
169         return 0;
170     if (!n[3])
171         return threebyte_strstr((void*)h, (void*)n);
172     if (!h[3])
173         return 0;
174     if (!n[4])
175         return fourbyte_strstr((void*)h, (void*)n);
176 
177     return twoway_strstr((void*)h, (void*)n);
178 }
179