|  | 
| 1 |  | -use std::fmt; | 
| 2 |  | -use std::num::NonZeroU8; | 
|  | 1 | +use std::{fmt, str}; | 
| 3 | 2 | 
 | 
|  | 3 | +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] | 
| 4 | 4 | #[allow(unused)] | 
| 5 | 5 | pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result { | 
| 6 |  | -    let mut escaped_buf = *b"&#__;"; | 
|  | 6 | +    // Even though [`jetscii`] ships a generic implementation for unsupported platforms, | 
|  | 7 | +    // it is not well optimized for this case. This implementation should work well enough in | 
|  | 8 | +    // the meantime, until portable SIMD gets stabilized. | 
|  | 9 | + | 
|  | 10 | +    // Instead of testing the platform, we could test the CPU features. But given that the needed | 
|  | 11 | +    // instruction set SSE 4.2 was introduced in 2008, that it has an 99.61 % availability rate | 
|  | 12 | +    // in Steam's June 2024 hardware survey, and is a prerequisite to run Windows 11, I don't | 
|  | 13 | +    // think we need to care. | 
|  | 14 | + | 
|  | 15 | +    let mut escaped_buf = ESCAPED_BUF_INIT; | 
| 7 | 16 |     let mut last = 0; | 
| 8 | 17 | 
 | 
| 9 | 18 |     for (index, byte) in string.bytes().enumerate() { | 
| 10 | 19 |         let escaped = match byte { | 
| 11 | 20 |             MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize], | 
| 12 |  | -            _ => None, | 
|  | 21 | +            _ => 0, | 
| 13 | 22 |         }; | 
| 14 |  | -        if let Some(escaped) = escaped { | 
| 15 |  | -            escaped_buf[2] = escaped[0].get(); | 
| 16 |  | -            escaped_buf[3] = escaped[1].get(); | 
| 17 |  | -            fmt.write_str(&string[last..index])?; | 
| 18 |  | -            fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?; | 
|  | 23 | +        if escaped != 0 { | 
|  | 24 | +            [escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes(); | 
|  | 25 | +            write_str_if_nonempty(&mut fmt, &string[last..index])?; | 
|  | 26 | +            // SAFETY: the content of `escaped_buf` is pure ASCII | 
|  | 27 | +            fmt.write_str(unsafe { | 
|  | 28 | +                std::str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) | 
|  | 29 | +            })?; | 
| 19 | 30 |             last = index + 1; | 
| 20 | 31 |         } | 
| 21 | 32 |     } | 
| 22 |  | -    fmt.write_str(&string[last..]) | 
|  | 33 | +    write_str_if_nonempty(&mut fmt, &string[last..]) | 
|  | 34 | +} | 
|  | 35 | + | 
|  | 36 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] | 
|  | 37 | +#[allow(unused)] | 
|  | 38 | +pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, mut string: &str) -> fmt::Result { | 
|  | 39 | +    let jetscii = jetscii::bytes!(b'"', b'&', b'\'', b'<', b'>'); | 
|  | 40 | + | 
|  | 41 | +    let mut escaped_buf = ESCAPED_BUF_INIT; | 
|  | 42 | +    loop { | 
|  | 43 | +        if string.is_empty() { | 
|  | 44 | +            return Ok(()); | 
|  | 45 | +        } | 
|  | 46 | + | 
|  | 47 | +        let found = if string.len() >= 16 { | 
|  | 48 | +            // Only strings of at least 16 bytes can be escaped using SSE instructions. | 
|  | 49 | +            match jetscii.find(string.as_bytes()) { | 
|  | 50 | +                Some(index) => { | 
|  | 51 | +                    let escaped = TABLE.lookup[(string.as_bytes()[index] - MIN_CHAR) as usize]; | 
|  | 52 | +                    Some((index, escaped)) | 
|  | 53 | +                } | 
|  | 54 | +                None => None, | 
|  | 55 | +            } | 
|  | 56 | +        } else { | 
|  | 57 | +            // The small-string fallback of [`jetscii`] is quite slow, so we roll our own | 
|  | 58 | +            // implementation. | 
|  | 59 | +            string.as_bytes().iter().find_map(|byte: &u8| { | 
|  | 60 | +                let escaped = get_escaped(*byte)?; | 
|  | 61 | +                let index = (byte as *const u8 as usize) - (string.as_ptr() as usize); | 
|  | 62 | +                Some((index, escaped)) | 
|  | 63 | +            }) | 
|  | 64 | +        }; | 
|  | 65 | +        let Some((index, escaped)) = found else { | 
|  | 66 | +            return fmt.write_str(string); | 
|  | 67 | +        }; | 
|  | 68 | + | 
|  | 69 | +        [escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes(); | 
|  | 70 | + | 
|  | 71 | +        // SAFETY: index points at an ASCII char in `string` | 
|  | 72 | +        let front; | 
|  | 73 | +        (front, string) = unsafe { | 
|  | 74 | +            ( | 
|  | 75 | +                string.get_unchecked(..index), | 
|  | 76 | +                string.get_unchecked(index + 1..), | 
|  | 77 | +            ) | 
|  | 78 | +        }; | 
|  | 79 | + | 
|  | 80 | +        write_str_if_nonempty(&mut fmt, front)?; | 
|  | 81 | +        // SAFETY: the content of `escaped_buf` is pure ASCII | 
|  | 82 | +        fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?; | 
|  | 83 | +    } | 
| 23 | 84 | } | 
| 24 | 85 | 
 | 
| 25 | 86 | #[allow(unused)] | 
| 26 | 87 | pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result { | 
| 27 |  | -    fmt.write_str(match (c.is_ascii(), c as u8) { | 
| 28 |  | -        (true, b'"') => """, | 
| 29 |  | -        (true, b'&') => "&", | 
| 30 |  | -        (true, b'\'') => "'", | 
| 31 |  | -        (true, b'<') => "<", | 
| 32 |  | -        (true, b'>') => ">", | 
| 33 |  | -        _ => return fmt.write_char(c), | 
| 34 |  | -    }) | 
|  | 88 | +    if !c.is_ascii() { | 
|  | 89 | +        fmt.write_char(c) | 
|  | 90 | +    } else if let Some(escaped) = get_escaped(c as u8) { | 
|  | 91 | +        let mut escaped_buf = ESCAPED_BUF_INIT; | 
|  | 92 | +        [escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes(); | 
|  | 93 | +        // SAFETY: the content of `escaped_buf` is pure ASCII | 
|  | 94 | +        fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) }) | 
|  | 95 | +    } else { | 
|  | 96 | +        // RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()` | 
|  | 97 | +        fmt.write_char(c) | 
|  | 98 | +    } | 
| 35 | 99 | } | 
| 36 | 100 | 
 | 
| 37 |  | -const MIN_CHAR: u8 = b'"'; | 
| 38 |  | -const MAX_CHAR: u8 = b'>'; | 
|  | 101 | +#[inline(always)] | 
|  | 102 | +fn get_escaped(byte: u8) -> Option<u16> { | 
|  | 103 | +    let c = byte.wrapping_sub(MIN_CHAR); | 
|  | 104 | +    if (c < u32::BITS as u8) && (BITS & (1 << c as u32) != 0) { | 
|  | 105 | +        Some(TABLE.lookup[c as usize]) | 
|  | 106 | +    } else { | 
|  | 107 | +        None | 
|  | 108 | +    } | 
|  | 109 | +} | 
| 39 | 110 | 
 | 
| 40 |  | -struct Table { | 
| 41 |  | -    _align: [usize; 0], | 
| 42 |  | -    lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize], | 
|  | 111 | +#[inline(always)] | 
|  | 112 | +fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result { | 
|  | 113 | +    if !input.is_empty() { | 
|  | 114 | +        output.write_str(input) | 
|  | 115 | +    } else { | 
|  | 116 | +        Ok(()) | 
|  | 117 | +    } | 
| 43 | 118 | } | 
| 44 | 119 | 
 | 
| 45 |  | -const TABLE: Table = { | 
| 46 |  | -    const fn n(c: u8) -> Option<[NonZeroU8; 2]> { | 
| 47 |  | -        assert!(MIN_CHAR <= c && c <= MAX_CHAR); | 
|  | 120 | +/// List of characters that need HTML escaping, not necessarily in ordinal order. | 
|  | 121 | +/// Filling the [`TABLE`] and [`BITS`] constants will ensure that the range of lowest to hightest | 
|  | 122 | +/// codepoint wont exceed [`u32::BITS`] (=32) items. | 
|  | 123 | +const CHARS: &[u8] = br#""&'<>"#; | 
| 48 | 124 | 
 | 
| 49 |  | -        let n0 = match NonZeroU8::new(c / 10 + b'0') { | 
| 50 |  | -            Some(n) => n, | 
| 51 |  | -            None => panic!(), | 
| 52 |  | -        }; | 
| 53 |  | -        let n1 = match NonZeroU8::new(c % 10 + b'0') { | 
| 54 |  | -            Some(n) => n, | 
| 55 |  | -            None => panic!(), | 
| 56 |  | -        }; | 
| 57 |  | -        Some([n0, n1]) | 
|  | 125 | +/// The character with the smallest codepoint that needs HTML escaping. | 
|  | 126 | +/// Both [`TABLE`] and [`BITS`] start at this value instead of `0`. | 
|  | 127 | +const MIN_CHAR: u8 = { | 
|  | 128 | +    let mut v = u8::MAX; | 
|  | 129 | +    let mut i = 0; | 
|  | 130 | +    while i < CHARS.len() { | 
|  | 131 | +        if v > CHARS[i] { | 
|  | 132 | +            v = CHARS[i]; | 
|  | 133 | +        } | 
|  | 134 | +        i += 1; | 
|  | 135 | +    } | 
|  | 136 | +    v | 
|  | 137 | +}; | 
|  | 138 | + | 
|  | 139 | +#[allow(unused)] | 
|  | 140 | +const MAX_CHAR: u8 = { | 
|  | 141 | +    let mut v = u8::MIN; | 
|  | 142 | +    let mut i = 0; | 
|  | 143 | +    while i < CHARS.len() { | 
|  | 144 | +        if v < CHARS[i] { | 
|  | 145 | +            v = CHARS[i]; | 
|  | 146 | +        } | 
|  | 147 | +        i += 1; | 
| 58 | 148 |     } | 
|  | 149 | +    v | 
|  | 150 | +}; | 
|  | 151 | + | 
|  | 152 | +struct Table { | 
|  | 153 | +    _align: [usize; 0], | 
|  | 154 | +    lookup: [u16; u32::BITS as usize], | 
|  | 155 | +} | 
| 59 | 156 | 
 | 
|  | 157 | +/// For characters that need HTML escaping, the codepoint formatted as decimal digits, | 
|  | 158 | +/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`]. | 
|  | 159 | +const TABLE: Table = { | 
| 60 | 160 |     let mut table = Table { | 
| 61 | 161 |         _align: [], | 
| 62 |  | -        lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize], | 
|  | 162 | +        lookup: [0; u32::BITS as usize], | 
| 63 | 163 |     }; | 
| 64 |  | - | 
| 65 |  | -    table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"'); | 
| 66 |  | -    table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&'); | 
| 67 |  | -    table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\''); | 
| 68 |  | -    table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<'); | 
| 69 |  | -    table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>'); | 
|  | 164 | +    let mut i = 0; | 
|  | 165 | +    while i < CHARS.len() { | 
|  | 166 | +        let c = CHARS[i]; | 
|  | 167 | +        let h = c / 10 + b'0'; | 
|  | 168 | +        let l = c % 10 + b'0'; | 
|  | 169 | +        table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]); | 
|  | 170 | +        i += 1; | 
|  | 171 | +    } | 
| 70 | 172 |     table | 
| 71 | 173 | }; | 
|  | 174 | + | 
|  | 175 | +/// A bitset of the characters that need escaping, starting at [`MIN_CHAR`] | 
|  | 176 | +const BITS: u32 = { | 
|  | 177 | +    let mut i = 0; | 
|  | 178 | +    let mut bits = 0; | 
|  | 179 | +    while i < CHARS.len() { | 
|  | 180 | +        bits |= 1 << (CHARS[i] - MIN_CHAR) as u32; | 
|  | 181 | +        i += 1; | 
|  | 182 | +    } | 
|  | 183 | +    bits | 
|  | 184 | +}; | 
|  | 185 | + | 
|  | 186 | +// RATIONALE: llvm generates better code if the buffer is register sized | 
|  | 187 | +const ESCAPED_BUF_INIT: [u8; 8] = *b"&#__;\0\0\0"; | 
|  | 188 | +const ESCAPED_BUF_LEN: usize = b"&#__;".len(); | 
|  | 189 | + | 
|  | 190 | +#[test] | 
|  | 191 | +fn simple() { | 
|  | 192 | +    let mut buf = String::new(); | 
|  | 193 | +    write_escaped_str(&mut buf, "<script>").unwrap(); | 
|  | 194 | +    assert_eq!(buf, "<script>"); | 
|  | 195 | + | 
|  | 196 | +    buf.clear(); | 
|  | 197 | +    write_escaped_str(&mut buf, "s<crip>t").unwrap(); | 
|  | 198 | +    assert_eq!(buf, "s<crip>t"); | 
|  | 199 | + | 
|  | 200 | +    buf.clear(); | 
|  | 201 | +    write_escaped_str(&mut buf, "s<cripcripcripcripcripcripcripcripcripcrip>t").unwrap(); | 
|  | 202 | +    assert_eq!(buf, "s<cripcripcripcripcripcripcripcripcripcrip>t"); | 
|  | 203 | +} | 
0 commit comments