Skip to content

Commit dfc15bc

Browse files
authored
Support custom (binary) data to be written into SSE Event (#3425)
1 parent d66cabd commit dfc15bc

File tree

1 file changed

+135
-94
lines changed

1 file changed

+135
-94
lines changed

axum/src/response/sse.rs

Lines changed: 135 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ use futures_util::stream::TryStream;
3939
use http_body::Frame;
4040
use pin_project_lite::pin_project;
4141
use std::{
42-
fmt, mem,
42+
fmt::{self, Write as _},
43+
io::Write as _,
44+
mem,
4345
pin::Pin,
4446
task::{ready, Context, Poll},
4547
time::Duration,
@@ -174,6 +176,27 @@ pub struct Event {
174176
flags: EventFlags,
175177
}
176178

179+
/// Expose [`Event`] as a [`std::fmt::Write`]
180+
/// such that any form of data can be written as data safely.
181+
///
182+
/// This also ensures that newline characters `\r` and `\n`
183+
/// correctly trigger a split with a new `data: ` prefix.
184+
///
185+
/// # Panics
186+
///
187+
/// Panics if any `data` has already been written prior to the first write
188+
/// of this [`EventDataWriter`] instance.
189+
#[derive(Debug)]
190+
#[must_use]
191+
pub struct EventDataWriter {
192+
event: Event,
193+
194+
// Indicates if _this_ EventDataWriter has written data,
195+
// this does not say anything about whether or not `event` contains
196+
// data or not.
197+
data_written: bool,
198+
}
199+
177200
impl Event {
178201
/// Default keep-alive event
179202
pub const DEFAULT_KEEP_ALIVE: Self = Self::finalized(Bytes::from_static(b":\n\n"));
@@ -185,6 +208,19 @@ impl Event {
185208
}
186209
}
187210

211+
/// Use this [`Event`] as a [`EventDataWriter`] to write custom data.
212+
///
213+
/// - [`Self::data`] can be used as a shortcut to write `str` data
214+
/// - [`Self::json_data`] can be used as a shortcut to write `json` data
215+
///
216+
/// Turn it into an [`Event`] again using [`EventDataWriter::into_event`].
217+
pub fn into_data_writer(self) -> EventDataWriter {
218+
EventDataWriter {
219+
event: self,
220+
data_written: false,
221+
}
222+
}
223+
188224
/// Set the event's data data field(s) (`data: <content>`)
189225
///
190226
/// Newlines in `data` will automatically be broken across `data: ` fields.
@@ -195,25 +231,16 @@ impl Event {
195231
///
196232
/// # Panics
197233
///
198-
/// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
199-
/// - Panics if `data` or `json_data` have already been called.
234+
/// Panics if any `data` has already been written before.
200235
///
201236
/// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
202-
pub fn data<T>(mut self, data: T) -> Self
237+
pub fn data<T>(self, data: T) -> Self
203238
where
204239
T: AsRef<str>,
205240
{
206-
if self.flags.contains(EventFlags::HAS_DATA) {
207-
panic!("Called `Event::data` multiple times");
208-
}
209-
210-
for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
211-
self.field("data", line);
212-
}
213-
214-
self.flags.insert(EventFlags::HAS_DATA);
215-
216-
self
241+
let mut writer = self.into_data_writer();
242+
let _ = writer.write_str(data.as_ref());
243+
writer.into_event()
217244
}
218245

219246
/// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`).
@@ -222,43 +249,31 @@ impl Event {
222249
///
223250
/// # Panics
224251
///
225-
/// Panics if `data` or `json_data` have already been called.
252+
/// Panics if any `data` has already been written before.
226253
///
227254
/// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
228255
#[cfg(feature = "json")]
229-
pub fn json_data<T>(mut self, data: T) -> Result<Self, axum_core::Error>
256+
pub fn json_data<T>(self, data: T) -> Result<Self, axum_core::Error>
230257
where
231258
T: serde::Serialize,
232259
{
233-
struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
234-
impl std::io::Write for IgnoreNewLines<'_> {
260+
struct JsonWriter<'a>(&'a mut EventDataWriter);
261+
impl std::io::Write for JsonWriter<'_> {
262+
#[inline]
235263
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
236-
let mut last_split = 0;
237-
for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
238-
self.0.write_all(&buf[last_split..delimiter])?;
239-
last_split = delimiter + 1;
240-
}
241-
self.0.write_all(&buf[last_split..])?;
242-
Ok(buf.len())
264+
Ok(self.0.write_buf(buf))
243265
}
244-
245266
fn flush(&mut self) -> std::io::Result<()> {
246-
self.0.flush()
267+
Ok(())
247268
}
248269
}
249-
if self.flags.contains(EventFlags::HAS_DATA) {
250-
panic!("Called `Event::json_data` multiple times");
251-
}
252270

253-
let buffer = self.buffer.as_mut();
254-
buffer.extend_from_slice(b"data: ");
255-
serde_json::to_writer(IgnoreNewLines(buffer.writer()), &data)
256-
.map_err(axum_core::Error::new)?;
257-
buffer.put_u8(b'\n');
271+
let mut writer = self.into_data_writer();
258272

259-
self.flags.insert(EventFlags::HAS_DATA);
273+
let json_writer = JsonWriter(&mut writer);
274+
serde_json::to_writer(json_writer, &data).map_err(axum_core::Error::new)?;
260275

261-
Ok(self)
276+
Ok(writer.into_event())
262277
}
263278

