refactor: extract ratelimit module (PR2-M1)

Move IP rate-limiting, auth-failure tracking, and global connection
counting out of ssh_server.c into a dedicated module.

New API (include/ratelimit.h):
- ratelimit_init()
- ratelimit_check_ip() / ratelimit_release_ip()
- ratelimit_record_auth_failure()
- ratelimit_check_and_increment_total() / ratelimit_decrement_total()
- ratelimit_get_active_total()  (replaces the direct g_total_connections
  read that exec_command_stats was doing under g_conn_count_lock)

env_int() also moves up to common.{c,h} since multiple modules need it.

ssh_server.c drops from 2469 to 2200 lines.  Behaviour is preserved:
the new functions are byte-for-byte the same implementations, only the
file boundary moved.

g_idle_timeout and g_access_token reads stay inline in ssh_server_init()
for now; they will follow the auth.c and input.c extractions later.
This commit is contained in:
m1ngsama 2026-05-16 23:06:56 +08:00
parent d9382882d1
commit 562ee5296d
5 changed files with 282 additions and 239 deletions

View file

@ -62,4 +62,9 @@ void buffer_append_bytes(char *buffer, size_t buf_size, size_t *pos,
void buffer_appendf(char *buffer, size_t buf_size, size_t *pos, void buffer_appendf(char *buffer, size_t buf_size, size_t *pos,
const char *fmt, ...); const char *fmt, ...);
/* Parse an integer from `getenv(name)`, clamping accepted values to
* [min_val, max_val]. Returns `fallback` when the variable is unset, empty,
* non-numeric, or out of range. */
int env_int(const char *name, int fallback, int min_val, int max_val);
#endif /* COMMON_H */ #endif /* COMMON_H */

27
include/ratelimit.h Normal file
View file

@ -0,0 +1,27 @@
#ifndef RATELIMIT_H
#define RATELIMIT_H
#include <stdbool.h>
/* Read TNT_MAX_CONNECTIONS / TNT_MAX_CONN_PER_IP / TNT_MAX_CONN_RATE_PER_IP /
* TNT_RATE_LIMIT from the environment. Idempotent, call once at startup. */
void ratelimit_init(void);
/* Per-IP entry point: returns false if the IP has hit any limit (concurrent,
* rate, or block). On success, increments the IP's active counter caller
* MUST pair with ratelimit_release_ip() when the connection ends. */
bool ratelimit_check_ip(const char *ip);
void ratelimit_release_ip(const char *ip);
/* Auth-failure ledger. After enough failures within the window the IP is
* blocked for a fixed duration. */
void ratelimit_record_auth_failure(const char *ip);
/* Global active-connection cap (separate from per-IP). Pair them. */
bool ratelimit_check_and_increment_total(void);
void ratelimit_decrement_total(void);
/* Read-only accessor for stats subcommand. */
int ratelimit_get_active_total(void);
#endif /* RATELIMIT_H */

View file

@ -126,3 +126,12 @@ void buffer_appendf(char *buffer, size_t buf_size, size_t *pos,
*pos += (size_t)written; *pos += (size_t)written;
} }
} }
int env_int(const char *name, int fallback, int min_val, int max_val) {
const char *env = getenv(name);
if (!env || env[0] == '\0') return fallback;
char *end;
long val = strtol(env, &end, 10);
if (*end != '\0' || val < min_val || val > max_val) return fallback;
return (int)val;
}

211
src/ratelimit.c Normal file
View file

