基于serde实现一个RESP序列化库

此前的文章介绍了RESP协议的基本内容,现在我们基于serde库在rust中编写序列化RESP的库。

与Redis建立连接

为了方便后续测试,我们首先连接到Redis服务端,并手动输入一些命令。

打开Redis服务器,首先使用netcat测试连接,并手动输入指令,确认连接。

1
2
netcat -v localhost 6379
PING
1
+PONG

接下来我们使用Rust中的std::net编写连接到Redis的程序。

1
2
3
4
5
6
7
8
9
10
11
12
fn main() -> Result<(), io::Error>{
let addr = "127.0.0.1:6379";
let mut stream = TcpStream::connect(addr)?;
println!("Connect to {addr}");
let msg = b"PING\r\n";
stream.write(msg)?;
let mut data = [0u8; 5];
stream.read(&mut data)?;
let res = from_utf8(&data).unwrap();
println!("{res}");
Ok(())
}
1
2
Connect to 127.0.0.1:6379
+PONG

使用Serde编写自己的序列化库

在Serde的相关文档中介绍了如何使用Serde进行二次开发。

错误类型

首先单独定义Error枚举类型,定义在序列化,反序列化中可能遇到的错误。根据规范,需要在其中添加Error::Message(String)枚举以向serde返回错误消息。

这里我定义了一些可能遇到的错误:

  • Eof:读取数据时读取到Eof
  • Syntax:广义上的语法错误,当错误匹配不到具体类型时,记为该错误
  • ExpectedSign:首字节没有读取到+-等符号。
  • ExpectedBulkString$<SIZE>\r\n消息之后没有跟上字符串内容。
  • ExpectedArrayElement*<SIZE>\r\n消息之后遗失了数组元素。
  • UnexpectedCR:当读取简单字符串时其中包含了CR
  • UnexpectedType:类型不匹配。
  • IntegerOverflow:数字大小超过64位整数范围。
  • BulkStringOverflow:大容量字符串大小超过512MB。
  • WrongSizeOfBulkString$<SIZE>\r\n后的大容量字符串长度与<SIZE>不匹配。

RESP类型

我们定义RESP类型作为使用该库的主体,它是一个包含RESP五种类型的枚举类

1
2
3
4
5
6
7
pub enum RESPType {
SimpleString(String),
Integer(i64),
Error(String),
BulkString(Option<Vec<u8>>),
Array(Option<Vec<RESPType>>)
}

BulkStringArray枚举中都添加了Some结构,这是由于RESP允许空值,必须要把空值和""以及[]区分开来。我们希望完成编写后,可以实现这样的功能:

1
2
3
4
5
6
7
let arr = vec![
RESPType::Integer(32),
RESPType::SimpleString("foobar".to_owned()),
RESPType::BulkString(Some("really bulk".as_bytes().to_vec())),
];
let resp_arr = RESPType::Array(Some(arr));
assert_eq!(to_string(&resp_arr)?, "*3\r\n:32\r\n+foobar\r\n$11\r\nreally bulk\r\n");

这样的语法看起来有些冗长,我们可以考虑在以后的更新中使用宏编写简洁的代码。在目前我们以实现功能为主。

序列化

我们需要做两件事:编写序列化器,为RESP类型实现序列化。

编写Serializer

首先定义Serializer并为其实现Serializetrait

1
2
3
4
pub struct Serializer<W: Write> {
buffer: itoa::Buffer,
writer: W
}

考虑到这个协议大多用于通信,我们围绕W: impl Write实现我们的功能。注意到序列化器中有一个itoa::Buffer,它是itoa库中的内容,用于实现快速的数字转字符串。

Serde为自行编写序列化器提供了相当便利的方法,我们只需要实现SerializeTrait就行。这个Trait中包含了各种方法,每个方法序列化一种类型的值。

Serde库将所有Rust中的值划分为一些数据模型,具体内容可以从官方文档中查看到。RESP并不是一种通用数据格式,它只能用于一些特定值的序列化,因此我们不需要实现诸如tuple struct等数据类型的序列化,只需要针对每种RESP类型实现一个方法即可。调用其他方法直接返回Error::UnexpectedType,以期望使用库的用户不要在不支持的类型上进行序列化。

