LCOV - code coverage report
Current view: top level - tests/functional/pty - pty_tel_stub.c (source / functions) Coverage Total Hit
Test: coverage.info Lines: 90.6 % 234 212
Test Date: 2026-04-20 19:54:22 Functions: 100.0 % 14 14

            Line data    Source code
       1              : /**
       2              :  * @file pty_tel_stub.c
       3              :  * @brief Minimal TCP MTProto stub server implementation — see pty_tel_stub.h.
       4              :  */
       5              : 
       6              : #define _DEFAULT_SOURCE
       7              : #define _POSIX_C_SOURCE 200809L
       8              : 
       9              : #include "pty_tel_stub.h"
      10              : 
      11              : /* Project headers (available because the CMakeLists links tg-proto + tg-app) */
      12              : #include "crypto.h"
      13              : #include "mtproto_crypto.h"
      14              : #include "mtproto_session.h"
      15              : #include "tl_serial.h"
      16              : #include "app/session_store.h"
      17              : 
      18              : #include <sys/socket.h>
      19              : #include <netinet/in.h>
      20              : #include <arpa/inet.h>
      21              : #include <unistd.h>
      22              : #include <errno.h>
      23              : #include <stdio.h>
      24              : #include <stdlib.h>
      25              : #include <string.h>
      26              : #include <time.h>
      27              : 
      28              : /* ── TL / MTProto constants ─────────────────────────────────────────── */
      29              : 
      30              : #define CRC_invokeWithLayer       0xda9b0d0dU
      31              : #define CRC_initConnection        0xc1cd5ea9U
      32              : #define CRC_rpc_result            0xf35c6d01U
      33              : #define CRC_rpc_error             0x2144ca19U
      34              : #define CRC_messages_getDialogs   0xa0f4cb4fU
      35              : #define CRC_updates_getState      0xedd4882aU
      36              : #define CRC_updates_getDifference 0x19c2f763U
      37              : #define TL_messages_dialogs       0x15ba6c40U
      38              : #define TL_vector                 0x1cb5c415U
      39              : 
      40              : #define FRAME_MAX  (256 * 1024)
      41              : #define AUTH_KEY_SIZE PTY_STUB_AUTH_KEY_SIZE
      42              : 
      43              : /* ── Internal per-connection context ────────────────────────────────── */
      44              : 
      45              : typedef struct {
      46              :     PtyTelStub *stub;
      47              :     int         fd;
      48              :     uint64_t    next_msg_id;
      49              :     uint32_t    seq_no;
      50              : } ConnCtx;
      51              : 
      52              : /* ── Helpers ─────────────────────────────────────────────────────────── */
      53              : 
      54           12 : static uint64_t derive_auth_key_id(const uint8_t *key) {
      55              :     uint8_t hash[32];
      56           12 :     crypto_sha256(key, AUTH_KEY_SIZE, hash);
      57           12 :     uint64_t id = 0;
      58          108 :     for (int i = 0; i < 8; ++i) id |= ((uint64_t)hash[24 + i]) << (i * 8);
      59           12 :     return id;
      60              : }
      61              : 
      62           10 : static uint64_t make_server_msg_id(ConnCtx *c) {
      63           10 :     uint64_t now = (uint64_t)time(NULL) << 32;
      64           10 :     if (now <= c->next_msg_id) now = c->next_msg_id + 4;
      65           10 :     now &= ~((uint64_t)3);
      66           10 :     now |= 1;
      67           10 :     c->next_msg_id = now;
      68           10 :     return now;
      69              : }
      70              : 
      71              : /* Blocking full-read from fd */
      72           34 : static int read_all(int fd, uint8_t *buf, size_t len) {
      73           34 :     size_t done = 0;
      74           61 :     while (done < len) {
      75           34 :         ssize_t r = recv(fd, buf + done, len - done, 0);
      76           34 :         if (r <= 0) return -1;
      77           27 :         done += (size_t)r;
      78              :     }
      79           27 :     return 0;
      80              : }
      81              : 
      82              : /* Blocking full-write to fd */
      83           20 : static int write_all(int fd, const uint8_t *buf, size_t len) {
      84           20 :     size_t done = 0;
      85           40 :     while (done < len) {
      86           20 :         ssize_t r = send(fd, buf + done, len - done, 0);
      87           20 :         if (r <= 0) return -1;
      88           20 :         done += (size_t)r;
      89              :     }
      90           20 :     return 0;
      91              : }
      92              : 
      93              : /* ── Frame I/O ───────────────────────────────────────────────────────── */
      94              : 
      95              : /** Read one abridged-framed MTProto payload from fd into *buf_out (malloc'd). */
      96           17 : static int read_frame(int fd, uint8_t **buf_out, size_t *len_out) {
      97              :     uint8_t first;
      98           17 :     if (read_all(fd, &first, 1) != 0) return -1;
      99              : 
     100              :     size_t units;
     101           10 :     if (first < 0x7F) {
     102           10 :         units = first;
     103              :     } else {
     104              :         uint8_t extra[3];
     105            0 :         if (read_all(fd, extra, 3) != 0) return -1;
     106            0 :         units = (size_t)extra[0] | ((size_t)extra[1] << 8) | ((size_t)extra[2] << 16);
     107              :     }
     108           10 :     size_t len = units * 4;
     109           10 :     if (len == 0 || len > FRAME_MAX) return -1;
     110              : 
     111           10 :     uint8_t *buf = (uint8_t *)malloc(len);
     112           10 :     if (!buf) return -1;
     113           10 :     if (read_all(fd, buf, len) != 0) { free(buf); return -1; }
     114              : 
     115           10 :     *buf_out = buf;
     116           10 :     *len_out = len;
     117           10 :     return 0;
     118              : }
     119              : 
     120              : /** Write a plaintext body as an abridged-framed encrypted MTProto message. */
     121           10 : static int send_frame(ConnCtx *c, const uint8_t *body, size_t body_len) {
     122              :     /* Build plaintext wrapper */
     123              :     TlWriter plain;
     124           10 :     tl_writer_init(&plain);
     125           10 :     tl_write_uint64(&plain, c->stub->server_salt);
     126           10 :     tl_write_uint64(&plain, c->stub->session_id);
     127           10 :     tl_write_uint64(&plain, make_server_msg_id(c));
     128           10 :     c->seq_no += 2;
     129           10 :     tl_write_uint32(&plain, c->seq_no - 1);
     130           10 :     tl_write_uint32(&plain, (uint32_t)body_len);
     131           10 :     tl_write_raw(&plain, body, body_len);
     132              : 
     133              :     /* Encrypt */
     134              :     uint8_t msg_key[16];
     135           10 :     size_t enc_len = 0;
     136           10 :     uint8_t *enc = (uint8_t *)malloc(plain.len + 1024);
     137           10 :     if (!enc) { tl_writer_free(&plain); return -1; }
     138           10 :     mtproto_encrypt(plain.data, plain.len, c->stub->auth_key, 8,
     139              :                     enc, &enc_len, msg_key);
     140           10 :     tl_writer_free(&plain);
     141              : 
     142              :     /* Build wire frame: auth_key_id(8) + msg_key(16) + enc */
     143           10 :     size_t wire_len = 8 + 16 + enc_len;
     144           10 :     uint8_t *wire = (uint8_t *)malloc(wire_len);
     145           10 :     if (!wire) { free(enc); return -1; }
     146           90 :     for (int i = 0; i < 8; ++i)
     147           80 :         wire[i] = (uint8_t)((c->stub->auth_key_id >> (i * 8)) & 0xFF);
     148           10 :     memcpy(wire + 8, msg_key, 16);
     149           10 :     memcpy(wire + 24, enc, enc_len);
     150           10 :     free(enc);
     151              : 
     152              :     /* Abridged length prefix */
     153           10 :     size_t units = wire_len / 4;
     154              :     int rc;
     155           10 :     if (units < 0x7F) {
     156           10 :         uint8_t p = (uint8_t)units;
     157           10 :         rc = write_all(c->fd, &p, 1);
     158              :     } else {
     159            0 :         uint8_t p[4] = { 0x7F,
     160            0 :             (uint8_t)(units & 0xFF),
     161            0 :             (uint8_t)((units >> 8) & 0xFF),
     162            0 :             (uint8_t)((units >> 16) & 0xFF) };
     163            0 :         rc = write_all(c->fd, p, 4);
     164              :     }
     165           10 :     if (rc == 0) rc = write_all(c->fd, wire, wire_len);
     166           10 :     free(wire);
     167           10 :     return rc;
     168              : }
     169              : 
     170              : /* ── RPC helpers ─────────────────────────────────────────────────────── */
     171              : 
     172           10 : static void reply_rpc_result(ConnCtx *c, uint64_t req_msg_id,
     173              :                               const uint8_t *result, size_t result_len) {
     174           10 :     size_t total = 4 + 8 + result_len;
     175           10 :     uint8_t *wrapped = (uint8_t *)malloc(total);
     176           10 :     if (!wrapped) return;
     177           10 :     wrapped[0] = (uint8_t)(CRC_rpc_result);
     178           10 :     wrapped[1] = (uint8_t)(CRC_rpc_result >> 8);
     179           10 :     wrapped[2] = (uint8_t)(CRC_rpc_result >> 16);
     180           10 :     wrapped[3] = (uint8_t)(CRC_rpc_result >> 24);
     181           90 :     for (int i = 0; i < 8; ++i)
     182           80 :         wrapped[4 + i] = (uint8_t)((req_msg_id >> (i * 8)) & 0xFF);
     183           10 :     memcpy(wrapped + 12, result, result_len);
     184           10 :     send_frame(c, wrapped, total);
     185           10 :     free(wrapped);
     186              : }
     187              : 
     188            5 : static void reply_error(ConnCtx *c, uint64_t req_msg_id,
     189              :                          int32_t code, const char *msg) {
     190              :     TlWriter w;
     191            5 :     tl_writer_init(&w);
     192            5 :     tl_write_uint32(&w, CRC_rpc_error);
     193            5 :     tl_write_int32(&w, code);
     194            5 :     tl_write_string(&w, msg ? msg : "");
     195            5 :     reply_rpc_result(c, req_msg_id, w.data, w.len);
     196            5 :     tl_writer_free(&w);
     197            5 : }
     198              : 
     199              : /** Respond with empty messages.dialogs. */
     200            5 : static void reply_empty_dialogs(ConnCtx *c, uint64_t req_msg_id) {
     201              :     TlWriter w;
     202            5 :     tl_writer_init(&w);
     203            5 :     tl_write_uint32(&w, TL_messages_dialogs);
     204              :     /* dialogs: Vector<Dialog> — empty */
     205            5 :     tl_write_uint32(&w, TL_vector);
     206            5 :     tl_write_uint32(&w, 0);
     207              :     /* messages: Vector<Message> — empty */
     208            5 :     tl_write_uint32(&w, TL_vector);
     209            5 :     tl_write_uint32(&w, 0);
     210              :     /* chats: Vector<Chat> — empty */
     211            5 :     tl_write_uint32(&w, TL_vector);
     212            5 :     tl_write_uint32(&w, 0);
     213              :     /* users: Vector<User> — empty */
     214            5 :     tl_write_uint32(&w, TL_vector);
     215            5 :     tl_write_uint32(&w, 0);
     216            5 :     reply_rpc_result(c, req_msg_id, w.data, w.len);
     217            5 :     tl_writer_free(&w);
     218            5 : }
     219              : 
     220              : /* ── Frame dispatcher ────────────────────────────────────────────────── */
     221              : 
     222           10 : static void dispatch(ConnCtx *c, uint64_t req_msg_id,
     223              :                      const uint8_t *body, size_t body_len) {
     224           10 :     if (body_len < 4) return;
     225              : 
     226              :     /* Skip invokeWithLayer / initConnection wrappers (up to 3 levels deep) */
     227           10 :     const uint8_t *cur = body;
     228           10 :     size_t remaining = body_len;
     229           30 :     for (int depth = 0; depth < 3; ++depth) {
     230           30 :         if (remaining < 4) return;
     231           30 :         uint32_t crc = (uint32_t)cur[0] | ((uint32_t)cur[1] << 8)
     232           30 :                      | ((uint32_t)cur[2] << 16) | ((uint32_t)cur[3] << 24);
     233           30 :         if (crc == CRC_invokeWithLayer) {
     234           10 :             if (remaining < 8) return;
     235           10 :             cur += 8; remaining -= 8; /* CRC(4) + layer(4) */
     236           10 :             continue;
     237              :         }
     238           20 :         if (crc == CRC_initConnection) {
     239              :             /* Skip: CRC(4) flags(4) api_id(4) 6×string query */
     240           10 :             if (remaining < 16) return;
     241           10 :             size_t skip = 12; /* CRC + flags + api_id */
     242              :             /* Skip 6 TL strings (device_model, sys_version, app_version,
     243              :              * system_lang_code, lang_pack, lang_code) */
     244           70 :             for (int s = 0; s < 6 && skip < remaining; ++s) {
     245           60 :                 uint8_t slen = cur[skip++];
     246           60 :                 if (slen == 0xFE) {
     247            0 :                     if (skip + 3 > remaining) return;
     248            0 :                     size_t n3 = (size_t)cur[skip] | ((size_t)cur[skip+1] << 8)
     249            0 :                               | ((size_t)cur[skip+2] << 16);
     250            0 :                     skip += 3 + n3;
     251              :                     /* align to 4 bytes from start of string prefix */
     252            0 :                     size_t total_str = 4 + n3;
     253            0 :                     size_t pad = (4 - (total_str % 4)) % 4;
     254            0 :                     skip += pad;
     255              :                 } else {
     256           60 :                     skip += (size_t)slen;
     257              :                     /* align: total = 1 + slen, pad to multiple of 4 */
     258           60 :                     size_t total_str = 1 + (size_t)slen;
     259           60 :                     size_t pad = (4 - (total_str % 4)) % 4;
     260           60 :                     skip += pad;
     261              :                 }
     262              :             }
     263           10 :             if (skip > remaining) return;
     264           10 :             cur += skip; remaining -= skip;
     265           10 :             continue;
     266              :         }
     267           10 :         break;
     268              :     }
     269              : 
     270           10 :     if (remaining < 4) return;
     271           10 :     uint32_t inner_crc = (uint32_t)cur[0] | ((uint32_t)cur[1] << 8)
     272           10 :                        | ((uint32_t)cur[2] << 16) | ((uint32_t)cur[3] << 24);
     273              : 
     274           10 :     if (inner_crc == CRC_messages_getDialogs) {
     275            5 :         reply_empty_dialogs(c, req_msg_id);
     276            5 :     } else if (inner_crc == CRC_updates_getState) {
     277            5 :         reply_error(c, req_msg_id, 400, "STUB_NOT_SUPPORTED");
     278            0 :     } else if (inner_crc == CRC_updates_getDifference) {
     279            0 :         reply_error(c, req_msg_id, 400, "STUB_NOT_SUPPORTED");
     280              :     } else {
     281              :         /* Unknown RPC — return an error so the client doesn't hang on recv. */
     282              :         char msg[64];
     283            0 :         snprintf(msg, sizeof(msg), "STUB_UNKNOWN_%08x", inner_crc);
     284            0 :         reply_error(c, req_msg_id, 500, msg);
     285              :     }
     286              : }
     287              : 
     288              : /* ── Connection handler ──────────────────────────────────────────────── */
     289              : 
     290            7 : static void handle_connection(ConnCtx *c) {
     291              :     /* 1. Read the 0xEF abridged marker. */
     292              :     uint8_t marker;
     293            7 :     if (read_all(c->fd, &marker, 1) != 0 || marker != 0xEF) return;
     294              : 
     295              :     /* 2. Read encrypted frames until the connection closes or we decide to stop. */
     296           10 :     for (;;) {
     297           17 :         uint8_t *frame = NULL;
     298           17 :         size_t frame_len = 0;
     299           17 :         if (read_frame(c->fd, &frame, &frame_len) != 0) break;
     300              : 
     301              :         /* Validate auth_key_id */
     302           10 :         if (frame_len < 24) { free(frame); break; }
     303           10 :         uint64_t key_id = 0;
     304           90 :         for (int i = 0; i < 8; ++i) key_id |= ((uint64_t)frame[i]) << (i * 8);
     305           10 :         if (key_id != c->stub->auth_key_id) { free(frame); continue; }
     306              : 
     307              :         /* Decrypt */
     308              :         uint8_t msg_key[16];
     309           10 :         memcpy(msg_key, frame + 8, 16);
     310           10 :         const uint8_t *cipher = frame + 24;
     311           10 :         size_t cipher_len = frame_len - 24;
     312           10 :         uint8_t *plain = (uint8_t *)malloc(cipher_len);
     313           10 :         if (!plain) { free(frame); break; }
     314           10 :         size_t plain_len = 0;
     315           10 :         int rc = mtproto_decrypt(cipher, cipher_len,
     316           10 :                                   c->stub->auth_key, msg_key, 0,
     317              :                                   plain, &plain_len);
     318           10 :         free(frame);
     319           10 :         if (rc != 0) { free(plain); continue; }
     320              : 
     321              :         /* Extract msg_id and body from plaintext header:
     322              :          * salt(8) + session_id(8) + msg_id(8) + seq_no(4) + body_len(4) + body */
     323           10 :         if (plain_len < 32) { free(plain); continue; }
     324           10 :         uint64_t msg_id = 0;
     325           90 :         for (int i = 0; i < 8; ++i) msg_id |= ((uint64_t)plain[16 + i]) << (i * 8);
     326           10 :         uint32_t body_len = 0;
     327           50 :         for (int i = 0; i < 4; ++i) body_len |= ((uint32_t)plain[28 + i]) << (i * 8);
     328           10 :         if (32 + body_len > plain_len) { free(plain); continue; }
     329              : 
     330           10 :         dispatch(c, msg_id, plain + 32, body_len);
     331           10 :         free(plain);
     332              :     }
     333              : }
     334              : 
     335              : /* ── Server thread ───────────────────────────────────────────────────── */
     336              : 
     337           12 : static void *server_thread(void *arg) {
     338           12 :     PtyTelStub *s = (PtyTelStub *)arg;
     339              : 
     340           12 :     struct timeval tv = { .tv_sec = 6, .tv_usec = 0 };
     341           12 :     int client_fd = accept(s->listen_fd, NULL, NULL);
     342            7 :     if (client_fd < 0) return NULL;
     343              : 
     344              :     /* 6-second read timeout on the client socket so the thread exits cleanly
     345              :      * when the binary quits and stops sending. */
     346            7 :     setsockopt(client_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
     347              : 
     348            7 :     ConnCtx ctx = {
     349              :         .stub        = s,
     350              :         .fd          = client_fd,
     351              :         .next_msg_id = 0,
     352              :         .seq_no      = 0,
     353              :     };
     354            7 :     handle_connection(&ctx);
     355            7 :     close(client_fd);
     356            7 :     return NULL;
     357              : }
     358              : 
     359              : /* ── Public API ──────────────────────────────────────────────────────── */
     360              : 
     361           12 : int pty_tel_stub_start(PtyTelStub *s) {
     362           12 :     if (!s) return -1;
     363           12 :     memset(s, 0, sizeof(*s));
     364           12 :     s->listen_fd = -1;
     365              : 
     366              :     /* Deterministic auth_key — identical to mt_server_seed_session(). */
     367         3084 :     for (int i = 0; i < AUTH_KEY_SIZE; ++i)
     368         3072 :         s->auth_key[i] = (uint8_t)((i * 31 + 7) & 0xFF);
     369           12 :     s->auth_key_id = derive_auth_key_id(s->auth_key);
     370           12 :     s->server_salt = 0xABCDEF0123456789ULL;
     371           12 :     s->session_id  = 0x1122334455667788ULL;
     372              : 
     373              :     /* Seed session.bin — caller must have set HOME to a tmp dir first. */
     374              :     MtProtoSession ms;
     375           12 :     mtproto_session_init(&ms);
     376           12 :     mtproto_session_set_auth_key(&ms, s->auth_key);
     377           12 :     mtproto_session_set_salt(&ms, s->server_salt);
     378           12 :     ms.session_id = s->session_id;
     379           12 :     if (session_store_save(&ms, 2) != 0) return -1;
     380              : 
     381              :     /* Bind to an OS-assigned port on localhost. */
     382           12 :     s->listen_fd = socket(AF_INET, SOCK_STREAM, 0);
     383           12 :     if (s->listen_fd < 0) return -1;
     384              : 
     385           12 :     int opt = 1;
     386           12 :     setsockopt(s->listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
     387              : 
     388              :     struct sockaddr_in addr;
     389           12 :     memset(&addr, 0, sizeof(addr));
     390           12 :     addr.sin_family      = AF_INET;
     391           12 :     addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
     392           12 :     addr.sin_port        = 0; /* OS assigns */
     393              : 
     394           12 :     if (bind(s->listen_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
     395            0 :         close(s->listen_fd); s->listen_fd = -1; return -1;
     396              :     }
     397           12 :     if (listen(s->listen_fd, 1) < 0) {
     398            0 :         close(s->listen_fd); s->listen_fd = -1; return -1;
     399              :     }
     400              : 
     401              :     /* Discover the assigned port. */
     402           12 :     socklen_t len = sizeof(addr);
     403           12 :     if (getsockname(s->listen_fd, (struct sockaddr *)&addr, &len) < 0) {
     404            0 :         close(s->listen_fd); s->listen_fd = -1; return -1;
     405              :     }
     406           12 :     s->port = ntohs(addr.sin_port);
     407              : 
     408              :     /* Set accept() timeout so the thread doesn't block forever if no client
     409              :      * ever connects. */
     410           12 :     struct timeval tv = { .tv_sec = 8, .tv_usec = 0 };
     411           12 :     setsockopt(s->listen_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
     412              : 
     413           12 :     s->running = 1;
     414           12 :     if (pthread_create(&s->thread, NULL, server_thread, s) != 0) {
     415            0 :         close(s->listen_fd); s->listen_fd = -1; return -1;
     416              :     }
     417           12 :     return 0;
     418              : }
     419              : 
     420            7 : void pty_tel_stub_stop(PtyTelStub *s) {
     421            7 :     if (!s || !s->running) return;
     422            7 :     s->running = 0;
     423            7 :     if (s->listen_fd >= 0) {
     424            7 :         close(s->listen_fd);
     425            7 :         s->listen_fd = -1;
     426              :     }
     427            7 :     pthread_join(s->thread, NULL);
     428              : }
        

Generated by: LCOV version 2.0-1