mirror of
https://github.com/m1ngsama/TNT.git
synced 2026-02-08 00:54:03 +00:00
Merge branch 'fix/auth-protection' into feat/security-audit-fixes
# Conflicts: # src/ssh_server.c
This commit is contained in:
commit
93c29ca2e9
1 changed files with 284 additions and 14 deletions
286
src/ssh_server.c
286
src/ssh_server.c
|
|
@ -17,6 +17,210 @@
|
||||||
/* Global SSH bind instance */
|
/* Global SSH bind instance */
|
||||||
static ssh_bind g_sshbind = NULL;
|
static ssh_bind g_sshbind = NULL;
|
||||||
|
|
||||||
|
/* Rate limiting and connection tracking */
|
||||||
|
#define MAX_TRACKED_IPS 256
|
||||||
|
#define RATE_LIMIT_WINDOW 60 /* seconds */
|
||||||
|
#define MAX_CONN_PER_WINDOW 10 /* connections per IP per window */
|
||||||
|
#define MAX_AUTH_FAILURES 5 /* auth failures before block */
|
||||||
|
#define BLOCK_DURATION 300 /* seconds to block after too many failures */
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
char ip[INET6_ADDRSTRLEN];
|
||||||
|
time_t window_start;
|
||||||
|
int connection_count;
|
||||||
|
int auth_failure_count;
|
||||||
|
bool is_blocked;
|
||||||
|
time_t block_until;
|
||||||
|
} ip_rate_limit_t;
|
||||||
|
|
||||||
|
static ip_rate_limit_t g_rate_limits[MAX_TRACKED_IPS];
|
||||||
|
static pthread_mutex_t g_rate_limit_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||||
|
static int g_total_connections = 0;
|
||||||
|
static pthread_mutex_t g_conn_count_lock = PTHREAD_MUTEX_INITIALIZER;
|
||||||
|
|
||||||
|
/* Configuration from environment variables */
|
||||||
|
static int g_max_connections = 64;
|
||||||
|
static int g_max_conn_per_ip = 5;
|
||||||
|
static int g_rate_limit_enabled = 1;
|
||||||
|
static char g_access_token[256] = "";
|
||||||
|
|
||||||
|
/* Initialize rate limit configuration from environment */
|
||||||
|
static void init_rate_limit_config(void) {
|
||||||
|
const char *env;
|
||||||
|
|
||||||
|
if ((env = getenv("TNT_MAX_CONNECTIONS")) != NULL) {
|
||||||
|
int val = atoi(env);
|
||||||
|
if (val > 0 && val <= 1024) {
|
||||||
|
g_max_connections = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((env = getenv("TNT_MAX_CONN_PER_IP")) != NULL) {
|
||||||
|
int val = atoi(env);
|
||||||
|
if (val > 0 && val <= 100) {
|
||||||
|
g_max_conn_per_ip = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((env = getenv("TNT_RATE_LIMIT")) != NULL) {
|
||||||
|
g_rate_limit_enabled = atoi(env);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((env = getenv("TNT_ACCESS_TOKEN")) != NULL) {
|
||||||
|
strncpy(g_access_token, env, sizeof(g_access_token) - 1);
|
||||||
|
g_access_token[sizeof(g_access_token) - 1] = '\0';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Get or create rate limit entry for an IP */
|
||||||
|
static ip_rate_limit_t* get_rate_limit_entry(const char *ip) {
|
||||||
|
/* Look for existing entry */
|
||||||
|
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||||
|
if (strcmp(g_rate_limits[i].ip, ip) == 0) {
|
||||||
|
return &g_rate_limits[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Find empty slot */
|
||||||
|
for (int i = 0; i < MAX_TRACKED_IPS; i++) {
|
||||||
|
if (g_rate_limits[i].ip[0] == '\0') {
|
||||||
|
strncpy(g_rate_limits[i].ip, ip, sizeof(g_rate_limits[i].ip) - 1);
|
||||||
|
g_rate_limits[i].window_start = time(NULL);
|
||||||
|
g_rate_limits[i].connection_count = 0;
|
||||||
|
g_rate_limits[i].auth_failure_count = 0;
|
||||||
|
g_rate_limits[i].is_blocked = false;
|
||||||
|
g_rate_limits[i].block_until = 0;
|
||||||
|
return &g_rate_limits[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Find oldest entry to replace */
|
||||||
|
int oldest_idx = 0;
|
||||||
|
time_t oldest_time = g_rate_limits[0].window_start;
|
||||||
|
for (int i = 1; i < MAX_TRACKED_IPS; i++) {
|
||||||
|
if (g_rate_limits[i].window_start < oldest_time) {
|
||||||
|
oldest_time = g_rate_limits[i].window_start;
|
||||||
|
oldest_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Reset and reuse */
|
||||||
|
strncpy(g_rate_limits[oldest_idx].ip, ip, sizeof(g_rate_limits[oldest_idx].ip) - 1);
|
||||||
|
g_rate_limits[oldest_idx].ip[sizeof(g_rate_limits[oldest_idx].ip) - 1] = '\0';
|
||||||
|
g_rate_limits[oldest_idx].window_start = time(NULL);
|
||||||
|
g_rate_limits[oldest_idx].connection_count = 0;
|
||||||
|
g_rate_limits[oldest_idx].auth_failure_count = 0;
|
||||||
|
g_rate_limits[oldest_idx].is_blocked = false;
|
||||||
|
g_rate_limits[oldest_idx].block_until = 0;
|
||||||
|
return &g_rate_limits[oldest_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check rate limit for an IP */
|
||||||
|
static bool check_rate_limit(const char *ip) {
|
||||||
|
if (!g_rate_limit_enabled) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
time_t now = time(NULL);
|
||||||
|
|
||||||
|
pthread_mutex_lock(&g_rate_limit_lock);
|
||||||
|
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||||
|
|
||||||
|
/* Check if blocked */
|
||||||
|
if (entry->is_blocked && now < entry->block_until) {
|
||||||
|
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||||
|
fprintf(stderr, "Blocked IP %s (blocked until %ld)\n", ip, (long)entry->block_until);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Unblock if block duration passed */
|
||||||
|
if (entry->is_blocked && now >= entry->block_until) {
|
||||||
|
entry->is_blocked = false;
|
||||||
|
entry->auth_failure_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Reset window if expired */
|
||||||
|
if (now - entry->window_start >= RATE_LIMIT_WINDOW) {
|
||||||
|
entry->window_start = now;
|
||||||
|
entry->connection_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check connection rate */
|
||||||
|
entry->connection_count++;
|
||||||
|
if (entry->connection_count > MAX_CONN_PER_WINDOW) {
|
||||||
|
entry->is_blocked = true;
|
||||||
|
entry->block_until = now + BLOCK_DURATION;
|
||||||
|
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||||
|
fprintf(stderr, "Rate limit exceeded for IP %s\n", ip);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Record authentication failure */
|
||||||
|
static void record_auth_failure(const char *ip) {
|
||||||
|
time_t now = time(NULL);
|
||||||
|
|
||||||
|
pthread_mutex_lock(&g_rate_limit_lock);
|
||||||
|
ip_rate_limit_t *entry = get_rate_limit_entry(ip);
|
||||||
|
|
||||||
|
entry->auth_failure_count++;
|
||||||
|
if (entry->auth_failure_count >= MAX_AUTH_FAILURES) {
|
||||||
|
entry->is_blocked = true;
|
||||||
|
entry->block_until = now + BLOCK_DURATION;
|
||||||
|
fprintf(stderr, "IP %s blocked due to %d auth failures\n", ip, entry->auth_failure_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
pthread_mutex_unlock(&g_rate_limit_lock);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check and increment total connection count */
|
||||||
|
static bool check_and_increment_connections(void) {
|
||||||
|
pthread_mutex_lock(&g_conn_count_lock);
|
||||||
|
|
||||||
|
if (g_total_connections >= g_max_connections) {
|
||||||
|
pthread_mutex_unlock(&g_conn_count_lock);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
g_total_connections++;
|
||||||
|
pthread_mutex_unlock(&g_conn_count_lock);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Decrement connection count */
|
||||||
|
static void decrement_connections(void) {
|
||||||
|
pthread_mutex_lock(&g_conn_count_lock);
|
||||||
|
if (g_total_connections > 0) {
|
||||||
|
g_total_connections--;
|
||||||
|
}
|
||||||
|
pthread_mutex_unlock(&g_conn_count_lock);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Get client IP address */
|
||||||
|
static void get_client_ip(ssh_session session, char *ip_buf, size_t buf_size) {
|
||||||
|
int fd = ssh_get_fd(session);
|
||||||
|
struct sockaddr_storage addr;
|
||||||
|
socklen_t addr_len = sizeof(addr);
|
||||||
|
|
||||||
|
if (getpeername(fd, (struct sockaddr *)&addr, &addr_len) == 0) {
|
||||||
|
if (addr.ss_family == AF_INET) {
|
||||||
|
struct sockaddr_in *s = (struct sockaddr_in *)&addr;
|
||||||
|
inet_ntop(AF_INET, &s->sin_addr, ip_buf, buf_size);
|
||||||
|
} else if (addr.ss_family == AF_INET6) {
|
||||||
|
struct sockaddr_in6 *s = (struct sockaddr_in6 *)&addr;
|
||||||
|
inet_ntop(AF_INET6, &s->sin6_addr, ip_buf, buf_size);
|
||||||
|
} else {
|
||||||
|
strncpy(ip_buf, "unknown", buf_size - 1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
strncpy(ip_buf, "unknown", buf_size - 1);
|
||||||
|
}
|
||||||
|
ip_buf[buf_size - 1] = '\0';
|
||||||
|
}
|
||||||
|
|
||||||
/* Validate username to prevent injection attacks */
|
/* Validate username to prevent injection attacks */
|
||||||
static bool is_valid_username(const char *username) {
|
static bool is_valid_username(const char *username) {
|
||||||
if (!username || username[0] == '\0') {
|
if (!username || username[0] == '\0') {
|
||||||
|
|
@ -43,7 +247,6 @@ static bool is_valid_username(const char *username) {
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Generate or load SSH host key */
|
/* Generate or load SSH host key */
|
||||||
static int setup_host_key(ssh_bind sshbind) {
|
static int setup_host_key(ssh_bind sshbind) {
|
||||||
struct stat st;
|
struct stat st;
|
||||||
|
|
@ -616,31 +819,70 @@ cleanup:
|
||||||
/* Release the main reference - client will be freed when all refs are gone */
|
/* Release the main reference - client will be freed when all refs are gone */
|
||||||
client_release(client);
|
client_release(client);
|
||||||
|
|
||||||
|
/* Decrement connection count */
|
||||||
|
decrement_connections();
|
||||||
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Handle SSH authentication */
|
/* Handle SSH authentication with optional token */
|
||||||
static int handle_auth(ssh_session session) {
|
static int handle_auth(ssh_session session, const char *client_ip) {
|
||||||
ssh_message message;
|
ssh_message message;
|
||||||
|
int auth_attempts = 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
message = ssh_message_get(session);
|
message = ssh_message_get(session);
|
||||||
if (!message) break;
|
if (!message) break;
|
||||||
|
|
||||||
if (ssh_message_type(message) == SSH_REQUEST_AUTH) {
|
if (ssh_message_type(message) == SSH_REQUEST_AUTH) {
|
||||||
|
auth_attempts++;
|
||||||
|
|
||||||
|
/* Limit auth attempts */
|
||||||
|
if (auth_attempts > 3) {
|
||||||
|
record_auth_failure(client_ip);
|
||||||
|
ssh_message_free(message);
|
||||||
|
fprintf(stderr, "Too many auth attempts from %s\n", client_ip);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
if (ssh_message_subtype(message) == SSH_AUTH_METHOD_PASSWORD) {
|
if (ssh_message_subtype(message) == SSH_AUTH_METHOD_PASSWORD) {
|
||||||
/* Accept any password for simplicity */
|
const char *password = ssh_message_auth_password(message);
|
||||||
/* In production, you'd want to verify against a user database */
|
|
||||||
|
/* If access token is configured, require it */
|
||||||
|
if (g_access_token[0] != '\0') {
|
||||||
|
if (password && strcmp(password, g_access_token) == 0) {
|
||||||
|
/* Token matches */
|
||||||
ssh_message_auth_reply_success(message, 0);
|
ssh_message_auth_reply_success(message, 0);
|
||||||
ssh_message_free(message);
|
ssh_message_free(message);
|
||||||
return 0;
|
return 0;
|
||||||
|
} else {
|
||||||
|
/* Wrong token */
|
||||||
|
record_auth_failure(client_ip);
|
||||||
|
ssh_message_reply_default(message);
|
||||||
|
ssh_message_free(message);
|
||||||
|
sleep(2); /* Slow down brute force */
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* No token configured, accept any password */
|
||||||
|
ssh_message_auth_reply_success(message, 0);
|
||||||
|
ssh_message_free(message);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
} else if (ssh_message_subtype(message) == SSH_AUTH_METHOD_NONE) {
|
} else if (ssh_message_subtype(message) == SSH_AUTH_METHOD_NONE) {
|
||||||
/* Accept passwordless authentication for open chatroom */
|
/* If access token is configured, reject passwordless */
|
||||||
|
if (g_access_token[0] != '\0') {
|
||||||
|
ssh_message_reply_default(message);
|
||||||
|
ssh_message_free(message);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
/* No token configured, allow passwordless */
|
||||||
ssh_message_auth_reply_success(message, 0);
|
ssh_message_auth_reply_success(message, 0);
|
||||||
ssh_message_free(message);
|
ssh_message_free(message);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ssh_message_reply_default(message);
|
ssh_message_reply_default(message);
|
||||||
ssh_message_free(message);
|
ssh_message_free(message);
|
||||||
|
|
@ -731,6 +973,9 @@ static int handle_pty_request(ssh_channel channel, client_t *client) {
|
||||||
|
|
||||||
/* Initialize SSH server */
|
/* Initialize SSH server */
|
||||||
int ssh_server_init(int port) {
|
int ssh_server_init(int port) {
|
||||||
|
/* Initialize rate limiting configuration */
|
||||||
|
init_rate_limit_config();
|
||||||
|
|
||||||
g_sshbind = ssh_bind_new();
|
g_sshbind = ssh_bind_new();
|
||||||
if (!g_sshbind) {
|
if (!g_sshbind) {
|
||||||
fprintf(stderr, "Failed to create SSH bind\n");
|
fprintf(stderr, "Failed to create SSH bind\n");
|
||||||
|
|
@ -794,19 +1039,44 @@ int ssh_server_start(int unused) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Get client IP address */
|
||||||
|
char client_ip[INET6_ADDRSTRLEN];
|
||||||
|
get_client_ip(session, client_ip, sizeof(client_ip));
|
||||||
|
|
||||||
|
/* Check rate limit */
|
||||||
|
if (!check_rate_limit(client_ip)) {
|
||||||
|
ssh_disconnect(session);
|
||||||
|
ssh_free(session);
|
||||||
|
sleep(1); /* Slow down blocked clients */
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Check total connection limit */
|
||||||
|
if (!check_and_increment_connections()) {
|
||||||
|
fprintf(stderr, "Max connections reached, rejecting %s\n", client_ip);
|
||||||
|
ssh_disconnect(session);
|
||||||
|
ssh_free(session);
|
||||||
|
sleep(1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
/* Perform key exchange */
|
/* Perform key exchange */
|
||||||
if (ssh_handle_key_exchange(session) != SSH_OK) {
|
if (ssh_handle_key_exchange(session) != SSH_OK) {
|
||||||
fprintf(stderr, "Key exchange failed: %s\n", ssh_get_error(session));
|
fprintf(stderr, "Key exchange failed: %s\n", ssh_get_error(session));
|
||||||
|
decrement_connections();
|
||||||
ssh_disconnect(session);
|
ssh_disconnect(session);
|
||||||
ssh_free(session);
|
ssh_free(session);
|
||||||
|
sleep(1);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Handle authentication */
|
/* Handle authentication */
|
||||||
if (handle_auth(session) < 0) {
|
if (handle_auth(session, client_ip) < 0) {
|
||||||
fprintf(stderr, "Authentication failed\n");
|
fprintf(stderr, "Authentication failed from %s\n", client_ip);
|
||||||
|
decrement_connections();
|
||||||
ssh_disconnect(session);
|
ssh_disconnect(session);
|
||||||
ssh_free(session);
|
ssh_free(session);
|
||||||
|
sleep(2); /* Longer delay for auth failures */
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue