转载自 https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md

《发明Service trait》一文中,我们深入探讨了Service的设计动机及其架构原理。虽然我们也动手实现过几个简易中间件,但当时采取了一些取巧方案。本指南将完整构建Tower现有的Timeout中间件,全程不采用任何捷径。

编写健壮的中间件需要运用比常规更底层的异步Rust技术。本指南旨在揭开这些核心概念与模式的神秘面纱,助你掌握中间件开发技能,甚至为Tower生态贡献代码!

准备工作

我们将构建的中间件是tower::timeout::Timeout。该组件会限定内部Service响应future的最大执行时长。若未能在指定时间内生成响应,则返回错误。这使得客户端可以重试请求或向用户报错,而非无限等待。

首先定义包含被包装服务和超时时长的Timeout结构体:

use std::time::Duration;

struct Timeout<S> {
    inner: S,
    timeout: Duration,
}

根据《发明Service trait》的指导,服务实现Clone trait至关重要——这允许将Service::call接收的&mut self转换为可移入响应future的独立所有权。因此我们为结构体添加#[derive(Clone)],同时一并实现Debug

#[derive(Debug, Clone)]
struct Timeout<S> {
    inner: S,
    timeout: Duration,
}

接着实现构造函数:

impl<S> Timeout<S> {
    pub fn new(inner: S, timeout: Duration) -> Self {
        Timeout { inner, timeout }
    }
}

注意我们遵循Rust API指南的建议,即便预期S会实现Service trait,此处也未添加任何约束。

现在进入关键环节:如何为Timeout<S>实现Service?先实现一个简单透传版本:

use tower::Service;
use std::task::{Context, Poll};

impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = S::Future;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // 中间件不关注背压,只要内部服务就绪即可
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        self.inner.call(request)
    }
}

在熟练编写中间件前,先搭建这样的代码骨架能显著降低实现难度。

要实现真正的超时控制,核心在于检测self.inner.call(request)返回的future执行是否超过self.timeout,若超时则终止并返回错误。

我们将采用以下方案:调用tokio::time::sleep获取超时future,然后通过select等待最先完成的future。虽然也可使用tokio::time::timeout,但sleep同样适用。

创建两个future的代码如下:

use tokio::time::sleep;

fn call(&mut self, request: Request) -> Self::Future {
    let response_future = self.inner.call(request);

    // 此变量类型为`tokio::time::Sleep`
    // 由于`self.timeout`实现`Copy` trait,无需显式克隆
    let sleep = tokio::time::sleep(self.timeout);

    // 此处应返回什么?
}

一种可能的返回类型是Pin<Box<dyn Future<...>>>。但为最小化Timeout的开销,我们希望能避免Box分配。设想一个包含数十层嵌套Service的调用栈,若每层都为请求分配新Box,将产生大量内存分配,进而影响性能1

响应future实现

为避免Box分配,我们选择自定义Future实现。首先创建名为ResponseFuture的结构体,需泛型化内部服务的响应future类型。这类似于用服务包装其他服务,但此处是用future包装其他future。

use tokio::time::Sleep;

pub struct ResponseFuture<F> {
    response_future: F,
    sleep: Sleep,
}

其中F对应self.inner.call(request)的类型。更新Service实现:

impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
{
    type Response = S::Response;
    type Error = S::Error;

    // 使用新的`ResponseFuture`类型
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let response_future = self.inner.call(request);
        let sleep = tokio::time::sleep(self.timeout);

        // 通过包装内部服务的future创建响应future
        ResponseFuture {
            response_future,
            sleep,
        }
    }
}

这里的关键在于Rust future具有_惰性_特性,即除非被await或poll,否则不会执行任何操作。因此self.inner.call(request)会立即返回而不会实际处理请求。

接下来为ResponseFuture实现Future

use std::{pin::Pin, future::Future};

impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
{
    type Output = Result<Response, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // 此处如何实现?
    }
}

理想情况下我们期望实现以下逻辑:

  1. 首先poll self.response_future,若就绪则返回其响应或错误
  2. 否则poll self.sleep,若就绪则返回超时错误
  3. 若两者均未就绪则返回Poll::Pending

初步尝试可能如下:

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    match self.response_future.poll(cx) {
        Poll::Ready(result) => return Poll::Ready(result),
        Poll::Pending => {}
    }

    todo!()
}

但这会导致如下错误:

error[E0599]: 在当前作用域中未找到类型参数`F`的`poll`方法
  --> src/lib.rs:56:29
   |
