mirror of
https://oauth2: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@github.com/m1ngsama/TNT.git
synced 2026-06-26 04:34:38 +08:00
refactor: extract ratelimit module (PR2-M1)
Move IP rate-limiting, auth-failure tracking, and global connection
counting out of ssh_server.c into a dedicated module.
New API (include/ratelimit.h):
- ratelimit_init()
- ratelimit_check_ip() / ratelimit_release_ip()
- ratelimit_record_auth_failure()
- ratelimit_check_and_increment_total() / ratelimit_decrement_total()
- ratelimit_get_active_total() (replaces the direct g_total_connections
read that exec_command_stats was doing under g_conn_count_lock)
env_int() also moves up to common.{c,h} since multiple modules need it.
ssh_server.c drops from 2469 to 2200 lines. Behaviour is preserved:
the new functions are byte-for-byte the same implementations, only the
file boundary moved.
g_idle_timeout and g_access_token reads stay inline in ssh_server_init()
for now; they will follow the auth.c and input.c extractions later.
This commit is contained in:
parent
d9382882d1
commit
562ee5296d
5 changed files with 282 additions and 239 deletions
|
|
@ -62,4 +62,9 @@ void buffer_append_bytes(char *buffer, size_t buf_size, size_t *pos,
|
|||
void buffer_appendf(char *buffer, size_t buf_size, size_t *pos,
|
||||
const char *fmt, ...);
|
||||
|
||||
/* Parse an integer from `getenv(name)`, clamping accepted values to
|
||||
* [min_val, max_val]. Returns `fallback` when the variable is unset, empty,
|
||||
* non-numeric, or out of range. */
|
||||
int env_int(const char *name, int fallback, int min_val, int max_val);
|
||||
|
||||
#endif /* COMMON_H */
|
||||
|
|
|
|||
27
include/ratelimit.h
Normal file
27
include/ratelimit.h
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
#ifndef RATELIMIT_H
|
||||
#define RATELIMIT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Read TNT_MAX_CONNECTIONS / TNT_MAX_CONN_PER_IP / TNT_MAX_CONN_RATE_PER_IP /
|
||||
* TNT_RATE_LIMIT from the environment. Idempotent, call once at startup. */
|
||||
void ratelimit_init(void);
|
||||
|
||||
/* Per-IP entry point: returns false if the IP has hit any limit (concurrent,
|
||||
* rate, or block). On success, increments the IP's active counter — caller
|
||||
* MUST pair with ratelimit_release_ip() when the connection ends. */
|
||||
bool ratelimit_check_ip(const char *ip);
|
||||
void ratelimit_release_ip(const char *ip);
|
||||
|
||||
/* Auth-failure ledger. After enough failures within the window the IP is
|
||||
* blocked for a fixed duration. */
|
||||
void ratelimit_record_auth_failure(const char *ip);
|
||||
|
||||
/* Global active-connection cap (separate from per-IP). Pair them. */
|
||||
bool ratelimit_check_and_increment_total(void);
|
||||
void ratelimit_decrement_total(void);
|
||||
|
||||
/* Read-only accessor for stats subcommand. */
|
||||
int ratelimit_get_active_total(void);
|
||||
|
||||
#endif /* RATELIMIT_H */
|
||||
|
|
@ -126,3 +126,12 @@ void buffer_appendf(char *buffer, size_t buf_size, size_t *pos,
|
|||
*pos += (size_t)written;
|
||||
}
|
||||
}
|
||||
|
||||
int env_int(const char *name, int fallback, int min_val, int max_val) {
|
||||
const char *env = getenv(name);
|
||||
if (!env || env[0] == '\0') return fallback;
|
||||
char *end;
|
||||
long val = strtol(env, &end, 10);
|
||||
if (*end != '\0' || val < min_val || val > max_val) return fallback;
|
||||
return (int)val;
|
||||
}
|
||||
|
|
|
|||
211
src/ratelimit.c
Normal file
211
src/ratelimit.c
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
#include "ratelimit.h"
|
||||
#include "common.h"
|
||||
#include <arpa/inet.h>
|
||||
#include <pthread.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
|
||||
#define MAX_TRACKED_IPS 256
|
||||
#define RATE_LIMIT_WINDOW 60 /* seconds */
|
||||
#define MAX_AUTH_FAILURES 5 /* auth failures before block */
|
||||
#define BLOCK_DURATION 300 /* seconds to block after too many failures */
|
||||
|
||||
typedef struct {
|
||||
char ip[INET6_ADDRSTRLEN];
|
||||
time_t window_start;
|
||||
int recent_connection_count;
|
||||
int active_connections;
|
||||
int auth_failure_count;
|
||||
bool is_blocked;
|
||||
time_t block_until;
|
||||
} ip_rate_limit_t;
|
||||
|
||||
static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS];
|
||||
static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
static int g_total_connections = 0;
|
||||
static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
|
||||
static int g_max_connections = 64;
|
||||
static int g_max_conn_per_ip = 5;
|
||||
static int g_max_conn_rate_per_ip = 10;
|
||||
static int g_rate_limit_enabled = 1;
|
||||
|
||||
void ratelimit_init(void) {
|
||||
g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024);
|
||||
g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024);
|
||||
g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024);
|
||||
g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1);
|
||||
}
|
||||
|
||||
/* Caller MUST hold g_rate_limit_lock. */
|
||||
static ip_rate_limit_t* get_rate_limit_entry(const char *ip) {
|
||||
/* Look for existing entry */
|
||||
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||
if (strcmp(g_rate_limits[i].ip, ip) == 0) {
|
||||
return &g_rate_limits[i];
|
||||
}
|
||||
}
|
||||
|
||||
/* Find empty slot */
|
||||
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||
if (g_rate_limits[i].ip[0] == '\0') {
|
||||
strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1);
|
||||
g_rate_limits[i].window_start = time(NULL);
|
||||
g_rate_limits[i].recent_connection_count = 0;
|
||||
g_rate_limits[i].active_connections = 0;
|
||||
g_rate_limits[i].auth_failure_count = 0;
|
||||
g_rate_limits[i].is_blocked = false;
|
||||
g_rate_limits[i].block_until = 0;
|
||||
return &g_rate_limits[i];
|
||||
}
|
||||
}
|
||||
|
||||
/* Reuse the oldest inactive entry first so active IP accounting stays intact. */
|
||||
int oldest_idx = -1;
|
||||
time_t oldest_time = 0;
|
||||
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||
if (g_rate_limits[i].active_connections != 0) {
|
||||
continue;
|
||||
}
|
||||
if (oldest_idx < 0 || g_rate_limits[i].window_start < oldest_time) {
|
||||
oldest_time = g_rate_limits[i].window_start;
|
||||
oldest_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (oldest_idx < 0) {
|
||||
/* All slots have active connections — evicting will corrupt their
|
||||
* concurrency accounting. Pick the oldest entry but warn. */
|
||||
oldest_idx = 0;
|
||||
oldest_time = g_rate_limits[0].window_start;
|
||||
for (int i = 1; i < MAX_TRACKED_IPS; i++) {
|
||||
if (g_rate_limits[i].window_start < oldest_time) {
|
||||
oldest_time = g_rate_limits[i].window_start;
|
||||
oldest_idx = i;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s "
|
||||
"(%d active connections lost)\n",
|
||||
g_rate_limits[oldest_idx].ip,
|
||||
g_rate_limits[oldest_idx].active_connections);
|
||||
}
|
||||
|
||||
/* Reset and reuse */
|
||||
strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1);
|
||||
g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0';
|
||||
g_rate_limits[oldest_idx].window_start = time(NULL);
|
||||
g_rate_limits[oldest_idx].recent_connection_count = 0;
|
||||
g_rate_limits[oldest_idx].active_connections = 0;
|
||||
g_rate_limits[oldest_idx].auth_failure_count = 0;
|
||||
g_rate_limits[oldest_idx].is_blocked = false;
|
||||
g_rate_limits[oldest_idx].block_until = 0;
|
||||
return &g_rate_limits[oldest_idx];
|
||||
}
|
||||
|
||||
bool ratelimit_check_ip(const char *ip) {
|
||||
time_t now = time(NULL);
|
||||
|
||||
pthread_mutex_lock(&g_rate_limit_lock);
|
||||
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||
|
||||
if (entry->active_connections >= g_max_conn_per_ip) {
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
fprintf(stderr, "Concurrent IP limit reached for %s\n", ip);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (g_rate_limit_enabled && entry->is_blocked && now < entry->block_until) {
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (g_rate_limit_enabled && entry->is_blocked && now >= entry->block_until) {
|
||||
entry->is_blocked = false;
|
||||
entry->auth_failure_count = 0;
|
||||
}
|
||||
|
||||
if (g_rate_limit_enabled) {
|
||||
if (now - entry->window_start >= RATE_LIMIT_WINDOW) {
|
||||
entry->window_start = now;
|
||||
entry->recent_connection_count = 0;
|
||||
}
|
||||
|
||||
entry->recent_connection_count++;
|
||||
if (entry->recent_connection_count >= g_max_conn_rate_per_ip) {
|
||||
entry->is_blocked = true;
|
||||
entry->block_until = now + BLOCK_DURATION;
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
fprintf(stderr, "Rate limit exceeded for IP %s\n", ip);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
entry->active_connections++;
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ratelimit_record_auth_failure(const char *ip) {
|
||||
time_t now = time(NULL);
|
||||
|
||||
if (!g_rate_limit_enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
pthread_mutex_lock(&g_rate_limit_lock);
|
||||
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||
|
||||
entry->auth_failure_count++;
|
||||
if (entry->auth_failure_count >= MAX_AUTH_FAILURES) {
|
||||
entry->is_blocked = true;
|
||||
entry->block_until = now + BLOCK_DURATION;
|
||||
fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count);
|
||||
}
|
||||
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
}
|
||||
|
||||
void ratelimit_release_ip(const char *ip) {
|
||||
if (!ip || ip[0] == '\0') {
|
||||
return;
|
||||
}
|
||||
|
||||
pthread_mutex_lock(&g_rate_limit_lock);
|
||||
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||
if (entry->active_connections > 0) {
|
||||
entry->active_connections--;
|
||||
}
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
}
|
||||
|
||||
bool ratelimit_check_and_increment_total(void) {
|
||||
pthread_mutex_lock(&g_conn_count_lock);
|
||||
|
||||
if (g_total_connections >= g_max_connections) {
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
return false;
|
||||
}
|
||||
|
||||
g_total_connections++;
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ratelimit_decrement_total(void) {
|
||||
pthread_mutex_lock(&g_conn_count_lock);
|
||||
if (g_total_connections > 0) {
|
||||
g_total_connections--;
|
||||
}
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
}
|
||||
|
||||
int ratelimit_get_active_total(void) {
|
||||
int count;
|
||||
pthread_mutex_lock(&g_conn_count_lock);
|
||||
count = g_total_connections;
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
return count;
|
||||
}
|
||||
269
src/ssh_server.c
269
src/ssh_server.c
|
|
@ -1,4 +1,5 @@
|
|||
#include "ssh_server.h"
|
||||
#include "ratelimit.h"
|
||||
#include "tui.h"
|
||||
#include "utf8.h"
|
||||
#include <libssh/libssh.h>
|
||||
|
|
@ -38,33 +39,11 @@ typedef struct {
|
|||
char client_ip[INET6_ADDRSTRLEN];
|
||||
} accepted_session_t;
|
||||
|
||||
/* Rate limiting and connection tracking */
|
||||
#define MAX_TRACKED_IPS 256
|
||||
#define RATE_LIMIT_WINDOW 60 /* seconds */
|
||||
#define MAX_AUTH_FAILURES 5 /* auth failures before block */
|
||||
#define BLOCK_DURATION 300 /* seconds to block after too many failures */
|
||||
|
||||
typedef struct {
|
||||
char ip[INET6_ADDRSTRLEN];
|
||||
time_t window_start;
|
||||
int recent_connection_count;
|
||||
int active_connections;
|
||||
int auth_failure_count;
|
||||
bool is_blocked;
|
||||
time_t block_until;
|
||||
} ip_rate_limit_t;
|
||||
|
||||
static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS];
|
||||
static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
static int g_total_connections = 0;
|
||||
static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||
static time_t g_server_start_time = 0;
|
||||
|
||||
/* Configuration from environment variables */
|
||||
static int g_max_connections = 64;
|
||||
static int g_max_conn_per_ip = 5;
|
||||
static int g_max_conn_rate_per_ip = 10;
|
||||
static int g_rate_limit_enabled = 1;
|
||||
/* Configuration from environment variables. Rate-limiting / connection-count
|
||||
* config has moved to ratelimit.{c,h}; the two below stay here until the auth
|
||||
* and input modules are extracted in later PR2 steps. */
|
||||
static char g_access_token[256] = "";
|
||||
static int g_idle_timeout = DEFAULT_IDLE_TIMEOUT;
|
||||
|
||||
|
|
@ -88,200 +67,6 @@ static bool constant_time_strcmp(const char *a, const char *b) {
|
|||
return length_diff == 0 && byte_diff == 0;
|
||||
}
|
||||
|
||||
/* Safe integer parse from environment variable; returns fallback on error. */
|
||||
static int env_int(const char *name, int fallback, int min_val, int max_val) {
|
||||
const char *env = getenv(name);
|
||||
if (!env || env[0] == '\0') return fallback;
|
||||
char *end;
|
||||
long val = strtol(env, &end, 10);
|
||||
if (*end != '\0' || val < min_val || val > max_val) return fallback;
|
||||
return (int)val;
|
||||
}
|
||||
|
||||
/* Initialize rate limit configuration from environment */
|
||||
static void init_rate_limit_config(void) {
|
||||
const char *env;
|
||||
|
||||
g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024);
|
||||
g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024);
|
||||
g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024);
|
||||
g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1);
|
||||
|
||||
g_idle_timeout = env_int("TNT_IDLE_TIMEOUT", DEFAULT_IDLE_TIMEOUT, 0, 86400);
|
||||
|
||||
if ((env = getenv("TNT_ACCESS_TOKEN")) != NULL) {
|
||||
strncpy(g_access_token, env, sizeof(g_access_token) - 1);
|
||||
g_access_token[sizeof(g_access_token) - 1] = '\0';
|
||||
}
|
||||
}
|
||||
|
||||
/* Get or create rate limit entry for an IP */
|
||||
static ip_rate_limit_t* get_rate_limit_entry(const char *ip) {
|
||||
/* Look for existing entry */
|
||||
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||
if (strcmp(g_rate_limits[i].ip, ip) == 0) {
|
||||
return &g_rate_limits[i];
|
||||
}
|
||||
}
|
||||
|
||||
/* Find empty slot */
|
||||
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||
if (g_rate_limits[i].ip[0] == '\0') {
|
||||
strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1);
|
||||
g_rate_limits[i].window_start = time(NULL);
|
||||
g_rate_limits[i].recent_connection_count = 0;
|
||||
g_rate_limits[i].active_connections = 0;
|
||||
g_rate_limits[i].auth_failure_count = 0;
|
||||
g_rate_limits[i].is_blocked = false;
|
||||
g_rate_limits[i].block_until = 0;
|
||||
return &g_rate_limits[i];
|
||||
}
|
||||
}
|
||||
|
||||
/* Reuse the oldest inactive entry first so active IP accounting stays intact. */
|
||||
int oldest_idx = -1;
|
||||
time_t oldest_time = 0;
|
||||
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||
if (g_rate_limits[i].active_connections != 0) {
|
||||
continue;
|
||||
}
|
||||
if (oldest_idx < 0 || g_rate_limits[i].window_start < oldest_time) {
|
||||
oldest_time = g_rate_limits[i].window_start;
|
||||
oldest_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (oldest_idx < 0) {
|
||||
/* All slots have active connections — evicting will corrupt their
|
||||
* concurrency accounting. Pick the oldest entry but warn. */
|
||||
oldest_idx = 0;
|
||||
oldest_time = g_rate_limits[0].window_start;
|
||||
for (int i = 1; i < MAX_TRACKED_IPS; i++) {
|
||||
if (g_rate_limits[i].window_start < oldest_time) {
|
||||
oldest_time = g_rate_limits[i].window_start;
|
||||
oldest_idx = i;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s "
|
||||
"(%d active connections lost)\n",
|
||||
g_rate_limits[oldest_idx].ip,
|
||||
g_rate_limits[oldest_idx].active_connections);
|
||||
}
|
||||
|
||||
/* Reset and reuse */
|
||||
strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1);
|
||||
g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0';
|
||||
g_rate_limits[oldest_idx].window_start = time(NULL);
|
||||
g_rate_limits[oldest_idx].recent_connection_count = 0;
|
||||
g_rate_limits[oldest_idx].active_connections = 0;
|
||||
g_rate_limits[oldest_idx].auth_failure_count = 0;
|
||||
g_rate_limits[oldest_idx].is_blocked = false;
|
||||
g_rate_limits[oldest_idx].block_until = 0;
|
||||
return &g_rate_limits[oldest_idx];
|
||||
}
|
||||
|
||||
/* Check rate and concurrency limits for an IP */
|
||||
static bool check_ip_connection_policy(const char *ip) {
|
||||
time_t now = time(NULL);
|
||||
|
||||
pthread_mutex_lock(&g_rate_limit_lock);
|
||||
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||
|
||||
if (entry->active_connections >= g_max_conn_per_ip) {
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
fprintf(stderr, "Concurrent IP limit reached for %s\n", ip);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (g_rate_limit_enabled && entry->is_blocked && now < entry->block_until) {
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (g_rate_limit_enabled && entry->is_blocked && now >= entry->block_until) {
|
||||
entry->is_blocked = false;
|
||||
entry->auth_failure_count = 0;
|
||||
}
|
||||
|
||||
if (g_rate_limit_enabled) {
|
||||
if (now - entry->window_start >= RATE_LIMIT_WINDOW) {
|
||||
entry->window_start = now;
|
||||
entry->recent_connection_count = 0;
|
||||
}
|
||||
|
||||
entry->recent_connection_count++;
|
||||
if (entry->recent_connection_count >= g_max_conn_rate_per_ip) {
|
||||
entry->is_blocked = true;
|
||||
entry->block_until = now + BLOCK_DURATION;
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
fprintf(stderr, "Rate limit exceeded for IP %s\n", ip);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
entry->active_connections++;
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Record authentication failure */
|
||||
static void record_auth_failure(const char *ip) {
|
||||
time_t now = time(NULL);
|
||||
|
||||
if (!g_rate_limit_enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
pthread_mutex_lock(&g_rate_limit_lock);
|
||||
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||
|
||||
entry->auth_failure_count++;
|
||||
if (entry->auth_failure_count >= MAX_AUTH_FAILURES) {
|
||||
entry->is_blocked = true;
|
||||
entry->block_until = now + BLOCK_DURATION;
|
||||
fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count);
|
||||
}
|
||||
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
}
|
||||
|
||||
static void release_ip_connection(const char *ip) {
|
||||
if (!ip || ip[0] == '\0') {
|
||||
return;
|
||||
}
|
||||
|
||||
pthread_mutex_lock(&g_rate_limit_lock);
|
||||
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||
if (entry->active_connections > 0) {
|
||||
entry->active_connections--;
|
||||
}
|
||||
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||
}
|
||||
|
||||
/* Check and increment total connection count */
|
||||
static bool check_and_increment_connections(void) {
|
||||
pthread_mutex_lock(&g_conn_count_lock);
|
||||
|
||||
if (g_total_connections >= g_max_connections) {
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
return false;
|
||||
}
|
||||
|
||||
g_total_connections++;
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Decrement connection count */
|
||||
static void decrement_connections(void) {
|
||||
pthread_mutex_lock(&g_conn_count_lock);
|
||||
if (g_total_connections > 0) {
|
||||
g_total_connections--;
|
||||
}
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
}
|
||||
|
||||
/* Get client IP address */
|
||||
static void get_client_ip(ssh_session session, char *ip_buf, size_t buf_size) {
|
||||
int fd = ssh_get_fd(session);
|
||||
|
|
@ -815,9 +600,7 @@ static int exec_command_stats(client_t *client, bool json) {
|
|||
client_capacity = g_room->client_capacity;
|
||||
pthread_rwlock_unlock(&g_room->lock);
|
||||
|
||||
pthread_mutex_lock(&g_conn_count_lock);
|
||||
active_connections = g_total_connections;
|
||||
pthread_mutex_unlock(&g_conn_count_lock);
|
||||
active_connections = ratelimit_get_active_total();
|
||||
|
||||
uptime_seconds = (g_server_start_time > 0 && now >= g_server_start_time)
|
||||
? (long)(now - g_server_start_time)
|
||||
|
|
@ -1810,7 +1593,7 @@ cleanup:
|
|||
message_save(&leave_msg);
|
||||
}
|
||||
|
||||
release_ip_connection(client->client_ip);
|
||||
ratelimit_release_ip(client->client_ip);
|
||||
|
||||
/* Remove channel callbacks before releasing refs to prevent use-after-free
|
||||
* if a callback fires between the two releases. */
|
||||
|
|
@ -1825,7 +1608,7 @@ cleanup:
|
|||
client_release(client);
|
||||
|
||||
/* Decrement connection count */
|
||||
decrement_connections();
|
||||
ratelimit_decrement_total();
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -1846,7 +1629,7 @@ static int auth_password(ssh_session session, const char *user,
|
|||
|
||||
/* Limit auth attempts */
|
||||
if (ctx->auth_attempts > 3) {
|
||||
record_auth_failure(ctx->client_ip);
|
||||
ratelimit_record_auth_failure(ctx->client_ip);
|
||||
fprintf(stderr, "Too many auth attempts from %s\n", ctx->client_ip);
|
||||
ssh_disconnect(session);
|
||||
return SSH_AUTH_DENIED;
|
||||
|
|
@ -1861,7 +1644,7 @@ static int auth_password(ssh_session session, const char *user,
|
|||
} else {
|
||||
/* Wrong token — IP blocking handles brute force, no sleep needed here
|
||||
* (sleeping in a libssh callback blocks the entire accept loop). */
|
||||
record_auth_failure(ctx->client_ip);
|
||||
ratelimit_record_auth_failure(ctx->client_ip);
|
||||
return SSH_AUTH_DENIED;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1948,10 +1731,10 @@ static void cleanup_failed_session(ssh_session session, session_context_t *ctx)
|
|||
}
|
||||
|
||||
if (ctx) {
|
||||
release_ip_connection(ctx->client_ip);
|
||||
ratelimit_release_ip(ctx->client_ip);
|
||||
}
|
||||
destroy_session_context(ctx);
|
||||
decrement_connections();
|
||||
ratelimit_decrement_total();
|
||||
}
|
||||
|
||||
static void setup_session_channel_callbacks(ssh_channel channel,
|
||||
|
|
@ -2168,10 +1951,10 @@ static void *bootstrap_client_session(void *arg) {
|
|||
|
||||
ctx = calloc(1, sizeof(session_context_t));
|
||||
if (!ctx) {
|
||||
release_ip_connection(accepted_ip);
|
||||
ratelimit_release_ip(accepted_ip);
|
||||
ssh_disconnect(session);
|
||||
ssh_free(session);
|
||||
decrement_connections();
|
||||
ratelimit_decrement_total();
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
|
@ -2315,8 +2098,16 @@ static void *bootstrap_client_session(void *arg) {
|
|||
|
||||
/* Initialize SSH server */
|
||||
int ssh_server_init(int port) {
|
||||
/* Initialize rate limiting configuration */
|
||||
init_rate_limit_config();
|
||||
/* Initialize rate-limit / connection-count subsystem */
|
||||
ratelimit_init();
|
||||
|
||||
/* Auth / session config (will move into auth.c and input.c in later PR2 steps) */
|
||||
g_idle_timeout = env_int("TNT_IDLE_TIMEOUT", DEFAULT_IDLE_TIMEOUT, 0, 86400);
|
||||
const char *token_env = getenv("TNT_ACCESS_TOKEN");
|
||||
if (token_env != NULL) {
|
||||
strncpy(g_access_token, token_env, sizeof(g_access_token) - 1);
|
||||
g_access_token[sizeof(g_access_token) - 1] = '\0';
|
||||
}
|
||||
g_listen_port = port;
|
||||
g_server_start_time = time(NULL);
|
||||
|
||||
|
|
@ -2392,15 +2183,15 @@ int ssh_server_start(int unused) {
|
|||
get_client_ip(session, client_ip, sizeof(client_ip));
|
||||
|
||||
/* Check total connection limit */
|
||||
if (!check_and_increment_connections()) {
|
||||
if (!ratelimit_check_and_increment_total()) {
|
||||
fprintf(stderr, "Max connections reached, rejecting %s\n", client_ip);
|
||||
ssh_disconnect(session);
|
||||
ssh_free(session);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!check_ip_connection_policy(client_ip)) {
|
||||
decrement_connections();
|
||||
if (!ratelimit_check_ip(client_ip)) {
|
||||
ratelimit_decrement_total();
|
||||
ssh_disconnect(session);
|
||||
ssh_free(session);
|
||||
continue;
|
||||
|
|
@ -2408,8 +2199,8 @@ int ssh_server_start(int unused) {
|
|||
|
||||
accepted = calloc(1, sizeof(*accepted));
|
||||
if (!accepted) {
|
||||
release_ip_connection(client_ip);
|
||||
decrement_connections();
|
||||
ratelimit_release_ip(client_ip);
|
||||
ratelimit_decrement_total();
|
||||
ssh_disconnect(session);
|
||||
ssh_free(session);
|
||||
continue;
|
||||
|
|
@ -2422,8 +2213,8 @@ int ssh_server_start(int unused) {
|
|||
if (pthread_create(&thread, &attr, bootstrap_client_session, accepted) != 0) {
|
||||
fprintf(stderr, "Thread creation failed: %s\n", strerror(errno));
|
||||
free(accepted);
|
||||
release_ip_connection(client_ip);
|
||||
decrement_connections();
|
||||
ratelimit_release_ip(client_ip);
|
||||
ratelimit_decrement_total();
|
||||
ssh_disconnect(session);
|
||||
ssh_free(session);
|
||||
continue;
|
||||
|
|
|
|||
Loading…
Reference in a new issue