diff --git a/include/utf8.h b/include/utf8.h index 3008a86..4dfc9c8 100644 --- a/include/utf8.h +++ b/include/utf8.h @@ -24,4 +24,7 @@ int utf8_strlen(const char *str); /* Remove last UTF-8 character from string */ void utf8_remove_last_char(char *str); +/* Validate a UTF-8 byte sequence */ +bool utf8_is_valid_sequence(const char *bytes, int len); + #endif /* UTF8_H */ diff --git a/src/ssh_server.c b/src/ssh_server.c index 4023611..c532d09 100644 --- a/src/ssh_server.c +++ b/src/ssh_server.c @@ -108,6 +108,12 @@ int client_printf(client_t *client, const char *fmt, ...) { va_start(args, fmt); int len = vsnprintf(buffer, sizeof(buffer), fmt, args); va_end(args); + + /* Check for buffer overflow or encoding error */ + if (len < 0 || len >= (int)sizeof(buffer)) { + return -1; + } + return client_send(client, buffer, len); } @@ -155,6 +161,10 @@ static int read_username(client_t *client) { } else { /* UTF-8 multi-byte */ int len = utf8_byte_length(b); + if (len <= 0 || len > 4) { + /* Invalid UTF-8 start byte */ + continue; + } buf[0] = b; if (len > 1) { int read_bytes = ssh_channel_read(client->channel, &buf[1], len - 1, 0); @@ -163,6 +173,11 @@ static int read_username(client_t *client) { continue; } } + /* Validate the complete UTF-8 sequence */ + if (!utf8_is_valid_sequence(buf, len)) { + /* Invalid UTF-8 sequence */ + continue; + } if (pos + len < MAX_USERNAME_LEN - 1) { memcpy(username + pos, buf, len); pos += len; @@ -175,7 +190,8 @@ static int read_username(client_t *client) { client_printf(client, "\r\n"); if (username[0] == '\0') { - strcpy(client->username, "anonymous"); + strncpy(client->username, "anonymous", MAX_USERNAME_LEN - 1); + client->username[MAX_USERNAME_LEN - 1] = '\0'; } else { strncpy(client->username, username, MAX_USERNAME_LEN - 1); /* Truncate to 20 characters */ @@ -420,7 +436,8 @@ void* client_handle_session(void *arg) { message_t join_msg = { .timestamp = time(NULL), }; - strcpy(join_msg.username, "系统"); + strncpy(join_msg.username, "系统", MAX_USERNAME_LEN - 1); + join_msg.username[MAX_USERNAME_LEN - 1] = '\0'; snprintf(join_msg.content, MAX_MESSAGE_LEN, "%s 加入了聊天室", client->username); room_broadcast(g_room, &join_msg); @@ -472,6 +489,10 @@ void* client_handle_session(void *arg) { } } else if (b >= 128) { /* UTF-8 multi-byte */ int char_len = utf8_byte_length(b); + if (char_len <= 0 || char_len > 4) { + /* Invalid UTF-8 start byte */ + continue; + } buf[0] = b; if (char_len > 1) { int read_bytes = ssh_channel_read(client->channel, &buf[1], char_len - 1, 0); @@ -480,6 +501,11 @@ void* client_handle_session(void *arg) { continue; } } + /* Validate the complete UTF-8 sequence */ + if (!utf8_is_valid_sequence(buf, char_len)) { + /* Invalid UTF-8 sequence */ + continue; + } int len = strlen(input); if (len + char_len < MAX_MESSAGE_LEN - 1) { memcpy(input + len, buf, char_len); @@ -507,7 +533,8 @@ cleanup: message_t leave_msg = { .timestamp = time(NULL), }; - strcpy(leave_msg.username, "系统"); + strncpy(leave_msg.username, "系统", MAX_USERNAME_LEN - 1); + leave_msg.username[MAX_USERNAME_LEN - 1] = '\0'; snprintf(leave_msg.content, MAX_MESSAGE_LEN, "%s 离开了聊天室", client->username); client->connected = false; diff --git a/src/utf8.c b/src/utf8.c index 8afa505..f75db32 100644 --- a/src/utf8.c +++ b/src/utf8.c @@ -135,3 +135,55 @@ void utf8_remove_last_char(char *str) { str[i] = '\0'; } + +/* Validate a UTF-8 byte sequence */ +bool utf8_is_valid_sequence(const char *bytes, int len) { + if (len <= 0 || len > 4 || !bytes) { + return false; + } + + const unsigned char *b = (const unsigned char *)bytes; + + /* Check first byte matches the expected length */ + int expected_len = utf8_byte_length(b[0]); + if (expected_len != len) { + return false; + } + + /* Validate continuation bytes (must be 10xxxxxx) */ + for (int i = 1; i < len; i++) { + if ((b[i] & 0xC0) != 0x80) { + return false; + } + } + + /* Validate codepoint ranges to prevent overlong encodings */ + uint32_t codepoint = 0; + switch (len) { + case 1: + /* 0xxxxxxx - valid range: 0x00-0x7F */ + codepoint = b[0]; + if (codepoint > 0x7F) return false; + break; + case 2: + /* 110xxxxx 10xxxxxx - valid range: 0x80-0x7FF */ + codepoint = ((b[0] & 0x1F) << 6) | (b[1] & 0x3F); + if (codepoint < 0x80 || codepoint > 0x7FF) return false; + break; + case 3: + /* 1110xxxx 10xxxxxx 10xxxxxx - valid range: 0x800-0xFFFF */ + codepoint = ((b[0] & 0x0F) << 12) | ((b[1] & 0x3F) << 6) | (b[2] & 0x3F); + if (codepoint < 0x800 || codepoint > 0xFFFF) return false; + /* Reject UTF-16 surrogates (0xD800-0xDFFF) */ + if (codepoint >= 0xD800 && codepoint <= 0xDFFF) return false; + break; + case 4: + /* 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - valid range: 0x10000-0x10FFFF */ + codepoint = ((b[0] & 0x07) << 18) | ((b[1] & 0x3F) << 12) | + ((b[2] & 0x3F) << 6) | (b[3] & 0x3F); + if (codepoint < 0x10000 || codepoint > 0x10FFFF) return false; + break; + } + + return true; +}