56 |         match self.response_future.poll(cx) {
   |                             ^^^^ 方法未找到
   |
   = 帮助: 只有类型参数受trait约束时才能使用trait中的项
帮助: 以下trait定义了`poll`项,可能需要为类型参数`F`添加约束:
   |
49 | impl<F: Future, Response, Error> Future for ResponseFuture<F>
   |      ^^^^^^^^^

虽然Rust的错误提示建议添加F: Future约束,但实际上我们已通过where F: Future<Output = Result<Response, E>>实现约束。真正的问题与[Pin]相关。

关于固定(Pinning)的完整讨论超出本指南范围。若对Pin不熟悉,推荐阅读Jon Gjengset的《Rust中Pinning的为什么、是什么和怎么做》

Rust试图告诉我们的是:需要Pin<&mut F>才能调用poll。当selfPin<&mut Self>时,通过self.response_future访问F无法正常工作。

我们需要"pin投影"——即从Pin<&mut Struct>转换到Pin<&mut Field>。通常这需要编写unsafe代码,但优秀的[pin-project] crate能安全处理这些底层细节。

使用pin-project时,我们用#[pin_project]标注结构体,并为需要固定访问的字段添加#[pin]

use pin_project::pin_project;

#[pin_project]
pub struct ResponseFuture<F> {
    #[pin]
    response_future: F,
    #[pin]
    sleep: Sleep,
}

impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
{
    type Output = Result<Response, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // 调用`#[pin_project]`生成的`project`魔法方法
        let this = self.project();

        // `project`返回`__ResponseFutureProjection`类型(可忽略具体类型)
        // 其字段与`ResponseFuture`匹配,并为标注`#[pin]`的字段维护pin

        // `this.response_future`现在是`Pin<&mut F>`
        let response_future: Pin<&mut F> = this.response_future;

        // `this.sleep`是`Pin<&mut Sleep>`
        let sleep: Pin<&mut Sleep> = this.sleep;

        // 若有未标注`#[pin]`的字段,则获得普通`&mut`引用(无`Pin`)

        // ...
    }
}

Rust中的固定机制虽然复杂难懂,但借助pin-project我们可以规避大部分复杂性。关键在于,即使不完全理解固定机制,也能编写Tower中间件!所以如果你对PinUnpin还存有疑惑,请放心使用pin-project!

注意在前述代码中,我们获得了Pin<&mut F>Pin<&mut Sleep>,这正是调用poll所需的:

impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
{
    type Output = Result<Response, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        // 首先检查响应future是否就绪
        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                // 内部服务已准备好响应或失败
                return Poll::Ready(result);
            }
            Poll::Pending => {
                // 尚未就绪...
            }
        }

        // 然后检查sleep是否就绪。若就绪则表示响应超时
        match this.sleep.poll(cx) {
            Poll::Ready(()) => {
                // 超时触发,但应返回什么错误?!
                todo!()
            }
            Poll::Pending => {
                // 仍有剩余时间...
            }
        }

        // 若两者均未就绪,则返回Pending
        Poll::Pending
    }
}

现在唯一剩下的问题是:当sleep先完成时,应该返回什么错误?

错误类型设计

当前我们承诺返回的泛型Error类型与内部服务的错误类型相同。但对该类型我们一无所知——它完全不透明且无法构造其值。

我们有三条路径可选:

  1. 返回装箱的错误特征对象,如Box<dyn std::error::Error + Send + Sync>
  2. 返回包含服务错误和超时错误的枚举类型
  3. 定义TimeoutError结构体,并要求泛型错误类型可通过TimeoutError: Into<Error>构造

虽然选项3看似最灵活,但要求使用自定义错误类型的用户手动实现From<TimeoutError> for MyError。当使用多个自带错误类型的中间件时,这种操作会变得繁琐。

选项2需要定义如下枚举:

enum TimeoutError<Error> {
    // 超时触发的变体
    Timeout(InnerTimeoutError),
    // 内部服务产生错误的变体
    Service(Error),
}

虽然表面上看能保留完整类型信息且可通过match精确处理错误,但存在三个问题:

  1. 实践中常会嵌套大量中间件,导致最终错误枚举异常庞大。类似BufferError<RateLimitError<TimeoutError<MyError>>>的类型很常见,对此类类型进行模式匹配(例如判断错误是否可重试)将非常繁琐
  2. 调整中间件顺序会改变最终错误类型,需要同步更新模式匹配
  3. 最终错误类型可能占用大量栈空间

因此我们选择选项1:将内部服务错误转换为装箱特征对象Box<dyn std::error::Error + Send + Sync>。这样可将多种错误类型统一处理,具有以下优势:

  1. 错误处理更健壮,调整中间件顺序不会改变最终错误类型
  2. 错误类型具有固定大小,不受中间件数量影响
  3. 提取错误时无需大型match,可使用error.downcast_ref::<Timeout>()

