Line data Source code
1 : /* SPDX-License-Identifier: GPL-3.0-or-later */
2 : /* Copyright 2026 Peter Csaszar */
3 :
4 : /**
5 : * @file mtproto_rpc.c
6 : * @brief MTProto RPC framework — message framing, encryption, decryption.
7 : */
8 :
9 : #include "mtproto_rpc.h"
10 : #include "mtproto_crypto.h"
11 : #include "tl_serial.h"
12 : #include "crypto.h"
13 : #include "tinf.h"
14 : #include "logger.h"
15 : #include "raii.h"
16 :
17 : #include <stdlib.h>
18 : #include <string.h>
19 :
20 : /* Must be large enough to hold a full encrypted frame, including
21 : * the 512 KiB upload chunk used by saveBigFilePart + the ~56-byte
22 : * MTProto plaintext header + the outer auth_key_id / msg_key.
23 : * 1 MiB leaves plenty of headroom. */
24 : #define RPC_BUF_SIZE (1024 * 1024)
25 :
26 : #define CRC_gzip_packed 0x3072cfa1
27 : #define CRC_msg_container 0x73f1f8dc
28 : #define CRC_rpc_result 0xf35c6d01
29 : #define CRC_rpc_error 0x2144ca19
30 :
31 : /* ---- Unencrypted messages ---- */
32 :
33 53 : int rpc_send_unencrypted(MtProtoSession *s, Transport *t,
34 : const uint8_t *data, size_t len) {
35 53 : if (!s || !t || !data) return -1;
36 :
37 50 : uint64_t msg_id = mtproto_session_next_msg_id(s);
38 :
39 : /* Wire format: auth_key_id(8) + msg_id(8) + len(4) + data */
40 : TlWriter w;
41 50 : tl_writer_init(&w);
42 50 : tl_write_uint64(&w, 0); /* auth_key_id = 0 */
43 50 : tl_write_uint64(&w, msg_id);
44 50 : tl_write_uint32(&w, (uint32_t)len);
45 50 : tl_write_raw(&w, data, len);
46 :
47 50 : int rc = transport_send(t, w.data, w.len);
48 50 : tl_writer_free(&w);
49 50 : return rc;
50 : }
51 :
52 54 : int rpc_recv_unencrypted(MtProtoSession *s, Transport *t,
53 : uint8_t *out, size_t max_len, size_t *out_len) {
54 54 : if (!s || !t || !out || !out_len) return -1;
55 :
56 : /* Read one transport packet (heap-allocated to avoid stack overflow) */
57 54 : RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
58 54 : if (!buf) return -1;
59 54 : size_t buf_len = 0;
60 54 : if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
61 :
62 : /* Parse: auth_key_id(8) + msg_id(8) + len(4) + data */
63 50 : if (buf_len < 20) return -1;
64 :
65 50 : TlReader r = tl_reader_init(buf, buf_len);
66 50 : uint64_t auth_key_id = tl_read_uint64(&r);
67 : (void)auth_key_id; /* should be 0 */
68 50 : uint64_t msg_id = tl_read_uint64(&r);
69 : (void)msg_id;
70 50 : uint32_t data_len = tl_read_uint32(&r);
71 :
72 50 : if (data_len > max_len) return -1;
73 50 : if (data_len > 0) {
74 50 : tl_read_raw(&r, out, data_len);
75 : }
76 50 : *out_len = data_len;
77 50 : return 0;
78 : }
79 :
80 : /* ---- Encrypted messages ---- */
81 :
82 763 : int rpc_send_encrypted(MtProtoSession *s, Transport *t,
83 : const uint8_t *data, size_t len,
84 : int content_related) {
85 763 : if (!s || !t || !data || !s->has_auth_key) return -1;
86 :
87 763 : uint64_t msg_id = mtproto_session_next_msg_id(s);
88 763 : uint32_t seq_no = mtproto_session_next_seq_no(s, content_related);
89 :
90 : /* Build plaintext payload: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
91 : TlWriter plain;
92 763 : tl_writer_init(&plain);
93 763 : tl_write_uint64(&plain, s->server_salt);
94 763 : tl_write_uint64(&plain, s->session_id);
95 763 : tl_write_uint64(&plain, msg_id);
96 763 : tl_write_uint32(&plain, seq_no);
97 763 : tl_write_uint32(&plain, (uint32_t)len);
98 763 : tl_write_raw(&plain, data, len);
99 :
100 : /* Encrypt with MTProto crypto. The msg_key is computed over the PADDED
101 : * plaintext inside mtproto_encrypt and returned via msg_key_out — the
102 : * spec requires that the wire msg_key match what was used to derive
103 : * the AES keys, so we cannot compute it independently here (the
104 : * padding bytes are random and generated inside mtproto_encrypt). */
105 763 : RAII_STRING uint8_t *encrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
106 763 : if (!encrypted) { tl_writer_free(&plain); return -1; }
107 763 : size_t enc_len = 0;
108 : uint8_t msg_key[16];
109 763 : mtproto_encrypt(plain.data, plain.len, s->auth_key, 0,
110 : encrypted, &enc_len, msg_key);
111 763 : tl_writer_free(&plain);
112 :
113 : /* Wire format: auth_key_id(8) + msg_key(16) + encrypted_data */
114 : TlWriter wire;
115 763 : tl_writer_init(&wire);
116 :
117 : /* auth_key_id = last 8 bytes of SHA256(auth_key). Use memcpy rather than
118 : * a pointer cast to avoid strict-aliasing UB, and let tl_write_uint64
119 : * handle little-endian encoding so we stay correct on big-endian hosts
120 : * (QA-11). */
121 : uint8_t key_hash[32];
122 763 : crypto_sha256(s->auth_key, 256, key_hash);
123 : uint64_t auth_key_id;
124 763 : memcpy(&auth_key_id, key_hash + 24, 8);
125 763 : tl_write_uint64(&wire, auth_key_id);
126 :
127 763 : tl_write_raw(&wire, msg_key, 16);
128 763 : tl_write_raw(&wire, encrypted, enc_len);
129 : /* encrypted freed automatically by RAII_STRING */
130 :
131 763 : int rc = transport_send(t, wire.data, wire.len);
132 763 : tl_writer_free(&wire);
133 763 : return rc;
134 : }
135 :
136 787 : int rpc_recv_encrypted(MtProtoSession *s, Transport *t,
137 : uint8_t *out, size_t max_len, size_t *out_len) {
138 787 : if (!s || !t || !out || !out_len || !s->has_auth_key) return -1;
139 :
140 : /* Read one transport packet (heap-allocated) */
141 787 : RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
142 787 : if (!buf) return -1;
143 787 : size_t buf_len = 0;
144 787 : if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
145 :
146 : /* Parse: auth_key_id(8) + msg_key(16) + encrypted_data */
147 782 : if (buf_len < 24) return -1;
148 :
149 782 : TlReader r = tl_reader_init(buf, buf_len);
150 782 : uint64_t recv_auth_key_id = tl_read_uint64(&r);
151 :
152 : uint8_t msg_key[16];
153 782 : tl_read_raw(&r, msg_key, 16);
154 :
155 782 : size_t cipher_len = buf_len - 24;
156 782 : const uint8_t *cipher = buf + 24;
157 :
158 : /* Verify auth_key_id: must equal SHA256(auth_key)[24:32]. */
159 : uint8_t key_hash[32];
160 782 : crypto_sha256(s->auth_key, 256, key_hash);
161 : uint64_t expected_auth_key_id;
162 782 : memcpy(&expected_auth_key_id, key_hash + 24, 8);
163 782 : if (recv_auth_key_id != expected_auth_key_id) {
164 1 : logger_log(LOG_ERROR,
165 : "rpc_recv_encrypted: auth_key_id mismatch "
166 : "(got %016llx, expected %016llx) — dropping frame",
167 : (unsigned long long)recv_auth_key_id,
168 : (unsigned long long)expected_auth_key_id);
169 1 : return -1;
170 : }
171 :
172 : /* Decrypt (heap-allocated) */
173 781 : RAII_STRING uint8_t *decrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
174 781 : if (!decrypted) return -1;
175 781 : size_t dec_len = 0;
176 781 : int rc = mtproto_decrypt(cipher, cipher_len, s->auth_key, msg_key, 8,
177 : decrypted, &dec_len);
178 781 : if (rc != 0) return -1;
179 :
180 : /* Parse plaintext: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
181 781 : if (dec_len < 32) return -1;
182 :
183 781 : TlReader pr = tl_reader_init(decrypted, dec_len);
184 781 : tl_read_uint64(&pr); /* salt */
185 781 : uint64_t recv_session_id = tl_read_uint64(&pr);
186 :
187 : /* Verify session_id against the local session. */
188 781 : if (recv_session_id != s->session_id) {
189 3 : logger_log(LOG_ERROR,
190 : "rpc_recv_encrypted: session_id mismatch "
191 : "(got %016llx, expected %016llx) — dropping frame",
192 : (unsigned long long)recv_session_id,
193 3 : (unsigned long long)s->session_id);
194 3 : return -1;
195 : }
196 778 : tl_read_uint64(&pr); /* msg_id */
197 778 : tl_read_uint32(&pr); /* seq_no */
198 778 : uint32_t data_len = tl_read_uint32(&pr);
199 :
200 778 : if (data_len > max_len) return -1;
201 778 : if (data_len > 0) {
202 778 : tl_read_raw(&pr, out, data_len);
203 : }
204 778 : *out_len = data_len;
205 : /* buf and decrypted freed automatically by RAII_STRING */
206 778 : return 0;
207 : }
208 :
209 : /* ---- gzip_packed unwrap ---- */
210 :
211 736 : int rpc_unwrap_gzip(const uint8_t *data, size_t len,
212 : uint8_t *out, size_t max_len, size_t *out_len) {
213 736 : if (!data || !out || !out_len) return -1;
214 733 : if (len < 4) {
215 : /* Too short to contain a constructor — copy as-is */
216 1 : if (len > max_len) return -1;
217 1 : memcpy(out, data, len);
218 1 : *out_len = len;
219 1 : return 0;
220 : }
221 :
222 : /* Check for gzip_packed constructor */
223 : uint32_t constructor;
224 732 : memcpy(&constructor, data, 4);
225 :
226 732 : if (constructor != CRC_gzip_packed) {
227 : /* Not gzip_packed — copy unchanged */
228 723 : if (len > max_len) return -1;
229 722 : memcpy(out, data, len);
230 722 : *out_len = len;
231 722 : return 0;
232 : }
233 :
234 : /* Parse: constructor(4) + bytes(TL-encoded compressed data) */
235 9 : TlReader r = tl_reader_init(data, len);
236 9 : tl_read_uint32(&r); /* skip constructor */
237 :
238 9 : size_t gz_len = 0;
239 18 : RAII_STRING uint8_t *gz_data = tl_read_bytes(&r, &gz_len);
240 9 : if (!gz_data || gz_len == 0) return -1;
241 :
242 : /* Decompress with tinf */
243 8 : unsigned int dest_len = (unsigned int)max_len;
244 8 : int rc = tinf_gzip_uncompress(out, &dest_len,
245 : gz_data, (unsigned int)gz_len);
246 : /* gz_data freed automatically by RAII_STRING */
247 :
248 8 : if (rc != TINF_OK) return -1;
249 :
250 5 : *out_len = dest_len;
251 5 : return 0;
252 : }
253 :
254 : /* ---- msg_container parse ---- */
255 :
256 22 : int rpc_parse_container(const uint8_t *data, size_t len,
257 : RpcContainerMsg *msgs, size_t max_msgs,
258 : size_t *count) {
259 22 : if (!data || !msgs || !count) return -1;
260 19 : if (len < 4) return -1;
261 :
262 : uint32_t constructor;
263 19 : memcpy(&constructor, data, 4);
264 :
265 19 : if (constructor != CRC_msg_container) {
266 : /* Not a container — return as single message */
267 3 : if (max_msgs < 1) return -1;
268 3 : msgs[0].msg_id = 0;
269 3 : msgs[0].seqno = 0;
270 3 : msgs[0].body_len = (uint32_t)len;
271 3 : msgs[0].body = data;
272 3 : *count = 1;
273 3 : return 0;
274 : }
275 :
276 : /* Parse: constructor(4) + count(4) + messages[] */
277 16 : TlReader r = tl_reader_init(data, len);
278 16 : tl_read_uint32(&r); /* skip constructor */
279 16 : uint32_t msg_count = tl_read_uint32(&r);
280 :
281 16 : if (msg_count > max_msgs) return -1;
282 :
283 33 : for (uint32_t i = 0; i < msg_count; i++) {
284 21 : if (!tl_reader_ok(&r)) return -1;
285 :
286 21 : msgs[i].msg_id = tl_read_uint64(&r);
287 21 : msgs[i].seqno = tl_read_uint32(&r);
288 21 : msgs[i].body_len = tl_read_uint32(&r);
289 :
290 : /* TL-serialized bodies are always 4-byte aligned. An unaligned
291 : * body_len would misalign subsequent message reads in the container,
292 : * silently producing garbage data. Reject malformed containers. */
293 21 : if (msgs[i].body_len % 4 != 0) {
294 3 : logger_log(LOG_WARN,
295 : "rpc_parse_container: unaligned body_len=%u "
296 : "at message %u (must be multiple of 4)",
297 3 : msgs[i].body_len, i);
298 3 : return -1;
299 : }
300 :
301 18 : if (msgs[i].body_len > len - r.pos) return -1;
302 :
303 18 : msgs[i].body = data + r.pos;
304 18 : tl_read_skip(&r, msgs[i].body_len);
305 : }
306 :
307 12 : *count = msg_count;
308 12 : return 0;
309 : }
310 :
311 : /* ---- rpc_result unwrap ---- */
312 :
313 734 : int rpc_unwrap_result(const uint8_t *data, size_t len,
314 : uint64_t *req_msg_id,
315 : const uint8_t **inner, size_t *inner_len) {
316 734 : if (!data || !req_msg_id || !inner || !inner_len) return -1;
317 734 : if (len < 12) return -1; /* constructor(4) + msg_id(8) minimum */
318 :
319 : uint32_t constructor;
320 695 : memcpy(&constructor, data, 4);
321 695 : if (constructor != CRC_rpc_result) return -1;
322 :
323 611 : memcpy(req_msg_id, data + 4, 8);
324 611 : *inner = data + 12;
325 611 : *inner_len = len - 12;
326 611 : return 0;
327 : }
328 :
329 : /* ---- rpc_error parse ---- */
330 :
331 : /** Extract trailing integer from error message (e.g. "FLOOD_WAIT_30" → 30). */
332 42 : static int extract_trailing_int(const char *msg) {
333 42 : const char *p = msg + strlen(msg);
334 93 : while (p > msg && p[-1] >= '0' && p[-1] <= '9') p--;
335 42 : if (*p == '\0') return 0;
336 42 : return atoi(p);
337 : }
338 :
339 125 : int rpc_parse_error(const uint8_t *data, size_t len, RpcError *err) {
340 125 : if (!data || !err) return -1;
341 123 : if (len < 4) return -1;
342 :
343 : uint32_t constructor;
344 123 : memcpy(&constructor, data, 4);
345 123 : if (constructor != CRC_rpc_error) return -1;
346 :
347 122 : TlReader r = tl_reader_init(data, len);
348 122 : tl_read_uint32(&r); /* skip constructor */
349 :
350 122 : err->error_code = tl_read_int32(&r);
351 :
352 244 : RAII_STRING char *msg = tl_read_string(&r);
353 122 : if (!msg) {
354 0 : memset(err->error_msg, 0, sizeof(err->error_msg));
355 0 : err->migrate_dc = -1;
356 0 : err->flood_wait_secs = 0;
357 0 : return 0;
358 : }
359 :
360 : /* Copy message (truncate if needed) */
361 122 : size_t msg_len = strlen(msg);
362 122 : if (msg_len >= sizeof(err->error_msg))
363 0 : msg_len = sizeof(err->error_msg) - 1;
364 122 : memcpy(err->error_msg, msg, msg_len);
365 122 : err->error_msg[msg_len] = '\0';
366 :
367 : /* Parse derived fields */
368 122 : err->migrate_dc = -1;
369 122 : err->flood_wait_secs = 0;
370 :
371 122 : if (strncmp(msg, "PHONE_MIGRATE_", 14) == 0 ||
372 109 : strncmp(msg, "FILE_MIGRATE_", 13) == 0 ||
373 95 : strncmp(msg, "NETWORK_MIGRATE_", 16) == 0 ||
374 91 : strncmp(msg, "USER_MIGRATE_", 13) == 0) {
375 33 : err->migrate_dc = extract_trailing_int(msg);
376 89 : } else if (strncmp(msg, "FLOOD_WAIT_", 11) == 0) {
377 9 : err->flood_wait_secs = extract_trailing_int(msg);
378 : }
379 :
380 : /* msg freed automatically by RAII_STRING */
381 122 : return 0;
382 : }
|