TNT/src/utf8.c
m1ngsama 629812a2d8 fix: correct pubkey auth response, strncpy warning, and NUL byte validation
- auth_pubkey: return SSH_AUTH_SUCCESS for key offers instead of
  SSH_AUTH_PARTIAL, which incorrectly signals partial authentication
- command history: replace strncpy with snprintf to eliminate
  -Wstringop-truncation warning on GCC
- utf8_is_valid_sequence: reject NUL byte (0x00) in single-byte
  validation to prevent C string truncation attacks

Closes #34
2026-04-19 18:27:50 +08:00

250 lines
7 KiB
C

#include "utf8.h"
/* Get the number of bytes in a UTF-8 character from its first byte */
int utf8_byte_length(unsigned char first_byte) {
if ((first_byte & 0x80) == 0) return 1; /* 0xxxxxxx */
if ((first_byte & 0xE0) == 0xC0) return 2; /* 110xxxxx */
if ((first_byte & 0xF0) == 0xE0) return 3; /* 1110xxxx */
if ((first_byte & 0xF8) == 0xF0) return 4; /* 11110xxx */
return 1; /* Invalid UTF-8, treat as single byte */
}
/* Decode a UTF-8 character and return its codepoint */
uint32_t utf8_decode(const char *str, int *bytes_read) {
const unsigned char *s = (const unsigned char *)str;
uint32_t codepoint = 0;
int len = utf8_byte_length(s[0]);
if (len < 1 || len > 4) {
len = 1;
}
for (int i = 1; i < len; i++) {
if (s[i] == '\0' || (s[i] & 0xC0) != 0x80) {
/* Truncated or invalid continuation byte — treat as single byte */
*bytes_read = 1;
return s[0];
}
}
*bytes_read = len;
switch (len) {
case 1:
codepoint = s[0];
break;
case 2:
codepoint = ((s[0] & 0x1F) << 6) | (s[1] & 0x3F);
break;
case 3:
codepoint = ((s[0] & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F);
break;
case 4:
codepoint = ((s[0] & 0x07) << 18) | ((s[1] & 0x3F) << 12) |
((s[2] & 0x3F) << 6) | (s[3] & 0x3F);
break;
}
return codepoint;
}
/* UTF-8 character width calculation for CJK and other wide characters */
int utf8_char_width(uint32_t codepoint) {
/* ASCII */
if (codepoint < 0x80) return 1;
/* CJK Unified Ideographs */
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || /* CJK Unified */
(codepoint >= 0x3400 && codepoint <= 0x4DBF) || /* CJK Extension A */
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) || /* CJK Extension B */
(codepoint >= 0x2A700 && codepoint <= 0x2B73F) || /* CJK Extension C */
(codepoint >= 0x2B740 && codepoint <= 0x2B81F) || /* CJK Extension D */
(codepoint >= 0x2B820 && codepoint <= 0x2CEAF) || /* CJK Extension E */
(codepoint >= 0xF900 && codepoint <= 0xFAFF) || /* CJK Compatibility */
(codepoint >= 0x2F800 && codepoint <= 0x2FA1F)) { /* CJK Compat Suppl */
return 2;
}
/* Hangul Syllables (Korean) */
if (codepoint >= 0xAC00 && codepoint <= 0xD7AF) return 2;
/* Hiragana and Katakana (Japanese) */
if ((codepoint >= 0x3040 && codepoint <= 0x309F) || /* Hiragana */
(codepoint >= 0x30A0 && codepoint <= 0x30FF)) { /* Katakana */
return 2;
}
/* Fullwidth forms */
if (codepoint >= 0xFF00 && codepoint <= 0xFFEF) return 2;
/* Default to single width */
return 1;
}
/* Calculate display width of a UTF-8 string */
int utf8_string_width(const char *str) {
int width = 0;
int bytes_read;
const char *p = str;
while (*p != '\0') {
uint32_t codepoint = utf8_decode(p, &bytes_read);
width += utf8_char_width(codepoint);
p += bytes_read;
}
return width;
}
/* Count the number of UTF-8 characters in a string */
int utf8_strlen(const char *str) {
int count = 0;
int bytes_read;
const char *p = str;
while (*p != '\0') {
utf8_decode(p, &bytes_read);
count++;
p += bytes_read;
}
return count;
}
/* Truncate string to fit within max_width display characters */
void utf8_truncate(char *str, int max_width) {
int width = 0;
int bytes_read;
char *p = str;
char *last_valid = str;
while (*p != '\0') {
uint32_t codepoint = utf8_decode(p, &bytes_read);
int char_width = utf8_char_width(codepoint);
if (width + char_width > max_width) {
break;
}
width += char_width;
p += bytes_read;
last_valid = p;
}
*last_valid = '\0';
}
/* Remove last UTF-8 character from string */
void utf8_remove_last_char(char *str) {
int len = strlen(str);
if (len == 0) return;
/* Find the start of the last character by walking backwards */
int i = len - 1;
while (i > 0 && (str[i] & 0xC0) == 0x80) {
i--; /* Continue byte of multi-byte sequence */
}
str[i] = '\0';
}
/* Remove last word from string (mimic Ctrl+W) */
void utf8_remove_last_word(char *str) {
int len = strlen(str);
if (len == 0) return;
int i = len;
/* Skip trailing spaces */
while (i > 0 && str[i - 1] == ' ') {
i--;
}
/* Skip non-spaces (the word) */
while (i > 0 && str[i - 1] != ' ') {
i--;
}
str[i] = '\0';
}
/* Validate a UTF-8 byte sequence */
bool utf8_is_valid_sequence(const char *bytes, int len) {
if (len <= 0 || len > 4 || !bytes) {
return false;
}
const unsigned char *b = (const unsigned char *)bytes;
/* Check first byte matches the expected length */
int expected_len = utf8_byte_length(b[0]);
if (expected_len != len) {
return false;
}
/* Validate continuation bytes (must be 10xxxxxx) */
for (int i = 1; i < len; i++) {
if ((b[i] & 0xC0) != 0x80) {
return false;
}
}
/* Validate codepoint ranges to prevent overlong encodings */
uint32_t codepoint = 0;
switch (len) {
case 1:
/* 0xxxxxxx - valid range: 0x01-0x7F (reject NUL) */
codepoint = b[0];
if (codepoint == 0 || codepoint > 0x7F) return false;
break;
case 2:
/* 110xxxxx 10xxxxxx - valid range: 0x80-0x7FF */
codepoint = ((b[0] & 0x1F) << 6) | (b[1] & 0x3F);
if (codepoint < 0x80 || codepoint > 0x7FF) return false;
break;
case 3:
/* 1110xxxx 10xxxxxx 10xxxxxx - valid range: 0x800-0xFFFF */
codepoint = ((b[0] & 0x0F) << 12) | ((b[1] & 0x3F) << 6) | (b[2] & 0x3F);
if (codepoint < 0x800 || codepoint > 0xFFFF) return false;
/* Reject UTF-16 surrogates (0xD800-0xDFFF) */
if (codepoint >= 0xD800 && codepoint <= 0xDFFF) return false;
break;
case 4:
/* 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - valid range: 0x10000-0x10FFFF */
codepoint = ((b[0] & 0x07) << 18) | ((b[1] & 0x3F) << 12) |
((b[2] & 0x3F) << 6) | (b[3] & 0x3F);
if (codepoint < 0x10000 || codepoint > 0x10FFFF) return false;
break;
}
return true;
}
bool utf8_is_valid_string(const char *str) {
const unsigned char *p = (const unsigned char *)str;
if (!str) {
return false;
}
while (*p != '\0') {
int len = utf8_byte_length(*p);
if (len < 1 || len > 4) {
return false;
}
for (int i = 1; i < len; i++) {
if (p[i] == '\0') {
return false;
}
}
if (!utf8_is_valid_sequence((const char *)p, len)) {
return false;
}
p += len;
}
return true;
}