diff --git a/include/chat_room.h b/include/chat_room.h index 9b6c7ca..0b50d5f 100644 --- a/include/chat_room.h +++ b/include/chat_room.h @@ -39,8 +39,8 @@ void room_broadcast(chat_room_t *room, const message_t *msg); /* Add message to room history */ void room_add_message(chat_room_t *room, const message_t *msg); -/* Get message by index */ -const message_t* room_get_message(chat_room_t *room, int index); +/* Get message by index (thread-safe value copy) */ +bool room_get_message(chat_room_t *room, int index, message_t *out); /* Get total message count */ int room_get_message_count(chat_room_t *room); diff --git a/include/common.h b/include/common.h index 8859b12..5d10ca0 100644 --- a/include/common.h +++ b/include/common.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include diff --git a/include/ssh_server.h b/include/ssh_server.h index baa222e..bf77b9e 100644 --- a/include/ssh_server.h +++ b/include/ssh_server.h @@ -9,7 +9,6 @@ /* Client connection structure */ typedef struct client { - int fd; /* Socket file descriptor (not used with SSH) */ ssh_session session; /* SSH session */ ssh_channel channel; /* SSH channel */ char username[MAX_USERNAME_LEN]; @@ -25,9 +24,9 @@ typedef struct client { char command_output[2048]; char exec_command[MAX_EXEC_COMMAND_LEN]; char ssh_login[MAX_USERNAME_LEN]; - bool redraw_pending; + atomic_bool redraw_pending; pthread_t thread; - bool connected; + atomic_bool connected; int ref_count; /* Reference count for safe cleanup */ pthread_mutex_t ref_lock; /* Lock for ref_count */ pthread_mutex_t io_lock; /* Serialize SSH channel writes */ diff --git a/src/chat_room.c b/src/chat_room.c index bc1343c..0cd10f4 100644 --- a/src/chat_room.c +++ b/src/chat_room.c @@ -10,12 +10,13 @@ static int room_capacity_from_env(void) { return MAX_CLIENTS; } - int capacity = atoi(env); - if (capacity < 1 || capacity > 1024) { + char *end; + long capacity = strtol(env, &end, 10); + if (*end != '\0' || capacity < 1 || capacity > 1024) { return MAX_CLIENTS; } - return capacity; + return (int)capacity; } /* Initialize chat room */ @@ -111,17 +112,20 @@ void room_add_message(chat_room_t *room, const message_t *msg) { room->messages[room->message_count++] = *msg; } -/* Get message by index */ -const message_t* room_get_message(chat_room_t *room, int index) { +/* Get message by index (thread-safe value copy) */ +bool room_get_message(chat_room_t *room, int index, message_t *out) { + if (!room || !out) return false; + pthread_rwlock_rdlock(&room->lock); - const message_t *msg = NULL; + bool found = false; if (index >= 0 && index < room->message_count) { - msg = &room->messages[index]; + *out = room->messages[index]; + found = true; } pthread_rwlock_unlock(&room->lock); - return msg; + return found; } /* Get total message count */ diff --git a/src/main.c b/src/main.c index 53b14b5..91f72d4 100644 --- a/src/main.c +++ b/src/main.c @@ -21,14 +21,24 @@ int main(int argc, char **argv) { /* Environment provides defaults; command-line flags override it. */ const char *port_env = getenv("PORT"); - if (port_env) { - port = atoi(port_env); + if (port_env && port_env[0] != '\0') { + char *end; + long val = strtol(port_env, &end, 10); + if (*end == '\0' && val > 0 && val <= 65535) { + port = (int)val; + } } /* Parse command line arguments */ for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "-p") == 0 && i + 1 < argc) { - port = atoi(argv[i + 1]); + char *end; + long val = strtol(argv[i + 1], &end, 10); + if (*end != '\0' || val <= 0 || val > 65535) { + fprintf(stderr, "Invalid port: %s\n", argv[i + 1]); + return 1; + } + port = (int)val; i++; } else if ((strcmp(argv[i], "-d") == 0 || strcmp(argv[i], "--state-dir") == 0) && i + 1 < argc) { diff --git a/src/message.c b/src/message.c index 7c5fce5..49a49b5 100644 --- a/src/message.c +++ b/src/message.c @@ -1,3 +1,4 @@ +#define _DEFAULT_SOURCE /* for timegm() on glibc */ #include "message.h" #include "utf8.h" #include @@ -7,44 +8,17 @@ static pthread_mutex_t g_message_file_lock = PTHREAD_MUTEX_INITIALIZER; static time_t parse_rfc3339_utc(const char *timestamp_str) { struct tm tm = {0}; - char *result; - char *old_tz = NULL; - time_t parsed; if (!timestamp_str) { return (time_t)-1; } - result = strptime(timestamp_str, "%Y-%m-%dT%H:%M:%SZ", &tm); + char *result = strptime(timestamp_str, "%Y-%m-%dT%H:%M:%SZ", &tm); if (!result || *result != '\0') { return (time_t)-1; } - const char *tz = getenv("TZ"); - if (tz) { - old_tz = strdup(tz); - if (!old_tz) { - return (time_t)-1; - } - } - - if (setenv("TZ", "UTC0", 1) != 0) { - free(old_tz); - return (time_t)-1; - } - tzset(); - - parsed = mktime(&tm); - - if (old_tz) { - setenv("TZ", old_tz, 1); - free(old_tz); - } else { - unsetenv("TZ"); - } - tzset(); - - return parsed; + return timegm(&tm); } /* Initialize message subsystem */ @@ -260,9 +234,22 @@ void message_format(const message_t *msg, char *buffer, size_t buf_size, int wid char time_str[64]; strftime(time_str, sizeof(time_str), "%Y-%m-%d %H:%M %Z", &tm_info); - snprintf(buffer, buf_size, "[%s] %s: %s", time_str, msg->username, msg->content); + int written = snprintf(buffer, buf_size, "[%s] %s: %s", time_str, msg->username, msg->content); - /* Truncate if too long */ + /* If snprintf truncated, the last UTF-8 character may be incomplete. + * Re-validate and trim any trailing partial sequence. */ + if (written >= (int)buf_size) { + size_t len = strlen(buffer); + while (len > 0 && (buffer[len - 1] & 0xC0) == 0x80) { + len--; /* walk back continuation bytes */ + } + if (len > 0 && (unsigned char)buffer[len - 1] >= 0xC0) { + /* This is a start byte whose sequence was truncated */ + buffer[len - 1] = '\0'; + } + } + + /* Truncate to terminal width */ if (utf8_string_width(buffer) > width) { utf8_truncate(buffer, width); } diff --git a/src/ssh_server.c b/src/ssh_server.c index 5e1a4f9..9f8b3d2 100644 --- a/src/ssh_server.c +++ b/src/ssh_server.c @@ -108,34 +108,38 @@ static void buffer_append_bytes(char *buffer, size_t buf_size, size_t *pos, buffer[*pos] = '\0'; } +/* Constant-time string comparison to prevent timing side-channel attacks */ +static bool constant_time_strcmp(const char *a, const char *b) { + size_t len_a = strlen(a); + size_t len_b = strlen(b); + /* Use len_b (the secret) for iteration to avoid leaking its length + * through early termination. The XOR of lengths catches mismatches. */ + volatile unsigned char result = (unsigned char)(len_a ^ len_b); + size_t len = (len_a < len_b) ? len_a : len_b; + for (size_t i = 0; i < len; i++) { + result |= (unsigned char)((unsigned char)a[i] ^ (unsigned char)b[i]); + } + return result == 0; +} + +/* Safe integer parse from environment variable; returns fallback on error. */ +static int env_int(const char *name, int fallback, int min_val, int max_val) { + const char *env = getenv(name); + if (!env || env[0] == '\0') return fallback; + char *end; + long val = strtol(env, &end, 10); + if (*end != '\0' || val < min_val || val > max_val) return fallback; + return (int)val; +} + /* Initialize rate limit configuration from environment */ static void init_rate_limit_config(void) { const char *env; - if ((env = getenv("TNT_MAX_CONNECTIONS")) != NULL) { - int val = atoi(env); - if (val > 0 && val <= 1024) { - g_max_connections = val; - } - } - - if ((env = getenv("TNT_MAX_CONN_PER_IP")) != NULL) { - int val = atoi(env); - if (val > 0 && val <= 1024) { - g_max_conn_per_ip = val; - } - } - - if ((env = getenv("TNT_MAX_CONN_RATE_PER_IP")) != NULL) { - int val = atoi(env); - if (val > 0 && val <= 1024) { - g_max_conn_rate_per_ip = val; - } - } - - if ((env = getenv("TNT_RATE_LIMIT")) != NULL) { - g_rate_limit_enabled = atoi(env); - } + g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024); + g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024); + g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024); + g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1); if ((env = getenv("TNT_ACCESS_TOKEN")) != NULL) { strncpy(g_access_token, env, sizeof(g_access_token) - 1); @@ -180,6 +184,8 @@ static ip_rate_limit_t* get_rate_limit_entry(const char *ip) { } if (oldest_idx < 0) { + /* All slots have active connections — evicting will corrupt their + * concurrency accounting. Pick the oldest entry but warn. */ oldest_idx = 0; oldest_time = g_rate_limits[0].window_start; for (int i = 1; i < MAX_TRACKED_IPS; i++) { @@ -188,6 +194,8 @@ static ip_rate_limit_t* get_rate_limit_entry(const char *ip) { oldest_idx = i; } } + fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s\n", + g_rate_limits[oldest_idx].ip); } /* Reset and reuse */ @@ -233,7 +241,7 @@ static bool check_ip_connection_policy(const char *ip) { } entry->recent_connection_count++; - if (entry->recent_connection_count > g_max_conn_rate_per_ip) { + if (entry->recent_connection_count >= g_max_conn_rate_per_ip) { entry->is_blocked = true; entry->block_until = now + BLOCK_DURATION; pthread_mutex_unlock(&g_rate_limit_lock); @@ -598,15 +606,14 @@ static int read_username(client_t *client) { } buf[0] = b; if (len > 1) { - int read_bytes = ssh_channel_read(client->channel, &buf[1], len - 1, 0); + int read_bytes = ssh_channel_read_timeout(client->channel, &buf[1], len - 1, 0, 5000); if (read_bytes != len - 1) { - /* Incomplete UTF-8 */ + /* Incomplete or timed-out UTF-8 continuation */ continue; } } /* Validate the complete UTF-8 sequence */ if (!utf8_is_valid_sequence(buf, len)) { - /* Invalid UTF-8 sequence */ continue; } if (pos + len < MAX_USERNAME_LEN - 1) { @@ -1444,9 +1451,9 @@ void* client_handle_session(void *arg) { } buf[0] = b; if (char_len > 1) { - int read_bytes = ssh_channel_read(client->channel, &buf[1], char_len - 1, 0); + int read_bytes = ssh_channel_read_timeout(client->channel, &buf[1], char_len - 1, 0, 5000); if (read_bytes != char_len - 1) { - /* Incomplete UTF-8 sequence */ + /* Incomplete or timed-out UTF-8 continuation */ continue; } } @@ -1493,6 +1500,9 @@ cleanup: release_ip_connection(client->client_ip); + /* Release the callback reference (paired with addref before install_client_channel_callbacks) */ + client_release(client); + /* Release the main reference - client will be freed when all refs are gone */ client_release(client); @@ -1526,7 +1536,7 @@ static int auth_password(ssh_session session, const char *user, /* If access token is configured, require it */ if (g_access_token[0] != '\0') { - if (password && strcmp(password, g_access_token) == 0) { + if (password && constant_time_strcmp(password, g_access_token)) { /* Token matches */ ctx->auth_success = true; return SSH_AUTH_SUCCESS; @@ -1567,9 +1577,8 @@ static int auth_none(ssh_session session, const char *user, void *userdata) { static int auth_pubkey(ssh_session session, const char *user, struct ssh_key_struct *pubkey, char signature_state, void *userdata) { - (void)session; /* Unused */ - (void)pubkey; /* Unused */ - (void)signature_state; /* Unused in anonymous mode */ + (void)session; + (void)pubkey; session_context_t *ctx = (session_context_t *)userdata; if (user && user[0] != '\0') { @@ -1577,10 +1586,17 @@ static int auth_pubkey(ssh_session session, const char *user, ctx->requested_user[sizeof(ctx->requested_user) - 1] = '\0'; } + /* Reject if access token is required (pubkey auth not supported with tokens) */ if (g_access_token[0] != '\0') { return SSH_AUTH_DENIED; } + /* Only accept after the signature has been verified by libssh. + * SSH_PUBLICKEY_STATE_NONE is just a key offer — no proof of possession. */ + if (signature_state != SSH_PUBLICKEY_STATE_VALID) { + return SSH_AUTH_PARTIAL; + } + ctx->auth_success = true; return SSH_AUTH_SUCCESS; } @@ -1926,7 +1942,6 @@ static void *bootstrap_client_session(void *arg) { client->session = session; client->channel = channel; - client->fd = -1; client->width = ctx->pty_width; client->height = ctx->pty_height; sanitize_terminal_size(&client->width, &client->height); @@ -1949,7 +1964,12 @@ static void *bootstrap_client_session(void *arg) { client->exec_command[sizeof(client->exec_command) - 1] = '\0'; } + /* Add a ref for the channel callbacks (eof/close/window_change) so the + * client_t outlives any in-flight callback invocation. */ + client_addref(client); + if (install_client_channel_callbacks(client) < 0) { + client_release(client); /* drop the callback ref */ pthread_mutex_destroy(&client->io_lock); pthread_mutex_destroy(&client->ref_lock); free(client); @@ -1998,14 +2018,7 @@ int ssh_server_init(int port) { ssh_bind_options_set(g_sshbind, SSH_BIND_OPTIONS_BINDADDR, bind_addr); /* Configurable SSH log level (default: SSH_LOG_WARNING=1) */ - int verbosity = SSH_LOG_WARNING; - const char *log_level_env = getenv("TNT_SSH_LOG_LEVEL"); - if (log_level_env) { - int level = atoi(log_level_env); - if (level >= 0 && level <= 4) { - verbosity = level; - } - } + int verbosity = env_int("TNT_SSH_LOG_LEVEL", SSH_LOG_WARNING, 0, 4); ssh_bind_options_set(g_sshbind, SSH_BIND_OPTIONS_LOG_VERBOSITY, &verbosity); if (ssh_bind_listen(g_sshbind) < 0) { @@ -2091,7 +2104,5 @@ int ssh_server_start(int unused) { continue; } } - - pthread_attr_destroy(&attr); - return 0; + /* Unreachable — the while(1) loop only exits via signal/_exit(). */ } diff --git a/src/tui.c b/src/tui.c index 75fabf9..2eb4cdf 100644 --- a/src/tui.c +++ b/src/tui.c @@ -64,10 +64,11 @@ void tui_render_screen(client_t *client) { size_t pos = 0; buffer[0] = '\0'; - /* Acquire all data in one lock to prevent TOCTOU */ + /* First pass under lock: compute indices and counts */ pthread_rwlock_rdlock(&g_room->lock); int online = g_room->client_count; int msg_count = g_room->message_count; + pthread_rwlock_unlock(&g_room->lock); /* Calculate which messages to show */ int msg_height = client->height - 3; @@ -90,19 +91,31 @@ void tui_render_screen(client_t *client) { int end = start + msg_height; if (end > msg_count) end = msg_count; - /* Create snapshot of messages to display */ + /* Allocate snapshot outside the lock to avoid blocking writers */ message_t *msg_snapshot = NULL; int snapshot_count = end - start; if (snapshot_count > 0) { msg_snapshot = calloc(snapshot_count, sizeof(message_t)); - if (msg_snapshot) { - memcpy(msg_snapshot, &g_room->messages[start], - snapshot_count * sizeof(message_t)); - } } - pthread_rwlock_unlock(&g_room->lock); + /* Second pass under lock: copy messages */ + if (msg_snapshot) { + pthread_rwlock_rdlock(&g_room->lock); + /* Re-clamp in case msg_count changed */ + int actual_count = g_room->message_count; + int actual_end = (end <= actual_count) ? end : actual_count; + int actual_start = (start < actual_end) ? start : actual_end; + int actual_snapshot = actual_end - actual_start; + if (actual_snapshot > 0 && actual_snapshot <= snapshot_count) { + memcpy(msg_snapshot, &g_room->messages[actual_start], + actual_snapshot * sizeof(message_t)); + snapshot_count = actual_snapshot; + } else { + snapshot_count = 0; + } + pthread_rwlock_unlock(&g_room->lock); + } /* Now render using snapshot (no lock held) */ diff --git a/src/utf8.c b/src/utf8.c index d0af230..3a77a1b 100644 --- a/src/utf8.c +++ b/src/utf8.c @@ -20,7 +20,8 @@ uint32_t utf8_decode(const char *str, int *bytes_read) { } for (int i = 1; i < len; i++) { - if (s[i] == '\0') { + if (s[i] == '\0' || (s[i] & 0xC0) != 0x80) { + /* Truncated or invalid continuation byte — treat as single byte */ *bytes_read = 1; return s[0]; }