mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2025-12-20 00:50:06 +08:00
Merge 6846af6376 into 17966b5418
This commit is contained in:
commit
6ed49bbc41
@ -14,6 +14,7 @@ import (
|
||||
tlsC "github.com/metacubex/mihomo/component/tls"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/transport/gun"
|
||||
"github.com/metacubex/mihomo/transport/splithttp"
|
||||
"github.com/metacubex/mihomo/transport/vless"
|
||||
"github.com/metacubex/mihomo/transport/vless/encryption"
|
||||
"github.com/metacubex/mihomo/transport/vmess"
|
||||
@ -61,6 +62,7 @@ type VlessOption struct {
|
||||
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
|
||||
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
|
||||
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
|
||||
SplitHTTPOpts SplitHTTPOptions `proxy:"splithttp-opts,omitempty"`
|
||||
WSOpts WSOptions `proxy:"ws-opts,omitempty"`
|
||||
WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
|
||||
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
|
||||
@ -150,6 +152,70 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
|
||||
}
|
||||
|
||||
c, err = vmess.StreamH2Conn(ctx, c, h2Opts)
|
||||
case "splithttp", "xhttp":
|
||||
c.Close()
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if v.option.TLS {
|
||||
host, _, _ := net.SplitHostPort(v.addr)
|
||||
tlsConfig, err = ca.GetTLSConfig(ca.Option{
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: host,
|
||||
InsecureSkipVerify: v.option.SkipCertVerify,
|
||||
NextProtos: v.option.ALPN,
|
||||
},
|
||||
Fingerprint: v.option.Fingerprint,
|
||||
Certificate: v.option.Certificate,
|
||||
PrivateKey: v.option.PrivateKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v.option.ServerName != "" {
|
||||
tlsConfig.ServerName = v.option.ServerName
|
||||
} else if host := v.option.SplitHTTPOpts.Headers["Host"]; host != "" {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
|
||||
// Default ALPN if empty
|
||||
if len(tlsConfig.NextProtos) == 0 {
|
||||
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
}
|
||||
}
|
||||
|
||||
conf := &splithttp.Config{
|
||||
Host: v.option.SplitHTTPOpts.Host,
|
||||
Path: v.option.SplitHTTPOpts.Path,
|
||||
Mode: v.option.SplitHTTPOpts.Mode,
|
||||
Headers: v.option.SplitHTTPOpts.Headers,
|
||||
NoGRPCHeader: v.option.SplitHTTPOpts.NoGRPCHeader,
|
||||
XPaddingBytes: v.option.SplitHTTPOpts.XPaddingBytes,
|
||||
ScMaxEachPostBytes: v.option.SplitHTTPOpts.ScMaxEachPostBytes,
|
||||
ScMinPostsIntervalMs: v.option.SplitHTTPOpts.ScMinPostsIntervalMs,
|
||||
ScMaxBufferedPosts: int32(v.option.SplitHTTPOpts.ScMaxBufferedPosts),
|
||||
ScStreamUpServerSecs: v.option.SplitHTTPOpts.ScStreamUpServerSecs,
|
||||
Xmux: v.option.SplitHTTPOpts.Xmux,
|
||||
}
|
||||
|
||||
if conf.Host == "" {
|
||||
host, _, _ := net.SplitHostPort(v.addr)
|
||||
conf.Host = host
|
||||
}
|
||||
|
||||
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var err error
|
||||
var cDialer C.Dialer = dialer.NewDialer(v.DialOptions()...)
|
||||
if len(v.option.DialerProxy) > 0 {
|
||||
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cDialer.DialContext(ctx, "tcp", v.addr)
|
||||
}
|
||||
|
||||
c, err = splithttp.Dial(ctx, dialFn, conf, tlsConfig)
|
||||
case "grpc":
|
||||
c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.echConfig, v.realityConfig)
|
||||
default:
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/ntp"
|
||||
"github.com/metacubex/mihomo/transport/gun"
|
||||
"github.com/metacubex/mihomo/transport/splithttp"
|
||||
mihomoVMess "github.com/metacubex/mihomo/transport/vmess"
|
||||
|
||||
"github.com/metacubex/http"
|
||||
@ -64,6 +65,7 @@ type VmessOption struct {
|
||||
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
|
||||
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
|
||||
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
|
||||
SplitHTTPOpts SplitHTTPOptions `proxy:"splithttp-opts,omitempty"`
|
||||
WSOpts WSOptions `proxy:"ws-opts,omitempty"`
|
||||
PacketAddr bool `proxy:"packet-addr,omitempty"`
|
||||
XUDP bool `proxy:"xudp,omitempty"`
|
||||
@ -73,6 +75,20 @@ type VmessOption struct {
|
||||
ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
|
||||
}
|
||||
|
||||
type SplitHTTPOptions struct {
|
||||
Host string `proxy:"host,omitempty"`
|
||||
Path string `proxy:"path,omitempty"`
|
||||
Mode string `proxy:"mode,omitempty"`
|
||||
Headers map[string]string `proxy:"headers,omitempty"`
|
||||
NoGRPCHeader bool `proxy:"no-grpc-header,omitempty"`
|
||||
XPaddingBytes *splithttp.RangeConfig `proxy:"x-padding-bytes,omitempty"`
|
||||
ScMaxEachPostBytes *splithttp.RangeConfig `proxy:"max-each-post-bytes,omitempty"`
|
||||
ScMinPostsIntervalMs *splithttp.RangeConfig `proxy:"min-posts-interval,omitempty"`
|
||||
ScMaxBufferedPosts int `proxy:"max-buffered-posts,omitempty"`
|
||||
ScStreamUpServerSecs *splithttp.RangeConfig `proxy:"stream-up-server-secs,omitempty"`
|
||||
Xmux *splithttp.XmuxConfig `proxy:"xmux,omitempty"`
|
||||
}
|
||||
|
||||
type HTTPOptions struct {
|
||||
Method string `proxy:"method,omitempty"`
|
||||
Path []string `proxy:"path,omitempty"`
|
||||
@ -203,6 +219,72 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
|
||||
}
|
||||
|
||||
c, err = mihomoVMess.StreamH2Conn(ctx, c, h2Opts)
|
||||
case "splithttp", "xhttp":
|
||||
c.Close()
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if v.option.TLS {
|
||||
host, _, _ := net.SplitHostPort(v.addr)
|
||||
tlsConfig, err = ca.GetTLSConfig(ca.Option{
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: host,
|
||||
InsecureSkipVerify: v.option.SkipCertVerify,
|
||||
NextProtos: v.option.ALPN,
|
||||
},
|
||||
Fingerprint: v.option.Fingerprint,
|
||||
Certificate: v.option.Certificate,
|
||||
PrivateKey: v.option.PrivateKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v.option.ServerName != "" {
|
||||
tlsConfig.ServerName = v.option.ServerName
|
||||
} else if host := v.option.SplitHTTPOpts.Headers["Host"]; host != "" {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
|
||||
// Default ALPN if empty
|
||||
if len(tlsConfig.NextProtos) == 0 {
|
||||
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
}
|
||||
|
||||
// Handle ECH and Reality if needed (skipped for now or requires more adaptation)
|
||||
}
|
||||
|
||||
conf := &splithttp.Config{
|
||||
Host: v.option.SplitHTTPOpts.Host,
|
||||
Path: v.option.SplitHTTPOpts.Path,
|
||||
Mode: v.option.SplitHTTPOpts.Mode,
|
||||
Headers: v.option.SplitHTTPOpts.Headers,
|
||||
NoGRPCHeader: v.option.SplitHTTPOpts.NoGRPCHeader,
|
||||
XPaddingBytes: v.option.SplitHTTPOpts.XPaddingBytes,
|
||||
ScMaxEachPostBytes: v.option.SplitHTTPOpts.ScMaxEachPostBytes,
|
||||
ScMinPostsIntervalMs: v.option.SplitHTTPOpts.ScMinPostsIntervalMs,
|
||||
ScMaxBufferedPosts: int32(v.option.SplitHTTPOpts.ScMaxBufferedPosts),
|
||||
ScStreamUpServerSecs: v.option.SplitHTTPOpts.ScStreamUpServerSecs,
|
||||
Xmux: v.option.SplitHTTPOpts.Xmux,
|
||||
}
|
||||
|
||||
if conf.Host == "" {
|
||||
host, _, _ := net.SplitHostPort(v.addr)
|
||||
conf.Host = host
|
||||
}
|
||||
|
||||
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var err error
|
||||
var cDialer C.Dialer = dialer.NewDialer(v.DialOptions()...)
|
||||
if len(v.option.DialerProxy) > 0 {
|
||||
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cDialer.DialContext(ctx, "tcp", v.addr)
|
||||
}
|
||||
|
||||
c, err = splithttp.Dial(ctx, dialFn, conf, tlsConfig)
|
||||
case "grpc":
|
||||
c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.echConfig, v.realityConfig)
|
||||
default:
|
||||
|
||||
515
transport/splithttp/client.go
Normal file
515
transport/splithttp/client.go
Normal file
@ -0,0 +1,515 @@
|
||||
package splithttp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/metacubex/mihomo/log"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type DialerClient interface {
|
||||
IsClosed() bool
|
||||
OpenStream(context.Context, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error)
|
||||
PostPacket(context.Context, string, io.Reader, int64) error
|
||||
}
|
||||
|
||||
type DefaultDialerClient struct {
|
||||
transportConfig *Config
|
||||
client *http.Client
|
||||
closed bool
|
||||
httpVersion string
|
||||
uploadRawPool *sync.Pool
|
||||
dialUploadConn func(ctxInner context.Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
func (c *DefaultDialerClient) IsClosed() bool {
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (wrc io.ReadCloser, remoteAddr, localAddr net.Addr, err error) {
|
||||
gotConn := make(chan struct{})
|
||||
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
|
||||
GotConn: func(connInfo httptrace.GotConnInfo) {
|
||||
remoteAddr = connInfo.Conn.RemoteAddr()
|
||||
localAddr = connInfo.Conn.LocalAddr()
|
||||
close(gotConn)
|
||||
},
|
||||
})
|
||||
|
||||
method := "GET"
|
||||
if body != nil {
|
||||
method = "POST"
|
||||
}
|
||||
req, _ := http.NewRequestWithContext(ctx, method, url, body)
|
||||
req.Header = c.transportConfig.GetRequestHeader(url)
|
||||
if method == "POST" && !c.transportConfig.NoGRPCHeader {
|
||||
req.Header.Set("Content-Type", "application/grpc")
|
||||
}
|
||||
|
||||
wrcWrapper := &WaitReadCloser{Wait: make(chan struct{})}
|
||||
wrc = wrcWrapper
|
||||
|
||||
go func() {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
if !uploadOnly {
|
||||
c.closed = true
|
||||
log.Debugln("failed to %s %s: %s", method, url, err)
|
||||
}
|
||||
// If GotConn didn't fire (e.g. connection error), we need to unblock the caller
|
||||
select {
|
||||
case <-gotConn:
|
||||
default:
|
||||
close(gotConn)
|
||||
}
|
||||
wrcWrapper.Close()
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != 200 && !uploadOnly {
|
||||
log.Debugln("unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
if resp.StatusCode != 200 || uploadOnly {
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
wrcWrapper.Close()
|
||||
return
|
||||
}
|
||||
wrcWrapper.Set(resp.Body)
|
||||
}()
|
||||
|
||||
<-gotConn
|
||||
return
|
||||
}
|
||||
|
||||
func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body io.Reader, contentLength int64) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.ContentLength = contentLength
|
||||
req.Header = c.transportConfig.GetRequestHeader(url)
|
||||
|
||||
if c.httpVersion != "1.1" {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
c.closed = true
|
||||
return err
|
||||
}
|
||||
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("bad status code: %s", resp.Status)
|
||||
}
|
||||
} else {
|
||||
requestBuff := new(bytes.Buffer)
|
||||
// req.Write in standard lib writes the wire format.
|
||||
// Check if Xray's req.Write usage implies standard http.Request.Write.
|
||||
// Yes, Xray uses common.Must(req.Write(requestBuff))
|
||||
if err := req.Write(requestBuff); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uploadConn any
|
||||
var h1UploadConn *H1Conn
|
||||
|
||||
for {
|
||||
uploadConn = c.uploadRawPool.Get()
|
||||
newConnection := uploadConn == nil
|
||||
if newConnection {
|
||||
newConn, err := c.dialUploadConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h1UploadConn = NewH1Conn(newConn)
|
||||
uploadConn = h1UploadConn
|
||||
} else {
|
||||
h1UploadConn = uploadConn.(*H1Conn)
|
||||
|
||||
if h1UploadConn.UnreadedResponsesCount > 0 {
|
||||
resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req)
|
||||
if err != nil {
|
||||
c.closed = true
|
||||
return fmt.Errorf("error while reading response: %s", err.Error())
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("got non-200 error response code: %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err := h1UploadConn.Write(requestBuff.Bytes())
|
||||
if err == nil {
|
||||
break
|
||||
} else if newConnection {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.uploadRawPool.Put(uploadConn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type WaitReadCloser struct {
|
||||
Wait chan struct{}
|
||||
io.ReadCloser
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *WaitReadCloser) Set(rc io.ReadCloser) {
|
||||
w.mu.Lock()
|
||||
w.ReadCloser = rc
|
||||
w.mu.Unlock()
|
||||
|
||||
// Avoid panic if closed twice
|
||||
defer func() { recover() }()
|
||||
close(w.Wait)
|
||||
}
|
||||
|
||||
func (w *WaitReadCloser) Read(b []byte) (int, error) {
|
||||
if w.ReadCloser == nil {
|
||||
if <-w.Wait; w.ReadCloser == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
return w.ReadCloser.Read(b)
|
||||
}
|
||||
|
||||
func (w *WaitReadCloser) Close() error {
|
||||
w.mu.Lock()
|
||||
rc := w.ReadCloser
|
||||
w.mu.Unlock()
|
||||
|
||||
if rc != nil {
|
||||
return rc.Close()
|
||||
}
|
||||
defer func() { recover() }()
|
||||
close(w.Wait)
|
||||
return nil
|
||||
}
|
||||
|
||||
type DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
type dialerConf struct {
|
||||
Host string // used as key
|
||||
*Config
|
||||
}
|
||||
|
||||
var (
|
||||
globalDialerMap map[dialerConf]*XmuxManager
|
||||
globalDialerAccess sync.Mutex
|
||||
)
|
||||
|
||||
func decideHTTPVersion(tlsConfig *tls.Config) string {
|
||||
if tlsConfig == nil {
|
||||
return "1.1"
|
||||
}
|
||||
if len(tlsConfig.NextProtos) != 1 {
|
||||
return "2"
|
||||
}
|
||||
if tlsConfig.NextProtos[0] == "http/1.1" {
|
||||
return "1.1"
|
||||
}
|
||||
if tlsConfig.NextProtos[0] == "h3" {
|
||||
return "3"
|
||||
}
|
||||
return "2"
|
||||
}
|
||||
|
||||
func getHTTPClient(ctx context.Context, dialFn DialFn, config *Config, tlsConfig *tls.Config) (DialerClient, *XmuxClient) {
|
||||
globalDialerAccess.Lock()
|
||||
defer globalDialerAccess.Unlock()
|
||||
|
||||
if globalDialerMap == nil {
|
||||
globalDialerMap = make(map[dialerConf]*XmuxManager)
|
||||
}
|
||||
|
||||
// Use Host and Config pointer as key.
|
||||
// Note: Config might be different instance with same content.
|
||||
// Ideally we should use value semantics or consistent ID.
|
||||
// For now using Host and Config pointer.
|
||||
key := dialerConf{config.Host, config}
|
||||
|
||||
xmuxManager, found := globalDialerMap[key]
|
||||
|
||||
if !found {
|
||||
var xmuxConfig XmuxConfig
|
||||
if config.Xmux != nil {
|
||||
xmuxConfig = *config.Xmux
|
||||
}
|
||||
|
||||
xmuxManager = NewXmuxManager(xmuxConfig, func() XmuxConn {
|
||||
return createHTTPClient(dialFn, config, tlsConfig)
|
||||
})
|
||||
globalDialerMap[key] = xmuxManager
|
||||
}
|
||||
|
||||
xmuxClient := xmuxManager.GetXmuxClient(ctx)
|
||||
return xmuxClient.XmuxConn.(DialerClient), xmuxClient
|
||||
}
|
||||
|
||||
func createHTTPClient(dialFn DialFn, config *Config, tlsConfig *tls.Config) DialerClient {
|
||||
httpVersion := decideHTTPVersion(tlsConfig)
|
||||
|
||||
dialContext := func(ctxInner context.Context) (net.Conn, error) {
|
||||
// network and addr are not used by dialFn in this context usually,
|
||||
// because dialFn is already bound to a destination in vless.
|
||||
// But DialFn signature has network/addr.
|
||||
// In vless.go, dialFn is `func(ctx, network, addr)`.
|
||||
return dialFn(ctxInner, "tcp", config.Host)
|
||||
}
|
||||
|
||||
var keepAlivePeriod time.Duration
|
||||
if config.Xmux != nil {
|
||||
keepAlivePeriod = time.Duration(config.Xmux.HKeepAlivePeriod) * time.Second
|
||||
}
|
||||
|
||||
var transport http.RoundTripper
|
||||
|
||||
if httpVersion == "3" {
|
||||
// quic-go setup
|
||||
// H3 support disabled for now to avoid dependency issues
|
||||
return nil
|
||||
} else if httpVersion == "2" {
|
||||
if keepAlivePeriod == 0 {
|
||||
keepAlivePeriod = 30 * time.Second // Default
|
||||
}
|
||||
transport = &http2.Transport{
|
||||
DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
// We ignore cfg here as we use the one from creation or handle TLS in dialFn?
|
||||
// Actually dialContext should return a TLS connection if TLS is enabled?
|
||||
// In Xray `dialContext` returns `tls.Client(conn)`.
|
||||
// Here `dialFn` from vless seems to return plain TCP usually, then wrapped.
|
||||
|
||||
conn, err := dialContext(ctxInner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
// Wrap with TLS
|
||||
conn = tls.Client(conn, tlsConfig)
|
||||
// Handshake?
|
||||
}
|
||||
return conn, nil
|
||||
},
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ReadIdleTimeout: keepAlivePeriod,
|
||||
}
|
||||
} else {
|
||||
httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
|
||||
conn, err := dialContext(ctxInner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
conn = tls.Client(conn, tlsConfig)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
transport = &http.Transport{
|
||||
DialTLSContext: httpDialContext,
|
||||
DialContext: httpDialContext,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
}
|
||||
|
||||
client := &DefaultDialerClient{
|
||||
transportConfig: config,
|
||||
client: &http.Client{
|
||||
Transport: transport,
|
||||
},
|
||||
httpVersion: httpVersion,
|
||||
uploadRawPool: &sync.Pool{},
|
||||
dialUploadConn: dialContext,
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func Dial(ctx context.Context, dialFn DialFn, config *Config, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
httpVersion := decideHTTPVersion(tlsConfig)
|
||||
|
||||
requestURL := url.URL{}
|
||||
if tlsConfig != nil {
|
||||
requestURL.Scheme = "https"
|
||||
} else {
|
||||
requestURL.Scheme = "http"
|
||||
}
|
||||
requestURL.Host = config.Host
|
||||
if requestURL.Host == "" && tlsConfig != nil {
|
||||
requestURL.Host = tlsConfig.ServerName
|
||||
}
|
||||
|
||||
sessionId, _ := uuid.NewV4()
|
||||
requestURL.Path = config.GetNormalizedPath() + sessionId.String()
|
||||
requestURL.RawQuery = config.GetNormalizedQuery()
|
||||
|
||||
httpClient, xmuxClient := getHTTPClient(ctx, dialFn, config, tlsConfig)
|
||||
|
||||
mode := config.Mode
|
||||
if mode == "" || mode == "auto" {
|
||||
mode = "packet-up"
|
||||
}
|
||||
|
||||
log.Debugln("XHTTP is dialing to %s, mode %s, HTTP version %s, host %s", config.Host, mode, httpVersion, requestURL.Host)
|
||||
|
||||
if xmuxClient != nil {
|
||||
xmuxClient.OpenUsage.Add(1)
|
||||
}
|
||||
|
||||
var closed atomic.Int32
|
||||
reader, writer := io.Pipe() // Use io.Pipe instead of Xray pipe
|
||||
|
||||
conn := &splitConn{
|
||||
writer: writer,
|
||||
onClose: func() {
|
||||
if closed.Add(1) > 1 {
|
||||
return
|
||||
}
|
||||
if xmuxClient != nil {
|
||||
xmuxClient.OpenUsage.Add(-1)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
if mode == "stream-one" {
|
||||
requestURL.Path = config.GetNormalizedPath()
|
||||
if xmuxClient != nil {
|
||||
xmuxClient.LeftRequests.Add(-1)
|
||||
}
|
||||
conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), reader, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Split mode (stream-up/packet-up + stream-down)
|
||||
|
||||
// Stream Down
|
||||
// We need a separate client for download? Xray handles DownloadSettings.
|
||||
// For simplicity, use the same client for now or if DownloadSettings is present, create another.
|
||||
// Assuming same config for now.
|
||||
httpClientDown := httpClient
|
||||
if config.DownloadSettings != nil {
|
||||
// TODO: Handle separate download settings
|
||||
// For now use same client
|
||||
}
|
||||
|
||||
if xmuxClient != nil {
|
||||
xmuxClient.LeftRequests.Add(-1)
|
||||
}
|
||||
conn.reader, conn.remoteAddr, conn.localAddr, err = httpClientDown.OpenStream(ctx, requestURL.String(), nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mode == "stream-up" {
|
||||
if xmuxClient != nil {
|
||||
xmuxClient.LeftRequests.Add(-1)
|
||||
}
|
||||
_, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), reader, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// packet-up
|
||||
scMaxEachPostBytes := config.GetNormalizedScMaxEachPostBytes()
|
||||
scMinPostsIntervalMs := config.GetNormalizedScMinPostsIntervalMs()
|
||||
|
||||
// We need a buffer to read from pipe and chunk it
|
||||
maxUploadSize := scMaxEachPostBytes.rand()
|
||||
|
||||
go func() {
|
||||
var seq int64
|
||||
var lastWrite time.Time
|
||||
buf := make([]byte, maxUploadSize)
|
||||
|
||||
for {
|
||||
wroteRequest := make(chan struct{})
|
||||
ctxTrace := httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
|
||||
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||
close(wroteRequest)
|
||||
},
|
||||
})
|
||||
|
||||
url := requestURL
|
||||
url.Path += "/" + strconv.FormatInt(seq, 10)
|
||||
seq++
|
||||
|
||||
if scMinPostsIntervalMs.From > 0 {
|
||||
time.Sleep(time.Duration(scMinPostsIntervalMs.rand())*time.Millisecond - time.Since(lastWrite))
|
||||
}
|
||||
|
||||
// Read from pipe
|
||||
n, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
lastWrite = time.Now()
|
||||
|
||||
if xmuxClient != nil && (xmuxClient.LeftRequests.Add(-1) <= 0 ||
|
||||
(xmuxClient.UnreusableAt != time.Time{} && lastWrite.After(xmuxClient.UnreusableAt))) {
|
||||
httpClient, xmuxClient = getHTTPClient(ctx, dialFn, config, tlsConfig)
|
||||
}
|
||||
|
||||
chunk := buf[:n]
|
||||
// Copy chunk because buf is reused
|
||||
chunkCopy := make([]byte, n)
|
||||
copy(chunkCopy, chunk)
|
||||
|
||||
go func() {
|
||||
// PostPacket expects io.Reader
|
||||
err := httpClient.PostPacket(
|
||||
ctxTrace,
|
||||
url.String(),
|
||||
bytes.NewReader(chunkCopy),
|
||||
int64(n),
|
||||
)
|
||||
|
||||
// If WroteRequest wasn't called (error), close channel
|
||||
select {
|
||||
case <-wroteRequest:
|
||||
default:
|
||||
close(wroteRequest)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Debugln("failed to send upload: %v", err)
|
||||
reader.CloseWithError(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, ok := httpClient.(*DefaultDialerClient); ok {
|
||||
<-wroteRequest
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
205
transport/splithttp/config.go
Normal file
205
transport/splithttp/config.go
Normal file
@ -0,0 +1,205 @@
|
||||
package splithttp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RangeConfig struct {
|
||||
From int32 `json:"from" proxy:"from"`
|
||||
To int32 `json:"to" proxy:"to"`
|
||||
}
|
||||
|
||||
func (c RangeConfig) rand() int32 {
|
||||
return int32(randBetween(int64(c.From), int64(c.To)))
|
||||
}
|
||||
|
||||
type XmuxConfig struct {
|
||||
MaxConcurrency *RangeConfig `json:"maxConcurrency" proxy:"max-concurrency"`
|
||||
MaxConnections *RangeConfig `json:"maxConnections" proxy:"max-connections"`
|
||||
CMaxReuseTimes *RangeConfig `json:"cMaxReuseTimes" proxy:"c-max-reuse-times"`
|
||||
HMaxRequestTimes *RangeConfig `json:"hMaxRequestTimes" proxy:"h-max-request-times"`
|
||||
HMaxReusableSecs *RangeConfig `json:"hMaxReusableSecs" proxy:"h-max-reusable-secs"`
|
||||
HKeepAlivePeriod int64 `json:"hKeepAlivePeriod" proxy:"h-keep-alive-period"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
Mode string `json:"mode"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
XPaddingBytes *RangeConfig `json:"xPaddingBytes"`
|
||||
NoGRPCHeader bool `json:"noGRPCHeader"`
|
||||
ScMaxEachPostBytes *RangeConfig `json:"scMaxEachPostBytes"`
|
||||
ScMinPostsIntervalMs *RangeConfig `json:"scMinPostsIntervalMs"`
|
||||
ScMaxBufferedPosts int32 `json:"scMaxBufferedPosts"`
|
||||
ScStreamUpServerSecs *RangeConfig `json:"scStreamUpServerSecs"`
|
||||
Xmux *XmuxConfig `json:"xmux"`
|
||||
DownloadSettings *Config `json:"downloadSettings"` // Simplified for now, assume same protocol
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedPath() string {
|
||||
pathAndQuery := strings.SplitN(c.Path, "?", 2)
|
||||
path := pathAndQuery[0]
|
||||
|
||||
if path == "" || path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
if path[len(path)-1] != '/' {
|
||||
path = path + "/"
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedQuery() string {
|
||||
pathAndQuery := strings.SplitN(c.Path, "?", 2)
|
||||
query := ""
|
||||
|
||||
if len(pathAndQuery) > 1 {
|
||||
query = pathAndQuery[1]
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (c *Config) GetRequestHeader(rawURL string) http.Header {
|
||||
header := http.Header{}
|
||||
for k, v := range c.Headers {
|
||||
header.Add(k, v)
|
||||
}
|
||||
|
||||
u, _ := url.Parse(rawURL)
|
||||
u.RawQuery = "x_padding=" + strings.Repeat("X", int(c.GetNormalizedXPaddingBytes().rand()))
|
||||
header.Set("Referer", u.String())
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (c *Config) WriteResponseHeader(writer http.ResponseWriter) {
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.Header().Set("Access-Control-Allow-Methods", "GET, POST")
|
||||
writer.Header().Set("X-Padding", strings.Repeat("X", int(c.GetNormalizedXPaddingBytes().rand())))
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedXPaddingBytes() RangeConfig {
|
||||
if c.XPaddingBytes == nil || c.XPaddingBytes.To == 0 {
|
||||
return RangeConfig{
|
||||
From: 100,
|
||||
To: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
return *c.XPaddingBytes
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedScMaxEachPostBytes() RangeConfig {
|
||||
if c.ScMaxEachPostBytes == nil || c.ScMaxEachPostBytes.To == 0 {
|
||||
return RangeConfig{
|
||||
From: 1000000,
|
||||
To: 1000000,
|
||||
}
|
||||
}
|
||||
|
||||
return *c.ScMaxEachPostBytes
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedScMinPostsIntervalMs() RangeConfig {
|
||||
if c.ScMinPostsIntervalMs == nil || c.ScMinPostsIntervalMs.To == 0 {
|
||||
return RangeConfig{
|
||||
From: 30,
|
||||
To: 30,
|
||||
}
|
||||
}
|
||||
|
||||
return *c.ScMinPostsIntervalMs
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedScMaxBufferedPosts() int {
|
||||
if c.ScMaxBufferedPosts == 0 {
|
||||
return 30
|
||||
}
|
||||
|
||||
return int(c.ScMaxBufferedPosts)
|
||||
}
|
||||
|
||||
func (c *Config) GetNormalizedScStreamUpServerSecs() RangeConfig {
|
||||
if c.ScStreamUpServerSecs == nil || c.ScStreamUpServerSecs.To == 0 {
|
||||
return RangeConfig{
|
||||
From: 20,
|
||||
To: 80,
|
||||
}
|
||||
}
|
||||
|
||||
return *c.ScMinPostsIntervalMs
|
||||
}
|
||||
|
||||
func (m *XmuxConfig) GetNormalizedMaxConcurrency() RangeConfig {
|
||||
if m.MaxConcurrency == nil {
|
||||
return RangeConfig{
|
||||
From: 0,
|
||||
To: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return *m.MaxConcurrency
|
||||
}
|
||||
|
||||
func (m *XmuxConfig) GetNormalizedMaxConnections() RangeConfig {
|
||||
if m.MaxConnections == nil {
|
||||
return RangeConfig{
|
||||
From: 0,
|
||||
To: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return *m.MaxConnections
|
||||
}
|
||||
|
||||
func (m *XmuxConfig) GetNormalizedCMaxReuseTimes() RangeConfig {
|
||||
if m.CMaxReuseTimes == nil {
|
||||
return RangeConfig{
|
||||
From: 0,
|
||||
To: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return *m.CMaxReuseTimes
|
||||
}
|
||||
|
||||
func (m *XmuxConfig) GetNormalizedHMaxRequestTimes() RangeConfig {
|
||||
if m.HMaxRequestTimes == nil {
|
||||
return RangeConfig{
|
||||
From: 0,
|
||||
To: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return *m.HMaxRequestTimes
|
||||
}
|
||||
|
||||
func (m *XmuxConfig) GetNormalizedHMaxReusableSecs() RangeConfig {
|
||||
if m.HMaxReusableSecs == nil {
|
||||
return RangeConfig{
|
||||
From: 0,
|
||||
To: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return *m.HMaxReusableSecs
|
||||
}
|
||||
|
||||
func randBetween(min, max int64) int64 {
|
||||
if min == max {
|
||||
return min
|
||||
}
|
||||
if min > max {
|
||||
min, max = max, min
|
||||
}
|
||||
n, _ := rand.Int(rand.Reader, big.NewInt(max-min+1))
|
||||
return min + n.Int64()
|
||||
}
|
||||
64
transport/splithttp/connection.go
Normal file
64
transport/splithttp/connection.go
Normal file
@ -0,0 +1,64 @@
|
||||
package splithttp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type splitConn struct {
|
||||
writer io.WriteCloser
|
||||
reader io.ReadCloser
|
||||
remoteAddr net.Addr
|
||||
localAddr net.Addr
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func (c *splitConn) Write(b []byte) (int, error) {
|
||||
return c.writer.Write(b)
|
||||
}
|
||||
|
||||
func (c *splitConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
func (c *splitConn) Close() error {
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
|
||||
err := c.writer.Close()
|
||||
err2 := c.reader.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err2 != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *splitConn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *splitConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *splitConn) SetDeadline(t time.Time) error {
|
||||
// TODO cannot do anything useful
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *splitConn) SetReadDeadline(t time.Time) error {
|
||||
// TODO cannot do anything useful
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *splitConn) SetWriteDeadline(t time.Time) error {
|
||||
// TODO cannot do anything useful
|
||||
return nil
|
||||
}
|
||||
19
transport/splithttp/h1_conn.go
Normal file
19
transport/splithttp/h1_conn.go
Normal file
@ -0,0 +1,19 @@
|
||||
package splithttp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
)
|
||||
|
||||
type H1Conn struct {
|
||||
UnreadedResponsesCount int
|
||||
RespBufReader *bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func NewH1Conn(conn net.Conn) *H1Conn {
|
||||
return &H1Conn{
|
||||
RespBufReader: bufio.NewReader(conn),
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
114
transport/splithttp/mux.go
Normal file
114
transport/splithttp/mux.go
Normal file
@ -0,0 +1,114 @@
|
||||
package splithttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"math"
|
||||
"math/big"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/log"
|
||||
)
|
||||
|
||||
type XmuxConn interface {
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
type XmuxClient struct {
|
||||
XmuxConn XmuxConn
|
||||
OpenUsage atomic.Int32
|
||||
leftUsage int32
|
||||
LeftRequests atomic.Int32
|
||||
UnreusableAt time.Time
|
||||
}
|
||||
|
||||
type XmuxManager struct {
|
||||
xmuxConfig XmuxConfig
|
||||
concurrency int32
|
||||
connections int32
|
||||
newConnFunc func() XmuxConn
|
||||
xmuxClients []*XmuxClient
|
||||
}
|
||||
|
||||
func NewXmuxManager(xmuxConfig XmuxConfig, newConnFunc func() XmuxConn) *XmuxManager {
|
||||
return &XmuxManager{
|
||||
xmuxConfig: xmuxConfig,
|
||||
concurrency: xmuxConfig.GetNormalizedMaxConcurrency().rand(),
|
||||
connections: xmuxConfig.GetNormalizedMaxConnections().rand(),
|
||||
newConnFunc: newConnFunc,
|
||||
xmuxClients: make([]*XmuxClient, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *XmuxManager) newXmuxClient() *XmuxClient {
|
||||
xmuxClient := &XmuxClient{
|
||||
XmuxConn: m.newConnFunc(),
|
||||
leftUsage: -1,
|
||||
}
|
||||
if x := m.xmuxConfig.GetNormalizedCMaxReuseTimes().rand(); x > 0 {
|
||||
xmuxClient.leftUsage = x - 1
|
||||
}
|
||||
xmuxClient.LeftRequests.Store(math.MaxInt32)
|
||||
if x := m.xmuxConfig.GetNormalizedHMaxRequestTimes().rand(); x > 0 {
|
||||
xmuxClient.LeftRequests.Store(x)
|
||||
}
|
||||
if x := m.xmuxConfig.GetNormalizedHMaxReusableSecs().rand(); x > 0 {
|
||||
xmuxClient.UnreusableAt = time.Now().Add(time.Duration(x) * time.Second)
|
||||
}
|
||||
m.xmuxClients = append(m.xmuxClients, xmuxClient)
|
||||
return xmuxClient
|
||||
}
|
||||
|
||||
func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { // when locking
|
||||
for i := 0; i < len(m.xmuxClients); {
|
||||
xmuxClient := m.xmuxClients[i]
|
||||
if xmuxClient.XmuxConn.IsClosed() ||
|
||||
xmuxClient.leftUsage == 0 ||
|
||||
xmuxClient.LeftRequests.Load() <= 0 ||
|
||||
(xmuxClient.UnreusableAt != time.Time{} && time.Now().After(xmuxClient.UnreusableAt)) {
|
||||
log.Debugln("XMUX: removing xmuxClient, IsClosed() = %v, OpenUsage = %d, leftUsage = %d, LeftRequests = %d, UnreusableAt = %v",
|
||||
xmuxClient.XmuxConn.IsClosed(),
|
||||
xmuxClient.OpenUsage.Load(),
|
||||
xmuxClient.leftUsage,
|
||||
xmuxClient.LeftRequests.Load(),
|
||||
xmuxClient.UnreusableAt)
|
||||
m.xmuxClients = append(m.xmuxClients[:i], m.xmuxClients[i+1:]...)
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.xmuxClients) == 0 {
|
||||
log.Debugln("XMUX: creating xmuxClient because xmuxClients is empty")
|
||||
return m.newXmuxClient()
|
||||
}
|
||||
|
||||
if m.connections > 0 && len(m.xmuxClients) < int(m.connections) {
|
||||
log.Debugln("XMUX: creating xmuxClient because maxConnections was not hit, xmuxClients = %d", len(m.xmuxClients))
|
||||
return m.newXmuxClient()
|
||||
}
|
||||
|
||||
xmuxClients := make([]*XmuxClient, 0)
|
||||
if m.concurrency > 0 {
|
||||
for _, xmuxClient := range m.xmuxClients {
|
||||
if xmuxClient.OpenUsage.Load() < m.concurrency {
|
||||
xmuxClients = append(xmuxClients, xmuxClient)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
xmuxClients = m.xmuxClients
|
||||
}
|
||||
|
||||
if len(xmuxClients) == 0 {
|
||||
log.Debugln("XMUX: creating xmuxClient because maxConcurrency was hit, xmuxClients = %d", len(m.xmuxClients))
|
||||
return m.newXmuxClient()
|
||||
}
|
||||
|
||||
i, _ := rand.Int(rand.Reader, big.NewInt(int64(len(xmuxClients))))
|
||||
xmuxClient := xmuxClients[i.Int64()]
|
||||
if xmuxClient.leftUsage > 0 {
|
||||
xmuxClient.leftUsage -= 1
|
||||
}
|
||||
return xmuxClient
|
||||
}
|
||||
170
transport/splithttp/upload_queue.go
Normal file
170
transport/splithttp/upload_queue.go
Normal file
@ -0,0 +1,170 @@
|
||||
package splithttp
|
||||
|
||||
// upload_queue is a specialized priorityqueue + channel to reorder generic
|
||||
// packets by a sequence number
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Packet struct {
|
||||
Reader io.ReadCloser
|
||||
Payload []byte
|
||||
Seq uint64
|
||||
}
|
||||
|
||||
type uploadQueue struct {
|
||||
reader io.ReadCloser
|
||||
nomore bool
|
||||
pushedPackets chan Packet
|
||||
writeCloseMutex sync.Mutex
|
||||
heap uploadHeap
|
||||
nextSeq uint64
|
||||
closed bool
|
||||
maxPackets int
|
||||
}
|
||||
|
||||
func NewUploadQueue(maxPackets int) *uploadQueue {
|
||||
return &uploadQueue{
|
||||
pushedPackets: make(chan Packet, maxPackets),
|
||||
heap: uploadHeap{},
|
||||
nextSeq: 0,
|
||||
closed: false,
|
||||
maxPackets: maxPackets,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *uploadQueue) Push(p Packet) error {
|
||||
h.writeCloseMutex.Lock()
|
||||
defer h.writeCloseMutex.Unlock()
|
||||
|
||||
if h.closed {
|
||||
return errors.New("packet queue closed")
|
||||
}
|
||||
if h.nomore {
|
||||
return errors.New("h.reader already exists")
|
||||
}
|
||||
if p.Reader != nil {
|
||||
h.nomore = true
|
||||
}
|
||||
h.pushedPackets <- p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *uploadQueue) Close() error {
|
||||
h.writeCloseMutex.Lock()
|
||||
defer h.writeCloseMutex.Unlock()
|
||||
|
||||
if !h.closed {
|
||||
h.closed = true
|
||||
runtime.Gosched() // hope Read() gets the packet
|
||||
f:
|
||||
for {
|
||||
select {
|
||||
case p := <-h.pushedPackets:
|
||||
if p.Reader != nil {
|
||||
h.reader = p.Reader
|
||||
}
|
||||
default:
|
||||
break f
|
||||
}
|
||||
}
|
||||
close(h.pushedPackets)
|
||||
}
|
||||
if h.reader != nil {
|
||||
return h.reader.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *uploadQueue) Read(b []byte) (int, error) {
|
||||
if h.reader != nil {
|
||||
return h.reader.Read(b)
|
||||
}
|
||||
|
||||
if h.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(h.heap) == 0 {
|
||||
packet, more := <-h.pushedPackets
|
||||
if !more {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if packet.Reader != nil {
|
||||
h.reader = packet.Reader
|
||||
return h.reader.Read(b)
|
||||
}
|
||||
heap.Push(&h.heap, packet)
|
||||
}
|
||||
|
||||
for len(h.heap) > 0 {
|
||||
packet := heap.Pop(&h.heap).(Packet)
|
||||
n := 0
|
||||
|
||||
if packet.Seq == h.nextSeq {
|
||||
copy(b, packet.Payload)
|
||||
n = min(len(b), len(packet.Payload))
|
||||
|
||||
if n < len(packet.Payload) {
|
||||
// partial read
|
||||
packet.Payload = packet.Payload[n:]
|
||||
heap.Push(&h.heap, packet)
|
||||
} else {
|
||||
h.nextSeq = packet.Seq + 1
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// misordered packet
|
||||
if packet.Seq > h.nextSeq {
|
||||
if len(h.heap) > h.maxPackets {
|
||||
// the "reassembly buffer" is too large, and we want to
|
||||
// constrain memory usage somehow. let's tear down the
|
||||
// connection, and hope the application retries.
|
||||
return 0, errors.New("packet queue is too large")
|
||||
}
|
||||
heap.Push(&h.heap, packet)
|
||||
packet2, more := <-h.pushedPackets
|
||||
if !more {
|
||||
return 0, io.EOF
|
||||
}
|
||||
heap.Push(&h.heap, packet2)
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// heap code directly taken from https://pkg.go.dev/container/heap
|
||||
type uploadHeap []Packet
|
||||
|
||||
func (h uploadHeap) Len() int { return len(h) }
|
||||
func (h uploadHeap) Less(i, j int) bool { return h[i].Seq < h[j].Seq }
|
||||
func (h uploadHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
func (h *uploadHeap) Push(x any) {
|
||||
// Push and Pop use pointer receivers because they modify the slice's length,
|
||||
// not just its contents.
|
||||
*h = append(*h, x.(Packet))
|
||||
}
|
||||
|
||||
func (h *uploadHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user