golang http中的context

源码分析

对 HTTP 服务器和客户端来说,超时处理是最容易犯错的问题之一。因为在网络连接到请求处理的多个阶段,都可能有相对应的超时时间。以 HTTP 请求为例,http.Client 有一个参数 Timeout 用于指定当前请求的总超时时间,它包括从连接、发送请求、到处理服务器响应的时间的总和。

1
2
3
4
client := &http.Client{
Timeout: 1 * time.Second,
}
resp, err := client.Get("https://baidu.com")

标准库 client.Do 方法内部会将超时时间换算为Deadline并传递到下一层。setRequestCancel 函数内部则会调用 context.WithDeadline ,派生出一个子 Context 并赋值给 req 中的 Context

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// net/http/client.go

func (c *Client) Do(req *Request) (*Response, error) {
return c.do(req)
}

func (c *Client) do(req *Request) (retres *Response, reterr error) {
...
deadline = c.deadline()
if resp, didTimeout, err = c.send(req, deadline); err != nil {
...
}
}

func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
...
stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
}

func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) {
// 选择一个较短的deadline
if oldCtx := req.Context(); timeBeforeContextDeadline(deadline, oldCtx) {
req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline)
}
...
}

在获取连接时,如果从闲置连接中找不到连接,则需要陷入 select 中去等待。如果连接时间超时,req.Context().Done() 通道会收到信号立即退出。在实际发送数据的 transport.roundTrip 函数中,也有很多通过在 select 语句中监听 Context 退出信号来实现超时控制的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// net/http/transport.go 

func (t *Transport) roundTrip(req *Request) (*Response, error) {
for {
select {
case <-ctx.Done():
req.closeBody()
return nil, ctx.Err()
default:
}

...
pconn, err := t.getConn(treq, cm)
}
}

func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error){
...
// 等待获取连接动作完成或取消
select {
case <-w.ready:
if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil {
// 执行 nettrace.Trace GotConn hook
trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()})
}
if w.err != nil {
select {
case <-req.Cancel:
return nil, errRequestCanceledConn
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-cancelc:
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
default:
}
}
return w.pc, w.err
case <-req.Cancel:
return nil, errRequestCanceledConn
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-cancelc:
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
}
}

获取 TCP 连接需要调用 sysDialer.dialSerial 方法,dialSerial 的功能是逐个遍历addrList中的addr ,如果与任一addr能成功建立连接则立即返回。代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// net/dial.go

func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
// dns解析,获取addr list
addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
if len(fallbacks) > 0 {
c, err = sd.dialParallel(ctx, primaries, fallbacks)
} else {
c, err = sd.dialSerial(ctx, primaries)
}
}

func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.

for i, ra := range ras {
select {
case <-ctx.Done():
return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}

dialCtx := ctx
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
// 计算连接的超时时间
partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
if err != nil {
// io timeout.
if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
}
break
}
if partialDeadline.Before(deadline) {
var cancel context.CancelFunc
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
defer cancel()
}
}

c, err := sd.dialSingle(dialCtx, ra)
if err == nil {
return c, nil
}
if firstErr == nil {
firstErr = err
}
}

if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}

dialSerial 函数中有几个典型的 Context 用法

  • 第16行代码遍历addr list时,判断 Context 是否已经退出,如果没有退出,会进入到 select 的 default 分支。如果通道已经退出了,则函数直接return

  • 第14行代码通过 ctx.Deadline() 判断是否传递进来的 Context 有超时时间。如果有超时时间,我们需要协调好后面每一个连接的超时时间。partialDeadline 会计算每一个连接的新的到期时间,如果该到期时间小于总到期时间,将派生出一个子 Context 传递给 dialSingle 函数,用于控制该连接的超时

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
    if deadline.IsZero() {
    return deadline, nil
    }
    timeRemaining := deadline.Sub(now)
    if timeRemaining <= 0 {
    return time.Time{}, errTimeout
    }
    // 暂时为每个剩余的addr分配相等的时间
    timeout := timeRemaining / time.Duration(addrsRemaining)
    // 如果每个addr获得的时间太短(小于2s),则使用所有remaining time
    const saneMinimum = 2 * time.Second
    if timeout < saneMinimum {
    if timeRemaining < saneMinimum {
    timeout = timeRemaining
    } else {
    timeout = saneMinimum
    }
    }
    return now.Add(timeout), nil
    }
  • dialSingle 函数中调用了 ctx.Value,用来获取一个特殊的接口 nettrace.Tracenettrace.Trace 用于对网络包中一些特殊的地方进行 hook。dialSingle 函数作为网络连接的起点,如果上下文中注入了 trace.ConnectStart 函数,则会在 dialSingle 函数之前调用 trace.ConnectStart 函数,如果上下文中注入了 trace.ConnectDone 函数,则会在执行 dialSingle 函数之后调用 trace.ConnectDone 函数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    // net/dial.go

    func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
    trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
    if trace != nil {
    raStr := ra.String()
    if trace.ConnectStart != nil {
    trace.ConnectStart(sd.network, raStr)
    }
    if trace.ConnectDone != nil {
    defer func() { trace.ConnectDone(sd.network, raStr, err) }()
    }
    }
    la := sd.LocalAddr
    switch ra := ra.(type) {
    case *TCPAddr:
    la, _ := la.(*TCPAddr)
    // tcp连接
    c, err = sd.dialTCP(ctx, la, ra)
    ...
    }