@ -0,0 +1,211 @@
#include "ratelimit.h"
#include "common.h"
#include <arpa/inet.h>
#include <pthread.h>
#include <stdbool.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#define MAX_TRACKED_IPS 256
#define RATE_LIMIT_WINDOW 60 /* seconds */
#define MAX_AUTH_FAILURES 5 /* auth failures before block */
#define BLOCK_DURATION 300 /* seconds to block after too many failures */
typedef struct {
char ip[INET6_ADDRSTRLEN];
time_t window_start;
int recent_connection_count;
int active_connections;
int auth_failure_count;
bool is_blocked;
time_t block_until;
} ip_rate_limit_t;
static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS];
static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER;
static int g_total_connections = 0;
static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER;
static int g_max_connections = 64;
static int g_max_conn_per_ip = 5;
static int g_max_conn_rate_per_ip = 10;
static int g_rate_limit_enabled = 1;
void ratelimit_init(void) {
g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024);
g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024);
g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024);
g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1);
}
/* Caller MUST hold g_rate_limit_lock. */
static ip_rate_limit_t* get_rate_limit_entry(const char *ip) {
/* Look for existing entry */
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
if (strcmp(g_rate_limits[i].ip, ip) == 0) {
return &g_rate_limits[i];
}
}
/* Find empty slot */
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
if (g_rate_limits[i].ip[0] == '\0') {
strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1);
g_rate_limits[i].window_start = time(NULL);
g_rate_limits[i].recent_connection_count = 0;
g_rate_limits[i].active_connections = 0;
g_rate_limits[i].auth_failure_count = 0;
g_rate_limits[i].is_blocked = false;
g_rate_limits[i].block_until = 0;
return &g_rate_limits[i];
}
}
/* Reuse the oldest inactive entry first so active IP accounting stays intact. */
int oldest_idx = -1;
time_t oldest_time = 0;
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
if (g_rate_limits[i].active_connections != 0) {
continue;
}
if (oldest_idx < 0 || g_rate_limits[i].window_start < oldest_time) {
oldest_time = g_rate_limits[i].window_start;
oldest_idx = i;
}
}
if (oldest_idx < 0) {
/* All slots have active connections — evicting will corrupt their
* concurrency accounting. Pick the oldest entry but warn. */
oldest_idx = 0;
oldest_time = g_rate_limits[0].window_start;
for (int i = 1; i < MAX_TRACKED_IPS; i++) {
if (g_rate_limits[i].window_start < oldest_time) {
oldest_time = g_rate_limits[i].window_start;
oldest_idx = i;
}
}
fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s "
"(%d active connections lost)\n",
g_rate_limits[oldest_idx].ip,
g_rate_limits[oldest_idx].active_connections);
}
/* Reset and reuse */
strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1);
g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0';
g_rate_limits[oldest_idx].window_start = time(NULL);
g_rate_limits[oldest_idx].recent_connection_count = 0;
g_rate_limits[oldest_idx].active_connections = 0;
g_rate_limits[oldest_idx].auth_failure_count = 0;
g_rate_limits[oldest_idx].is_blocked = false;
g_rate_limits[oldest_idx].block_until = 0;
return &g_rate_limits[oldest_idx];
}
bool ratelimit_check_ip(const char *ip) {
time_t now = time(NULL);
pthread_mutex_lock(&g_rate_limit_lock);
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
if (entry->active_connections >= g_max_conn_per_ip) {
pthread_mutex_unlock(&g_rate_limit_lock);
fprintf(stderr, "Concurrent IP limit reached for %s\n", ip);
return false;
}
if (g_rate_limit_enabled && entry->is_blocked && now < entry->block_until) {
pthread_mutex_unlock(&g_rate_limit_lock);
fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until);
return false;
}
if (g_rate_limit_enabled && entry->is_blocked && now >= entry->block_until) {
entry->is_blocked = false;
entry->auth_failure_count = 0;
}
if (g_rate_limit_enabled) {
if (now - entry->window_start >= RATE_LIMIT_WINDOW) {
entry->window_start = now;
entry->recent_connection_count = 0;
}
entry->recent_connection_count++;
if (entry->recent_connection_count >= g_max_conn_rate_per_ip) {
entry->is_blocked = true;
entry->block_until = now + BLOCK_DURATION;
pthread_mutex_unlock(&g_rate_limit_lock);
fprintf(stderr, "Rate limit exceeded for IP %s\n", ip);
return false;
}
}
entry->active_connections++;
pthread_mutex_unlock(&g_rate_limit_lock);
return true;
}
void ratelimit_record_auth_failure(const char *ip) {
time_t now = time(NULL);
if (!g_rate_limit_enabled) {
return;
}
pthread_mutex_lock(&g_rate_limit_lock);
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
entry->auth_failure_count++;
if (entry->auth_failure_count >= MAX_AUTH_FAILURES) {
entry->is_blocked = true;
entry->block_until = now + BLOCK_DURATION;
fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count);
}
pthread_mutex_unlock(&g_rate_limit_lock);
}
void ratelimit_release_ip(const char *ip) {
if (!ip || ip[0] == '\0') {
return;
}
pthread_mutex_lock(&g_rate_limit_lock);
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
if (entry->active_connections > 0) {
entry->active_connections--;
}
pthread_mutex_unlock(&g_rate_limit_lock);
}
bool ratelimit_check_and_increment_total(void) {
pthread_mutex_lock(&g_conn_count_lock);
if (g_total_connections >= g_max_connections) {
pthread_mutex_unlock(&g_conn_count_lock);
return false;
}
g_total_connections++;
pthread_mutex_unlock(&g_conn_count_lock);
return true;
}
void ratelimit_decrement_total(void) {
pthread_mutex_lock(&g_conn_count_lock);
if (g_total_connections > 0) {
g_total_connections--;
}
pthread_mutex_unlock(&g_conn_count_lock);
}
int ratelimit_get_active_total(void) {
int count;
pthread_mutex_lock(&g_conn_count_lock);
count = g_total_connections;
pthread_mutex_unlock(&g_conn_count_lock);
return count;
}

View file

@ -1,4 +1,5 @@
#include "ssh_server.h" #include "ssh_server.h"
#include "ratelimit.h"
#include "tui.h" #include "tui.h"
#include "utf8.h" #include "utf8.h"
#include <libssh/libssh.h> #include <libssh/libssh.h>
@ -38,33 +39,11 @@ typedef struct {
char client_ip[INET6_ADDRSTRLEN]; char client_ip[INET6_ADDRSTRLEN];
} accepted_session_t; } accepted_session_t;
/* Rate limiting and connection tracking */
#define MAX_TRACKED_IPS 256
#define RATE_LIMIT_WINDOW 60 /* seconds */
#define MAX_AUTH_FAILURES 5 /* auth failures before block */
#define BLOCK_DURATION 300 /* seconds to block after too many failures */
typedef struct {
char ip[INET6_ADDRSTRLEN];
time_t window_start;
int recent_connection_count;
int active_connections;
int auth_failure_count;
bool is_blocked;
time_t block_until;
} ip_rate_limit_t;
static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS];
static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER;
static int g_total_connections = 0;
static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER;
static time_t g_server_start_time = 0; static time_t g_server_start_time = 0;
/* Configuration from environment variables */ /* Configuration from environment variables. Rate-limiting / connection-count
static int g_max_connections = 64; * config has moved to ratelimit.{c,h}; the two below stay here until the auth
static int g_max_conn_per_ip = 5; * and input modules are extracted in later PR2 steps. */
static int g_max_conn_rate_per_ip = 10;
static int g_rate_limit_enabled = 1;
static char g_access_token[256] = ""; static char g_access_token[256] = "";
static int g_idle_timeout = DEFAULT_IDLE_TIMEOUT; static int g_idle_timeout = DEFAULT_IDLE_TIMEOUT;
@ -88,200 +67,6 @@ static bool constant_time_strcmp(const char *a, const char *b) {
return length_diff == 0 && byte_diff == 0; return length_diff == 0 && byte_diff == 0;
} }
/* Safe integer parse from environment variable; returns fallback on error. */
static int env_int(const char *name, int fallback, int min_val, int max_val) {
const char *env = getenv(name);
if (!env || env[0] == '\0') return fallback;
char *end;
long val = strtol(env, &end, 10);
if (*end != '\0' || val < min_val || val > max_val) return fallback;
return (int)val;
}
/* Initialize rate limit configuration from environment */
static void init_rate_limit_config(void) {
const char *env;
g_max_connections = env_int("TNT_MAX_CONNECTIONS", 64, 1, 1024);
g_max_conn_per_ip = env_int("TNT_MAX_CONN_PER_IP", 5, 1, 1024);
g_max_conn_rate_per_ip = env_int("TNT_MAX_CONN_RATE_PER_IP", 10, 1, 1024);
g_rate_limit_enabled = env_int("TNT_RATE_LIMIT", 1, 0, 1);
g_idle_timeout = env_int("TNT_IDLE_TIMEOUT", DEFAULT_IDLE_TIMEOUT, 0, 86400);
if ((env = getenv("TNT_ACCESS_TOKEN")) != NULL) {
strncpy(g_access_token, env, sizeof(g_access_token) - 1);
g_access_token[sizeof(g_access_token) - 1] = '\0';
}
}
/* Get or create rate limit entry for an IP */
static ip_rate_limit_t* get_rate_limit_entry(const char *ip) {
/* Look for existing entry */
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
if (strcmp(g_rate_limits[i].ip, ip) == 0) {
return &g_rate_limits[i];
}
}
/* Find empty slot */
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
if (g_rate_limits[i].ip[0] == '\0') {
strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1);
g_rate_limits[i].window_start = time(NULL);
g_rate_limits[i].recent_connection_count = 0;
g_rate_limits[i].active_connections = 0;
g_rate_limits[i].auth_failure_count = 0;
g_rate_limits[i].is_blocked = false;
g_rate_limits[i].block_until = 0;
return &g_rate_limits[i];
}
}
/* Reuse the oldest inactive entry first so active IP accounting stays intact. */
int oldest_idx = -1;
time_t oldest_time = 0;
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
if (g_rate_limits[i].active_connections != 0) {
continue;
}
if (oldest_idx < 0 || g_rate_limits[i].window_start < oldest_time) {
oldest_time = g_rate_limits[i].window_start;
oldest_idx = i;
}
}
if (oldest_idx < 0) {
/* All slots have active connections — evicting will corrupt their
* concurrency accounting. Pick the oldest entry but warn. */
oldest_idx = 0;
oldest_time = g_rate_limits[0].window_start;
for (int i = 1; i < MAX_TRACKED_IPS; i++) {
if (g_rate_limits[i].window_start < oldest_time) {
oldest_time = g_rate_limits[i].window_start;
oldest_idx = i;
}
}
fprintf(stderr, "Warning: rate-limit table full, evicting active IP %s "
"(%d active connections lost)\n",
g_rate_limits[oldest_idx].ip,
g_rate_limits[oldest_idx].active_connections);
}
/* Reset and reuse */
strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1);
g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0';
g_rate_limits[oldest_idx].window_start = time(NULL);
g_rate_limits[oldest_idx].recent_connection_count = 0;
g_rate_limits[oldest_idx].active_connections = 0;
g_rate_limits[oldest_idx].auth_failure_count = 0;
g_rate_limits[oldest_idx].is_blocked = false;
g_rate_limits[oldest_idx].block_until = 0;
return &g_rate_limits[oldest_idx];
}
/* Check rate and concurrency limits for an IP */
static bool check_ip_connection_policy(const char *ip) {
time_t now = time(NULL);
pthread_mutex_lock(&g_rate_limit_lock);
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
if (entry->active_connections >= g_max_conn_per_ip) {
pthread_mutex_unlock(&g_rate_limit_lock);
fprintf(stderr, "Concurrent IP limit reached for %s\n", ip);
return false;
}
if (g_rate_limit_enabled && entry->is_blocked && now < entry->block_until) {
pthread_mutex_unlock(&g_rate_limit_lock);
fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until);
return false;
}
if (g_rate_limit_enabled && entry->is_blocked && now >= entry->block_until) {
entry->is_blocked = false;
entry->auth_failure_count = 0;
}
if (g_rate_limit_enabled) {
if (now - entry->window_start >= RATE_LIMIT_WINDOW) {
entry->window_start = now;
entry->recent_connection_count = 0;
}
entry->recent_connection_count++;
if (entry->recent_connection_count >= g_max_conn_rate_per_ip) {
entry->is_blocked = true;
entry->block_until = now + BLOCK_DURATION;
pthread_mutex_unlock(&g_rate_limit_lock);
fprintf(stderr, "Rate limit exceeded for IP %s\n", ip);
return false;
}
}
entry->active_connections++;
pthread_mutex_unlock(&g_rate_limit_lock);
return true;
}
/* Record authentication failure */
static void record_auth_failure(const char *ip) {
time_t now = time(NULL);
if (!g_rate_limit_enabled) {
return;
}
pthread_mutex_lock(&g_rate_limit_lock);
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
entry->auth_failure_count++;
if (entry->auth_failure_count >= MAX_AUTH_FAILURES) {
entry->is_blocked = true;
entry->block_until = now + BLOCK_DURATION;
fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count);
}
pthread_mutex_unlock(&g_rate_limit_lock);
}
static void release_ip_connection(const char *ip) {
if (!ip || ip[0] == '\0') {
return;
}
pthread_mutex_lock(&g_rate_limit_lock);
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
if (entry->active_connections > 0) {
entry->active_connections--;
}
pthread_mutex_unlock(&g_rate_limit_lock);
}
/* Check and increment total connection count */
static bool check_and_increment_connections(void) {
pthread_mutex_lock(&g_conn_count_lock);
if (g_total_connections >= g_max_connections) {
pthread_mutex_unlock(&g_conn_count_lock);
return false;
}
g_total_connections++;
pthread_mutex_unlock(&g_conn_count_lock);
return true;
}
/* Decrement connection count */
static void decrement_connections(void) {
pthread_mutex_lock(&g_conn_count_lock);
if (g_total_connections > 0) {
g_total_connections--;
}
pthread_mutex_unlock(&g_conn_count_lock);
}
/* Get client IP address */ /* Get client IP address */
static void get_client_ip(ssh_session session, char *ip_buf, size_t buf_size) { static void get_client_ip(ssh_session session, char *ip_buf, size_t buf_size) {
int fd = ssh_get_fd(session); int fd = ssh_get_fd(session);
@ -815,9 +600,7 @@ static int exec_command_stats(client_t *client, bool json) {
client_capacity = g_room->client_capacity; client_capacity = g_room->client_capacity;
pthread_rwlock_unlock(&g_room->lock); pthread_rwlock_unlock(&g_room->lock);
pthread_mutex_lock(&g_conn_count_lock); active_connections = ratelimit_get_active_total();
active_connections = g_total_connections;
pthread_mutex_unlock(&g_conn_count_lock);
uptime_seconds = (g_server_start_time > 0 && now >= g_server_start_time) uptime_seconds = (g_server_start_time > 0 && now >= g_server_start_time)
? (long)(now - g_server_start_time) ? (long)(now - g_server_start_time)
@ -1810,7 +1593,7 @@ cleanup:
message_save(&leave_msg); message_save(&leave_msg);
} }
release_ip_connection(client->client_ip); ratelimit_release_ip(client->client_ip);
/* Remove channel callbacks before releasing refs to prevent use-after-free /* Remove channel callbacks before releasing refs to prevent use-after-free
* if a callback fires between the two releases. */ * if a callback fires between the two releases. */
@ -1825,7 +1608,7 @@ cleanup:
client_release(client); client_release(client);
/* Decrement connection count */ /* Decrement connection count */
decrement_connections(); ratelimit_decrement_total();
return NULL; return NULL;
} }
@ -1846,7 +1629,7 @@ static int auth_password(ssh_session session, const char *user,
/* Limit auth attempts */ /* Limit auth attempts */
if (ctx->auth_attempts > 3) { if (ctx->auth_attempts > 3) {
record_auth_failure(ctx->client_ip); ratelimit_record_auth_failure(ctx->client_ip);
fprintf(stderr, "Too many auth attempts from %s\n", ctx->client_ip); fprintf(stderr, "Too many auth attempts from %s\n", ctx->client_ip);
ssh_disconnect(session); ssh_disconnect(session);
return SSH_AUTH_DENIED; return SSH_AUTH_DENIED;
@ -1861,7 +1644,7 @@ static int auth_password(ssh_session session, const char *user,
} else { } else {
/* Wrong token — IP blocking handles brute force, no sleep needed here /* Wrong token — IP blocking handles brute force, no sleep needed here
* (sleeping in a libssh callback blocks the entire accept loop). */ * (sleeping in a libssh callback blocks the entire accept loop). */
record_auth_failure(ctx->client_ip); ratelimit_record_auth_failure(ctx->client_ip);
return SSH_AUTH_DENIED; return SSH_AUTH_DENIED;
} }
} else { } else {
@ -1948,10 +1731,10 @@ static void cleanup_failed_session(ssh_session session, session_context_t *ctx)
} }
if (ctx) { if (ctx) {
release_ip_connection(ctx->client_ip); ratelimit_release_ip(ctx->client_ip);
} }
destroy_session_context(ctx); destroy_session_context(ctx);
decrement_connections(); ratelimit_decrement_total();
} }
static void setup_session_channel_callbacks(ssh_channel channel, static void setup_session_channel_callbacks(ssh_channel channel,
@ -2168,10 +1951,10 @@ static void *bootstrap_client_session(void *arg) {
ctx = calloc(1, sizeof(session_context_t)); ctx = calloc(1, sizeof(session_context_t));
if (!ctx) { if (!ctx) {
release_ip_connection(accepted_ip); ratelimit_release_ip(accepted_ip);
ssh_disconnect(session); ssh_disconnect(session);
ssh_free(session); ssh_free(session);
decrement_connections(); ratelimit_decrement_total();
return NULL; return NULL;
} }
@ -2315,8 +2098,16 @@ static void *bootstrap_client_session(void *arg) {
/* Initialize SSH server */ /* Initialize SSH server */
int ssh_server_init(int port) { int ssh_server_init(int port) {
/* Initialize rate limiting configuration */ /* Initialize rate-limit / connection-count subsystem */
init_rate_limit_config(); ratelimit_init();
/* Auth / session config (will move into auth.c and input.c in later PR2 steps) */
g_idle_timeout = env_int("TNT_IDLE_TIMEOUT", DEFAULT_IDLE_TIMEOUT, 0, 86400);
const char *token_env = getenv("TNT_ACCESS_TOKEN");
if (token_env != NULL) {
strncpy(g_access_token, token_env, sizeof(g_access_token) - 1);
g_access_token[sizeof(g_access_token) - 1] = '\0';
}
g_listen_port = port; g_listen_port = port;
g_server_start_time = time(NULL); g_server_start_time = time(NULL);
@ -2392,15 +2183,15 @@ int ssh_server_start(int unused) {
get_client_ip(session, client_ip, sizeof(client_ip)); get_client_ip(session, client_ip, sizeof(client_ip));
/* Check total connection limit */ /* Check total connection limit */
if (!check_and_increment_connections()) { if (!ratelimit_check_and_increment_total()) {
fprintf(stderr, "Max connections reached, rejecting %s\n", client_ip); fprintf(stderr, "Max connections reached, rejecting %s\n", client_ip);
ssh_disconnect(session); ssh_disconnect(session);
ssh_free(session); ssh_free(session);
continue; continue;
} }
if (!check_ip_connection_policy(client_ip)) { if (!ratelimit_check_ip(client_ip)) {
decrement_connections(); ratelimit_decrement_total();
ssh_disconnect(session); ssh_disconnect(session);
ssh_free(session); ssh_free(session);
continue; continue;
@ -2408,8 +2199,8 @@ int ssh_server_start(int unused) {
accepted = calloc(1, sizeof(*accepted)); accepted = calloc(1, sizeof(*accepted));
if (!accepted) { if (!accepted) {
release_ip_connection(client_ip); ratelimit_release_ip(client_ip);
decrement_connections(); ratelimit_decrement_total();
ssh_disconnect(session); ssh_disconnect(session);
ssh_free(session); ssh_free(session);
continue; continue;
@ -2422,8 +2213,8 @@ int ssh_server_start(int unused) {
if (pthread_create(&thread, &attr, bootstrap_client_session, accepted) != 0) { if (pthread_create(&thread, &attr, bootstrap_client_session, accepted) != 0) {
fprintf(stderr, "Thread creation failed: %s\n", strerror(errno)); fprintf(stderr, "Thread creation failed: %s\n", strerror(errno));
free(accepted); free(accepted);
release_ip_connection(client_ip); ratelimit_release_ip(client_ip);
decrement_connections(); ratelimit_decrement_total();
ssh_disconnect(session); ssh_disconnect(session);
ssh_free(session); ssh_free(session);
continue; continue;