264279
/// Set the event's comment field (`:<comment-text>`).
@@ -407,6 +422,60 @@ impl Event {
407422
}
408423
}
409424

425+
impl EventDataWriter {
426+
/// Consume the [`EventDataWriter`] and return the [`Event`] once again.
427+
///
428+
/// In case any data was written by this instance
429+
/// it will also write the trailing `\n` character.
430+
pub fn into_event(self) -> Event {
431+
let mut event = self.event;
432+
if self.data_written {
433+
let _ = event.buffer.as_mut().write_char('\n');
434+
}
435+
event
436+
}
437+
}
438+
439+
impl EventDataWriter {
440+
// Assumption: underlying writer never returns an error:
441+
// <https://docs.rs/bytes/latest/src/bytes/buf/writer.rs.html#79-82>
442+
fn write_buf(&mut self, buf: &[u8]) -> usize {
443+
if buf.is_empty() {
444+
return 0;
445+
}
446+
447+
let buffer = self.event.buffer.as_mut();
448+
449+
if !std::mem::replace(&mut self.data_written, true) {
450+
if self.event.flags.contains(EventFlags::HAS_DATA) {
451+
panic!("Called `Event::data*` multiple times");
452+
}
453+
454+
let _ = buffer.write_str("data: ");
455+
self.event.flags.insert(EventFlags::HAS_DATA);
456+
}
457+
458+
let mut writer = buffer.writer();
459+
460+
let mut last_split = 0;
461+
for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
462+
let _ = writer.write_all(&buf[last_split..=delimiter]);
463+
let _ = writer.write_all(b"data: ");
464+
last_split = delimiter + 1;
465+
}
466+
let _ = writer.write_all(&buf[last_split..]);
467+
468+
buf.len()
469+
}
470+
}
471+
472+
impl fmt::Write for EventDataWriter {
473+
fn write_str(&mut self, s: &str) -> fmt::Result {
474+
let _ = self.write_buf(s.as_bytes());
475+
Ok(())
476+
}
477+
}
478+
410479
impl Default for Event {
411480
fn default() -> Self {
412481
Self {
@@ -566,32 +635,6 @@ where
566635
}
567636
}
568637

569-
fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
570-
MemchrSplit {
571-
needle,
572-
haystack: Some(haystack),
573-
}
574-
}
575-
576-
struct MemchrSplit<'a> {
577-
needle: u8,
578-
haystack: Option<&'a [u8]>,
579-
}
580-
581-
impl<'a> Iterator for MemchrSplit<'a> {
582-
type Item = &'a [u8];
583-
fn next(&mut self) -> Option<Self::Item> {
584-
let haystack = self.haystack?;
585-
if let Some(pos) = memchr::memchr(self.needle, haystack) {
586-
let (front, back) = haystack.split_at(pos);
587-
self.haystack = Some(&back[1..]);
588-
Some(front)
589-
} else {
590-
self.haystack.take()
591-
}
592-
}
593-
}
594-
595638
#[cfg(test)]
596639
mod tests {
597640
use super::*;
@@ -611,14 +654,40 @@ mod tests {
611654
}
612655

613656
#[test]
614-
fn valid_json_raw_value_chars_stripped() {
657+
fn write_data_writer_str() {
658+
// also confirm that nop writers do nothing :)
659+
let mut writer = Event::default()
660+
.into_data_writer()
661+
.into_event()
662+
.into_data_writer();
663+
writer.write_str("").unwrap();
664+
let mut writer = writer.into_event().into_data_writer();
665+
666+
writer.write_str("").unwrap();
667+
writer.write_str("moon ").unwrap();
668+
writer.write_str("star\nsun").unwrap();
669+
writer.write_str("").unwrap();
670+
writer.write_str("set").unwrap();
671+
writer.write_str("").unwrap();
672+
writer.write_str(" bye\r").unwrap();
673+
674+
let event = writer.into_event();
675+
676+
assert_eq!(
677+
&*event.finalize(),
678+
b"data: moon star\ndata: sunset bye\rdata: \n\n"
679+
);
680+
}
681+
682+
#[test]
683+
fn valid_json_raw_value_chars_handled() {
615684
let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}";
616685
let json_raw_value_event = Event::default()
617686
.json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
618687
.unwrap();
619688
assert_eq!(
620689
&*json_raw_value_event.finalize(),
621-
format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
690+
b"data: {\rdata: \"foo\": \ndata: \rdata: \rdata: \"bar\\n\"\ndata: }\n\n"
622691
);
623692
}
624693

@@ -763,32 +832,4 @@ mod tests {
763832

764833
fields
765834
}
766-
767-
#[test]
768-
fn memchr_splitting() {
769-
assert_eq!(
770-
memchr_split(2, &[]).collect::<Vec<_>>(),
771-
[&[]] as [&[u8]; 1]
772-
);
773-
assert_eq!(
774-
memchr_split(2, &[2]).collect::<Vec<_>>(),
775-
[&[], &[]] as [&[u8]; 2]
776-
);
777-
assert_eq!(
778-
memchr_split(2, &[1]).collect::<Vec<_>>(),
779-
[&[1]] as [&[u8]; 1]
780-
);
781-
assert_eq!(
782-
memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
783-
[&[1], &[]] as [&[u8]; 2]
784-
);
785-
assert_eq!(
786-
memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
787-
[&[], &[1]] as [&[u8]; 2]
788-
);
789-
assert_eq!(
790-
memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
791-
[&[1], &[], &[1]] as [&[u8]; 3]
792-
);
793-
}
794835
}

0 commit comments

Comments
 (0)