Coverage Report

Created: 2024-01-26 01:52

/work/toxcore/forwarding.c
Line
Count
Source (jump to first uncovered line)
1
/* SPDX-License-Identifier: GPL-3.0-or-later
2
 * Copyright © 2019-2022 The TokTok team.
3
 */
4
5
#include "forwarding.h"
6
7
#include <assert.h>
8
#include <stdlib.h>
9
#include <string.h>
10
11
#include "DHT.h"
12
#include "ccompat.h"
13
#include "crypto_core.h"
14
#include "logger.h"
15
#include "mono_time.h"
16
#include "network.h"
17
#include "timed_auth.h"
18
19
struct Forwarding {
20
    const Logger *log;
21
    const Random *rng;
22
    DHT *dht;
23
    const Mono_Time *mono_time;
24
    Networking_Core *net;
25
26
    uint8_t hmac_key[CRYPTO_HMAC_KEY_SIZE];
27
28
    forward_reply_cb *forward_reply_callback;
29
    void *forward_reply_callback_object;
30
31
    forwarded_request_cb *forwarded_request_callback;
32
    void *forwarded_request_callback_object;
33
34
    forwarded_response_cb *forwarded_response_callback;
35
    void *forwarded_response_callback_object;
36
};
37
38
DHT *forwarding_get_dht(const Forwarding *forwarding)
39
3.76k
{
40
3.76k
    return forwarding->dht;
41
3.76k
}
42
43
166
#define SENDBACK_TIMEOUT 3600
44
45
bool send_forward_request(const Networking_Core *net, const IP_Port *forwarder,
46
                          const uint8_t *chain_keys, uint16_t chain_length,
47
                          const uint8_t *data, uint16_t data_length)
48
30
{
49
30
    if (chain_length == 0 || chain_length > MAX_FORWARD_CHAIN_LENGTH
50
30
            || data_length > MAX_FORWARD_DATA_SIZE) {
51
4
        return false;
52
4
    }
53
54
26
    const uint16_t len = forward_chain_packet_size(chain_length, data_length);
55
26
    VLA(uint8_t, packet, len);
56
57
26
    return create_forward_chain_packet(chain_keys, chain_length, data, data_length, packet)
58
26
           && sendpacket(net, forwarder, packet, len) == len;
59
30
}
60
61
uint16_t forward_chain_packet_size(uint16_t chain_length, uint16_t data_length)
62
34
{
63
34
    return chain_length * (1 + CRYPTO_PUBLIC_KEY_SIZE) + data_length;
64
34
}
65
66
bool create_forward_chain_packet(const uint8_t *chain_keys, uint16_t chain_length,
67
                                 const uint8_t *data, uint16_t data_length,
68
                                 uint8_t *packet)
69
34
{
70
34
    if (chain_length == 0 || chain_length > MAX_FORWARD_CHAIN_LENGTH
71
34
            || data_length > MAX_FORWARD_DATA_SIZE) {
72
0
        return false;
73
0
    }
74
75
34
    uint16_t offset = 0;
76
77
123
    for (uint16_t j = 0; j < chain_length; ++j) {
78
89
        packet[offset] = NET_PACKET_FORWARD_REQUEST;
79
89
        ++offset;
80
89
        memcpy(packet + offset, chain_keys + j * CRYPTO_PUBLIC_KEY_SIZE, CRYPTO_PUBLIC_KEY_SIZE);
81
89
        offset += CRYPTO_PUBLIC_KEY_SIZE;
82
89
    }
83
84
34
    memcpy(packet + offset, data, data_length);
85
34
    return true;
86
34
}
87
88
non_null()
89
static uint16_t forwarding_packet_length(uint16_t sendback_data_len, uint16_t data_length)
90
124
{
91
124
    const uint16_t sendback_len = sendback_data_len == 0 ? 0 : TIMED_AUTH_SIZE + sendback_data_len;
92
124
    return 1 + 1 + sendback_len + data_length;
93
124
}
94
95
non_null(1, 4, 6) nullable(2)
96
static bool create_forwarding_packet(const Forwarding *forwarding,
97
                                     const uint8_t *sendback_data, uint16_t sendback_data_len,
98
                                     const uint8_t *data, uint16_t length,
99
                                     uint8_t *packet)
