LCOV - code coverage report
Current view: top level - src/infrastructure - mtproto_rpc.c (source / functions) Coverage Total Hit
Test: coverage.info Lines: 97.3 % 186 181
Test Date: 2026-04-20 19:54:22 Functions: 100.0 % 9 9

            Line data    Source code
       1              : /* SPDX-License-Identifier: GPL-3.0-or-later */
       2              : /* Copyright 2026 Peter Csaszar */
       3              : 
       4              : /**
       5              :  * @file mtproto_rpc.c
       6              :  * @brief MTProto RPC framework — message framing, encryption, decryption.
       7              :  */
       8              : 
       9              : #include "mtproto_rpc.h"
      10              : #include "mtproto_crypto.h"
      11              : #include "tl_serial.h"
      12              : #include "crypto.h"
      13              : #include "tinf.h"
      14              : #include "logger.h"
      15              : #include "raii.h"
      16              : 
      17              : #include <stdlib.h>
      18              : #include <string.h>
      19              : 
      20              : /* Must be large enough to hold a full encrypted frame, including
      21              :  * the 512 KiB upload chunk used by saveBigFilePart + the ~56-byte
      22              :  * MTProto plaintext header + the outer auth_key_id / msg_key.
      23              :  * 1 MiB leaves plenty of headroom. */
      24              : #define RPC_BUF_SIZE (1024 * 1024)
      25              : 
      26              : #define CRC_gzip_packed    0x3072cfa1
      27              : #define CRC_msg_container  0x73f1f8dc
      28              : #define CRC_rpc_result     0xf35c6d01
      29              : #define CRC_rpc_error      0x2144ca19
      30              : 
      31              : /* ---- Unencrypted messages ---- */
      32              : 
      33           53 : int rpc_send_unencrypted(MtProtoSession *s, Transport *t,
      34              :                          const uint8_t *data, size_t len) {
      35           53 :     if (!s || !t || !data) return -1;
      36              : 
      37           50 :     uint64_t msg_id = mtproto_session_next_msg_id(s);
      38              : 
      39              :     /* Wire format: auth_key_id(8) + msg_id(8) + len(4) + data */
      40              :     TlWriter w;
      41           50 :     tl_writer_init(&w);
      42           50 :     tl_write_uint64(&w, 0);            /* auth_key_id = 0 */
      43           50 :     tl_write_uint64(&w, msg_id);
      44           50 :     tl_write_uint32(&w, (uint32_t)len);
      45           50 :     tl_write_raw(&w, data, len);
      46              : 
      47           50 :     int rc = transport_send(t, w.data, w.len);
      48           50 :     tl_writer_free(&w);
      49           50 :     return rc;
      50              : }
      51              : 
      52           54 : int rpc_recv_unencrypted(MtProtoSession *s, Transport *t,
      53              :                          uint8_t *out, size_t max_len, size_t *out_len) {
      54           54 :     if (!s || !t || !out || !out_len) return -1;
      55              : 
      56              :     /* Read one transport packet (heap-allocated to avoid stack overflow) */
      57           54 :     RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
      58           54 :     if (!buf) return -1;
      59           54 :     size_t buf_len = 0;
      60           54 :     if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
      61              : 
      62              :     /* Parse: auth_key_id(8) + msg_id(8) + len(4) + data */
      63           50 :     if (buf_len < 20) return -1;
      64              : 
      65           50 :     TlReader r = tl_reader_init(buf, buf_len);
      66           50 :     uint64_t auth_key_id = tl_read_uint64(&r);
      67              :     (void)auth_key_id; /* should be 0 */
      68           50 :     uint64_t msg_id = tl_read_uint64(&r);
      69              :     (void)msg_id;
      70           50 :     uint32_t data_len = tl_read_uint32(&r);
      71              : 
      72           50 :     if (data_len > max_len) return -1;
      73           50 :     if (data_len > 0) {
      74           50 :         tl_read_raw(&r, out, data_len);
      75              :     }
      76           50 :     *out_len = data_len;
      77           50 :     return 0;
      78              : }
      79              : 
      80              : /* ---- Encrypted messages ---- */
      81              : 
      82          763 : int rpc_send_encrypted(MtProtoSession *s, Transport *t,
      83              :                        const uint8_t *data, size_t len,
      84              :                        int content_related) {
      85          763 :     if (!s || !t || !data || !s->has_auth_key) return -1;
      86              : 
      87          763 :     uint64_t msg_id = mtproto_session_next_msg_id(s);
      88          763 :     uint32_t seq_no = mtproto_session_next_seq_no(s, content_related);
      89              : 
      90              :     /* Build plaintext payload: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
      91              :     TlWriter plain;
      92          763 :     tl_writer_init(&plain);
      93          763 :     tl_write_uint64(&plain, s->server_salt);
      94          763 :     tl_write_uint64(&plain, s->session_id);
      95          763 :     tl_write_uint64(&plain, msg_id);
      96          763 :     tl_write_uint32(&plain, seq_no);
      97          763 :     tl_write_uint32(&plain, (uint32_t)len);
      98          763 :     tl_write_raw(&plain, data, len);
      99              : 
     100              :     /* Encrypt with MTProto crypto. The msg_key is computed over the PADDED
     101              :      * plaintext inside mtproto_encrypt and returned via msg_key_out — the
     102              :      * spec requires that the wire msg_key match what was used to derive
     103              :      * the AES keys, so we cannot compute it independently here (the
     104              :      * padding bytes are random and generated inside mtproto_encrypt). */
     105          763 :     RAII_STRING uint8_t *encrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
     106          763 :     if (!encrypted) { tl_writer_free(&plain); return -1; }
     107          763 :     size_t enc_len = 0;
     108              :     uint8_t msg_key[16];
     109          763 :     mtproto_encrypt(plain.data, plain.len, s->auth_key, 0,
     110              :                     encrypted, &enc_len, msg_key);
     111          763 :     tl_writer_free(&plain);
     112              : 
     113              :     /* Wire format: auth_key_id(8) + msg_key(16) + encrypted_data */
     114              :     TlWriter wire;
     115          763 :     tl_writer_init(&wire);
     116              : 
     117              :     /* auth_key_id = last 8 bytes of SHA256(auth_key). Use memcpy rather than
     118              :      * a pointer cast to avoid strict-aliasing UB, and let tl_write_uint64
     119              :      * handle little-endian encoding so we stay correct on big-endian hosts
     120              :      * (QA-11). */
     121              :     uint8_t key_hash[32];
     122          763 :     crypto_sha256(s->auth_key, 256, key_hash);
     123              :     uint64_t auth_key_id;
     124          763 :     memcpy(&auth_key_id, key_hash + 24, 8);
     125          763 :     tl_write_uint64(&wire, auth_key_id);
     126              : 
     127          763 :     tl_write_raw(&wire, msg_key, 16);
     128          763 :     tl_write_raw(&wire, encrypted, enc_len);
     129              :     /* encrypted freed automatically by RAII_STRING */
     130              : 
     131          763 :     int rc = transport_send(t, wire.data, wire.len);
     132          763 :     tl_writer_free(&wire);
     133          763 :     return rc;
     134              : }
     135              : 
     136          787 : int rpc_recv_encrypted(MtProtoSession *s, Transport *t,
     137              :                        uint8_t *out, size_t max_len, size_t *out_len) {
     138          787 :     if (!s || !t || !out || !out_len || !s->has_auth_key) return -1;
     139              : 
     140              :     /* Read one transport packet (heap-allocated) */
     141          787 :     RAII_STRING uint8_t *buf = (uint8_t *)malloc(RPC_BUF_SIZE);
     142          787 :     if (!buf) return -1;
     143          787 :     size_t buf_len = 0;
     144          787 :     if (transport_recv(t, buf, RPC_BUF_SIZE, &buf_len) != 0) return -1;
     145              : 
     146              :     /* Parse: auth_key_id(8) + msg_key(16) + encrypted_data */
     147          782 :     if (buf_len < 24) return -1;
     148              : 
     149          782 :     TlReader r = tl_reader_init(buf, buf_len);
     150          782 :     uint64_t recv_auth_key_id = tl_read_uint64(&r);
     151              : 
     152              :     uint8_t msg_key[16];
     153          782 :     tl_read_raw(&r, msg_key, 16);
     154              : 
     155          782 :     size_t cipher_len = buf_len - 24;
     156          782 :     const uint8_t *cipher = buf + 24;
     157              : 
     158              :     /* Verify auth_key_id: must equal SHA256(auth_key)[24:32]. */
     159              :     uint8_t key_hash[32];
     160          782 :     crypto_sha256(s->auth_key, 256, key_hash);
     161              :     uint64_t expected_auth_key_id;
     162          782 :     memcpy(&expected_auth_key_id, key_hash + 24, 8);
     163          782 :     if (recv_auth_key_id != expected_auth_key_id) {
     164            1 :         logger_log(LOG_ERROR,
     165              :                    "rpc_recv_encrypted: auth_key_id mismatch "
     166              :                    "(got %016llx, expected %016llx) — dropping frame",
     167              :                    (unsigned long long)recv_auth_key_id,
     168              :                    (unsigned long long)expected_auth_key_id);
     169            1 :         return -1;
     170              :     }
     171              : 
     172              :     /* Decrypt (heap-allocated) */
     173          781 :     RAII_STRING uint8_t *decrypted = (uint8_t *)malloc(RPC_BUF_SIZE);
     174          781 :     if (!decrypted) return -1;
     175          781 :     size_t dec_len = 0;
     176          781 :     int rc = mtproto_decrypt(cipher, cipher_len, s->auth_key, msg_key, 8,
     177              :                              decrypted, &dec_len);
     178          781 :     if (rc != 0) return -1;
     179              : 
     180              :     /* Parse plaintext: salt(8) + session_id(8) + msg_id(8) + seq_no(4) + len(4) + data */
     181          781 :     if (dec_len < 32) return -1;
     182              : 
     183          781 :     TlReader pr = tl_reader_init(decrypted, dec_len);
     184          781 :     tl_read_uint64(&pr); /* salt */
     185          781 :     uint64_t recv_session_id = tl_read_uint64(&pr);
     186              : 
     187              :     /* Verify session_id against the local session. */
     188          781 :     if (recv_session_id != s->session_id) {
     189            3 :         logger_log(LOG_ERROR,
     190              :                    "rpc_recv_encrypted: session_id mismatch "
     191              :                    "(got %016llx, expected %016llx) — dropping frame",
     192              :                    (unsigned long long)recv_session_id,
     193            3 :                    (unsigned long long)s->session_id);
     194            3 :         return -1;
     195              :     }
     196          778 :     tl_read_uint64(&pr); /* msg_id */
     197          778 :     tl_read_uint32(&pr); /* seq_no */
     198          778 :     uint32_t data_len = tl_read_uint32(&pr);
     199              : 
     200          778 :     if (data_len > max_len) return -1;
     201          778 :     if (data_len > 0) {
     202          778 :         tl_read_raw(&pr, out, data_len);
     203              :     }
     204          778 :     *out_len = data_len;
     205              :     /* buf and decrypted freed automatically by RAII_STRING */
     206          778 :     return 0;
     207              : }
     208              : 
     209              : /* ---- gzip_packed unwrap ---- */
     210              : 
     211          736 : int rpc_unwrap_gzip(const uint8_t *data, size_t len,
     212              :                     uint8_t *out, size_t max_len, size_t *out_len) {
     213          736 :     if (!data || !out || !out_len) return -1;
     214          733 :     if (len < 4) {
     215              :         /* Too short to contain a constructor — copy as-is */
     216            1 :         if (len > max_len) return -1;
     217            1 :         memcpy(out, data, len);
     218            1 :         *out_len = len;
     219            1 :         return 0;
     220              :     }
     221              : 
     222              :     /* Check for gzip_packed constructor */
     223              :     uint32_t constructor;
     224          732 :     memcpy(&constructor, data, 4);
     225              : 
     226          732 :     if (constructor != CRC_gzip_packed) {
     227              :         /* Not gzip_packed — copy unchanged */
     228          723 :         if (len > max_len) return -1;
     229          722 :         memcpy(out, data, len);
     230          722 :         *out_len = len;
     231          722 :         return 0;
     232              :     }
     233              : 
     234              :     /* Parse: constructor(4) + bytes(TL-encoded compressed data) */
     235            9 :     TlReader r = tl_reader_init(data, len);
     236            9 :     tl_read_uint32(&r); /* skip constructor */
     237              : 
     238            9 :     size_t gz_len = 0;
     239           18 :     RAII_STRING uint8_t *gz_data = tl_read_bytes(&r, &gz_len);
     240            9 :     if (!gz_data || gz_len == 0) return -1;
     241              : 
     242              :     /* Decompress with tinf */
     243            8 :     unsigned int dest_len = (unsigned int)max_len;
     244            8 :     int rc = tinf_gzip_uncompress(out, &dest_len,
     245              :                                    gz_data, (unsigned int)gz_len);
     246              :     /* gz_data freed automatically by RAII_STRING */
     247              : 
     248            8 :     if (rc != TINF_OK) return -1;
     249              : 
     250            5 :     *out_len = dest_len;
     251            5 :     return 0;
     252              : }
     253              : 
     254              : /* ---- msg_container parse ---- */
     255              : 
     256           22 : int rpc_parse_container(const uint8_t *data, size_t len,
     257              :                         RpcContainerMsg *msgs, size_t max_msgs,
     258              :                         size_t *count) {
     259           22 :     if (!data || !msgs || !count) return -1;
     260           19 :     if (len < 4) return -1;
     261              : 
     262              :     uint32_t constructor;
     263           19 :     memcpy(&constructor, data, 4);
     264              : 
     265           19 :     if (constructor != CRC_msg_container) {
     266              :         /* Not a container — return as single message */
     267            3 :         if (max_msgs < 1) return -1;
     268            3 :         msgs[0].msg_id = 0;
     269            3 :         msgs[0].seqno = 0;
     270            3 :         msgs[0].body_len = (uint32_t)len;
     271            3 :         msgs[0].body = data;
     272            3 :         *count = 1;
     273            3 :         return 0;
     274              :     }
     275              : 
     276              :     /* Parse: constructor(4) + count(4) + messages[] */
     277           16 :     TlReader r = tl_reader_init(data, len);
     278           16 :     tl_read_uint32(&r); /* skip constructor */
     279           16 :     uint32_t msg_count = tl_read_uint32(&r);
     280              : 
     281           16 :     if (msg_count > max_msgs) return -1;
     282              : 
     283           33 :     for (uint32_t i = 0; i < msg_count; i++) {
     284           21 :         if (!tl_reader_ok(&r)) return -1;
     285              : 
     286           21 :         msgs[i].msg_id = tl_read_uint64(&r);
     287           21 :         msgs[i].seqno = tl_read_uint32(&r);
     288           21 :         msgs[i].body_len = tl_read_uint32(&r);
     289              : 
     290              :         /* TL-serialized bodies are always 4-byte aligned. An unaligned
     291              :          * body_len would misalign subsequent message reads in the container,
     292              :          * silently producing garbage data. Reject malformed containers. */
     293           21 :         if (msgs[i].body_len % 4 != 0) {
     294            3 :             logger_log(LOG_WARN,
     295              :                        "rpc_parse_container: unaligned body_len=%u "
     296              :                        "at message %u (must be multiple of 4)",
     297            3 :                        msgs[i].body_len, i);
     298            3 :             return -1;
     299              :         }
     300              : 
     301           18 :         if (msgs[i].body_len > len - r.pos) return -1;
     302              : 
     303           18 :         msgs[i].body = data + r.pos;
     304           18 :         tl_read_skip(&r, msgs[i].body_len);
     305              :     }
     306              : 
     307           12 :     *count = msg_count;
     308           12 :     return 0;
     309              : }
     310              : 
     311              : /* ---- rpc_result unwrap ---- */
     312              : 
     313          734 : int rpc_unwrap_result(const uint8_t *data, size_t len,
     314              :                       uint64_t *req_msg_id,
     315              :                       const uint8_t **inner, size_t *inner_len) {
     316          734 :     if (!data || !req_msg_id || !inner || !inner_len) return -1;
     317          734 :     if (len < 12) return -1; /* constructor(4) + msg_id(8) minimum */
     318              : 
     319              :     uint32_t constructor;
     320          695 :     memcpy(&constructor, data, 4);
     321          695 :     if (constructor != CRC_rpc_result) return -1;
     322              : 
     323          611 :     memcpy(req_msg_id, data + 4, 8);
     324          611 :     *inner = data + 12;
     325          611 :     *inner_len = len - 12;
     326          611 :     return 0;
     327              : }
     328              : 
     329              : /* ---- rpc_error parse ---- */
     330              : 
     331              : /** Extract trailing integer from error message (e.g. "FLOOD_WAIT_30" → 30). */
     332           42 : static int extract_trailing_int(const char *msg) {
     333           42 :     const char *p = msg + strlen(msg);
     334           93 :     while (p > msg && p[-1] >= '0' && p[-1] <= '9') p--;
     335           42 :     if (*p == '\0') return 0;
     336           42 :     return atoi(p);
     337              : }
     338              : 
     339          125 : int rpc_parse_error(const uint8_t *data, size_t len, RpcError *err) {
     340          125 :     if (!data || !err) return -1;
     341          123 :     if (len < 4) return -1;
     342              : 
     343              :     uint32_t constructor;
     344          123 :     memcpy(&constructor, data, 4);
     345          123 :     if (constructor != CRC_rpc_error) return -1;
     346              : 
     347          122 :     TlReader r = tl_reader_init(data, len);
     348          122 :     tl_read_uint32(&r); /* skip constructor */
     349              : 
     350          122 :     err->error_code = tl_read_int32(&r);
     351              : 
     352          244 :     RAII_STRING char *msg = tl_read_string(&r);
     353          122 :     if (!msg) {
     354            0 :         memset(err->error_msg, 0, sizeof(err->error_msg));
     355            0 :         err->migrate_dc = -1;
     356            0 :         err->flood_wait_secs = 0;
     357            0 :         return 0;
     358              :     }
     359              : 
     360              :     /* Copy message (truncate if needed) */
     361          122 :     size_t msg_len = strlen(msg);
     362          122 :     if (msg_len >= sizeof(err->error_msg))
     363            0 :         msg_len = sizeof(err->error_msg) - 1;
     364          122 :     memcpy(err->error_msg, msg, msg_len);
     365          122 :     err->error_msg[msg_len] = '\0';
     366              : 
     367              :     /* Parse derived fields */
     368          122 :     err->migrate_dc = -1;
     369          122 :     err->flood_wait_secs = 0;
     370              : 
     371          122 :     if (strncmp(msg, "PHONE_MIGRATE_", 14) == 0 ||
     372          109 :         strncmp(msg, "FILE_MIGRATE_", 13) == 0 ||
     373           95 :         strncmp(msg, "NETWORK_MIGRATE_", 16) == 0 ||
     374           91 :         strncmp(msg, "USER_MIGRATE_", 13) == 0) {
     375           33 :         err->migrate_dc = extract_trailing_int(msg);
     376           89 :     } else if (strncmp(msg, "FLOOD_WAIT_", 11) == 0) {
     377            9 :         err->flood_wait_secs = extract_trailing_int(msg);
     378              :     }
     379              : 
     380              :     /* msg freed automatically by RAII_STRING */
     381          122 :     return 0;
     382              : }
        

Generated by: LCOV version 2.0-1