Merge branch 'fix/concurrency-safety' into feat/security-audit-fixes

This commit is contained in:
m1ngsama 2026-01-22 14:08:45 +08:00
commit bc08269743
3 changed files with 72 additions and 35 deletions

View file

@ -98,8 +98,18 @@ void room_broadcast(chat_room_t *room, const message_t *msg) {
/* Render to each client (outside of lock) */
for (int i = 0; i < count; i++) {
client_t *client = clients_copy[i];
if (client->connected && !client->show_help &&
client->command_output[0] == '\0') {
/* Check client state before rendering (while holding ref) */
bool should_render = false;
pthread_mutex_lock(&client->ref_lock);
if (client->ref_count > 0) {
should_render = client->connected &&
!client->show_help &&
client->command_output[0] == '\0';
}
pthread_mutex_unlock(&client->ref_lock);
if (should_render) {
tui_render_screen(client);
}

View file

@ -632,7 +632,13 @@ static bool handle_key(client_t *client, unsigned char key, char *input) {
tui_render_screen(client);
return true; /* Key consumed - prevents double colon */
} else if (key == 'j') {
int max_scroll = room_get_message_count(g_room) - 1;
/* Get message count atomically to prevent TOCTOU */
int max_scroll = room_get_message_count(g_room);
int msg_height = client->height - 3;
if (msg_height < 1) msg_height = 1;
max_scroll = max_scroll - msg_height;
if (max_scroll < 0) max_scroll = 0;
if (client->scroll_pos < max_scroll) {
client->scroll_pos++;
tui_render_screen(client);
@ -647,8 +653,14 @@ static bool handle_key(client_t *client, unsigned char key, char *input) {
tui_render_screen(client);
return true; /* Key consumed */
} else if (key == 'G') {
client->scroll_pos = room_get_message_count(g_room) - 1;
if (client->scroll_pos < 0) client->scroll_pos = 0;
/* Get message count atomically to prevent TOCTOU */
int max_scroll = room_get_message_count(g_room);
int msg_height = client->height - 3;
if (msg_height < 1) msg_height = 1;
max_scroll = max_scroll - msg_height;
if (max_scroll < 0) max_scroll = 0;
client->scroll_pos = max_scroll;
tui_render_screen(client);
return true; /* Key consumed */
} else if (key == '?') {

View file

@ -19,11 +19,48 @@ void tui_render_screen(client_t *client) {
char buffer[8192];
int pos = 0;
/* Acquire all data in one lock to prevent TOCTOU */
pthread_rwlock_rdlock(&g_room->lock);
int online = g_room->client_count;
int msg_count = g_room->message_count;
/* Calculate which messages to show */
int msg_height = client->height - 3;
if (msg_height < 1) msg_height = 1;
int start = 0;
if (client->mode == MODE_NORMAL) {
start = client->scroll_pos;
if (start > msg_count - msg_height) {
start = msg_count - msg_height;
}
if (start < 0) start = 0;
} else {
/* INSERT mode: show latest */
if (msg_count > msg_height) {
start = msg_count - msg_height;
}
}
int end = start + msg_height;
if (end > msg_count) end = msg_count;
/* Create snapshot of messages to display */
message_t *msg_snapshot = NULL;
int snapshot_count = end - start;
if (snapshot_count > 0) {
msg_snapshot = calloc(snapshot_count, sizeof(message_t));
if (msg_snapshot) {
memcpy(msg_snapshot, &g_room->messages[start],
snapshot_count * sizeof(message_t));
}
}
pthread_rwlock_unlock(&g_room->lock);
/* Now render using snapshot (no lock held) */
/* Clear and move to top */
pos += snprintf(buffer + pos, sizeof(buffer) - pos, ANSI_CLEAR ANSI_HOME);
@ -47,40 +84,18 @@ void tui_render_screen(client_t *client) {
}
pos += snprintf(buffer + pos, sizeof(buffer) - pos, ANSI_RESET "\r\n");
/* Messages area */
int msg_height = client->height - 3;
if (msg_height < 1) msg_height = 1;
/* Calculate which messages to show */
int start = 0;
if (client->mode == MODE_NORMAL) {
start = client->scroll_pos;
if (start > msg_count - msg_height) {
start = msg_count - msg_height;
}
if (start < 0) start = 0;
} else {
/* INSERT mode: show latest */
if (msg_count > msg_height) {
start = msg_count - msg_height;
/* Render messages from snapshot */
if (msg_snapshot) {
for (int i = 0; i < snapshot_count; i++) {
char msg_line[1024];
message_format(&msg_snapshot[i], msg_line, sizeof(msg_line), client->width);
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%s\r\n", msg_line);
}
free(msg_snapshot);
}
pthread_rwlock_rdlock(&g_room->lock);
int end = start + msg_height;
if (end > msg_count) end = msg_count;
for (int i = start; i < end; i++) {
char msg_line[1024];
message_format(&g_room->messages[i], msg_line, sizeof(msg_line), client->width);
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%s\r\n", msg_line);
}
pthread_rwlock_unlock(&g_room->lock);
/* Fill empty lines */
for (int i = end - start; i < msg_height; i++) {
for (int i = snapshot_count; i < msg_height; i++) {
buffer[pos++] = '\r';
buffer[pos++] = '\n';
}