100
124
{
101
124
    packet[0] = NET_PACKET_FORWARDING;
102
103
124
    if (sendback_data_len == 0) {
104
15
        packet[1] = 0;
105
15
        memcpy(packet + 1 + 1, data, length);
106
109
    } else {
107
109
        const uint16_t sendback_len = TIMED_AUTH_SIZE + sendback_data_len;
108
109
109
        if (sendback_len > MAX_SENDBACK_SIZE) {
110
0
            return false;
111
0
        }
112
113
109
        packet[1] = sendback_len;
114
109
        generate_timed_auth(forwarding->mono_time, SENDBACK_TIMEOUT, forwarding->hmac_key, sendback_data,
115
109
                            sendback_data_len, packet + 1 + 1);
116
117
109
        if (sendback_data_len != 0) {
118
109
            assert(sendback_data != nullptr);
119
109
            memcpy(packet + 1 + 1 + TIMED_AUTH_SIZE, sendback_data, sendback_data_len);
120
109
        }
121
122
109
        memcpy(packet + 1 + 1 + sendback_len, data, length);
123
109
    }
124
125
124
    return true;
126
124
}
127
128
bool send_forwarding(const Forwarding *forwarding, const IP_Port *dest,
129
                     const uint8_t *sendback_data, uint16_t sendback_data_len,
130
                     const uint8_t *data, uint16_t length)
131
24
{
132
24
    if (length > MAX_FORWARD_DATA_SIZE) {
133
0
        return false;
134
0
    }
135
136
24
    const uint16_t len = forwarding_packet_length(sendback_data_len, length);
137
24
    VLA(uint8_t, packet, len);
138
24
    create_forwarding_packet(forwarding, sendback_data, sendback_data_len, data, length, packet);
139
24
    return sendpacket(forwarding->net, dest, packet, len) == len;
140
24
}
141
142
114
#define FORWARD_REQUEST_MIN_PACKET_SIZE (1 + CRYPTO_PUBLIC_KEY_SIZE)
143
144
non_null(1) nullable(2, 4)
145
static bool handle_forward_request_dht(const Forwarding *forwarding,
146
                                       const uint8_t *sendback_data, uint16_t sendback_data_len,
147
                                       const uint8_t *packet, uint16_t length)
148
114
{
149
114
    if (length < FORWARD_REQUEST_MIN_PACKET_SIZE) {
150
12
        return false;
151
12
    }
152
153
102
    const uint8_t *const public_key = packet + 1;
154
102
    const uint8_t *const forward_data = packet + (1 + CRYPTO_PUBLIC_KEY_SIZE);
155
102
    const uint16_t forward_data_len = length - (1 + CRYPTO_PUBLIC_KEY_SIZE);
156
157
102
    if (TIMED_AUTH_SIZE + sendback_data_len > MAX_SENDBACK_SIZE ||
158
102
            forward_data_len > MAX_FORWARD_DATA_SIZE) {
159
2
        return false;
160
2
    }
161
162
100
    const uint16_t len = forwarding_packet_length(sendback_data_len, forward_data_len);
163
100
    VLA(uint8_t, forwarding_packet, len);
164
165
100
    create_forwarding_packet(forwarding, sendback_data, sendback_data_len, forward_data, forward_data_len,
166
100
                             forwarding_packet);
167
168
100
    return route_packet(forwarding->dht, public_key, forwarding_packet, len) == len;
169
102
}
170
171
non_null(1, 2) nullable(3, 5)
172
static int handle_forward_request(void *object, const IP_Port *source, const uint8_t *packet, uint16_t length,
173
                                  void *userdata)
174
66
{
175
66
    const Forwarding *forwarding = (const Forwarding *)object;
176
177
66
    uint8_t sendback_data[1 + MAX_PACKED_IPPORT_SIZE];
178
66
    sendback_data[0] = SENDBACK_IPPORT;
179
180
66
    const int ipport_length = pack_ip_port(forwarding->log, sendback_data + 1, MAX_PACKED_IPPORT_SIZE, source);
181
182
66
    if (ipport_length == -1) {
183
0
        return 1;
184
0
    }
185
186
66
    return handle_forward_request_dht(forwarding, sendback_data, 1 + ipport_length, packet, length) ? 0 : 1;
187
66
}
188
189
80
#define MIN_NONEMPTY_SENDBACK_SIZE TIMED_AUTH_SIZE
190
80
#define FORWARD_REPLY_MIN_PACKET_SIZE (1 + 1 + MIN_NONEMPTY_SENDBACK_SIZE)
191
192
non_null(1, 2) nullable(3, 5)
193
static int handle_forward_reply(void *object, const IP_Port *source, const uint8_t *packet, uint16_t length,
194
                                void *userdata)
195
80
{
196
80
    const Forwarding *forwarding = (const Forwarding *)object;
197
198
80
    if (length < FORWARD_REPLY_MIN_PACKET_SIZE) {
199
5
        return 1;
200
5
    }
201
202
75
    const uint8_t sendback_len = packet[1];
203
75
    const uint8_t *const sendback_auth = packet + 1 + 1;
204
75
    const uint8_t *const sendback_data = sendback_auth + TIMED_AUTH_SIZE;
205
206
75
    if (sendback_len > MAX_SENDBACK_SIZE) {
207
        /* value 0xff is reserved for possible future expansion */
208
3
        return 1;
209
3
    }
210
211
72
    if (sendback_len < TIMED_AUTH_SIZE + 1) {
212
10
        return 1;
213
10
    }
214
215
62
    const uint16_t sendback_data_len = sendback_len - TIMED_AUTH_SIZE;
216
217
62
    if (length < 1 + 1 + sendback_len) {
218
5
        return 1;
219
5
    }
220
221
57
    const uint8_t *const to_forward = packet + (1 + 1 + sendback_len);
222
57
    const uint16_t to_forward_len = length - (1 + 1 + sendback_len);
223
224
57
    if (!check_timed_auth(forwarding->mono_time, SENDBACK_TIMEOUT, forwarding->hmac_key, sendback_data, sendback_data_len,
225
57
                          sendback_auth)) {
226
3
        return 1;
227
3
    }
228
229
54
    if (sendback_data[0] == SENDBACK_IPPORT) {
230
15
        IP_Port dest;
231
232
15
        if (unpack_ip_port(&dest, sendback_data + 1, sendback_data_len - 1, false)
233
15
                != sendback_data_len - 1) {
234
0
            return 1;
235
0
        }
236
237
15
        return send_forwarding(forwarding, &dest, nullptr, 0, to_forward, to_forward_len) ? 0 : 1;
238
15
    }
239
240
39
    if (sendback_data[0] == SENDBACK_FORWARD) {
241
34
        IP_Port forwarder;
242
34
        const int ipport_length = unpack_ip_port(&forwarder, sendback_data + 1, sendback_data_len - 1, false);
243
244
34
        if (ipport_length == -1) {
245
0
            return 1;
246
0
        }
247
248
34
        const uint8_t *const forward_sendback = sendback_data + (1 + ipport_length);
249
34
        const uint16_t forward_sendback_len = sendback_data_len - (1 + ipport_length);
250
251
34
        return forward_reply(forwarding->net, &forwarder, forward_sendback, forward_sendback_len, to_forward,
252
34
                             to_forward_len) ? 0 : 1;
253
34
    }
254
255
5
    if (forwarding->forward_reply_callback == nullptr) {
256
0
        return 1;
257
0
    }
258
259
5
    return forwarding->forward_reply_callback(forwarding->forward_reply_callback_object,
260
5
            sendback_data, sendback_data_len,
261
5
            to_forward, to_forward_len) ? 0 : 1;
262
5
}
263
264
322
#define FORWARDING_MIN_PACKET_SIZE (1 + 1)
265
266
non_null(1, 2) nullable(3, 5)
267
static int handle_forwarding(void *object, const IP_Port *source, const uint8_t *packet, uint16_t length,
268
                             void *userdata)
269
322
{
270
322
    const Forwarding *forwarding = (const Forwarding *)object;
271
272
322
    if (length < FORWARDING_MIN_PACKET_SIZE) {
273
10
        return 1;
274
10
    }
275
276
312
    const uint8_t sendback_len = packet[1];
277
278
312
    if (length < 1 + 1 + sendback_len) {
279
68
        return 1;
280
68
    }
281
282
244
    const uint8_t *const sendback = packet + 1 + 1;
283
284
244
    const uint8_t *const forwarded = sendback + sendback_len;
285
244
    const uint16_t forwarded_len = length - (1 + 1 + sendback_len);
286
287
244
    if (forwarded_len >= 1 && forwarded[0] == NET_PACKET_FORWARD_REQUEST) {
288
48
        VLA(uint8_t, sendback_data, 1 + MAX_PACKED_IPPORT_SIZE + sendback_len);
289
48
        sendback_data[0] = SENDBACK_FORWARD;
290
291
48
        const int ipport_length = pack_ip_port(forwarding->log, sendback_data + 1, MAX_PACKED_IPPORT_SIZE, source);
292
293
48
        if (ipport_length == -1) {
294
0
            return 1;
295
0
        }
296
297
48
        memcpy(sendback_data + 1 + ipport_length, sendback, sendback_len);
298
299
48
        return handle_forward_request_dht(forwarding, sendback_data, 1 + ipport_length + sendback_len, forwarded,
300
48
                                          forwarded_len) ? 0 : 1;
301
48
    }
302
303
196
    if (sendback_len > 0) {
304
27
        if (forwarding->forwarded_request_callback == nullptr) {
305
0
            return 1;
306
0
        }
307
308
27
        forwarding->forwarded_request_callback(forwarding->forwarded_request_callback_object,
309
27
                                               source, sendback, sendback_len,
310
27
                                               forwarded, forwarded_len, userdata);
311
27
        return 0;
312
169
    } else {
313
169
        if (forwarding->forwarded_response_callback == nullptr) {
314
154
            return 1;
315
154
        }
316
317
15
        forwarding->forwarded_response_callback(forwarding->forwarded_response_callback_object,
318
15
                                                forwarded, forwarded_len, userdata);
319
15
        return 0;
320
169
    }
321
196
}
322
323
bool forward_reply(const Networking_Core *net, const IP_Port *forwarder,
324
                   const uint8_t *sendback, uint16_t sendback_length,
325
                   const uint8_t *data, uint16_t length)