一个例子

在这个例子中,我在本地起了一个服务监听5001端口,并使用自定义的dns resolver,对于www.example.io这个domain,该dns resolver会返回两个地址:127.0.0.2 & 127.0.0.1,只有后一个addr可以建立连接。完整代码可以从这个仓库获取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package main

import (
"context"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptrace"
"time"

"go.uber.org/zap"
)

func main() {
var (
dnsResolverIP = "127.0.0.1:1053"
dnsResolverProto = "udp"
dnsResolverTimeoutMs = 500
// 上文提到,如果addr获得的时间小于2s,则会使用所有的remaining time,
// 如果这里改成一个小于2000的值,addr2将没有机会connect
dialTimeoutMs = 3000
)

l, err := zap.NewDevelopment()
if err != nil {
panic(err)
}
logger := l.Sugar()

dialer := &net.Dialer{
Timeout: time.Duration(dialTimeoutMs) * time.Millisecond,
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
dialer := net.Dialer{
Timeout: time.Duration(dnsResolverTimeoutMs) * time.Millisecond,
}
return dialer.DialContext(ctx, dnsResolverProto, dnsResolverIP)
},
},
}

dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr)
}

tr := http.DefaultTransport
tr.(*http.Transport).DialContext = dialContext

httpClient := &http.Client{
Transport: tr,
}

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

req, err := http.NewRequestWithContext(ctx, "GET", "http://www.example.io:5001", nil)
if err != nil {
log.Fatalln(err)
}

trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
logger.Debugf("[trace] Got Conn: %+v", connInfo)
},
ConnectDone: func(network, addr string, err error) {
logger.Debugf("[trace] Conn done, addr, err : %+v, %v", addr, err)
},
DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
logger.Debugf("[trace] DNS Info: %+v", dnsInfo)
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))

resp, err := httpClient.Do(req)
if err != nil {
logger.Fatal(err)
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
logger.Fatal(err)
}

logger.Debugw("get response body", "data", string(body))
}

输出

1
2
3
4
5
6
$ go run main.go
2023-02-26T16:44:01.226+0800 DEBUG gohttp/main.go:76 [trace] DNS Info: {Addrs:[{IP:127.0.0.2 Zone:} {IP:127.0.0.1 Zone:}] Err:<nil> Coalesced:false}
2023-02-26T16:44:03.230+0800 DEBUG gohttp/main.go:73 [trace] Conn done, addr, err : 127.0.0.2:5001, dial tcp 127.0.0.2:5001: i/o timeout
2023-02-26T16:44:03.230+0800 DEBUG gohttp/main.go:73 [trace] Conn done, addr, err : 127.0.0.1:5001, <nil>
2023-02-26T16:44:03.230+0800 DEBUG gohttp/main.go:70 [trace] Got Conn: {Conn:0xc000010010 Reused:false WasIdle:false IdleTime:0s}
2023-02-26T16:44:03.235+0800 DEBUG gohttp/main.go:93 get response body {"data": "hello"}

可以看到,第一个addr过了2s还未连接成功,返回io timeout错误,第二个addr连接成功,Transport获取到conn,最终完成了一个http请求的发送