Line data Source code
1 : /* SPDX-License-Identifier: GPL-3.0-or-later */
2 : /* Copyright 2026 Peter Csaszar */
3 :
4 : /**
5 : * @file mtproto_rpc.c
6 : * @brief MTProto RPC framework — message framing, encryption, decryption.
7 : */
8 :
9 : #include "mtproto_rpc.h"
10 : #include "mtproto_crypto.h"
11 : #include "tl_serial.h"
12 : #include "crypto.h"
13 : #include "tinf.h"
14 : #include "logger.h"
15 : #include "raii.h"
16 :
17 : #include <stdlib.h>
18 : #include <string.h>
19 :
20 : /* Must be large enough to hold a full encrypted frame, including
21 : * the 512 KiB upload chunk used by saveBigFilePart + the ~56-byte
22 : * MTProto plaintext header + the outer auth_key_id / msg_key.
23 : * 1 MiB leaves plenty of headroom. */
24 : #define RPC_BUF_SIZE (1024 * 1024)
25 :
26 : #define CRC_gzip_packed 0x3072cfa1
27 : #define CRC_msg_container 0x73f1f8dc
28 : #define CRC_rpc_result 0xf35c6d01
29 : #define CRC_rpc_error 0x2144ca19
30 :
31 : /* ---- Unencrypted messages ---- */
32 :
33 15 : int rpc_send_unencrypted(MtProtoSession *s, Transport *t,
34 : const uint8_t *data, size_t len) {
35 15 : if (!s || !t || !data) return -1;
36 :
37 15 : uint64_t msg_id = mtproto_session_next_msg_id(s);
38 :
39 : /* Wire format: auth_key_id(8) + msg_id(8) + len(4) + data */
40 : TlWriter w;
41 15 : tl_writer_init(&w);
42 15 : tl_write_uint64(&w, 0); /* auth_key_id = 0 */
43 15 : tl_write_uint64(&w, msg_id);
44 15 : tl_write_uint32(&w, (uint32_t)len);
45 15 : tl_write_raw(&w, data, len);
46 :
47 15 : int rc = transport_send(t, w.data, w.len);
48 15 : tl_writer_free(&w);
49 15 : return rc;
50 : }
51 :
52 14 : int rpc_recv_unencrypted(MtProtoSession *s, Transport *t,
53 : uint8_t *out, size_t max_len, size_t *out_len) {
54 14 : if (!s || !t || !out || !out_len) return -1;
55 :
56 : /* Read one transport packet (heap-allocated to avoid stack overflow) */
57 14 : RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
58 14 : if (!buf) return -1;
59 14 : size_t buf_len = 0;
60 14 : if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
61 :
62 : /* Parse: auth_key_id(8) + msg_id(8) + len(4) + data */
63 13 : if (buf_len < 20) {
64 0 : logger_log(LOG_ERROR,
65 : "rpc_recv_unencrypted: packet too short (%zu bytes, need ≥20)"
66 : " — likely a transport-level error from the server",
67 : buf_len);
68 0 : return -1;
69 : }
70 :
71 13 : TlReader r = tl_reader_init(buf, buf_len);
72 13 : uint64_t auth_key_id = tl_read_uint64(&r);
73 : (void)auth_key_id; /* should be 0 */
74 13 : uint64_t msg_id = tl_read_uint64(&r);
75 : (void)msg_id;
76 13 : uint32_t data_len = tl_read_uint32(&r);
77 :
78 13 : if (data_len > max_len) return -1;
79 13 : if (data_len > 0) {
80 13 : tl_read_raw(&r, out, data_len);
81 : }
82 13 : *out_len = data_len;
83 13 : return 0;
84 : }
85 :
86 : /* ---- Encrypted messages ---- */
87 :
88 313 : int rpc_send_encrypted(MtProtoSession *s, Transport *t,
89 : const uint8_t *data, size_t len,
90 : int content_related) {
91 313 : if (!s || !t || !data || !s->has_auth_key) return -1;
92 :
93 313 : uint64_t msg_id = mtproto_session_next_msg_id(s);
94 313 : uint32_t seq_no = mtproto_session_next_seq_no(s, content_related);
95 :
96 : /* Build plaintext payload: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
97 : TlWriter plain;
98 313 : tl_writer_init(&plain);
99 313 : tl_write_uint64(&plain, s->server_salt);
100 313 : tl_write_uint64(&plain, s->session_id);
101 313 : tl_write_uint64(&plain, msg_id);
102 313 : tl_write_uint32(&plain, seq_no);
103 313 : tl_write_uint32(&plain, (uint32_t)len);
104 313 : tl_write_raw(&plain, data, len);
105 :
106 : /* Encrypt with MTProto crypto. The msg_key is computed over the PADDED
107 : * plaintext inside mtproto_encrypt and returned via msg_key_out — the
108 : * spec requires that the wire msg_key match what was used to derive
109 : * the AES keys, so we cannot compute it independently here (the
110 : * padding bytes are random and generated inside mtproto_encrypt). */
111 313 : RAII_STRING uint8_t *encrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
112 313 : if (!encrypted) { tl_writer_free(&plain); return -1; }
113 313 : size_t enc_len = 0;
114 : uint8_t msg_key[16];
115 313 : mtproto_encrypt(plain.data, plain.len, s->auth_key, 0,
116 : encrypted, &enc_len, msg_key);
117 313 : tl_writer_free(&plain);
118 :
119 : /* Wire format: auth_key_id(8) + msg_key(16) + encrypted_data */
120 : TlWriter wire;
121 313 : tl_writer_init(&wire);
122 :
123 : /* auth_key_id = last 8 bytes of SHA1(auth_key) = SHA1[12:20].
124 : * MTProto spec: "lower 64 bits of the SHA1 fingerprint" (big-endian number,
125 : * so last 8 bytes). Use memcpy to avoid strict-aliasing UB (QA-11). */
126 : uint8_t key_hash[20];
127 313 : crypto_sha1(s->auth_key, 256, key_hash);
128 : uint64_t auth_key_id;
129 313 : memcpy(&auth_key_id, key_hash + 12, 8);
130 313 : tl_write_uint64(&wire, auth_key_id);
131 :
132 313 : tl_write_raw(&wire, msg_key, 16);
133 313 : tl_write_raw(&wire, encrypted, enc_len);
134 : /* encrypted freed automatically by RAII_STRING */
135 :
136 313 : int rc = transport_send(t, wire.data, wire.len);
137 313 : tl_writer_free(&wire);
138 313 : return rc;
139 : }
140 :
141 379 : int rpc_recv_encrypted(MtProtoSession *s, Transport *t,
142 : uint8_t *out, size_t max_len, size_t *out_len) {
143 379 : if (!s || !t || !out || !out_len || !s->has_auth_key) return -1;
144 :
145 : /* Read one transport packet (heap-allocated) */
146 379 : RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
147 379 : if (!buf) return -1;
148 379 : size_t buf_len = 0;
149 379 : if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
150 :
151 : /* Parse: auth_key_id(8) + msg_key(16) + encrypted_data */
152 378 : if (buf_len < 24) return -1;
153 :
154 378 : TlReader r = tl_reader_init(buf, buf_len);
155 378 : uint64_t recv_auth_key_id = tl_read_uint64(&r);
156 :
157 : uint8_t msg_key[16];
158 378 : tl_read_raw(&r, msg_key, 16);
159 :
160 378 : size_t cipher_len = buf_len - 24;
161 378 : const uint8_t *cipher = buf + 24;
162 :
163 : /* Verify auth_key_id: must equal SHA1(auth_key)[12:20]. */
164 : uint8_t key_hash[20];
165 378 : crypto_sha1(s->auth_key, 256, key_hash);
166 : uint64_t expected_auth_key_id;
167 378 : memcpy(&expected_auth_key_id, key_hash + 12, 8);
168 378 : if (recv_auth_key_id != expected_auth_key_id) {
169 0 : logger_log(LOG_ERROR,
170 : "rpc_recv_encrypted: auth_key_id mismatch "
171 : "(got %016llx, expected %016llx) — dropping frame",
172 : (unsigned long long)recv_auth_key_id,
173 : (unsigned long long)expected_auth_key_id);
174 0 : return -1;
175 : }
176 :
177 : /* Decrypt (heap-allocated) */
178 378 : RAII_STRING uint8_t *decrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
179 378 : if (!decrypted) return -1;
180 378 : size_t dec_len = 0;
181 378 : int rc = mtproto_decrypt(cipher, cipher_len, s->auth_key, msg_key, 8,
182 : decrypted, &dec_len);
183 378 : if (rc != 0) return -1;
184 :
185 : /* Parse plaintext: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
186 378 : if (dec_len < 32) return -1;
187 :
188 378 : TlReader pr = tl_reader_init(decrypted, dec_len);
189 378 : tl_read_uint64(&pr); /* salt */
190 378 : uint64_t recv_session_id = tl_read_uint64(&pr);
191 :
192 : /* Verify session_id against the local session. */
193 378 : if (recv_session_id != s->session_id) {
194 1 : logger_log(LOG_ERROR,
195 : "rpc_recv_encrypted: session_id mismatch "
196 : "(got %016llx, expected %016llx) — dropping frame",
197 : (unsigned long long)recv_session_id,
198 1 : (unsigned long long)s->session_id);
199 1 : return -1;
200 : }
201 377 : tl_read_uint64(&pr); /* msg_id */
202 377 : tl_read_uint32(&pr); /* seq_no */
203 377 : uint32_t data_len = tl_read_uint32(&pr);
204 :
205 377 : if (data_len > max_len) return -1;
206 377 : if (data_len > 0) {
207 377 : tl_read_raw(&pr, out, data_len);
208 : }
209 377 : *out_len = data_len;
210 : /* buf and decrypted freed automatically by RAII_STRING */
211 377 : return 0;
212 : }
213 :
214 : /* ---- gzip_packed unwrap ---- */
215 :
216 298 : int rpc_unwrap_gzip(const uint8_t *data, size_t len,
217 : uint8_t *out, size_t max_len, size_t *out_len) {
218 298 : if (!data || !out || !out_len) return -1;
219 298 : if (len < 4) {
220 : /* Too short to contain a constructor — copy as-is */
221 0 : if (len > max_len) return -1;
222 0 : memcpy(out, data, len);
223 0 : *out_len = len;
224 0 : return 0;
225 : }
226 :
227 : /* Check for gzip_packed constructor */
228 : uint32_t constructor;
229 298 : memcpy(&constructor, data, 4);
230 :
231 298 : if (constructor != CRC_gzip_packed) {
232 : /* Not gzip_packed — copy unchanged */
233 295 : if (len > max_len) return -1;
234 295 : memcpy(out, data, len);
235 295 : *out_len = len;
236 295 : return 0;
237 : }
238 :
239 : /* Parse: constructor(4) + bytes(TL-encoded compressed data) */
240 3 : TlReader r = tl_reader_init(data, len);
241 3 : tl_read_uint32(&r); /* skip constructor */
242 :
243 3 : size_t gz_len = 0;
244 6 : RAII_STRING uint8_t *gz_data = tl_read_bytes(&r, &gz_len);
245 3 : if (!gz_data || gz_len == 0) return -1;
246 :
247 : /* Decompress with tinf */
248 3 : unsigned int dest_len = (unsigned int)max_len;
249 3 : int rc = tinf_gzip_uncompress(out, &dest_len,
250 : gz_data, (unsigned int)gz_len);
251 : /* gz_data freed automatically by RAII_STRING */
252 :
253 3 : if (rc != TINF_OK) return -1;
254 :
255 2 : *out_len = dest_len;
256 2 : return 0;
257 : }
258 :
259 : /* ---- msg_container parse ---- */
260 :
261 7 : int rpc_parse_container(const uint8_t *data, size_t len,
262 : RpcContainerMsg *msgs, size_t max_msgs,
263 : size_t *count) {
264 7 : if (!data || !msgs || !count) return -1;
265 7 : if (len < 4) return -1;
266 :
267 : uint32_t constructor;
268 7 : memcpy(&constructor, data, 4);
269 :
270 7 : if (constructor != CRC_msg_container) {
271 : /* Not a container — return as single message */
272 1 : if (max_msgs < 1) return -1;
273 1 : msgs[0].msg_id = 0;
274 1 : msgs[0].seqno = 0;
275 1 : msgs[0].body_len = (uint32_t)len;
276 1 : msgs[0].body = data;
277 1 : *count = 1;
278 1 : return 0;
279 : }
280 :
281 : /* Parse: constructor(4) + count(4) + messages[] */
282 6 : TlReader r = tl_reader_init(data, len);
283 6 : tl_read_uint32(&r); /* skip constructor */
284 6 : uint32_t msg_count = tl_read_uint32(&r);
285 :
286 6 : if (msg_count > max_msgs) return -1;
287 :
288 13 : for (uint32_t i = 0; i < msg_count; i++) {
289 8 : if (!tl_reader_ok(&r)) return -1;
290 :
291 8 : msgs[i].msg_id = tl_read_uint64(&r);
292 8 : msgs[i].seqno = tl_read_uint32(&r);
293 8 : msgs[i].body_len = tl_read_uint32(&r);
294 :
295 : /* TL-serialized bodies are always 4-byte aligned. An unaligned
296 : * body_len would misalign subsequent message reads in the container,
297 : * silently producing garbage data. Reject malformed containers. */
298 8 : if (msgs[i].body_len % 4 != 0) {
299 1 : logger_log(LOG_WARN,
300 : "rpc_parse_container: unaligned body_len=%u "
301 : "at message %u (must be multiple of 4)",
302 1 : msgs[i].body_len, i);
303 1 : return -1;
304 : }
305 :
306 7 : if (msgs[i].body_len > len - r.pos) return -1;
307 :
308 7 : msgs[i].body = data + r.pos;
309 7 : tl_read_skip(&r, msgs[i].body_len);
310 : }
311 :
312 5 : *count = msg_count;
313 5 : return 0;
314 : }
315 :
316 : /* ---- rpc_result unwrap ---- */
317 :
318 301 : int rpc_unwrap_result(const uint8_t *data, size_t len,
319 : uint64_t *req_msg_id,
320 : const uint8_t **inner, size_t *inner_len) {
321 301 : if (!data || !req_msg_id || !inner || !inner_len) return -1;
322 301 : if (len < 12) return -1; /* constructor(4) + msg_id(8) minimum */
323 :
324 : uint32_t constructor;
325 301 : memcpy(&constructor, data, 4);
326 301 : if (constructor != CRC_rpc_result) return -1;
327 :
328 301 : memcpy(req_msg_id, data + 4, 8);
329 301 : *inner = data + 12;
330 301 : *inner_len = len - 12;
331 301 : return 0;
332 : }
333 :
334 : /* ---- rpc_error parse ---- */
335 :
336 : /** Extract trailing integer from error message (e.g. "FLOOD_WAIT_30" → 30). */
337 18 : static int extract_trailing_int(const char *msg) {
338 18 : const char *p = msg + strlen(msg);
339 39 : while (p > msg && p[-1] >= '0' && p[-1] <= '9') p--;
340 18 : if (*p == '\0') return 0;
341 18 : return atoi(p);
342 : }
343 :
344 47 : int rpc_parse_error(const uint8_t *data, size_t len, RpcError *err) {
345 47 : if (!data || !err) return -1;
346 47 : if (len < 4) return -1;
347 :
348 : uint32_t constructor;
349 47 : memcpy(&constructor, data, 4);
350 47 : if (constructor != CRC_rpc_error) return -1;
351 :
352 47 : TlReader r = tl_reader_init(data, len);
353 47 : tl_read_uint32(&r); /* skip constructor */
354 :
355 47 : err->error_code = tl_read_int32(&r);
356 :
357 94 : RAII_STRING char *msg = tl_read_string(&r);
358 47 : if (!msg) {
359 0 : memset(err->error_msg, 0, sizeof(err->error_msg));
360 0 : err->migrate_dc = -1;
361 0 : err->flood_wait_secs = 0;
362 0 : return 0;
363 : }
364 :
365 : /* Copy message (truncate if needed) */
366 47 : size_t msg_len = strlen(msg);
367 47 : if (msg_len >= sizeof(err->error_msg))
368 0 : msg_len = sizeof(err->error_msg) - 1;
369 47 : memcpy(err->error_msg, msg, msg_len);
370 47 : err->error_msg[msg_len] = '\0';
371 :
372 : /* Parse derived fields */
373 47 : err->migrate_dc = -1;
374 47 : err->flood_wait_secs = 0;
375 :
376 47 : if (strncmp(msg, "PHONE_MIGRATE_", 14) == 0 ||
377 41 : strncmp(msg, "FILE_MIGRATE_", 13) == 0 ||
378 35 : strncmp(msg, "NETWORK_MIGRATE_", 16) == 0 ||
379 33 : strncmp(msg, "USER_MIGRATE_", 13) == 0) {
380 15 : err->migrate_dc = extract_trailing_int(msg);
381 32 : } else if (strncmp(msg, "FLOOD_WAIT_", 11) == 0) {
382 3 : err->flood_wait_secs = extract_trailing_int(msg);
383 : }
384 :
385 : /* msg freed automatically by RAII_STRING */
386 47 : return 0;
387 : }
|