326
63
{
327
63
    if (sendback_length > MAX_SENDBACK_SIZE ||
328
63
            length > MAX_FORWARD_DATA_SIZE) {
329
2
        return false;
330
2
    }
331
332
61
    const uint16_t len = 1 + 1 + sendback_length + length;
333
61
    VLA(uint8_t, packet, len);
334
61
    packet[0] = NET_PACKET_FORWARD_REPLY;
335
61
    packet[1] = (uint8_t) sendback_length;
336
61
    memcpy(packet + 1 + 1, sendback, sendback_length);
337
61
    memcpy(packet + 1 + 1 + sendback_length, data, length);
338
61
    return sendpacket(net, forwarder, packet, len) == len;
339
63
}
340
341
void set_callback_forwarded_request(Forwarding *forwarding, forwarded_request_cb *function, void *object)
342
6.40k
{
343
6.40k
    forwarding->forwarded_request_callback = function;
344
6.40k
    forwarding->forwarded_request_callback_object = object;
345
6.40k
}
346
347
void set_callback_forwarded_response(Forwarding *forwarding, forwarded_response_cb *function, void *object)
348
20
{
349
20
    forwarding->forwarded_response_callback = function;
350
20
    forwarding->forwarded_response_callback_object = object;
351
20
}
352
353
void set_callback_forward_reply(Forwarding *forwarding, forward_reply_cb *function, void *object)
354
40
{
355
40
    forwarding->forward_reply_callback = function;
356
40
    forwarding->forward_reply_callback_object = object;
357
40
}
358
359
Forwarding *new_forwarding(const Logger *log, const Random *rng, const Mono_Time *mono_time, DHT *dht)
360
3.80k
{
361
3.80k
    if (log == nullptr || mono_time == nullptr || dht == nullptr) {
362
0
        return nullptr;
363
0
    }
364
365
3.80k
    Forwarding *forwarding = (Forwarding *)calloc(1, sizeof(Forwarding));
366
367
3.80k
    if (forwarding == nullptr) {
368
20
        return nullptr;
369
20
    }
370
371
3.78k
    forwarding->log = log;
372
3.78k
    forwarding->rng = rng;
373
3.78k
    forwarding->mono_time = mono_time;
374
3.78k
    forwarding->dht = dht;
375
3.78k
    forwarding->net = dht_get_net(dht);
376
377
3.78k
    networking_registerhandler(forwarding->net, NET_PACKET_FORWARD_REQUEST, &handle_forward_request, forwarding);
378
3.78k
    networking_registerhandler(forwarding->net, NET_PACKET_FORWARD_REPLY, &handle_forward_reply, forwarding);
379
3.78k
    networking_registerhandler(forwarding->net, NET_PACKET_FORWARDING, &handle_forwarding, forwarding);
380
381
3.78k
    new_hmac_key(forwarding->rng, forwarding->hmac_key);
382
383
3.78k
    return forwarding;
384
3.80k
}
385
386
void kill_forwarding(Forwarding *forwarding)
387
2.75k
{
388
2.75k
    if (forwarding == nullptr) {
389
20
        return;
390
20
    }
391
392
2.73k
    networking_registerhandler(forwarding->net, NET_PACKET_FORWARD_REQUEST, nullptr, nullptr);
393
2.73k
    networking_registerhandler(forwarding->net, NET_PACKET_FORWARD_REPLY, nullptr, nullptr);
394
2.73k
    networking_registerhandler(forwarding->net, NET_PACKET_FORWARDING, nullptr, nullptr);
395
396
2.73k
    crypto_memzero(forwarding->hmac_key, CRYPTO_HMAC_KEY_SIZE);
397
398
2.73k
    free(forwarding);
399
2.73k
}