package libcore import ( "bytes" "context" "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/hex" "errors" "fmt" "io" "libcore/device" "libcore/ech" "log" "net" "net/http" "net/url" "os" "strconv" "sync" "sync/atomic" "time" "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/protocol/socks" "github.com/sagernet/sing/protocol/socks/socks5" ) var errFailConnectSocks5 = errors.New("fail connect socks5") type HTTPClient interface { RestrictedTLS() ModernTLS() PinnedTLS12() PinnedSHA256(sumHex string) TrySocks5(port int32) TryH3Direct() KeepAlive() NewRequest() HTTPRequest Close() } type HTTPRequest interface { SetURL(link string) error SetMethod(method string) SetHeader(key string, value string) SetContent(content []byte) SetContentString(content string) SetUserAgent(userAgent string) AllowInsecure() Execute() (HTTPResponse, error) } type HTTPResponse interface { GetHeader(string) *StringBox GetContent() ([]byte, error) GetContentString() (*StringBox, error) WriteTo(path string) error } var ( _ HTTPClient = (*httpClient)(nil) _ HTTPRequest = (*httpRequest)(nil) _ HTTPResponse = (*httpResponse)(nil) ) type httpClient struct { tls tls.Config h1h2Transport http.Transport h1h2Client http.Client trySocks5 bool tryH3Direct bool } func NewHttpClient() HTTPClient { client := new(httpClient) client.h1h2Client.Transport = &client.h1h2Transport client.h1h2Transport.TLSClientConfig = &client.tls client.h1h2Transport.DisableKeepAlives = true return client } func (c *httpClient) ModernTLS() { c.tls.MinVersion = tls.VersionTLS12 // c.tls.CipherSuites = nekoutils.Map(tls.CipherSuites(), func(it *tls.CipherSuite) uint16 { return it.ID }) } func (c *httpClient) RestrictedTLS() { c.tls.MinVersion = tls.VersionTLS13 // c.tls.CipherSuites = nekoutils.Map(nekoutils.Filter(tls.CipherSuites(), func(it *tls.CipherSuite) bool { // return nekoutils.Contains(it.SupportedVersions, uint16(tls.VersionTLS13)) // }), func(it *tls.CipherSuite) uint16 { // return it.ID // }) } func (c *httpClient) PinnedTLS12() { c.tls.MinVersion = tls.VersionTLS12 c.tls.MaxVersion = tls.VersionTLS12 } func (c *httpClient) PinnedSHA256(sumHex string) { c.tls.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { for _, rawCert := range rawCerts { certSum := sha256.Sum256(rawCert) if sumHex == hex.EncodeToString(certSum[:]) { return nil } } return errors.New("pinned sha256 sum mismatch") } } func (c *httpClient) TrySocks5(port int32) { dialer := new(net.Dialer) c.h1h2Transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { for { socksConn, err := dialer.DialContext(ctx, "tcp", "127.0.0.1:"+strconv.Itoa(int(port))) if err != nil { if c.tryH3Direct { return nil, errFailConnectSocks5 } break } _, err = socks.ClientHandshake5(socksConn, socks5.CommandConnect, metadata.ParseSocksaddr(addr), "", "") if err != nil { if c.tryH3Direct { return nil, errFailConnectSocks5 } break } return socksConn, err } return dialer.DialContext(ctx, network, addr) } c.trySocks5 = true } func (c *httpClient) TryH3Direct() { c.tryH3Direct = true } func (c *httpClient) KeepAlive() { c.h1h2Transport.ForceAttemptHTTP2 = true c.h1h2Transport.DisableKeepAlives = false } func (c *httpClient) NewRequest() HTTPRequest { req := &httpRequest{httpClient: c} req.request = http.Request{ Method: "GET", Header: http.Header{}, } return req } func (c *httpClient) Close() { c.h1h2Transport.CloseIdleConnections() } type httpRequest struct { *httpClient request http.Request } func (r *httpRequest) AllowInsecure() { r.tls.InsecureSkipVerify = true } func (r *httpRequest) SetURL(link string) (err error) { r.request.URL, err = url.Parse(link) if err != nil { return } if r.request.URL.User != nil { user := r.request.URL.User.Username() password, _ := r.request.URL.User.Password() r.request.SetBasicAuth(user, password) } return } func (r *httpRequest) SetMethod(method string) { r.request.Method = method } func (r *httpRequest) SetHeader(key string, value string) { r.request.Header.Set(key, value) } func (r *httpRequest) SetUserAgent(userAgent string) { r.request.Header.Set("User-Agent", userAgent) } func (r *httpRequest) SetContent(content []byte) { buffer := bytes.Buffer{} buffer.Write(content) r.request.Body = io.NopCloser(bytes.NewReader(buffer.Bytes())) r.request.ContentLength = int64(len(content)) } func (r *httpRequest) SetContentString(content string) { r.SetContent([]byte(content)) } func (r *httpRequest) Execute() (HTTPResponse, error) { defer device.DeferPanicToError("http execute", func(err error) { log.Println(err) }) // full direct if r.tryH3Direct && !r.trySocks5 { return r.doH3Direct() } response, err := r.h1h2Client.Do(&r.request) if err != nil { // trySocks5 && tryH3Direct if r.tryH3Direct && errors.Is(err, errFailConnectSocks5) { return r.doH3Direct() } return nil, err } httpResp := &httpResponse{Response: response} if response.StatusCode != http.StatusOK { return nil, errors.New(httpResp.errorString()) } return httpResp, nil } type requestFunc func() (response *http.Response, err error) func (r *httpRequest) doH3Direct() (HTTPResponse, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() successCh := make(chan *http.Response, 1) var finalErr error var failedCount atomic.Uint32 var successCount atomic.Uint32 var mu sync.Mutex funcs := []requestFunc{ // 普通,不再重试 socks5 func() (response *http.Response, err error) { request := r.request.Clone(context.Background()) h1h2Client := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, }, } return h1h2Client.Do(request) }, // ECH HTTPS func() (response *http.Response, err error) { request := r.request.Clone(context.Background()) echClient := &http.Client{ Transport: &http.Transport{ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer c, err := d.DialContext(ctx, network, addr) if err != nil { return c, err } domain := addr if host, _, _ := net.SplitHostPort(addr); host != "" { domain = host } echTls := ech.NewECHClientConfig(domain, &r.tls, gLocalDNSTransport) return echTls.ClientHandshake(ctx, c) }, DisableKeepAlives: true, }, } return echClient.Do(request) }, // H3 HTTPS func() (response *http.Response, err error) { request := r.request.Clone(context.Background()) h3Client := &http.Client{ Transport: &http3.Transport{ TLSClientConfig: r.tls.Clone(), QUICConfig: &quic.Config{ MaxIdleTimeout: time.Second, }, }, } return h3Client.Do(request) }, } if r.request.URL.Scheme == "http" { funcs = funcs[:1] } for i, f := range funcs { go func(f requestFunc) { defer device.DeferPanicToError("http", func(err error) { log.Println(err) }) defer func() { if successCount.Load() == 0 { if failedCount.Add(1) >= uint32(len(funcs)) { // 全部失败了 cancel() } } }() var t string switch i { case 0: t = "h1h2" case 1: t = "ech" case 2: t = "h3" } // 执行HTTP请求 rsp, err := f() if rsp == nil || err != nil { mu.Lock() finalErr = errors.Join(finalErr, fmt.Errorf("%s: %w", t, err)) mu.Unlock() if rsp != nil && rsp.Body != nil { rsp.Body.Close() } return } // 处理 HTTP 状态码 if rsp.StatusCode != http.StatusOK { hr := &httpResponse{Response: rsp} err = fmt.Errorf("%s: %s", t, hr.errorString()) mu.Lock() finalErr = errors.Join(finalErr, err) mu.Unlock() return } select { case successCh <- rsp: // 第一个成功的请求,不要关闭 body successCount.Add(1) default: rsp.Body.Close() } }(f) } select { case result := <-successCh: return &httpResponse{Response: result}, nil case <-ctx.Done(): return nil, finalErr } } type httpResponse struct { *http.Response getContentOnce sync.Once content []byte contentError error } func (h *httpResponse) errorString() string { content, err := h.getContentString() if err != nil { return fmt.Sprint("HTTP ", h.Status) } if len(content) > 100 { content = content[:100] + " ..." } return fmt.Sprint("HTTP ", h.Status, ": ", content) } func (h *httpResponse) GetHeader(key string) *StringBox { return wrapString(h.Header.Get(key)) } func (h *httpResponse) GetContent() ([]byte, error) { h.getContentOnce.Do(func() { defer h.Body.Close() h.content, h.contentError = io.ReadAll(h.Body) }) return h.content, h.contentError } func (h *httpResponse) GetContentString() (*StringBox, error) { content, err := h.getContentString() if err != nil { return nil, err } return wrapString(content), nil } func (h *httpResponse) getContentString() (string, error) { content, err := h.GetContent() if err != nil { return "", err } return string(content), nil } func (h *httpResponse) WriteTo(path string) error { defer h.Body.Close() file, err := os.Create(path) if err != nil { return err } defer file.Close() _, err = io.Copy(file, h.Body) return err }