具体来说,我们实现以下方法:

  • serialize_ix:序列化整数
  • serialize_str:序列化简单字符串和错误
  • serialize_none:序列化空值
  • serialize_bytes:序列化大容量字符串
  • serialize_seq:序列化数组

值得注意的是,虽然字符串和错误类型的格式不同,但是Serde对此没有细粒度的区分,只有一种字符串序列化的方法,因此我们需要在为RESP类型实现序列化时再做区分。

对于数组类型,还需实现SerializeSeq。序列化数组分为三部:建立首部,逐个序列化元素,建立尾部。在Serialize中实现了建立首部,在SerializeSeq中实现逐个序列化和建立尾部。RESP的尾部没有任何标记,直接返回即可。对于每个元素,只需要调用各自的serialize方法。

为RESPType实现Serialize

这一部分较为简单。前四个枚举值基本只需调用serialize方法即可,对于Array,Serde有固定格式的序列化语句:

1
2
3
4
5
let mut serializer = serializer.serialize_seq(Some(arr.len()))?;
for element in arr {
serializer.serialize_element(element)?;
}
serializer.end()

注意到serialize_seq就是我们之前编写的,支持链式调用的方法,它为Array添加了首部之后返回自己。

反序列化

类似的,反序列化的目标是实现一个from_str函数,从这个函数返回RESPType。通过实现自己的DeserializerSeqAccessVisitor来实现这个函数。

我们的Deserializer定义为:

1
2
3
4
pub struct Deserializer<'de> {
input: &'de str,
offset: usize,
}

它的内部是一个生命周期为de的字符串切片与offset,表示正在处理的内容的偏移量,在异常时使用。基本的思路为:为Deserializer实现DeserializeTrait并实现单独序列化一个RESPType的方法。为了反序列化RESPArray,我们需要使用VisitorSeqAccess,后者提供了迭代RESPArray的方法,前者执行将Serialize的返回值转化为RESPType类型的过程。

具体的算法在Deserializer<'de>的方法中,依据RESP协议,我们从输入中获取数据,并解析为对应的&str i64 &[u8] 等值。随后我们使用基本方法实现Deserialize的相关方法。

Deserialize的方法依赖于Visitor,以deserialize_bytes为例:

1
2
3
4
5
6
7
8
9
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self.parse_bytes()? {
Some(bytes) => visitor.visit_bytes(bytes),
None => visitor.visit_none()
}
}

数据处理过程为:Deserializer处理为基本顺序→Visitor包装为RESPType数据。

为了遍历数组,定义RESPArrayAccess类,该类包含一个Deserializerremain_cnt,用于表示剩余的元素数量。

1
2
3
4
struct RESPArrayAccess<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
remain_cnt: usize,
}

我们实现SeqAccess的方法,并在Visitor中的visit_seq中使用它:

1
2
3
4
5
6
7
8
9
10
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut array: Vec<RESPType> = vec![];
while let Some(element) = seq.next_element()? {
array.push(element);
}
Ok(RESPType::Array(array))
}

由于RESP协议是自解释的,所以我们可以为Deserializer实现deserialize_any()

1
2
3
4
5
6
7
8
9
10
11
12
13
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
match self.peek_char()? {
'+' => self.deserialize_str(visitor),
'-' => self.deserialize_string(visitor),
':' => self.deserialize_i64(visitor),
'$' => self.deserialize_bytes(visitor),
'*' => self.deserialize_seq(visitor),
_ => Err(Error::ExpectedSign(self.offset)),
}
}

最后我们封装功能到from_str中,它会使用Deserializer进行解析。

1
2
3
4
5
6
7
8
9
10
11
12
pub fn from_str<T>(s: & str) -> Result<T>
where
T: DeserializeOwned,
{
let mut de = Deserializer::from_str(s);
let t = T::deserialize(&mut de)?;
if de.input.is_empty() {
Ok(t)
} else {
Err(Error::TrailingCharacters)
}
}

注意到泛型T为实现了DeserializeOwned的类型而非Deserialize<'a>,这是因为后者序列化为可能含有借用的类型,例如我们有一段数据,想要序列化为&str,需要保证&str的生命周期和数据的生命周期一致。但是我们使用的RESPType全部都是独占数据,不需要考虑这个问题。类似地实现from_reader

最后实现的项目开源在Github上,还有许多需要改进的内容(如实现宏)。仓库地址