diff --git a/Cargo.lock b/Cargo.lock index 8e1bedf8..0f803c2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,9 +100,9 @@ dependencies = [ [[package]] name = "byteorder" -version = "1.3.2" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5" +checksum = "ae44d1a3d5a19df61dd0c8beb138458ac2a53a7ac09eba97d55592540004306b" [[package]] name = "bytes" @@ -643,7 +643,9 @@ dependencies = [ "log", "neovim-lib", "rmp", + "rmp-serde", "rmpv", + "serde", "tempdir", "tokio", "tokio-util", @@ -717,9 +719,9 @@ checksum = "369a6ed065f249a159e06c45752c780bda2fb53c995718f9e484d08daa9eb42e" [[package]] name = "proc-macro2" -version = "1.0.7" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0319972dcae462681daf4da1adeeaa066e3ebd29c69be96c6abb1259d2ee2bcc" +checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71" dependencies = [ "unicode-xid", ] @@ -856,6 +858,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "rmp-serde" +version = "0.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "839395ef53057db96b84c9238ab29e1a13f2e5c8ec9f66bef853ab4197303924" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rmpv" version = "0.4.3" @@ -915,9 +928,12 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.104" +version = "1.0.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414115f25f818d7dfccec8ee535d76949ae78584fc4f79a6f45a904bf8ab4449" +checksum = "92d5161132722baa40d802cc70b15262b98258453e85e5d1d365c757c73869ae" +dependencies = [ + "serde_derive", +] [[package]] name = "serde_bytes" @@ -930,9 +946,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.104" +version = "1.0.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128f9e303a5a29922045a830221b8f78ec74a5f544944f3d5984f8ec3895ef64" +checksum = "9391c295d64fc0abb2c556bad848f33cb8296276b1ad2677d1ae1ace4f258f31" dependencies = [ "proc-macro2", "quote", @@ -984,9 +1000,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.13" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e4ff033220a41d1a57d8125eab57bf5263783dfdcc18688b1dacc6ce9651ef8" +checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index e190fca6..485565b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,18 +29,21 @@ bench = false [dependencies] rmp = "0.8.9" -rmpv = "0.4.3" +rmpv = { version = "0.4.3", features = ["with-serde"] } log = "0.4.8" async-trait = "0.1.22" futures = { version = "0.3.1", features = ["io-compat"] } -tokio = { version = "1.1.1", features = ["full"] , optional = true} +tokio = { version = "1.1.1", features = ["full"] , optional = true } tokio-util = { version = "0.6.3", features = ["compat"], optional = true } async-std = { version = "1.4.0", features = ["attributes"], optional = true } neovim-lib = { version = "0.6.1", optional = true } +serde = { version = "1.0.123", features = ["derive"] } +rmp-serde = "0.15.4" [dev-dependencies] tempdir = "0.3" criterion = "0.3.0" +tokio = { version = "1.1.1", features = ["full"] } [profile.bench] lto = true diff --git a/src/error.rs b/src/error.rs index 862f5da4..88a5ae1c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -114,6 +114,7 @@ impl Display for InvalidMessage { pub enum DecodeError { /// Reading from the internal buffer failed. BufferError(RmpvDecodeError), + SerdeBufferError(rmp_serde::decode::Error), /// Reading from the stream failed. This is probably unrecoverable from, but /// might also mean that neovim closed the stream and wants the plugin to /// finish. See examples/quitting.rs on how this might be caught. @@ -126,6 +127,7 @@ impl Error for DecodeError { fn source(&self) -> Option<&(dyn Error + 'static)> { match *self { DecodeError::BufferError(ref e) => Some(e), + DecodeError::SerdeBufferError(ref e) => Some(e), DecodeError::InvalidMessage(ref e) => Some(e), DecodeError::ReaderError(ref e) => Some(e), } @@ -136,6 +138,9 @@ impl Display for DecodeError { fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { let s = match *self { DecodeError::BufferError(_) => "Error while reading from buffer", + DecodeError::SerdeBufferError(_) => { + "Error reading from buffer using serde" + } DecodeError::InvalidMessage(_) => "Error while decoding", DecodeError::ReaderError(_) => "Error while reading from Reader", }; @@ -150,6 +155,12 @@ impl From for Box { } } +impl From for Box { + fn from(err: rmp_serde::decode::Error) -> Box { + Box::new(DecodeError::SerdeBufferError(err)) + } +} + impl From for Box { fn from(err: InvalidMessage) -> Box { Box::new(DecodeError::InvalidMessage(err)) @@ -167,6 +178,8 @@ impl From for Box { pub enum EncodeError { /// Encoding the message into the internal buffer has failed. BufferError(RmpvEncodeError), + SerdeBufferError(rmp_serde::encode::Error), + ToValueError(rmpv::ext::Error), /// Writing the encoded message to the stream failed. WriterError(io::Error), } @@ -175,6 +188,8 @@ impl Error for EncodeError { fn source(&self) -> Option<&(dyn Error + 'static)> { match *self { EncodeError::BufferError(ref e) => Some(e), + EncodeError::ToValueError(ref e) => Some(e), + EncodeError::SerdeBufferError(ref e) => Some(e), EncodeError::WriterError(ref e) => Some(e), } } @@ -184,6 +199,8 @@ impl Display for EncodeError { fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { let s = match *self { Self::BufferError(_) => "Error writing to buffer", + Self::SerdeBufferError(_) => "Error writing to buffer using serde", + Self::ToValueError(_) => "Error converting serializable to Value", Self::WriterError(_) => "Error writing to the Writer", }; @@ -197,6 +214,18 @@ impl From for Box { } } +impl From for Box { + fn from(err: rmpv::ext::Error) -> Box { + Box::new(EncodeError::ToValueError(err)) + } +} + +impl From for Box { + fn from(err: rmp_serde::encode::Error) -> Self { + Box::new(EncodeError::SerdeBufferError(err)) + } +} + impl From for Box { fn from(err: io::Error) -> Box { Box::new(EncodeError::WriterError(err)) diff --git a/src/neovim.rs b/src/neovim.rs index ff93e71b..952c5d6f 100644 --- a/src/neovim.rs +++ b/src/neovim.rs @@ -1,5 +1,6 @@ //! An active neovim session. use std::{ + fmt, future::Future, sync::{ atomic::{AtomicU64, Ordering}, @@ -23,13 +24,19 @@ use crate::{ }, uioptions::UiAttachOptions, }; -use rmpv::Value; +use rmpv::{ext::to_value, Value}; +use serde::{self, Deserialize, Deserializer, Serialize}; /// Pack the given arguments into a `Vec`, suitable for using it for a /// [`call`](crate::neovim::Neovim::call) to neovim. #[macro_export] macro_rules! call_args { - () => (Vec::new()); + () => { + { + let vec: Vec<$crate::Value> = Vec::new(); + vec + } + }; ($($e:expr), +,) => (call_args![$($e),*]); ($($e:expr), +) => {{ let mut vec = Vec::new(); @@ -96,17 +103,31 @@ where (req, fut) } - async fn send_msg( + /// Will panic if args is serialized into something that is not an array + async fn send_msg( &self, method: &str, - args: Vec, + args: T, ) -> Result, Box> { let msgid = self.msgid_counter.fetch_add(1, Ordering::SeqCst); + fn get_args( + args: T, + ) -> Result, Box> { + debug!("Args value is {:?}", args); + let args_value = to_value(args)?; + debug!("Args value is {:?}", args_value); + + Ok(match args_value { + Value::Array(arr) => arr, + v => vec![v], + }) + } + let req = RpcMessage::RpcRequest { msgid, method: method.to_owned(), - params: args, + params: get_args(args)?, }; let (sender, receiver) = oneshot::channel(); @@ -119,10 +140,10 @@ where Ok(receiver) } - pub async fn call( + pub async fn call( &self, method: &str, - args: Vec, + args: T, ) -> Result, Box> { let receiver = self .send_msg(method, args) diff --git a/src/neovim_api.rs b/src/neovim_api.rs index be6f8b3a..de57fe09 100644 --- a/src/neovim_api.rs +++ b/src/neovim_api.rs @@ -1,7 +1,10 @@ //! The auto generated API for [`neovim`](crate::neovim::Neovim) //! //! Auto generated 2020-08-18 09:13:24.551223 +use std::fmt; + use futures::io::AsyncWrite; +use serde::Serialize; use crate::{ error::CallError, @@ -1009,13 +1012,17 @@ where .map_err(|v| Box::new(CallError::WrongValueType(v))) } - pub async fn call_function( + pub async fn call_function( &self, fname: &str, - args: Vec, + args: T, ) -> Result> { + + #[derive(Debug, Serialize)] + struct Args<'a, T: Serialize>(&'a str, T); + self - .call("nvim_call_function", call_args![fname, args]) + .call("nvim_call_function", Args(fname, args)) .await?? .try_unpack() .map_err(|v| Box::new(CallError::WrongValueType(v))) diff --git a/src/rpc/model.rs b/src/rpc/model.rs index 1a46d090..6c386d00 100644 --- a/src/rpc/model.rs +++ b/src/rpc/model.rs @@ -2,6 +2,7 @@ use std::{ self, convert::TryInto, + fmt, io::{self, Cursor, ErrorKind, Read}, sync::Arc, }; @@ -10,7 +11,14 @@ use futures::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter}, lock::Mutex, }; +use rmp_serde::encode; +use rmpv::ext::from_value; use rmpv::{decode::read_value, encode::write_value, Value}; +use serde::ser::{SerializeSeq, SerializeTuple}; +use serde::{ + de::{self, SeqAccess, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use crate::error::{DecodeError, EncodeError}; @@ -34,16 +42,110 @@ pub enum RpcMessage { }, // 2 } -macro_rules! rpc_args { - ($($e:expr), *) => {{ - let mut vec = Vec::new(); - $( - vec.push(Value::from($e)); - )* - Value::from(vec) - }} +impl Serialize for RpcMessage { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + RpcMessage::RpcRequest { + msgid, + method, + params, + } => { + let mut seq = serializer.serialize_seq(Some(4))?; + seq.serialize_element(&0)?; + seq.serialize_element(msgid)?; + seq.serialize_element(method)?; + seq.serialize_element(params)?; + seq.end() + } + RpcMessage::RpcResponse { + msgid, + error, + result, + } => { + let mut seq = serializer.serialize_seq(Some(4))?; + seq.serialize_element(&1)?; + seq.serialize_element(msgid)?; + seq.serialize_element(error)?; + seq.serialize_element(result)?; + seq.end() + } + + RpcMessage::RpcNotification { method, params } => { + let mut seq = serializer.serialize_seq(Some(3))?; + seq.serialize_element(&2)?; + seq.serialize_element(method)?; + seq.serialize_element(params)?; + seq.end() + } + } + } } +impl<'de> Deserialize<'de> for RpcMessage { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct RpcVisitor; + + impl<'de> Visitor<'de> for RpcVisitor { + type Value = RpcMessage; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("an array") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: SeqAccess<'de>, + { + let res = match seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))? + { + 0 => RpcMessage::RpcRequest { + msgid: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?, + method: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?, + params: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(3, &self))?, + }, + 1 => RpcMessage::RpcResponse { + msgid: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?, + error: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?, + result: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(3, &self))?, + }, + 2 => RpcMessage::RpcNotification { + method: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?, + params: seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?, + }, + i => return Err(de::Error::custom(format!("invalid id: {}", i))), + }; + + Ok(res) + } + } + + deserializer.deserialize_seq(RpcVisitor) + } +} /// Continously reads from reader, pushing onto `rest`. Then tries to decode the /// contents of `rest`. If it succeeds, returns the message, and leaves any /// non-decoded bytes in `rest`. If we did not read enough for a full message, @@ -66,12 +168,24 @@ pub async fn decode( *rest = rest.split_off(pos as usize); // TODO: more efficiency return Ok(msg); } + Err(DecodeError::BufferError(e)) if e.kind() == ErrorKind::UnexpectedEof => { debug!("Not enough data, reading more!"); bytes_read = reader.read(&mut *buf).await; } + + Err(DecodeError::SerdeBufferError( + rmp_serde::decode::Error::InvalidMarkerRead(e), + )) + | Err(DecodeError::SerdeBufferError( + rmp_serde::decode::Error::InvalidDataRead(e), + )) if e.kind() == ErrorKind::UnexpectedEof => { + debug!("Not enough data, reading more!"); + bytes_read = reader.read(&mut *buf).await; + } + Err(err) => return Err(err.into()), } @@ -91,6 +205,14 @@ pub async fn decode( /// give detailed errors if something went wrong. fn decode_buffer( reader: &mut R, +) -> std::result::Result> { + Ok(rmp_serde::decode::from_read(reader)?) +} + +/// Syncronously decode the content of a reader into an rpc message. Tries to +/// give detailed errors if something went wrong. +fn decode_buffer_other( + reader: &mut R, ) -> std::result::Result> { use crate::error::InvalidMessage::*; @@ -169,32 +291,10 @@ pub async fn encode( writer: Arc>>, msg: RpcMessage, ) -> std::result::Result<(), Box> { - let mut v: Vec = vec![]; - match msg { - RpcMessage::RpcRequest { - msgid, - method, - params, - } => { - let val = rpc_args!(0, msgid, method, params); - write_value(&mut v, &val)?; - } - RpcMessage::RpcResponse { - msgid, - error, - result, - } => { - let val = rpc_args!(1, msgid, error, result); - write_value(&mut v, &val)?; - } - RpcMessage::RpcNotification { method, params } => { - let val = rpc_args!(2, method, params); - write_value(&mut v, &val)?; - } - }; - + let mut buf: Vec = vec![]; + encode::write(&mut buf, &msg)?; let mut writer = writer.lock().await; - writer.write_all(&v).await?; + writer.write_all(&buf).await?; writer.flush().await?; Ok(()) @@ -265,13 +365,13 @@ impl IntoVal for Vec<(Value, Value)> { } } -#[cfg(all(test, feature = "use_tokio"))] +// #[cfg(all(test, feature = "use_tokio"))] mod test { use super::*; use futures::{io::BufWriter, lock::Mutex}; use std::{io::Cursor, sync::Arc}; - use tokio; + // use tokio; #[tokio::test] async fn request_test() { @@ -334,4 +434,30 @@ mod test { let msg_dest_2 = decode_buffer(&mut cursor).unwrap(); assert_eq!(msg_2, msg_dest_2); } + + #[tokio::test] + async fn decode_test() { + let msg_1 = RpcMessage::RpcRequest { + msgid: 1, + method: "test_method".to_owned(), + params: vec![], + }; + + let buff: Vec = vec![]; + let tmp = Arc::new(Mutex::new(BufWriter::new(buff))); + + let tmp_c = tmp.clone(); + encode(tmp_c.clone(), msg_1.clone()).await.unwrap(); + + let v = &mut *tmp_c.lock().await; + let x = v.get_mut(); + let mut cursor = Cursor::new(x.as_slice()); + println!("{:?}", cursor); + let msg_dest_1 = decode_buffer(&mut cursor).unwrap(); + // println!("{:?}", cursor); + // let msg_dest_2 = decode_buffer_other(&mut cursor).unwrap(); + + assert_eq!(msg_1, msg_dest_1); + // assert_eq!(msg_1, msg_dest_2); + } }