diff --git a/include/common.h b/include/common.h index e284781..cbe5e33 100644 --- a/include/common.h +++ b/include/common.h @@ -62,4 +62,9 @@ void buffer_append_bytes(char *buffer, size_t buf_size, size_t *pos, void buffer_appendf(char *buffer, size_t buf_size, size_t *pos, const char *fmt, ...); +/* Parse an integer from `getenv(name)`, clamping accepted values to + * [min_val, max_val]. Returns `fallback` when the variable is unset, empty, + * non-numeric, or out of range. */ +int env_int(const char *name, int fallback, int min_val, int max_val); + #endif /* COMMON_H */ diff --git a/include/ratelimit.h b/include/ratelimit.h new file mode 100644 index 0000000..711a189 --- /dev/null +++ b/include/ratelimit.h @@ -0,0 +1,27 @@ +#ifndef RATELIMIT_H +#define RATELIMIT_H + +#include + +/* Read TNT_MAX_CONNECTIONS / TNT_MAX_CONN_PER_IP / TNT_MAX_CONN_RATE_PER_IP / + * TNT_RATE_LIMIT from the environment. Idempotent, call once at startup. */ +void ratelimit_init(void); + +/* Per-IP entry point: returns false if the IP has hit any limit (concurrent, + * rate, or block). On success, increments the IP's active counter — caller + * MUST pair with ratelimit_release_ip() when the connection ends. */ +bool ratelimit_check_ip(const char *ip); +void ratelimit_release_ip(const char *ip); + +/* Auth-failure ledger. After enough failures within the window the IP is + * blocked for a fixed duration. */ +void ratelimit_record_auth_failure(const char *ip); + +/* Global active-connection cap (separate from per-IP). Pair them. */ +bool ratelimit_check_and_increment_total(void); +void ratelimit_decrement_total(void); + +/* Read-only accessor for stats subcommand. */ +int ratelimit_get_active_total(void); + +#endif /* RATELIMIT_H */ diff --git a/src/common.c b/src/common.c index 2cc8b5f..6b5a407 100644 --- a/src/common.c +++ b/src/common.c @@ -126,3 +126,12 @@ void buffer_appendf(char *buffer, size_t buf_size, size_t *pos, *pos += (size_t)written; } } + +int env_int(const char *name, int fallback, int min_val, int max_val) { + const char *env = getenv(name); + if (!env || env[0] == '\0') return fallback; + char *end; + long val = strtol(env, &end, 10); + if (*end != '\0' || val < min_val || val > max_val) return fallback; + return (int)val; +} diff --git a/src/ratelimit.c b/src/ratelimit.c new file mode 100644 index 0000000..5f2ff7f --- /dev/null +++ b/src/ratelimit.c @@ -0,0 +1,211 @@ +#include "ratelimit.h" +#include "common.h" +#include +#include +#include +#include +#include +#include + +#define MAX_TRACKED_IPS 256 +#define RATE_LIMIT_WINDOW 60 /* seconds */ +#define MAX_AUTH_FAILURES 5 /* auth failures before block */ +#define BLOCK_DURATION 300 /* seconds to block after too many failures */ + +typedef struct { + char ip[INET6_ADDRSTRLEN]; + time_t window_start; + int recent_connection_count; + int active_connections; + int auth_failure_count; + bool is_blocked; + time_t block_until; +} ip_rate_limit_t; + +static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS]; +static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER; +static int g_total_connections = 0; +static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER; + +static int g_max_connections = 64; +static int g_max_conn_per_ip = 5; +static int g_max_conn_rate_per_ip = 10; +static int g_rate_limit_enabled = 1; + +void ratelimit_init(void) { + g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024); + g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024); + g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024); + g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1); +} + +/* Caller MUST hold g_rate_limit_lock. */ +static ip_rate_limit_t* get_rate_limit_entry(const char *ip) { + /* Look for existing entry */ + for (int i = 0; i < MAX_TRACKED_IPS; i++) { + if (strcmp(g_rate_limits[i].ip, ip) == 0) { + return &g_rate_limits[i]; + } + } + + /* Find empty slot */ + for (int i = 0; i < MAX_TRACKED_IPS; i++) { + if (g_rate_limits[i].ip[0] == '\0') { + strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1); + g_rate_limits[i].window_start = time(NULL); + g_rate_limits[i].recent_connection_count = 0; + g_rate_limits[i].active_connections = 0; + g_rate_limits[i].auth_failure_count = 0; + g_rate_limits[i].is_blocked = false; + g_rate_limits[i].block_until = 0; + return &g_rate_limits[i]; + } + } + + /* Reuse the oldest inactive entry first so active IP accounting stays intact. */ + int oldest_idx = -1; + time_t oldest_time = 0; + for (int i = 0; i < MAX_TRACKED_IPS; i++) { + if (g_rate_limits[i].active_connections != 0) { + continue; + } + if (oldest_idx < 0 || g_rate_limits[i].window_start < oldest_time) { + oldest_time = g_rate_limits[i].window_start; + oldest_idx = i; + } + } + + if (oldest_idx < 0) { + /* All slots have active connections — evicting will corrupt their + * concurrency accounting. Pick the oldest entry but warn. */ + oldest_idx = 0; + oldest_time = g_rate_limits[0].window_start; + for (int i = 1; i < MAX_TRACKED_IPS; i++) { + if (g_rate_limits[i].window_start < oldest_time) { + oldest_time = g_rate_limits[i].window_start; + oldest_idx = i; + } + } + fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s " + "(%d active connections lost)\n", + g_rate_limits[oldest_idx].ip, + g_rate_limits[oldest_idx].active_connections); + } + + /* Reset and reuse */ + strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1); + g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0'; + g_rate_limits[oldest_idx].window_start = time(NULL); + g_rate_limits[oldest_idx].recent_connection_count = 0; + g_rate_limits[oldest_idx].active_connections = 0; + g_rate_limits[oldest_idx].auth_failure_count = 0; + g_rate_limits[oldest_idx].is_blocked = false; + g_rate_limits[oldest_idx].block_until = 0; + return &g_rate_limits[oldest_idx]; +} + +bool ratelimit_check_ip(const char *ip) { + time_t now = time(NULL); + + pthread_mutex_lock(&g_rate_limit_lock); + ip_rate_limit_t *entry = get_rate_limit_entry(ip); + + if (entry->active_connections >= g_max_conn_per_ip) { + pthread_mutex_unlock(&g_rate_limit_lock); + fprintf(stderr, "Concurrent IP limit reached for %s\n", ip); + return false; + } + + if (g_rate_limit_enabled && entry->is_blocked && now < entry->block_until) { + pthread_mutex_unlock(&g_rate_limit_lock); + fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until); + return false; + } + + if (g_rate_limit_enabled && entry->is_blocked && now >= entry->block_until) { + entry->is_blocked = false; + entry->auth_failure_count = 0; + } + + if (g_rate_limit_enabled) { + if (now - entry->window_start >= RATE_LIMIT_WINDOW) { + entry->window_start = now; + entry->recent_connection_count = 0; + } + + entry->recent_connection_count++; + if (entry->recent_connection_count >= g_max_conn_rate_per_ip) { + entry->is_blocked = true; + entry->block_until = now + BLOCK_DURATION; + pthread_mutex_unlock(&g_rate_limit_lock); + fprintf(stderr, "Rate limit exceeded for IP %s\n", ip); + return false; + } + } + + entry->active_connections++; + pthread_mutex_unlock(&g_rate_limit_lock); + return true; +} + +void ratelimit_record_auth_failure(const char *ip) { + time_t now = time(NULL); + + if (!g_rate_limit_enabled) { + return; + } + + pthread_mutex_lock(&g_rate_limit_lock); + ip_rate_limit_t *entry = get_rate_limit_entry(ip); + + entry->auth_failure_count++; + if (entry->auth_failure_count >= MAX_AUTH_FAILURES) { + entry->is_blocked = true; + entry->block_until = now + BLOCK_DURATION; + fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count); + } + + pthread_mutex_unlock(&g_rate_limit_lock); +} + +void ratelimit_release_ip(const char *ip) { + if (!ip || ip[0] == '\0') { + return; + } + + pthread_mutex_lock(&g_rate_limit_lock); + ip_rate_limit_t *entry = get_rate_limit_entry(ip); + if (entry->active_connections > 0) { + entry->active_connections--; + } + pthread_mutex_unlock(&g_rate_limit_lock); +} + +bool ratelimit_check_and_increment_total(void) { + pthread_mutex_lock(&g_conn_count_lock); + + if (g_total_connections >= g_max_connections) { + pthread_mutex_unlock(&g_conn_count_lock); + return false; + } + + g_total_connections++; + pthread_mutex_unlock(&g_conn_count_lock); + return true; +} + +void ratelimit_decrement_total(void) { + pthread_mutex_lock(&g_conn_count_lock); + if (g_total_connections > 0) { + g_total_connections--; + } + pthread_mutex_unlock(&g_conn_count_lock); +} + +int ratelimit_get_active_total(void) { + int count; + pthread_mutex_lock(&g_conn_count_lock); + count = g_total_connections; + pthread_mutex_unlock(&g_conn_count_lock); + return count; +} diff --git a/src/ssh_server.c b/src/ssh_server.c index b070288..af167f3 100644 --- a/src/ssh_server.c +++ b/src/ssh_server.c @@ -1,4 +1,5 @@ #include "ssh_server.h" +#include "ratelimit.h" #include "tui.h" #include "utf8.h" #include @@ -38,33 +39,11 @@ typedef struct { char client_ip[INET6_ADDRSTRLEN]; } accepted_session_t; -/* Rate limiting and connection tracking */ -#define MAX_TRACKED_IPS 256 -#define RATE_LIMIT_WINDOW 60 /* seconds */ -#define MAX_AUTH_FAILURES 5 /* auth failures before block */ -#define BLOCK_DURATION 300 /* seconds to block after too many failures */ - -typedef struct { - char ip[INET6_ADDRSTRLEN]; - time_t window_start; - int recent_connection_count; - int active_connections; - int auth_failure_count; - bool is_blocked; - time_t block_until; -} ip_rate_limit_t; - -static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS]; -static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER; -static int g_total_connections = 0; -static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER; static time_t g_server_start_time = 0; -/* Configuration from environment variables */ -static int g_max_connections = 64; -static int g_max_conn_per_ip = 5; -static int g_max_conn_rate_per_ip = 10; -static int g_rate_limit_enabled = 1; +/* Configuration from environment variables. Rate-limiting / connection-count + * config has moved to ratelimit.{c,h}; the two below stay here until the auth + * and input modules are extracted in later PR2 steps. */ static char g_access_token[256] = ""; static int g_idle_timeout = DEFAULT_IDLE_TIMEOUT; @@ -88,200 +67,6 @@ static bool constant_time_strcmp(const char *a, const char *b) { return length_diff == 0 && byte_diff == 0; } -/* Safe integer parse from environment variable; returns fallback on error. */ -static int env_int(const char *name, int fallback, int min_val, int max_val) { - const char *env = getenv(name); - if (!env || env[0] == '\0') return fallback; - char *end; - long val = strtol(env, &end, 10); - if (*end != '\0' || val < min_val || val > max_val) return fallback; - return (int)val; -} - -/* Initialize rate limit configuration from environment */ -static void init_rate_limit_config(void) { - const char *env; - - g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024); - g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024); - g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024); - g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1); - - g_idle_timeout = env_int("TNT_IDLE_TIMEOUT", DEFAULT_IDLE_TIMEOUT, 0, 86400); - - if ((env = getenv("TNT_ACCESS_TOKEN")) != NULL) { - strncpy(g_access_token, env, sizeof(g_access_token) - 1); - g_access_token[sizeof(g_access_token) - 1] = '\0'; - } -} - -/* Get or create rate limit entry for an IP */ -static ip_rate_limit_t* get_rate_limit_entry(const char *ip) { - /* Look for existing entry */ - for (int i = 0; i < MAX_TRACKED_IPS; i++) { - if (strcmp(g_rate_limits[i].ip, ip) == 0) { - return &g_rate_limits[i]; - } - } - - /* Find empty slot */ - for (int i = 0; i < MAX_TRACKED_IPS; i++) { - if (g_rate_limits[i].ip[0] == '\0') { - strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1); - g_rate_limits[i].window_start = time(NULL); - g_rate_limits[i].recent_connection_count = 0; - g_rate_limits[i].active_connections = 0; - g_rate_limits[i].auth_failure_count = 0; - g_rate_limits[i].is_blocked = false; - g_rate_limits[i].block_until = 0; - return &g_rate_limits[i]; - } - } - - /* Reuse the oldest inactive entry first so active IP accounting stays intact. */ - int oldest_idx = -1; - time_t oldest_time = 0; - for (int i = 0; i < MAX_TRACKED_IPS; i++) { - if (g_rate_limits[i].active_connections != 0) { - continue; - } - if (oldest_idx < 0 || g_rate_limits[i].window_start < oldest_time) { - oldest_time = g_rate_limits[i].window_start; - oldest_idx = i; - } - } - - if (oldest_idx < 0) { - /* All slots have active connections — evicting will corrupt their - * concurrency accounting. Pick the oldest entry but warn. */ - oldest_idx = 0; - oldest_time = g_rate_limits[0].window_start; - for (int i = 1; i < MAX_TRACKED_IPS; i++) { - if (g_rate_limits[i].window_start < oldest_time) { - oldest_time = g_rate_limits[i].window_start; - oldest_idx = i; - } - } - fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s " - "(%d active connections lost)\n", - g_rate_limits[oldest_idx].ip, - g_rate_limits[oldest_idx].active_connections); - } - - /* Reset and reuse */ - strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1); - g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0'; - g_rate_limits[oldest_idx].window_start = time(NULL); - g_rate_limits[oldest_idx].recent_connection_count = 0; - g_rate_limits[oldest_idx].active_connections = 0; - g_rate_limits[oldest_idx].auth_failure_count = 0; - g_rate_limits[oldest_idx].is_blocked = false; - g_rate_limits[oldest_idx].block_until = 0; - return &g_rate_limits[oldest_idx]; -} - -/* Check rate and concurrency limits for an IP */ -static bool check_ip_connection_policy(const char *ip) { - time_t now = time(NULL); - - pthread_mutex_lock(&g_rate_limit_lock); - ip_rate_limit_t *entry = get_rate_limit_entry(ip); - - if (entry->active_connections >= g_max_conn_per_ip) { - pthread_mutex_unlock(&g_rate_limit_lock); - fprintf(stderr, "Concurrent IP limit reached for %s\n", ip); - return false; - } - - if (g_rate_limit_enabled && entry->is_blocked && now < entry->block_until) { - pthread_mutex_unlock(&g_rate_limit_lock); - fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until); - return false; - } - - if (g_rate_limit_enabled && entry->is_blocked && now >= entry->block_until) { - entry->is_blocked = false; - entry->auth_failure_count = 0; - } - - if (g_rate_limit_enabled) { - if (now - entry->window_start >= RATE_LIMIT_WINDOW) { - entry->window_start = now; - entry->recent_connection_count = 0; - } - - entry->recent_connection_count++; - if (entry->recent_connection_count >= g_max_conn_rate_per_ip) { - entry->is_blocked = true; - entry->block_until = now + BLOCK_DURATION; - pthread_mutex_unlock(&g_rate_limit_lock); - fprintf(stderr, "Rate limit exceeded for IP %s\n", ip); - return false; - } - } - - entry->active_connections++; - pthread_mutex_unlock(&g_rate_limit_lock); - return true; -} - -/* Record authentication failure */ -static void record_auth_failure(const char *ip) { - time_t now = time(NULL); - - if (!g_rate_limit_enabled) { - return; - } - - pthread_mutex_lock(&g_rate_limit_lock); - ip_rate_limit_t *entry = get_rate_limit_entry(ip); - - entry->auth_failure_count++; - if (entry->auth_failure_count >= MAX_AUTH_FAILURES) { - entry->is_blocked = true; - entry->block_until = now + BLOCK_DURATION; - fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count); - } - - pthread_mutex_unlock(&g_rate_limit_lock); -} - -static void release_ip_connection(const char *ip) { - if (!ip || ip[0] == '\0') { - return; - } - - pthread_mutex_lock(&g_rate_limit_lock); - ip_rate_limit_t *entry = get_rate_limit_entry(ip); - if (entry->active_connections > 0) { - entry->active_connections--; - } - pthread_mutex_unlock(&g_rate_limit_lock); -} - -/* Check and increment total connection count */ -static bool check_and_increment_connections(void) { - pthread_mutex_lock(&g_conn_count_lock); - - if (g_total_connections >= g_max_connections) { - pthread_mutex_unlock(&g_conn_count_lock); - return false; - } - - g_total_connections++; - pthread_mutex_unlock(&g_conn_count_lock); - return true; -} - -/* Decrement connection count */ -static void decrement_connections(void) { - pthread_mutex_lock(&g_conn_count_lock); - if (g_total_connections > 0) { - g_total_connections--; - } - pthread_mutex_unlock(&g_conn_count_lock); -} - /* Get client IP address */ static void get_client_ip(ssh_session session, char *ip_buf, size_t buf_size) { int fd = ssh_get_fd(session); @@ -815,9 +600,7 @@ static int exec_command_stats(client_t *client, bool json) { client_capacity = g_room->client_capacity; pthread_rwlock_unlock(&g_room->lock); - pthread_mutex_lock(&g_conn_count_lock); - active_connections = g_total_connections; - pthread_mutex_unlock(&g_conn_count_lock); + active_connections = ratelimit_get_active_total(); uptime_seconds = (g_server_start_time > 0 && now >= g_server_start_time) ? (long)(now - g_server_start_time) @@ -1810,7 +1593,7 @@ cleanup: message_save(&leave_msg); } - release_ip_connection(client->client_ip); + ratelimit_release_ip(client->client_ip); /* Remove channel callbacks before releasing refs to prevent use-after-free * if a callback fires between the two releases. */ @@ -1825,7 +1608,7 @@ cleanup: client_release(client); /* Decrement connection count */ - decrement_connections(); + ratelimit_decrement_total(); return NULL; } @@ -1846,7 +1629,7 @@ static int auth_password(ssh_session session, const char *user, /* Limit auth attempts */ if (ctx->auth_attempts > 3) { - record_auth_failure(ctx->client_ip); + ratelimit_record_auth_failure(ctx->client_ip); fprintf(stderr, "Too many auth attempts from %s\n", ctx->client_ip); ssh_disconnect(session); return SSH_AUTH_DENIED; @@ -1861,7 +1644,7 @@ static int auth_password(ssh_session session, const char *user, } else { /* Wrong token — IP blocking handles brute force, no sleep needed here * (sleeping in a libssh callback blocks the entire accept loop). */ - record_auth_failure(ctx->client_ip); + ratelimit_record_auth_failure(ctx->client_ip); return SSH_AUTH_DENIED; } } else { @@ -1948,10 +1731,10 @@ static void cleanup_failed_session(ssh_session session, session_context_t *ctx) } if (ctx) { - release_ip_connection(ctx->client_ip); + ratelimit_release_ip(ctx->client_ip); } destroy_session_context(ctx); - decrement_connections(); + ratelimit_decrement_total(); } static void setup_session_channel_callbacks(ssh_channel channel, @@ -2168,10 +1951,10 @@ static void *bootstrap_client_session(void *arg) { ctx = calloc(1, sizeof(session_context_t)); if (!ctx) { - release_ip_connection(accepted_ip); + ratelimit_release_ip(accepted_ip); ssh_disconnect(session); ssh_free(session); - decrement_connections(); + ratelimit_decrement_total(); return NULL; } @@ -2315,8 +2098,16 @@ static void *bootstrap_client_session(void *arg) { /* Initialize SSH server */ int ssh_server_init(int port) { - /* Initialize rate limiting configuration */ - init_rate_limit_config(); + /* Initialize rate-limit / connection-count subsystem */ + ratelimit_init(); + + /* Auth / session config (will move into auth.c and input.c in later PR2 steps) */ + g_idle_timeout = env_int("TNT_IDLE_TIMEOUT", DEFAULT_IDLE_TIMEOUT, 0, 86400); + const char *token_env = getenv("TNT_ACCESS_TOKEN"); + if (token_env != NULL) { + strncpy(g_access_token, token_env, sizeof(g_access_token) - 1); + g_access_token[sizeof(g_access_token) - 1] = '\0'; + } g_listen_port = port; g_server_start_time = time(NULL); @@ -2392,15 +2183,15 @@ int ssh_server_start(int unused) { get_client_ip(session, client_ip, sizeof(client_ip)); /* Check total connection limit */ - if (!check_and_increment_connections()) { + if (!ratelimit_check_and_increment_total()) { fprintf(stderr, "Max connections reached, rejecting %s\n", client_ip); ssh_disconnect(session); ssh_free(session); continue; } - if (!check_ip_connection_policy(client_ip)) { - decrement_connections(); + if (!ratelimit_check_ip(client_ip)) { + ratelimit_decrement_total(); ssh_disconnect(session); ssh_free(session); continue; @@ -2408,8 +2199,8 @@ int ssh_server_start(int unused) { accepted = calloc(1, sizeof(*accepted)); if (!accepted) { - release_ip_connection(client_ip); - decrement_connections(); + ratelimit_release_ip(client_ip); + ratelimit_decrement_total(); ssh_disconnect(session); ssh_free(session); continue; @@ -2422,8 +2213,8 @@ int ssh_server_start(int unused) { if (pthread_create(&thread, &attr, bootstrap_client_session, accepted) != 0) { fprintf(stderr, "Thread creation failed: %s\n", strerror(errno)); free(accepted); - release_ip_connection(client_ip); - decrement_connections(); + ratelimit_release_ip(client_ip); + ratelimit_decrement_total(); ssh_disconnect(session); ssh_free(session); continue;