但也存在以下缺点:

  1. 使用动态转换后,编译器无法保证检查所有可能的错误类型
  2. 创建错误需要进行内存分配。实践中错误应属罕见,故影响有限

选择哪种方案取决于个人偏好。Tower最终采用的是装箱特征对象方案,原始讨论参见这里

对于Timeout中间件,我们需要创建实现std::error::Error的结构体,以便转换为Box<dyn std::error::Error + Send + Sync>。同时要求内部服务的错误类型实现Into<Box<dyn std::error::Error + Send + Sync>>。幸运的是大多数错误类型自动满足该条件,用户无需编写额外代码。根据标准库建议,我们使用Into而非From作为trait约束。

错误类型实现如下:

use std::fmt;

#[derive(Debug, Default)]
pub struct TimeoutError(());

impl fmt::Display for TimeoutError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.pad("request timed out")
    }
}

impl std::error::Error for TimeoutError {}

我们向 TimeoutError 添加了一个私有字段,这样Tower外部的用户就无法构造自己的TimeoutError 。他们只能通过我们的中间件获取。

Box<dyn std::error::Error + Send + Sync> 这个表达确实有些冗长,因此我们为它定义一个类型别名:

// 该类型在`tower`中定义为`tower::BoxError`
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

现在future实现更新为:

impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
    // 要求内部服务错误可转换为`BoxError`
    Error: Into<BoxError>,
{
    type Output = Result<
        Response,
        // `ResponseFuture`的错误类型现在是`BoxError`
        BoxError,
    >;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                // 使用`map_err`转换错误类型
                let result = result.map_err(Into::into);
                return Poll::Ready(result);
            }
            Poll::Pending => {}
        }

        match this.sleep.poll(cx) {
            Poll::Ready(()) => {
                // 构造并返回超时错误
                let error = Box::new(TimeoutError(()));
                return Poll::Ready(Err(error));
            }
            Poll::Pending => {}
        }

        Poll::Pending
    }
}

最后需要更新Service实现,同样使用BoxError

impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
    // 与`ResponseFuture`的future实现相同约束
    S::Error: Into<BoxError>,
{
    type Response = S::Response;
    // `Timeout`的错误类型现在是`BoxError`
    type Error = BoxError;
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // 此处也需要转换错误类型
        self.inner.poll_ready(cx).map_err(Into::into)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let response_future = self.inner.call(request);
        let sleep = tokio::time::sleep(self.timeout);

        ResponseFuture {
            response_future,
            sleep,
        }
    }
}

最终成果

至此我们已成功实现了与Tower现有版本完全一致的Timeout中间件!

完整实现如下:

use pin_project::pin_project;
use std::time::Duration;
use std::{
    fmt,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tokio::time::Sleep;
use tower::Service;

#[derive(Debug, Clone)]
struct Timeout<S> {
    inner: S,
    timeout: Duration,
}

impl<S> Timeout<S> {
    fn new(inner: S, timeout: Duration) -> Self {
        Timeout { inner, timeout }
    }
}

impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
    S::Error: Into<BoxError>,
{
    type Response = S::Response;
    type Error = BoxError;
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map_err(Into::into)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let response_future = self.inner.call(request);
        let sleep = tokio::time::sleep(self.timeout);

        ResponseFuture {
            response_future,
            sleep,
        }
    }
}

#[pin_project]
struct ResponseFuture<F> {
    #[pin]
    response_future: F,
    #[pin]
    sleep: Sleep,
}

impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
    Error: Into<BoxError>,
{
    type Output = Result<Response, BoxError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                let result = result.map_err(Into::into);
                return Poll::Ready(result);
            }
            Poll::Pending => {}
        }

        match this.sleep.poll(cx) {
            Poll::Ready(()) => {
                let error = Box::new(TimeoutError(()));
                return Poll::Ready(Err(error));
            }
            Poll::Pending => {}
        }

        Poll::Pending
    }
}

#[derive(Debug, Default)]
struct TimeoutError(());

impl fmt::Display for TimeoutError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.pad("request timed out")
    }
}

impl std::error::Error for TimeoutError {}

type BoxError = Box<dyn std::error::Error + Send + Sync>;

Tower中的完整实现参见这里

这种为包装其他Service的类型实现Service trait,并返回包装其他FutureFuture的模式,是大多数Tower中间件的工作方式。

其他典型示例包括:

掌握这些知识后,你应该已具备编写生产级中间件的能力。以下练习可供实践:

如有疑问,欢迎加入Tokio Discord服务器#tower频道交流。


  1. Rust编译器团队计划增加"类型别名中的impl Trait"功能,将允许从call返回impl Future,但目前尚未实现。