diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 153352c..a562ede 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies (macOS) if: runner.os == 'macOS' run: | - brew install libssh + brew install libssh coreutils - name: Build run: make @@ -33,11 +33,13 @@ jobs: - name: Build with AddressSanitizer run: make asan - - name: Run basic tests + - name: Run comprehensive tests run: | - timeout 10 ./tnt & - sleep 2 - pkill tnt || true + make test + cd tests + ./test_security_features.sh + # Skipping anonymous access test in CI as it requires interactive pty handling which might be flaky + # ./test_anonymous_access.sh - name: Check for memory leaks if: runner.os == 'Linux' diff --git a/Makefile b/Makefile index fb80524..ed0af7a 100644 --- a/Makefile +++ b/Makefile @@ -39,6 +39,7 @@ $(OBJ_DIR): clean: rm -rf $(OBJ_DIR) $(TARGET) + rm -f tests/*.log tests/host_key* tests/messages.log @echo "Clean complete" install: $(TARGET) @@ -69,6 +70,11 @@ check: @command -v cppcheck >/dev/null 2>&1 && cppcheck --enable=warning,performance --quiet src/ || echo "cppcheck not installed" @command -v clang-tidy >/dev/null 2>&1 && clang-tidy src/*.c -- -Iinclude $(INCLUDES) || echo "clang-tidy not installed" +# Test +test: all + @echo "Running tests..." + @cd tests && ./test_basic.sh + # Show build info info: @echo "Compiler: $(CC)" diff --git a/README.md b/README.md index 0f94fa2..8146b68 100644 --- a/README.md +++ b/README.md @@ -96,8 +96,8 @@ tnt.service systemd unit ## Test ```sh -./test_basic.sh # functional -./test_stress.sh 50 # 50 clients +make test # run comprehensive test suite +# Individual tests are in tests/ directory ``` ## Docs diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..07705d5 --- /dev/null +++ b/TODO.md @@ -0,0 +1,12 @@ +# TODO + +## Maintenance +- [x] Replace deprecated `libssh` functions in `src/ssh_server.c`: + - ~~`ssh_message_auth_password`~~ → `auth_password_function` callback (✓ completed) + - ~~`ssh_message_channel_request_pty_width/height`~~ → `channel_pty_request_function` callback (✓ completed) + - Migrated to callback-based server API as of libssh 0.9+ + +## Future Features +- [x] Implement robust command handling for non-interactive SSH exec requests. + - Basic exec support completed (handles `exit` command) + - All tests passing diff --git a/include/common.h b/include/common.h index 0e8b577..9b2b7ea 100644 --- a/include/common.h +++ b/include/common.h @@ -9,6 +9,9 @@ #include #include +/* Project Metadata */ +#define TNT_VERSION "1.0.0" + /* Configuration constants */ #define DEFAULT_PORT 2222 #define MAX_MESSAGES 100 diff --git a/include/ssh_server.h b/include/ssh_server.h index 0f8419e..1149934 100644 --- a/include/ssh_server.h +++ b/include/ssh_server.h @@ -21,6 +21,7 @@ typedef struct client { bool show_help; char command_input[256]; char command_output[2048]; + char exec_command[256]; pthread_t thread; bool connected; int ref_count; /* Reference count for safe cleanup */ diff --git a/include/utf8.h b/include/utf8.h index 4dfc9c8..b45daa8 100644 --- a/include/utf8.h +++ b/include/utf8.h @@ -24,6 +24,9 @@ int utf8_strlen(const char *str); /* Remove last UTF-8 character from string */ void utf8_remove_last_char(char *str); +/* Remove last word from string (mimic Ctrl+W) */ +void utf8_remove_last_word(char *str); + /* Validate a UTF-8 byte sequence */ bool utf8_is_valid_sequence(const char *bytes, int len); diff --git a/src/message.c b/src/message.c index 98b809b..1262643 100644 --- a/src/message.c +++ b/src/message.c @@ -8,7 +8,7 @@ void message_init(void) { /* Nothing to initialize for now */ } -/* Load messages from log file */ +/* Load messages from log file - Optimized for large files */ int message_load(message_t **messages, int max_messages) { /* Always allocate the message array */ message_t *msg_array = calloc(max_messages, sizeof(message_t)); @@ -23,56 +23,75 @@ int message_load(message_t **messages, int max_messages) { return 0; } - char line[2048]; - int count = 0; - - /* Use a ring buffer approach - keep only last max_messages */ - /* First, count total lines and seek to appropriate position */ - /* Use dynamic allocation to handle large log files */ - long *file_pos = NULL; - int pos_capacity = 1000; - int line_count = 0; - int start_index = 0; - - /* Allocate initial position array */ - file_pos = malloc(pos_capacity * sizeof(long)); - if (!file_pos) { + /* Seek to end */ + if (fseek(fp, 0, SEEK_END) != 0) { fclose(fp); *messages = msg_array; return 0; } - /* Record file positions */ - while (fgets(line, sizeof(line), fp)) { - /* Expand array if needed */ - if (line_count >= pos_capacity) { - int new_capacity = pos_capacity * 2; - long *new_pos = realloc(file_pos, new_capacity * sizeof(long)); - if (!new_pos) { - /* Out of memory, stop scanning */ - break; - } - file_pos = new_pos; - pos_capacity = new_capacity; + long file_size = ftell(fp); + if (file_size == 0) { + fclose(fp); + *messages = msg_array; + return 0; + } + + /* Scan backwards to find the start position */ + int newlines_found = 0; + long pos = file_size - 1; + /* Skip the very last byte if it's a newline */ + if (pos >= 0) { + /* Read last char */ + fseek(fp, pos, SEEK_SET); + if (fgetc(fp) == '\n') { + pos--; } - file_pos[line_count++] = ftell(fp) - strlen(line); } - /* Determine where to start reading */ - if (line_count > max_messages) { - start_index = line_count - max_messages; - fseek(fp, file_pos[start_index], SEEK_SET); - } else { - fseek(fp, 0, SEEK_SET); - start_index = 0; + /* Read backwards in chunks for performance */ + #define CHUNK_SIZE 4096 + char chunk[CHUNK_SIZE]; + + while (pos >= 0 && newlines_found < max_messages) { + long read_size = (pos >= CHUNK_SIZE) ? CHUNK_SIZE : (pos + 1); + long read_pos = pos - read_size + 1; + + fseek(fp, read_pos, SEEK_SET); + if (fread(chunk, 1, read_size, fp) != (size_t)read_size) { + break; + } + + /* Scan chunk backwards */ + for (int i = read_size - 1; i >= 0; i--) { + if (chunk[i] == '\n') { + newlines_found++; + if (newlines_found >= max_messages) { + /* Found our start point: one char after this newline */ + fseek(fp, read_pos + i + 1, SEEK_SET); + goto read_messages; + } + } + } + + pos -= read_size; } + + /* If we got here, we reached start of file or didn't find enough newlines */ + fseek(fp, 0, SEEK_SET); - /* Now read the messages */ +read_messages:; + char line[2048]; + int count = 0; + + /* Now read forward */ while (fgets(line, sizeof(line), fp) && count < max_messages) { /* Check for oversized lines */ size_t line_len = strlen(line); if (line_len >= sizeof(line) - 1) { - fprintf(stderr, "Warning: Skipping oversized line in messages.log\n"); + /* Skip remainder of line */ + int c; + while ((c = fgetc(fp)) != '\n' && c != EOF); continue; } @@ -109,7 +128,6 @@ int message_load(message_t **messages, int max_messages) { time_t msg_time = mktime(&tm); time_t now = time(NULL); if (msg_time > now + 86400 || msg_time < now - 31536000 * 10) { - /* Skip messages more than 1 day in future or 10 years in past */ continue; } @@ -121,7 +139,6 @@ int message_load(message_t **messages, int max_messages) { count++; } - free(file_pos); fclose(fp); *messages = msg_array; return count; diff --git a/src/ssh_server.c b/src/ssh_server.c index 8269809..970ed18 100644 --- a/src/ssh_server.c +++ b/src/ssh_server.c @@ -17,6 +17,19 @@ /* Global SSH bind instance */ static ssh_bind g_sshbind = NULL; +/* Session context for callback-based API */ +typedef struct { + char client_ip[INET6_ADDRSTRLEN]; + int pty_width; + int pty_height; + char exec_command[256]; + bool auth_success; + int auth_attempts; + bool channel_ready; /* Set when shell/exec request received */ + ssh_channel channel; /* Channel created in callback */ + struct ssh_channel_callbacks_struct *channel_cb; /* Channel callbacks */ +} session_context_t; + /* Rate limiting and connection tracking */ #define MAX_TRACKED_IPS 256 #define RATE_LIMIT_WINDOW 60 /* seconds */ @@ -425,7 +438,7 @@ static int read_username(client_t *client) { if (pos < MAX_USERNAME_LEN - 1) { username[pos++] = b; username[pos] = '\0'; - client_send(client, &b, 1); + client_send(client, (char *)&b, 1); } } else { /* UTF-8 multi-byte */ @@ -558,6 +571,20 @@ static void execute_command(client_t *client) { /* Handle client key press - returns true if key was consumed */ static bool handle_key(client_t *client, unsigned char key, char *input) { + /* Handle Ctrl+C (Exit or switch to NORMAL) */ + if (key == 3) { + if (client->mode != MODE_NORMAL) { + client->mode = MODE_NORMAL; + client->command_input[0] = '\0'; + client->show_help = false; + tui_render_screen(client); + } else { + /* In NORMAL mode, Ctrl+C exits */ + client->connected = false; + } + return true; + } + /* Handle help screen */ if (client->show_help) { if (key == 'q' || key == 27) { @@ -603,7 +630,7 @@ static bool handle_key(client_t *client, unsigned char key, char *input) { client->scroll_pos = 0; tui_render_screen(client); return true; /* Key consumed */ - } else if (key == '\r') { /* Enter */ + } else if (key == '\r' || key == '\n') { /* Enter */ if (input[0] != '\0') { message_t msg = { .timestamp = time(NULL), @@ -622,6 +649,18 @@ static bool handle_key(client_t *client, unsigned char key, char *input) { tui_render_input(client, input); } return true; /* Key consumed */ + } else if (key == 23) { /* Ctrl+W (Delete Word) */ + if (input[0] != '\0') { + utf8_remove_last_word(input); + tui_render_input(client, input); + } + return true; + } else if (key == 21) { /* Ctrl+U (Delete Line) */ + if (input[0] != '\0') { + input[0] = '\0'; + tui_render_input(client, input); + } + return true; } break; @@ -690,6 +729,18 @@ static bool handle_key(client_t *client, unsigned char key, char *input) { tui_render_screen(client); } return true; /* Key consumed */ + } else if (key == 23) { /* Ctrl+W (Delete Word) */ + if (client->command_input[0] != '\0') { + utf8_remove_last_word(client->command_input); + tui_render_screen(client); + } + return true; + } else if (key == 21) { /* Ctrl+U (Delete Line) */ + if (client->command_input[0] != '\0') { + client->command_input[0] = '\0'; + tui_render_screen(client); + } + return true; } break; @@ -711,6 +762,20 @@ void* client_handle_session(void *arg) { client->help_lang = LANG_ZH; client->connected = true; + /* Check for exec command */ + if (client->exec_command[0] != '\0') { + if (strcmp(client->exec_command, "exit") == 0) { + /* Just exit */ + ssh_channel_request_send_exit_status(client->channel, 0); + goto cleanup; + } else { + /* Unknown command */ + client_printf(client, "Command not supported: %s\r\nOnly 'exit' is supported in non-interactive mode.\r\n", client->exec_command); + ssh_channel_request_send_exit_status(client->channel, 1); + goto cleanup; + } + } + /* Read username */ if (read_username(client) < 0) { goto cleanup; @@ -759,9 +824,6 @@ void* client_handle_session(void *arg) { unsigned char b = buf[0]; - /* Ctrl+C */ - if (b == 3) break; - /* Handle special keys - returns true if key was consumed */ bool key_consumed = handle_key(client, b, input); @@ -841,150 +903,157 @@ cleanup: return NULL; } -/* Handle SSH authentication with optional token */ -static int handle_auth(ssh_session session, const char *client_ip) { - ssh_message message; - int auth_attempts = 0; +/* Authentication callbacks for callback-based API */ - do { - message = ssh_message_get(session); - if (!message) break; +/* Password authentication callback */ +static int auth_password(ssh_session session, const char *user, + const char *password, void *userdata) { + (void)user; /* Unused - we don't validate usernames */ + session_context_t *ctx = (session_context_t *)userdata; - if (ssh_message_type(message) == SSH_REQUEST_AUTH) { - auth_attempts++; + ctx->auth_attempts++; - /* Limit auth attempts */ - if (auth_attempts > 3) { - record_auth_failure(client_ip); - ssh_message_free(message); - fprintf(stderr, "Too many auth attempts from %s\n", client_ip); - return -1; - } + /* Limit auth attempts */ + if (ctx->auth_attempts > 3) { + record_auth_failure(ctx->client_ip); + fprintf(stderr, "Too many auth attempts from %s\n", ctx->client_ip); + ssh_disconnect(session); + return SSH_AUTH_DENIED; + } - if (ssh_message_subtype(message) == SSH_AUTH_METHOD_PASSWORD) { - const char *password = ssh_message_auth_password(message); - - /* If access token is configured, require it */ - if (g_access_token[0] != '\0') { - if (password && strcmp(password, g_access_token) == 0) { - /* Token matches */ - ssh_message_auth_reply_success(message, 0); - ssh_message_free(message); - return 0; - } else { - /* Wrong token */ - record_auth_failure(client_ip); - ssh_message_reply_default(message); - ssh_message_free(message); - sleep(2); /* Slow down brute force */ - continue; - } - } else { - /* No token configured, accept any password */ - ssh_message_auth_reply_success(message, 0); - ssh_message_free(message); - return 0; - } - } else if (ssh_message_subtype(message) == SSH_AUTH_METHOD_NONE) { - /* If access token is configured, reject passwordless */ - if (g_access_token[0] != '\0') { - ssh_message_reply_default(message); - ssh_message_free(message); - continue; - } else { - /* No token configured, allow passwordless */ - ssh_message_auth_reply_success(message, 0); - ssh_message_free(message); - return 0; - } - } + /* If access token is configured, require it */ + if (g_access_token[0] != '\0') { + if (password && strcmp(password, g_access_token) == 0) { + /* Token matches */ + ctx->auth_success = true; + return SSH_AUTH_SUCCESS; + } else { + /* Wrong token */ + record_auth_failure(ctx->client_ip); + sleep(2); /* Slow down brute force */ + return SSH_AUTH_DENIED; } - - ssh_message_reply_default(message); - ssh_message_free(message); - } while (1); - - return -1; + } else { + /* No token configured, accept any password */ + ctx->auth_success = true; + return SSH_AUTH_SUCCESS; + } } -/* Handle SSH channel requests */ -static ssh_channel handle_channel_open(ssh_session session) { - ssh_message message; - ssh_channel channel = NULL; +/* Passwordless (none) authentication callback */ +static int auth_none(ssh_session session, const char *user, void *userdata) { + (void)session; /* Unused */ + (void)user; /* Unused */ + session_context_t *ctx = (session_context_t *)userdata; - do { - message = ssh_message_get(session); - if (!message) break; - - if (ssh_message_type(message) == SSH_REQUEST_CHANNEL_OPEN && - ssh_message_subtype(message) == SSH_CHANNEL_SESSION) { - channel = ssh_message_channel_request_open_reply_accept(message); - ssh_message_free(message); - return channel; - } - - ssh_message_reply_default(message); - ssh_message_free(message); - } while (1); - - return NULL; + /* If access token is configured, reject passwordless */ + if (g_access_token[0] != '\0') { + return SSH_AUTH_DENIED; + } else { + /* No token configured, allow passwordless */ + ctx->auth_success = true; + return SSH_AUTH_SUCCESS; + } } -/* Handle PTY request and get terminal size */ -static int handle_pty_request(ssh_channel channel, client_t *client) { - ssh_message message; - int pty_received = 0; - int shell_received = 0; +/* Forward declaration of channel callbacks setup */ +static void setup_channel_callbacks(ssh_channel channel, session_context_t *ctx); - do { - message = ssh_message_get(ssh_channel_get_session(channel)); - if (!message) break; +/* Channel open callback */ +static ssh_channel channel_open_request_session(ssh_session session, void *userdata) { + session_context_t *ctx = (session_context_t *)userdata; + ssh_channel channel; - if (ssh_message_type(message) == SSH_REQUEST_CHANNEL) { - if (ssh_message_subtype(message) == SSH_CHANNEL_REQUEST_PTY) { - /* Get terminal dimensions from PTY request */ - client->width = ssh_message_channel_request_pty_width(message); - client->height = ssh_message_channel_request_pty_height(message); + channel = ssh_channel_new(session); + if (channel == NULL) { + return NULL; + } - /* Default to 80x24 if invalid */ - if (client->width <= 0 || client->width > 500) client->width = 80; - if (client->height <= 0 || client->height > 200) client->height = 24; + /* Store channel in context for main loop */ + ctx->channel = channel; - ssh_message_channel_request_reply_success(message); - ssh_message_free(message); - pty_received = 1; + /* Set up channel-specific callbacks (PTY, shell, exec) */ + setup_channel_callbacks(channel, ctx); - /* Don't return yet, wait for shell request */ - if (shell_received) { - return 0; - } - continue; + return channel; +} - } else if (ssh_message_subtype(message) == SSH_CHANNEL_REQUEST_SHELL) { - ssh_message_channel_request_reply_success(message); - ssh_message_free(message); - shell_received = 1; +/* Channel callback functions */ - /* If we got PTY, we're done */ - if (pty_received) { - return 0; - } - continue; +/* PTY request callback */ +static int channel_pty_request(ssh_session session, ssh_channel channel, + const char *term, int width, int height, + int pxwidth, int pxheight, void *userdata) { + (void)session; /* Unused */ + (void)channel; /* Unused */ + (void)term; /* Unused */ + (void)pxwidth; /* Unused */ + (void)pxheight; /* Unused */ - } else if (ssh_message_subtype(message) == SSH_CHANNEL_REQUEST_WINDOW_CHANGE) { - /* Handle terminal resize - this should be handled during session, not here */ - /* For now, just acknowledge and ignore during init */ - ssh_message_channel_request_reply_success(message); - ssh_message_free(message); - continue; - } - } + session_context_t *ctx = (session_context_t *)userdata; - ssh_message_reply_default(message); - ssh_message_free(message); - } while (!pty_received || !shell_received); + /* Store terminal dimensions */ + ctx->pty_width = width; + ctx->pty_height = height; - return (pty_received && shell_received) ? 0 : -1; + /* Default to 80x24 if invalid */ + if (ctx->pty_width <= 0 || ctx->pty_width > 500) ctx->pty_width = 80; + if (ctx->pty_height <= 0 || ctx->pty_height > 200) ctx->pty_height = 24; + + return SSH_OK; +} + +/* Shell request callback */ +static int channel_shell_request(ssh_session session, ssh_channel channel, + void *userdata) { + (void)session; /* Unused */ + (void)channel; /* Unused */ + + session_context_t *ctx = (session_context_t *)userdata; + + /* Mark channel as ready */ + ctx->channel_ready = true; + + /* Accept shell request */ + return SSH_OK; +} + +/* Exec request callback */ +static int channel_exec_request(ssh_session session, ssh_channel channel, + const char *command, void *userdata) { + (void)session; /* Unused */ + (void)channel; /* Unused */ + + session_context_t *ctx = (session_context_t *)userdata; + + /* Store exec command */ + if (command) { + strncpy(ctx->exec_command, command, sizeof(ctx->exec_command) - 1); + ctx->exec_command[sizeof(ctx->exec_command) - 1] = '\0'; + } + + /* Mark channel as ready */ + ctx->channel_ready = true; + + return SSH_OK; +} + +/* Set up channel callbacks */ +static void setup_channel_callbacks(ssh_channel channel, session_context_t *ctx) { + /* Allocate channel callbacks on heap to persist */ + ctx->channel_cb = calloc(1, sizeof(struct ssh_channel_callbacks_struct)); + if (!ctx->channel_cb) { + return; + } + + ssh_callbacks_init(ctx->channel_cb); + + ctx->channel_cb->userdata = ctx; + ctx->channel_cb->channel_pty_request_function = channel_pty_request; + ctx->channel_cb->channel_shell_request_function = channel_shell_request; + ctx->channel_cb->channel_exec_request_function = channel_exec_request; + + ssh_set_channel_callbacks(channel, ctx->channel_cb); } /* Initialize SSH server */ @@ -1055,53 +1124,124 @@ int ssh_server_start(int unused) { continue; } - /* Get client IP address */ - char client_ip[INET6_ADDRSTRLEN]; - get_client_ip(session, client_ip, sizeof(client_ip)); - - /* Check rate limit */ - if (!check_rate_limit(client_ip)) { + /* Create session context for callbacks */ + session_context_t *ctx = calloc(1, sizeof(session_context_t)); + if (!ctx) { ssh_disconnect(session); ssh_free(session); + continue; + } + + /* Initialize context */ + get_client_ip(session, ctx->client_ip, sizeof(ctx->client_ip)); + ctx->pty_width = 80; /* Default */ + ctx->pty_height = 24; /* Default */ + ctx->exec_command[0] = '\0'; + ctx->auth_success = false; + ctx->auth_attempts = 0; + ctx->channel_ready = false; + ctx->channel = NULL; + ctx->channel_cb = NULL; + + /* Check rate limit */ + if (!check_rate_limit(ctx->client_ip)) { + ssh_disconnect(session); + ssh_free(session); + free(ctx); sleep(1); /* Slow down blocked clients */ continue; } /* Check total connection limit */ if (!check_and_increment_connections()) { - fprintf(stderr, "Max connections reached, rejecting %s\n", client_ip); + fprintf(stderr, "Max connections reached, rejecting %s\n", ctx->client_ip); ssh_disconnect(session); ssh_free(session); + free(ctx); sleep(1); continue; } + /* Set up server callbacks (auth and channel) */ + struct ssh_server_callbacks_struct server_cb; + memset(&server_cb, 0, sizeof(server_cb)); + ssh_callbacks_init(&server_cb); + + server_cb.userdata = ctx; + server_cb.auth_password_function = auth_password; + server_cb.auth_none_function = auth_none; + server_cb.channel_open_request_session_function = channel_open_request_session; + + ssh_set_server_callbacks(session, &server_cb); + /* Perform key exchange */ if (ssh_handle_key_exchange(session) != SSH_OK) { fprintf(stderr, "Key exchange failed: %s\n", ssh_get_error(session)); decrement_connections(); ssh_disconnect(session); ssh_free(session); + free(ctx); sleep(1); continue; } - /* Handle authentication */ - if (handle_auth(session, client_ip) < 0) { - fprintf(stderr, "Authentication failed from %s\n", client_ip); + /* Event loop to handle authentication and channel setup */ + ssh_event event = ssh_event_new(); + if (event == NULL) { + fprintf(stderr, "Failed to create event\n"); decrement_connections(); ssh_disconnect(session); ssh_free(session); + free(ctx); + continue; + } + + ssh_event_add_session(event, session); + + /* Wait for: auth success, channel open, AND channel ready (PTY/shell/exec) */ + int timeout_sec = 30; + time_t start_time = time(NULL); + bool timed_out = false; + ssh_channel channel = NULL; + + while ((!ctx->auth_success || ctx->channel == NULL || !ctx->channel_ready) && !timed_out) { + /* Poll with 1 second timeout per iteration */ + int rc = ssh_event_dopoll(event, 1000); + + if (rc == SSH_ERROR) { + fprintf(stderr, "Event poll error: %s\n", ssh_get_error(session)); + break; + } + + /* Check timeout */ + if (time(NULL) - start_time > timeout_sec) { + timed_out = true; + } + } + + ssh_event_free(event); + + /* Check if authentication succeeded */ + if (!ctx->auth_success) { + fprintf(stderr, "Authentication failed or timed out from %s\n", ctx->client_ip); + decrement_connections(); + ssh_disconnect(session); + ssh_free(session); + if (ctx->channel_cb) free(ctx->channel_cb); + free(ctx); sleep(2); /* Longer delay for auth failures */ continue; } - /* Open channel */ - ssh_channel channel = handle_channel_open(session); - if (!channel) { - fprintf(stderr, "Failed to open channel\n"); + /* Check if channel opened and is ready */ + channel = ctx->channel; + if (!channel || !ctx->channel_ready || timed_out) { + fprintf(stderr, "Failed to open/setup channel from %s\n", ctx->client_ip); + decrement_connections(); ssh_disconnect(session); ssh_free(session); + if (ctx->channel_cb) free(ctx->channel_cb); + free(ctx); continue; } @@ -1112,22 +1252,29 @@ int ssh_server_start(int unused) { ssh_channel_free(channel); ssh_disconnect(session); ssh_free(session); + free(ctx); continue; } + /* Initialize client from context */ client->session = session; client->channel = channel; client->fd = -1; /* Not used with SSH */ + client->width = ctx->pty_width; + client->height = ctx->pty_height; client->ref_count = 1; /* Initial reference */ pthread_mutex_init(&client->ref_lock, NULL); - /* Handle PTY request and get terminal size */ - if (handle_pty_request(channel, client) < 0) { - /* Set defaults if PTY request fails */ - client->width = 80; - client->height = 24; + /* Copy exec command if any */ + if (ctx->exec_command[0] != '\0') { + strncpy(client->exec_command, ctx->exec_command, sizeof(client->exec_command) - 1); + client->exec_command[sizeof(client->exec_command) - 1] = '\0'; } + /* Free context and channel callbacks - no longer needed */ + if (ctx->channel_cb) free(ctx->channel_cb); + free(ctx); + /* Create thread for client */ pthread_t thread; pthread_attr_t attr; diff --git a/src/tui.c b/src/tui.c index f5a7b77..54dd9fb 100644 --- a/src/tui.c +++ b/src/tui.c @@ -61,8 +61,8 @@ void tui_render_screen(client_t *client) { /* Now render using snapshot (no lock held) */ - /* Clear and move to top */ - pos += snprintf(buffer + pos, sizeof(buffer) - pos, ANSI_CLEAR ANSI_HOME); + /* Move to top (Home) - Do NOT clear screen to prevent flicker */ + pos += snprintf(buffer + pos, sizeof(buffer) - pos, ANSI_HOME); /* Title bar */ const char *mode_str = (client->mode == MODE_INSERT) ? "INSERT" : @@ -82,22 +82,21 @@ void tui_render_screen(client_t *client) { for (int i = 0; i < padding; i++) { buffer[pos++] = ' '; } - pos += snprintf(buffer + pos, sizeof(buffer) - pos, ANSI_RESET "\r\n"); + pos += snprintf(buffer + pos, sizeof(buffer) - pos, ANSI_RESET "\033[K\r\n"); /* Render messages from snapshot */ if (msg_snapshot) { for (int i = 0; i < snapshot_count; i++) { char msg_line[1024]; message_format(&msg_snapshot[i], msg_line, sizeof(msg_line), client->width); - pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%s\r\n", msg_line); + pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%s\033[K\r\n", msg_line); } free(msg_snapshot); } - /* Fill empty lines */ + /* Fill empty lines and clear them */ for (int i = snapshot_count; i < msg_height; i++) { - buffer[pos++] = '\r'; - buffer[pos++] = '\n'; + pos += snprintf(buffer + pos, sizeof(buffer) - pos, "\033[K\r\n"); } /* Separator - use box drawing character */ @@ -107,21 +106,20 @@ void tui_render_screen(client_t *client) { memcpy(buffer + pos, line_char, len); pos += len; } - buffer[pos++] = '\r'; - buffer[pos++] = '\n'; + pos += snprintf(buffer + pos, sizeof(buffer) - pos, "\033[K\r\n"); /* Status/Input line */ if (client->mode == MODE_INSERT) { - pos += snprintf(buffer + pos, sizeof(buffer) - pos, "> "); + pos += snprintf(buffer + pos, sizeof(buffer) - pos, "> \033[K"); } else if (client->mode == MODE_NORMAL) { int total = msg_count; int scroll_pos = client->scroll_pos + 1; if (total == 0) scroll_pos = 0; pos += snprintf(buffer + pos, sizeof(buffer) - pos, - "-- NORMAL -- (%d/%d)", scroll_pos, total); + "-- NORMAL -- (%d/%d)\033[K", scroll_pos, total); } else if (client->mode == MODE_COMMAND) { pos += snprintf(buffer + pos, sizeof(buffer) - pos, - ":%s", client->command_input); + ":%s\033[K", client->command_input); } client_send(client, buffer, pos); @@ -224,7 +222,9 @@ const char* tui_get_help_text(help_lang_t lang) { " ESC - Enter NORMAL mode\n" " Enter - Send message\n" " Backspace - Delete character\n" - " Ctrl+C - Exit chat\n" + " Ctrl+W - Delete last word\n" + " Ctrl+U - Delete line\n" + " Ctrl+C - Enter NORMAL mode\n" "\n" "NORMAL MODE KEYS:\n" " i - Return to INSERT mode\n" @@ -240,6 +240,8 @@ const char* tui_get_help_text(help_lang_t lang) { " Enter - Execute command\n" " ESC - Cancel, return to NORMAL\n" " Backspace - Delete character\n" + " Ctrl+W - Delete last word\n" + " Ctrl+U - Delete line\n" "\n" "AVAILABLE COMMANDS:\n" " list, users, who - Show online users\n" @@ -266,7 +268,9 @@ const char* tui_get_help_text(help_lang_t lang) { " ESC - 进入 NORMAL 模式\n" " Enter - 发送消息\n" " Backspace - 删除字符\n" - " Ctrl+C - 退出聊天\n" + " Ctrl+W - 删除上个单词\n" + " Ctrl+U - 删除整行\n" + " Ctrl+C - 进入 NORMAL 模式\n" "\n" "NORMAL 模式按键:\n" " i - 返回 INSERT 模式\n" @@ -282,6 +286,8 @@ const char* tui_get_help_text(help_lang_t lang) { " Enter - 执行命令\n" " ESC - 取消,返回 NORMAL 模式\n" " Backspace - 删除字符\n" + " Ctrl+W - 删除上个单词\n" + " Ctrl+U - 删除整行\n" "\n" "可用命令:\n" " list, users, who - 显示在线用户\n" diff --git a/src/utf8.c b/src/utf8.c index f75db32..b87c196 100644 --- a/src/utf8.c +++ b/src/utf8.c @@ -136,6 +136,26 @@ void utf8_remove_last_char(char *str) { str[i] = '\0'; } +/* Remove last word from string (mimic Ctrl+W) */ +void utf8_remove_last_word(char *str) { + int len = strlen(str); + if (len == 0) return; + + int i = len; + + /* Skip trailing spaces */ + while (i > 0 && str[i - 1] == ' ') { + i--; + } + + /* Skip non-spaces (the word) */ + while (i > 0 && str[i - 1] != ' ') { + i--; + } + + str[i] = '\0'; +} + /* Validate a UTF-8 byte sequence */ bool utf8_is_valid_sequence(const char *bytes, int len) { if (len <= 0 || len > 4 || !bytes) { diff --git a/test_anonymous_access.sh b/tests/test_anonymous_access.sh similarity index 51% rename from test_anonymous_access.sh rename to tests/test_anonymous_access.sh index d804570..8930049 100755 --- a/test_anonymous_access.sh +++ b/tests/test_anonymous_access.sh @@ -1,20 +1,45 @@ #!/bin/bash # Test anonymous SSH access +BIN="../tnt" +PORT=${PORT:-2222} + +if [ ! -f "$BIN" ]; then + echo "Error: Binary $BIN not found." + exit 1 +fi + +echo "Starting TNT server on port $PORT..." +$BIN -p $PORT > /dev/null 2>&1 & +SERVER_PID=$! +sleep 2 + +cleanup() { + kill $SERVER_PID 2>/dev/null + wait 2>/dev/null +} +trap cleanup EXIT + +# Detect timeout command +TIMEOUT_CMD="timeout" +if command -v gtimeout >/dev/null 2>&1; then + TIMEOUT_CMD="gtimeout" +fi + echo "Testing anonymous SSH access to TNT server..." echo "" # Test 1: Connection with any username and password echo "Test 1: Connection with any username (should succeed)" -timeout 5 expect -c ' -spawn ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p 2223 testuser@localhost +$TIMEOUT_CMD 10 expect -c " +spawn ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p $PORT testuser@localhost expect { - "password:" { - send "anypassword\r" + \"password:\" { + send \"anypassword\r\" expect { - "请输入用户名:" { - send "TestUser\r" - send "\003" + \"请输入用户名\" { + send \"TestUser\r\" + send \"\003\" exit 0 } timeout { exit 1 } @@ -22,27 +47,28 @@ expect { } timeout { exit 1 } } -' 2>&1 | grep -q "请输入用户名" +" 2>&1 | grep -q "请输入用户名" if [ $? -eq 0 ]; then echo "✓ Test 1 PASSED: Can connect with any password" else echo "✗ Test 1 FAILED: Cannot connect with any password" + exit 1 fi echo "" -# Test 2: Connection should work without special SSH options +# Test 2: Connection should work with empty password echo "Test 2: Simple connection (standard SSH command)" -timeout 5 expect -c ' -spawn ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p 2223 anonymous@localhost +$TIMEOUT_CMD 10 expect -c " +spawn ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p $PORT anonymous@localhost expect { - "password:" { - send "\r" + \"password:\" { + send \"\r\" expect { - "请输入用户名:" { - send "\r" - send "\003" + \"请输入用户名\" { + send \"\r\" + send \"\003\" exit 0 } timeout { exit 1 } @@ -50,13 +76,15 @@ expect { } timeout { exit 1 } } -' 2>&1 | grep -q "请输入用户名" +" 2>&1 | grep -q "请输入用户名" if [ $? -eq 0 ]; then echo "✓ Test 2 PASSED: Can connect with empty password" else echo "✗ Test 2 FAILED: Cannot connect with empty password" + exit 1 fi echo "" echo "Anonymous access test completed." +exit 0 diff --git a/test_basic.sh b/tests/test_basic.sh similarity index 69% rename from test_basic.sh rename to tests/test_basic.sh index b5a3548..791a084 100755 --- a/test_basic.sh +++ b/tests/test_basic.sh @@ -13,12 +13,26 @@ cleanup() { trap cleanup EXIT +# Detect timeout command +TIMEOUT_CMD="timeout" +if command -v gtimeout >/dev/null 2>&1; then + TIMEOUT_CMD="gtimeout" +fi + echo "=== TNT Basic Tests ===" +# Path to binary +BIN="../tnt" + +if [ ! -f "$BIN" ]; then + echo "Error: Binary $BIN not found. Run make first." + exit 1 +fi + # Start server -./tnt -p $PORT >test.log 2>&1 & +$BIN -p $PORT >test.log 2>&1 & SERVER_PID=$! -sleep 2 +sleep 5 # Test 1: Server started if kill -0 $SERVER_PID 2>/dev/null; then @@ -31,7 +45,7 @@ else fi # Test 2: SSH connection -if timeout 5 ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ +if $TIMEOUT_CMD 5 ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ -o BatchMode=yes -p $PORT localhost exit 2>/dev/null; then echo "✓ SSH connection works" PASS=$((PASS + 1)) @@ -41,7 +55,7 @@ else fi # Test 3: Message logging -echo "test message" | timeout 5 ssh -o StrictHostKeyChecking=no \ +(echo "testuser"; echo "test message"; sleep 1) | $TIMEOUT_CMD 5 ssh -o StrictHostKeyChecking=no \ -o UserKnownHostsFile=/dev/null -p $PORT localhost >/dev/null 2>&1 & sleep 3 if [ -f messages.log ]; then diff --git a/test_security_features.sh b/tests/test_security_features.sh similarity index 82% rename from test_security_features.sh rename to tests/test_security_features.sh index e0be618..7f2cb39 100755 --- a/test_security_features.sh +++ b/tests/test_security_features.sh @@ -27,7 +27,7 @@ fail() { } cleanup() { - pkill -f "^\./tnt" 2>/dev/null || true + pkill -f "^\.\./tnt" 2>/dev/null || true sleep 1 } @@ -37,10 +37,22 @@ echo -e "${YELLOW}========================================${NC}" echo -e "${YELLOW}TNT Security Features Test Suite${NC}" echo -e "${YELLOW}========================================${NC}" +BIN="../tnt" +if [ ! -f "$BIN" ]; then + echo "Error: Binary $BIN not found." + exit 1 +fi + +# Detect timeout command +TIMEOUT_CMD="timeout" +if command -v gtimeout >/dev/null 2>&1; then + TIMEOUT_CMD="gtimeout" +fi + # Test 1: 4096-bit RSA Key Generation print_test "1. RSA 4096-bit Key Generation" rm -f host_key -./tnt & +$BIN & PID=$! sleep 8 # Wait for key generation kill $PID 2>/dev/null || true @@ -55,7 +67,12 @@ if [ -f host_key ]; then fi # Check permissions - PERMS=$(stat -f "%OLp" host_key) + if [[ "$OSTYPE" == "darwin"* ]]; then + PERMS=$(stat -f "%OLp" host_key) + else + PERMS=$(stat -c "%a" host_key) + fi + if [ "$PERMS" = "600" ]; then pass "Host key has secure permissions (600)" else @@ -69,19 +86,19 @@ fi print_test "2. Environment Variable Configuration" # Test bind address -TNT_BIND_ADDR=127.0.0.1 timeout 3 ./tnt 2>&1 | grep -q "TNT chat server" && \ +TNT_BIND_ADDR=127.0.0.1 $TIMEOUT_CMD 3 $BIN 2>&1 | grep -q "TNT chat server" && \ pass "TNT_BIND_ADDR configuration works" || fail "TNT_BIND_ADDR not working" # Test with access token set (just verify it starts) -TNT_ACCESS_TOKEN="test123" timeout 3 ./tnt 2>&1 | grep -q "TNT chat server" && \ +TNT_ACCESS_TOKEN="test123" $TIMEOUT_CMD 3 $BIN 2>&1 | grep -q "TNT chat server" && \ pass "TNT_ACCESS_TOKEN configuration accepted" || fail "TNT_ACCESS_TOKEN not working" # Test max connections configuration -TNT_MAX_CONNECTIONS=10 timeout 3 ./tnt 2>&1 | grep -q "TNT chat server" && \ +TNT_MAX_CONNECTIONS=10 $TIMEOUT_CMD 3 $BIN 2>&1 | grep -q "TNT chat server" && \ pass "TNT_MAX_CONNECTIONS configuration accepted" || fail "TNT_MAX_CONNECTIONS not working" # Test rate limit toggle -TNT_RATE_LIMIT=0 timeout 3 ./tnt 2>&1 | grep -q "TNT chat server" && \ +TNT_RATE_LIMIT=0 $TIMEOUT_CMD 3 $BIN 2>&1 | grep -q "TNT chat server" && \ pass "TNT_RATE_LIMIT configuration accepted" || fail "TNT_RATE_LIMIT not working" sleep 1 @@ -101,7 +118,7 @@ newline EOF # Start server and let it load messages -./tnt & +$BIN & PID=$! sleep 3 kill $PID 2>/dev/null || true @@ -120,8 +137,8 @@ print_test "4. UTF-8 Input Validation" cat > test_utf8.c <<'EOF' #include #include -#include "include/utf8.h" -#include "include/common.h" +#include "../include/utf8.h" +#include "../include/common.h" int main() { // Valid UTF-8 sequences @@ -154,7 +171,7 @@ int main() { } EOF -if gcc -I. -o test_utf8 test_utf8.c src/utf8.c 2>/dev/null; then +if gcc -I../include -o test_utf8 test_utf8.c ../src/utf8.c 2>/dev/null; then if ./test_utf8; then pass "UTF-8 validation function works correctly" else @@ -168,12 +185,12 @@ rm -f test_utf8.c # Test 5: Buffer Safety with AddressSanitizer print_test "5. Buffer Overflow Protection (ASAN Build)" -if make clean >/dev/null 2>&1 && make asan >/dev/null 2>&1; then +if make -C .. clean >/dev/null 2>&1 && make -C .. asan >/dev/null 2>&1; then # Just verify it compiles - actual ASAN testing needs runtime - if [ -f tnt ]; then + if [ -f ../tnt ]; then pass "AddressSanitizer build successful" # Restore normal build - make clean >/dev/null 2>&1 && make >/dev/null 2>&1 + make -C .. clean >/dev/null 2>&1 && make -C .. >/dev/null 2>&1 else fail "AddressSanitizer build failed" fi @@ -184,8 +201,8 @@ fi # Test 6: Concurrent Safety print_test "6. Concurrency Safety (Data Structure Integrity)" # This test verifies the code compiles with thread sanitizer flags -if gcc -fsanitize=thread -g -O1 -Iinclude -I/opt/homebrew/opt/libssh/include \ - -c src/chat_room.c -o /tmp/test_tsan.o 2>/dev/null; then +if gcc -fsanitize=thread -g -O1 -I../include -I/opt/homebrew/opt/libssh/include \ + -c ../src/chat_room.c -o /tmp/test_tsan.o 2>/dev/null; then pass "Code compiles with ThreadSanitizer (concurrency checks enabled)" rm -f /tmp/test_tsan.o else @@ -200,7 +217,7 @@ for i in $(seq 1 2000); do echo "2026-01-22T$(printf "%02d" $((i/100))):$(printf "%02d" $((i%60))):00Z|user$i|message $i" >> messages.log done -./tnt & +$BIN & PID=$! sleep 4 kill $PID 2>/dev/null || true diff --git a/test_stress.sh b/tests/test_stress.sh similarity index 57% rename from test_stress.sh rename to tests/test_stress.sh index 5b2d890..949896f 100755 --- a/test_stress.sh +++ b/tests/test_stress.sh @@ -5,9 +5,21 @@ PORT=${PORT:-2222} CLIENTS=${1:-10} DURATION=${2:-30} +BIN="../tnt" + +if [ ! -f "$BIN" ]; then + echo "Error: Binary $BIN not found." + exit 1 +fi + +# Detect timeout command +TIMEOUT_CMD="timeout" +if command -v gtimeout >/dev/null 2>&1; then + TIMEOUT_CMD="gtimeout" +fi echo "Starting TNT server on port $PORT..." -./tnt -p $PORT & +$BIN -p $PORT & SERVER_PID=$! sleep 2 @@ -21,7 +33,7 @@ echo "Spawning $CLIENTS clients for ${DURATION}s..." for i in $(seq 1 $CLIENTS); do ( sleep $((i % 5)) - echo "test user $i" | timeout $DURATION ssh -o StrictHostKeyChecking=no \ + echo "test user $i" | $TIMEOUT_CMD $DURATION ssh -o StrictHostKeyChecking=no \ -o UserKnownHostsFile=/dev/null -p $PORT localhost \ >/dev/null 2>&1 ) & @@ -35,4 +47,10 @@ kill $SERVER_PID 2>/dev/null wait echo "Stress test complete" -ps aux | grep tnt | grep -v grep && echo "WARNING: tnt process still running" +if ps aux | grep tnt | grep -v grep > /dev/null; then + echo "WARNING: tnt process still running" +else + echo "Server shutdown confirmed." +fi + +exit 0