LCOV - code coverage report
Current view: top level - src/infrastructure - mtproto_rpc.c (source / functions) Coverage Total Hit
Test: coverage-functional.info Lines: 94.1 % 186 175
Test Date: 2026-04-20 19:54:24 Functions: 100.0 % 9 9

            Line data    Source code
       1              : /* SPDX-License-Identifier: GPL-3.0-or-later */
       2              : /* Copyright 2026 Peter Csaszar */
       3              : 
       4              : /**
       5              :  * @file mtproto_rpc.c
       6              :  * @brief MTProto RPC framework — message framing, encryption, decryption.
       7              :  */
       8              : 
       9              : #include "mtproto_rpc.h"
      10              : #include "mtproto_crypto.h"
      11              : #include "tl_serial.h"
      12              : #include "crypto.h"
      13              : #include "tinf.h"
      14              : #include "logger.h"
      15              : #include "raii.h"
      16              : 
      17              : #include <stdlib.h>
      18              : #include <string.h>
      19              : 
      20              : /* Must be large enough to hold a full encrypted frame, including
      21              :  * the 512 KiB upload chunk used by saveBigFilePart + the ~56-byte
      22              :  * MTProto plaintext header + the outer auth_key_id / msg_key.
      23              :  * 1 MiB leaves plenty of headroom. */
      24              : #define RPC_BUF_SIZE (1024 * 1024)
      25              : 
      26              : #define CRC_gzip_packed    0x3072cfa1
      27              : #define CRC_msg_container  0x73f1f8dc
      28              : #define CRC_rpc_result     0xf35c6d01
      29              : #define CRC_rpc_error      0x2144ca19
      30              : 
      31              : /* ---- Unencrypted messages ---- */
      32              : 
      33           15 : int rpc_send_unencrypted(MtProtoSession *s, Transport *t,
      34              :                          const uint8_t *data, size_t len) {
      35           15 :     if (!s || !t || !data) return -1;
      36              : 
      37           15 :     uint64_t msg_id = mtproto_session_next_msg_id(s);
      38              : 
      39              :     /* Wire format: auth_key_id(8) + msg_id(8) + len(4) + data */
      40              :     TlWriter w;
      41           15 :     tl_writer_init(&w);
      42           15 :     tl_write_uint64(&w, 0);            /* auth_key_id = 0 */
      43           15 :     tl_write_uint64(&w, msg_id);
      44           15 :     tl_write_uint32(&w, (uint32_t)len);
      45           15 :     tl_write_raw(&w, data, len);
      46              : 
      47           15 :     int rc = transport_send(t, w.data, w.len);
      48           15 :     tl_writer_free(&w);
      49           15 :     return rc;
      50              : }
      51              : 
      52           14 : int rpc_recv_unencrypted(MtProtoSession *s, Transport *t,
      53              :                          uint8_t *out, size_t max_len, size_t *out_len) {
      54           14 :     if (!s || !t || !out || !out_len) return -1;
      55              : 
      56              :     /* Read one transport packet (heap-allocated to avoid stack overflow) */
      57           14 :     RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
      58           14 :     if (!buf) return -1;
      59           14 :     size_t buf_len = 0;
      60           14 :     if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
      61              : 
      62              :     /* Parse: auth_key_id(8) + msg_id(8) + len(4) + data */
      63           13 :     if (buf_len < 20) return -1;
      64              : 
      65           13 :     TlReader r = tl_reader_init(buf, buf_len);
      66           13 :     uint64_t auth_key_id = tl_read_uint64(&r);
      67              :     (void)auth_key_id; /* should be 0 */
      68           13 :     uint64_t msg_id = tl_read_uint64(&r);
      69              :     (void)msg_id;
      70           13 :     uint32_t data_len = tl_read_uint32(&r);
      71              : 
      72           13 :     if (data_len > max_len) return -1;
      73           13 :     if (data_len > 0) {
      74           13 :         tl_read_raw(&r, out, data_len);
      75              :     }
      76           13 :     *out_len = data_len;
      77           13 :     return 0;
      78              : }
      79              : 
      80              : /* ---- Encrypted messages ---- */
      81              : 
      82          313 : int rpc_send_encrypted(MtProtoSession *s, Transport *t,
      83              :                        const uint8_t *data, size_t len,
      84              :                        int content_related) {
      85          313 :     if (!s || !t || !data || !s->has_auth_key) return -1;
      86              : 
      87          313 :     uint64_t msg_id = mtproto_session_next_msg_id(s);
      88          313 :     uint32_t seq_no = mtproto_session_next_seq_no(s, content_related);
      89              : 
      90              :     /* Build plaintext payload: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
      91              :     TlWriter plain;
      92          313 :     tl_writer_init(&plain);
      93          313 :     tl_write_uint64(&plain, s->server_salt);
      94          313 :     tl_write_uint64(&plain, s->session_id);
      95          313 :     tl_write_uint64(&plain, msg_id);
      96          313 :     tl_write_uint32(&plain, seq_no);
      97          313 :     tl_write_uint32(&plain, (uint32_t)len);
      98          313 :     tl_write_raw(&plain, data, len);
      99              : 
     100              :     /* Encrypt with MTProto crypto. The msg_key is computed over the PADDED
     101              :      * plaintext inside mtproto_encrypt and returned via msg_key_out — the
     102              :      * spec requires that the wire msg_key match what was used to derive
     103              :      * the AES keys, so we cannot compute it independently here (the
     104              :      * padding bytes are random and generated inside mtproto_encrypt). */
     105          313 :     RAII_STRING uint8_t *encrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
     106          313 :     if (!encrypted) { tl_writer_free(&plain); return -1; }
     107          313 :     size_t enc_len = 0;
     108              :     uint8_t msg_key[16];
     109          313 :     mtproto_encrypt(plain.data, plain.len, s->auth_key, 0,
     110              :                     encrypted, &enc_len, msg_key);
     111          313 :     tl_writer_free(&plain);
     112              : 
     113              :     /* Wire format: auth_key_id(8) + msg_key(16) + encrypted_data */
     114              :     TlWriter wire;
     115          313 :     tl_writer_init(&wire);
     116              : 
     117              :     /* auth_key_id = last 8 bytes of SHA256(auth_key). Use memcpy rather than
     118              :      * a pointer cast to avoid strict-aliasing UB, and let tl_write_uint64
     119              :      * handle little-endian encoding so we stay correct on big-endian hosts
     120              :      * (QA-11). */
     121              :     uint8_t key_hash[32];
     122          313 :     crypto_sha256(s->auth_key, 256, key_hash);
     123              :     uint64_t auth_key_id;
     124          313 :     memcpy(&auth_key_id, key_hash + 24, 8);
     125          313 :     tl_write_uint64(&wire, auth_key_id);
     126              : 
     127          313 :     tl_write_raw(&wire, msg_key, 16);
     128          313 :     tl_write_raw(&wire, encrypted, enc_len);
     129              :     /* encrypted freed automatically by RAII_STRING */
     130              : 
     131          313 :     int rc = transport_send(t, wire.data, wire.len);
     132          313 :     tl_writer_free(&wire);
     133          313 :     return rc;
     134              : }
     135              : 
     136          323 : int rpc_recv_encrypted(MtProtoSession *s, Transport *t,
     137              :                        uint8_t *out, size_t max_len, size_t *out_len) {
     138          323 :     if (!s || !t || !out || !out_len || !s->has_auth_key) return -1;
     139              : 
     140              :     /* Read one transport packet (heap-allocated) */
     141          323 :     RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
     142          323 :     if (!buf) return -1;
     143          323 :     size_t buf_len = 0;
     144          323 :     if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
     145              : 
     146              :     /* Parse: auth_key_id(8) + msg_key(16) + encrypted_data */
     147          322 :     if (buf_len < 24) return -1;
     148              : 
     149          322 :     TlReader r = tl_reader_init(buf, buf_len);
     150          322 :     uint64_t recv_auth_key_id = tl_read_uint64(&r);
     151              : 
     152              :     uint8_t msg_key[16];
     153          322 :     tl_read_raw(&r, msg_key, 16);
     154              : 
     155          322 :     size_t cipher_len = buf_len - 24;
     156          322 :     const uint8_t *cipher = buf + 24;
     157              : 
     158              :     /* Verify auth_key_id: must equal SHA256(auth_key)[24:32]. */
     159              :     uint8_t key_hash[32];
     160          322 :     crypto_sha256(s->auth_key, 256, key_hash);
     161              :     uint64_t expected_auth_key_id;
     162          322 :     memcpy(&expected_auth_key_id, key_hash + 24, 8);
     163          322 :     if (recv_auth_key_id != expected_auth_key_id) {
     164            0 :         logger_log(LOG_ERROR,
     165              :                    "rpc_recv_encrypted: auth_key_id mismatch "
     166              :                    "(got %016llx, expected %016llx) — dropping frame",
     167              :                    (unsigned long long)recv_auth_key_id,
     168              :                    (unsigned long long)expected_auth_key_id);
     169            0 :         return -1;
     170              :     }
     171              : 
     172              :     /* Decrypt (heap-allocated) */
     173          322 :     RAII_STRING uint8_t *decrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
     174          322 :     if (!decrypted) return -1;
     175          322 :     size_t dec_len = 0;
     176          322 :     int rc = mtproto_decrypt(cipher, cipher_len, s->auth_key, msg_key, 8,
     177              :                              decrypted, &dec_len);
     178          322 :     if (rc != 0) return -1;
     179              : 
     180              :     /* Parse plaintext: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
     181          322 :     if (dec_len < 32) return -1;
     182              : 
     183          322 :     TlReader pr = tl_reader_init(decrypted, dec_len);
     184          322 :     tl_read_uint64(&pr); /* salt */
     185          322 :     uint64_t recv_session_id = tl_read_uint64(&pr);
     186              : 
     187              :     /* Verify session_id against the local session. */
     188          322 :     if (recv_session_id != s->session_id) {
     189            1 :         logger_log(LOG_ERROR,
     190              :                    "rpc_recv_encrypted: session_id mismatch "
     191              :                    "(got %016llx, expected %016llx) — dropping frame",
     192              :                    (unsigned long long)recv_session_id,
     193            1 :                    (unsigned long long)s->session_id);
     194            1 :         return -1;
     195              :     }
     196          321 :     tl_read_uint64(&pr); /* msg_id */
     197          321 :     tl_read_uint32(&pr); /* seq_no */
     198          321 :     uint32_t data_len = tl_read_uint32(&pr);
     199              : 
     200          321 :     if (data_len > max_len) return -1;
     201          321 :     if (data_len > 0) {
     202          321 :         tl_read_raw(&pr, out, data_len);
     203              :     }
     204          321 :     *out_len = data_len;
     205              :     /* buf and decrypted freed automatically by RAII_STRING */
     206          321 :     return 0;
     207              : }
     208              : 
     209              : /* ---- gzip_packed unwrap ---- */
     210              : 
     211          298 : int rpc_unwrap_gzip(const uint8_t *data, size_t len,
     212              :                     uint8_t *out, size_t max_len, size_t *out_len) {
     213          298 :     if (!data || !out || !out_len) return -1;
     214          298 :     if (len < 4) {
     215              :         /* Too short to contain a constructor — copy as-is */
     216            0 :         if (len > max_len) return -1;
     217            0 :         memcpy(out, data, len);
     218            0 :         *out_len = len;
     219            0 :         return 0;
     220              :     }
     221              : 
     222              :     /* Check for gzip_packed constructor */
     223              :     uint32_t constructor;
     224          298 :     memcpy(&constructor, data, 4);
     225              : 
     226          298 :     if (constructor != CRC_gzip_packed) {
     227              :         /* Not gzip_packed — copy unchanged */
     228          295 :         if (len > max_len) return -1;
     229          295 :         memcpy(out, data, len);
     230          295 :         *out_len = len;
     231          295 :         return 0;
     232              :     }
     233              : 
     234              :     /* Parse: constructor(4) + bytes(TL-encoded compressed data) */
     235            3 :     TlReader r = tl_reader_init(data, len);
     236            3 :     tl_read_uint32(&r); /* skip constructor */
     237              : 
     238            3 :     size_t gz_len = 0;
     239            6 :     RAII_STRING uint8_t *gz_data = tl_read_bytes(&r, &gz_len);
     240            3 :     if (!gz_data || gz_len == 0) return -1;
     241              : 
     242              :     /* Decompress with tinf */
     243            3 :     unsigned int dest_len = (unsigned int)max_len;
     244            3 :     int rc = tinf_gzip_uncompress(out, &dest_len,
     245              :                                    gz_data, (unsigned int)gz_len);
     246              :     /* gz_data freed automatically by RAII_STRING */
     247              : 
     248            3 :     if (rc != TINF_OK) return -1;
     249              : 
     250            2 :     *out_len = dest_len;
     251            2 :     return 0;
     252              : }
     253              : 
     254              : /* ---- msg_container parse ---- */
     255              : 
     256            7 : int rpc_parse_container(const uint8_t *data, size_t len,
     257              :                         RpcContainerMsg *msgs, size_t max_msgs,
     258              :                         size_t *count) {
     259            7 :     if (!data || !msgs || !count) return -1;
     260            7 :     if (len < 4) return -1;
     261              : 
     262              :     uint32_t constructor;
     263            7 :     memcpy(&constructor, data, 4);
     264              : 
     265            7 :     if (constructor != CRC_msg_container) {
     266              :         /* Not a container — return as single message */
     267            1 :         if (max_msgs < 1) return -1;
     268            1 :         msgs[0].msg_id = 0;
     269            1 :         msgs[0].seqno = 0;
     270            1 :         msgs[0].body_len = (uint32_t)len;
     271            1 :         msgs[0].body = data;
     272            1 :         *count = 1;
     273            1 :         return 0;
     274              :     }
     275              : 
     276              :     /* Parse: constructor(4) + count(4) + messages[] */
     277            6 :     TlReader r = tl_reader_init(data, len);
     278            6 :     tl_read_uint32(&r); /* skip constructor */
     279            6 :     uint32_t msg_count = tl_read_uint32(&r);
     280              : 
     281            6 :     if (msg_count > max_msgs) return -1;
     282              : 
     283           13 :     for (uint32_t i = 0; i < msg_count; i++) {
     284            8 :         if (!tl_reader_ok(&r)) return -1;
     285              : 
     286            8 :         msgs[i].msg_id = tl_read_uint64(&r);
     287            8 :         msgs[i].seqno = tl_read_uint32(&r);
     288            8 :         msgs[i].body_len = tl_read_uint32(&r);
     289              : 
     290              :         /* TL-serialized bodies are always 4-byte aligned. An unaligned
     291              :          * body_len would misalign subsequent message reads in the container,
     292              :          * silently producing garbage data. Reject malformed containers. */
     293            8 :         if (msgs[i].body_len % 4 != 0) {
     294            1 :             logger_log(LOG_WARN,
     295              :                        "rpc_parse_container: unaligned body_len=%u "
     296              :                        "at message %u (must be multiple of 4)",
     297            1 :                        msgs[i].body_len, i);
     298            1 :             return -1;
     299              :         }
     300              : 
     301            7 :         if (msgs[i].body_len > len - r.pos) return -1;
     302              : 
     303            7 :         msgs[i].body = data + r.pos;
     304            7 :         tl_read_skip(&r, msgs[i].body_len);
     305              :     }
     306              : 
     307            5 :     *count = msg_count;
     308            5 :     return 0;
     309              : }
     310              : 
     311              : /* ---- rpc_result unwrap ---- */
     312              : 
     313          301 : int rpc_unwrap_result(const uint8_t *data, size_t len,
     314              :                       uint64_t *req_msg_id,
     315              :                       const uint8_t **inner, size_t *inner_len) {
     316          301 :     if (!data || !req_msg_id || !inner || !inner_len) return -1;
     317          301 :     if (len < 12) return -1; /* constructor(4) + msg_id(8) minimum */
     318              : 
     319              :     uint32_t constructor;
     320          301 :     memcpy(&constructor, data, 4);
     321          301 :     if (constructor != CRC_rpc_result) return -1;
     322              : 
     323          301 :     memcpy(req_msg_id, data + 4, 8);
     324          301 :     *inner = data + 12;
     325          301 :     *inner_len = len - 12;
     326          301 :     return 0;
     327              : }
     328              : 
     329              : /* ---- rpc_error parse ---- */
     330              : 
     331              : /** Extract trailing integer from error message (e.g. "FLOOD_WAIT_30" → 30). */
     332           18 : static int extract_trailing_int(const char *msg) {
     333           18 :     const char *p = msg + strlen(msg);
     334           39 :     while (p > msg && p[-1] >= '0' && p[-1] <= '9') p--;
     335           18 :     if (*p == '\0') return 0;
     336           18 :     return atoi(p);
     337              : }
     338              : 
     339           48 : int rpc_parse_error(const uint8_t *data, size_t len, RpcError *err) {
     340           48 :     if (!data || !err) return -1;
     341           48 :     if (len < 4) return -1;
     342              : 
     343              :     uint32_t constructor;
     344           48 :     memcpy(&constructor, data, 4);
     345           48 :     if (constructor != CRC_rpc_error) return -1;
     346              : 
     347           48 :     TlReader r = tl_reader_init(data, len);
     348           48 :     tl_read_uint32(&r); /* skip constructor */
     349              : 
     350           48 :     err->error_code = tl_read_int32(&r);
     351              : 
     352           96 :     RAII_STRING char *msg = tl_read_string(&r);
     353           48 :     if (!msg) {
     354            0 :         memset(err->error_msg, 0, sizeof(err->error_msg));
     355            0 :         err->migrate_dc = -1;
     356            0 :         err->flood_wait_secs = 0;
     357            0 :         return 0;
     358              :     }
     359              : 
     360              :     /* Copy message (truncate if needed) */
     361           48 :     size_t msg_len = strlen(msg);
     362           48 :     if (msg_len >= sizeof(err->error_msg))
     363            0 :         msg_len = sizeof(err->error_msg) - 1;
     364           48 :     memcpy(err->error_msg, msg, msg_len);
     365           48 :     err->error_msg[msg_len] = '\0';
     366              : 
     367              :     /* Parse derived fields */
     368           48 :     err->migrate_dc = -1;
     369           48 :     err->flood_wait_secs = 0;
     370              : 
     371           48 :     if (strncmp(msg, "PHONE_MIGRATE_", 14) == 0 ||
     372           42 :         strncmp(msg, "FILE_MIGRATE_", 13) == 0 ||
     373           36 :         strncmp(msg, "NETWORK_MIGRATE_", 16) == 0 ||
     374           34 :         strncmp(msg, "USER_MIGRATE_", 13) == 0) {
     375           15 :         err->migrate_dc = extract_trailing_int(msg);
     376           33 :     } else if (strncmp(msg, "FLOOD_WAIT_", 11) == 0) {
     377            3 :         err->flood_wait_secs = extract_trailing_int(msg);
     378              :     }
     379              : 
     380              :     /* msg freed automatically by RAII_STRING */
     381           48 :     return 0;
     382              : }
        

Generated by: LCOV version 2.0-1