1 /*
2  * Copyright (c) 2024 Endress+Hauser AG
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/net/dns_resolve.h>
8 #include <zephyr/net/net_ip.h>
9 #include "dns_cache.h"
10 
11 LOG_MODULE_REGISTER(net_dns_cache, CONFIG_DNS_RESOLVER_LOG_LEVEL);
12 
13 static void dns_cache_clean(struct dns_cache const *cache);
14 
dns_cache_flush(struct dns_cache * cache)15 int dns_cache_flush(struct dns_cache *cache)
16 {
17 	k_mutex_lock(cache->lock, K_FOREVER);
18 	for (size_t i = 0; i < cache->size; i++) {
19 		cache->entries[i].in_use = false;
20 	}
21 	k_mutex_unlock(cache->lock);
22 
23 	return 0;
24 }
25 
dns_cache_add(struct dns_cache * cache,char const * query,struct dns_addrinfo const * addrinfo,uint32_t ttl)26 int dns_cache_add(struct dns_cache *cache, char const *query, struct dns_addrinfo const *addrinfo,
27 		  uint32_t ttl)
28 {
29 	k_timepoint_t closest_to_expiry = sys_timepoint_calc(K_FOREVER);
30 	size_t index_to_replace = 0;
31 	bool found_empty = false;
32 
33 	if (cache == NULL || query == NULL || addrinfo == NULL || ttl == 0) {
34 		return -EINVAL;
35 	}
36 
37 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
38 		NET_WARN("Query string to big to be processed %u >= "
39 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
40 			 strlen(query));
41 		return -EINVAL;
42 	}
43 
44 	k_mutex_lock(cache->lock, K_FOREVER);
45 
46 	NET_DBG("Add \"%s\" with TTL %" PRIu32, query, ttl);
47 
48 	dns_cache_clean(cache);
49 
50 	for (size_t i = 0; i < cache->size; i++) {
51 		if (!cache->entries[i].in_use) {
52 			index_to_replace = i;
53 			found_empty = true;
54 			break;
55 		} else if (sys_timepoint_cmp(closest_to_expiry, cache->entries[i].expiry) > 0) {
56 			index_to_replace = i;
57 			closest_to_expiry = cache->entries[i].expiry;
58 		}
59 	}
60 
61 	if (!found_empty) {
62 		NET_DBG("Overwrite \"%s\"", cache->entries[index_to_replace].query);
63 	}
64 
65 	strncpy(cache->entries[index_to_replace].query, query,
66 		CONFIG_DNS_RESOLVER_MAX_QUERY_LEN - 1);
67 	cache->entries[index_to_replace].data = *addrinfo;
68 	cache->entries[index_to_replace].expiry = sys_timepoint_calc(K_SECONDS(ttl));
69 	cache->entries[index_to_replace].in_use = true;
70 
71 	k_mutex_unlock(cache->lock);
72 
73 	return 0;
74 }
75 
dns_cache_remove(struct dns_cache * cache,char const * query)76 int dns_cache_remove(struct dns_cache *cache, char const *query)
77 {
78 	if (cache == NULL || query == NULL) {
79 		return -EINVAL;
80 	}
81 
82 	NET_DBG("Remove all entries with query \"%s\"", query);
83 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
84 		NET_WARN("Query string to big to be processed %u >= "
85 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
86 			 strlen(query));
87 		return -EINVAL;
88 	}
89 
90 	k_mutex_lock(cache->lock, K_FOREVER);
91 
92 	dns_cache_clean(cache);
93 
94 	for (size_t i = 0; i < cache->size; i++) {
95 		if (cache->entries[i].in_use && strcmp(cache->entries[i].query, query) == 0) {
96 			cache->entries[i].in_use = false;
97 		}
98 	}
99 
100 	k_mutex_unlock(cache->lock);
101 
102 	return 0;
103 }
104 
dns_cache_find(struct dns_cache const * cache,const char * query,enum dns_query_type type,struct dns_addrinfo * addrinfo,size_t addrinfo_array_len)105 int dns_cache_find(struct dns_cache const *cache, const char *query, enum dns_query_type type,
106 		   struct dns_addrinfo *addrinfo, size_t addrinfo_array_len)
107 {
108 	size_t found = 0;
109 	sa_family_t family;
110 
111 	NET_DBG("Find \"%s\"", query);
112 	if (cache == NULL || query == NULL || addrinfo == NULL || addrinfo_array_len <= 0) {
113 		return -EINVAL;
114 	}
115 	if (type == DNS_QUERY_TYPE_A) {
116 		family = AF_INET;
117 	} else if (type == DNS_QUERY_TYPE_AAAA) {
118 		family = AF_INET6;
119 	} else {
120 		return -EINVAL;
121 	}
122 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
123 		NET_WARN("Query string to big to be processed %u >= "
124 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
125 			 strlen(query));
126 		return -EINVAL;
127 	}
128 
129 	k_mutex_lock(cache->lock, K_FOREVER);
130 
131 	dns_cache_clean(cache);
132 
133 	for (size_t i = 0; i < cache->size; i++) {
134 		if (!cache->entries[i].in_use) {
135 			continue;
136 		}
137 		if (strcmp(cache->entries[i].query, query) != 0) {
138 			continue;
139 		}
140 		if (cache->entries[i].data.ai_family != family) {
141 			continue;
142 		}
143 		if (found >= addrinfo_array_len) {
144 			NET_WARN("Found \"%s\" but not enough space in provided buffer.", query);
145 			found++;
146 		} else {
147 			addrinfo[found] = cache->entries[i].data;
148 			found++;
149 			NET_DBG("Found \"%s\"", query);
150 		}
151 	}
152 
153 	k_mutex_unlock(cache->lock);
154 
155 	if (found > addrinfo_array_len) {
156 		return -ENOSR;
157 	}
158 
159 	if (found == 0) {
160 		NET_DBG("Could not find \"%s\"", query);
161 	}
162 	return found;
163 }
164 
165 /* Needs to be called when lock is already acquired */
dns_cache_clean(struct dns_cache const * cache)166 static void dns_cache_clean(struct dns_cache const *cache)
167 {
168 	for (size_t i = 0; i < cache->size; i++) {
169 		if (!cache->entries[i].in_use) {
170 			continue;
171 		}
172 
173 		if (sys_timepoint_expired(cache->entries[i].expiry)) {
174 			NET_DBG("Remove \"%s\"", cache->entries[i].query);
175 			cache->entries[i].in_use = false;
176 		}
177 	}
178 }
179