diff --git a/adapter/outbound/sudoku.go b/adapter/outbound/sudoku.go index 36cf4421..8f952dd6 100644 --- a/adapter/outbound/sudoku.go +++ b/adapter/outbound/sudoku.go @@ -19,6 +19,9 @@ type Sudoku struct { option *SudokuOption baseConf sudoku.ProtocolConfig + httpMaskMu sync.Mutex + httpMaskClient *sudoku.HTTPMaskTunnelClient + muxMu sync.Mutex muxClient *sudoku.MultiplexClient muxBackoffUntil time.Time @@ -40,7 +43,7 @@ type SudokuOption struct { HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) - HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto", "on" + HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target) CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty } @@ -53,18 +56,12 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con } muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) - if !cfg.DisableHTTPMask && muxMode != "off" { - shouldTry := muxMode == "on" || (muxMode == "auto" && httpTunnelModeEnabled(cfg.HTTPMaskMode)) - if shouldTry { - stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode) - if muxErr == nil { - return NewConn(stream, s), nil - } - if muxMode != "auto" { - return nil, muxErr - } - s.noteMuxFailure(muxMode, muxErr) + if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { + stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode) + if muxErr == nil { + return NewConn(stream, s), nil } + return nil, muxErr } c, err := s.dialAndHandshake(ctx, cfg) @@ -229,6 +226,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) { func (s *Sudoku) Close() error { s.resetMuxClient() + s.resetHTTPMaskClient() return s.Base.Close() } @@ -261,7 +259,17 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi var c net.Conn if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { - c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) + muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) + switch muxMode { + case "auto", "on": + client, errX := s.getOrCreateHTTPMaskClient(cfg) + if errX != nil { + return nil, errX + } + c, err = client.Dial(ctx) + default: + c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) + } } if c == nil && err == nil { c, err = s.dialer.DialContext(ctx, "tcp", s.addr) @@ -380,3 +388,35 @@ func (s *Sudoku) resetMuxClient() { s.muxClient = nil } } + +func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) { + if s == nil { + return nil, fmt.Errorf("nil adapter") + } + if cfg == nil { + return nil, fmt.Errorf("config is required") + } + + s.httpMaskMu.Lock() + defer s.httpMaskMu.Unlock() + + if s.httpMaskClient != nil { + return s.httpMaskClient, nil + } + + c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext) + if err != nil { + return nil, err + } + s.httpMaskClient = c + return c, nil +} + +func (s *Sudoku) resetHTTPMaskClient() { + s.httpMaskMu.Lock() + defer s.httpMaskMu.Unlock() + if s.httpMaskClient != nil { + s.httpMaskClient.CloseIdleConnections() + s.httpMaskClient = nil + } +} diff --git a/docs/config.yaml b/docs/config.yaml index 1a2c2a90..69107872 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -1068,7 +1068,7 @@ proxies: # socks5 # http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代 # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断) # http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效 - # http-mask-multiplex: off # 可选:off(默认)、auto、on;复用单条隧道并在其内多路复用多个目标连接 + # http-mask-multiplex: off # 可选:off(默认)、auto(复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT)、on(单条隧道内多路复用多个目标连接;仅在 http-mask-mode=stream/poll/auto 生效) enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none) # anytls diff --git a/transport/sudoku/config.go b/transport/sudoku/config.go index 8eee292d..27649fbc 100644 --- a/transport/sudoku/config.go +++ b/transport/sudoku/config.go @@ -58,9 +58,10 @@ type ProtocolConfig struct { // HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side). HTTPMaskHost string - // HTTPMaskMultiplex controls whether the client reuses a single (HTTP-masked) tunnel connection and - // opens multiple logical target streams inside it (reduces RTT for subsequent connections). - // Values: "off" / "auto" / "on". + // HTTPMaskMultiplex controls multiplex behavior when HTTPMask tunnel modes are enabled: + // - "off": disable reuse; each Dial establishes its own HTTPMask tunnel + // - "auto": reuse underlying HTTP connections across multiple tunnel dials (HTTP/1.1 keep-alive / HTTP/2) + // - "on": enable "single tunnel, multi-target" mux (Sudoku-level multiplex; Dial behaves like "auto" otherwise) HTTPMaskMultiplex string } diff --git a/transport/sudoku/httpmask_tunnel.go b/transport/sudoku/httpmask_tunnel.go index aeedfe15..48d1846c 100644 --- a/transport/sudoku/httpmask_tunnel.go +++ b/transport/sudoku/httpmask_tunnel.go @@ -83,6 +83,59 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol Mode: cfg.HTTPMaskMode, TLSEnabled: cfg.HTTPMaskTLSEnabled, HostOverride: cfg.HTTPMaskHost, + Multiplex: cfg.HTTPMaskMultiplex, DialContext: dial, }) } + +type HTTPMaskTunnelClient struct { + mode string + client *httpmask.TunnelClient +} + +func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) { + if cfg == nil { + return nil, fmt.Errorf("config is required") + } + if cfg.DisableHTTPMask { + return nil, fmt.Errorf("http mask is disabled") + } + switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + default: + return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode) + } + switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMultiplex)) { + case "auto", "on": + default: + return nil, fmt.Errorf("http-mask-multiplex=%q does not enable reuse", cfg.HTTPMaskMultiplex) + } + + c, err := httpmask.NewTunnelClient(serverAddress, httpmask.TunnelClientOptions{ + TLSEnabled: cfg.HTTPMaskTLSEnabled, + HostOverride: cfg.HTTPMaskHost, + DialContext: dial, + }) + if err != nil { + return nil, err + } + + return &HTTPMaskTunnelClient{ + mode: cfg.HTTPMaskMode, + client: c, + }, nil +} + +func (c *HTTPMaskTunnelClient) Dial(ctx context.Context) (net.Conn, error) { + if c == nil || c.client == nil { + return nil, fmt.Errorf("nil httpmask tunnel client") + } + return c.client.DialTunnel(ctx, c.mode) +} + +func (c *HTTPMaskTunnelClient) CloseIdleConnections() { + if c == nil || c.client == nil { + return + } + c.client.CloseIdleConnections() +} diff --git a/transport/sudoku/multiplex.go b/transport/sudoku/multiplex.go index 2834ef80..da635708 100644 --- a/transport/sudoku/multiplex.go +++ b/transport/sudoku/multiplex.go @@ -1,11 +1,11 @@ package sudoku import ( + "bytes" "context" "fmt" "net" "strings" - "time" "github.com/metacubex/mihomo/transport/sudoku/multiplex" ) @@ -46,26 +46,19 @@ func (c *MultiplexClient) Dial(ctx context.Context, targetAddress string) (net.C return nil, fmt.Errorf("target address cannot be empty") } - stream, err := c.sess.OpenStream() + addrBuf, err := EncodeAddress(targetAddress) + if err != nil { + return nil, fmt.Errorf("encode target address failed: %w", err) + } + + if ctx != nil && ctx.Err() != nil { + return nil, ctx.Err() + } + + stream, err := c.sess.OpenStream(addrBuf) if err != nil { return nil, err } - - if deadline, ok := ctx.Deadline(); ok { - _ = stream.SetWriteDeadline(deadline) - defer stream.SetWriteDeadline(time.Time{}) - } - - addrBuf, err := EncodeAddress(targetAddress) - if err != nil { - _ = stream.Close() - return nil, fmt.Errorf("encode target address failed: %w", err) - } - if _, err := stream.Write(addrBuf); err != nil { - _ = stream.Close() - return nil, fmt.Errorf("send target address failed: %w", err) - } - return stream, nil } @@ -114,18 +107,21 @@ func (s *MultiplexServer) AcceptStream() (net.Conn, error) { if s == nil || s.sess == nil { return nil, fmt.Errorf("nil session") } - return s.sess.AcceptStream() + c, _, err := s.sess.AcceptStream() + return c, err } -// AcceptTCP accepts a multiplex stream and reads the target address preface, returning the stream positioned at -// application data. +// AcceptTCP accepts a multiplex stream and returns the target address declared in the open frame. func (s *MultiplexServer) AcceptTCP() (net.Conn, string, error) { - stream, err := s.AcceptStream() + if s == nil || s.sess == nil { + return nil, "", fmt.Errorf("nil session") + } + stream, payload, err := s.sess.AcceptStream() if err != nil { return nil, "", err } - target, err := DecodeAddress(stream) + target, err := DecodeAddress(bytes.NewReader(payload)) if err != nil { _ = stream.Close() return nil, "", err @@ -147,4 +143,3 @@ func (s *MultiplexServer) IsClosed() bool { } return s.sess.IsClosed() } - diff --git a/transport/sudoku/multiplex/mux.go b/transport/sudoku/multiplex/mux.go new file mode 100644 index 00000000..38bc34c8 --- /dev/null +++ b/transport/sudoku/multiplex/mux.go @@ -0,0 +1,39 @@ +package multiplex + +import ( + "fmt" + "io" +) + +const ( + // MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode. + // It is sent after the Sudoku handshake + downlink mode byte. + // + // Keep it distinct from UoTMagicByte and address type bytes. + MagicByte byte = 0xED + Version byte = 0x01 +) + +func WritePreface(w io.Writer) error { + if w == nil { + return fmt.Errorf("nil writer") + } + _, err := w.Write([]byte{MagicByte, Version}) + return err +} + +func ReadVersion(r io.Reader) (byte, error) { + var b [1]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return 0, err + } + return b[0], nil +} + +func ValidateVersion(v byte) error { + if v != Version { + return fmt.Errorf("unsupported multiplex version: %d", v) + } + return nil +} + diff --git a/transport/sudoku/multiplex/session.go b/transport/sudoku/multiplex/session.go index 9e24c6b7..64205dec 100644 --- a/transport/sudoku/multiplex/session.go +++ b/transport/sudoku/multiplex/session.go @@ -1,101 +1,504 @@ package multiplex import ( + "encoding/binary" + "errors" "fmt" "io" "net" + "sync" "time" - - "github.com/metacubex/smux" ) const ( - // MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode. - // It is sent after the Sudoku handshake + downlink mode byte. - MagicByte byte = 0xEF - Version = 0x01 + frameOpen byte = 0x01 + frameData byte = 0x02 + frameClose byte = 0x03 + frameReset byte = 0x04 ) -func WritePreface(w io.Writer) error { - _, err := w.Write([]byte{MagicByte, Version}) - return err -} +const ( + headerSize = 1 + 4 + 4 + maxFrameSize = 256 * 1024 + maxDataPayload = 32 * 1024 +) -func ReadVersion(r io.Reader) (byte, error) { - var b [1]byte - if _, err := io.ReadFull(r, b[:]); err != nil { - return 0, err - } - return b[0], nil -} - -func ValidateVersion(v byte) error { - if v != Version { - return fmt.Errorf("unsupported multiplex version: %d", v) - } - return nil -} - -func defaultSmuxConfig() *smux.Config { - cfg := smux.DefaultConfig() - cfg.KeepAliveInterval = 15 * time.Second - cfg.KeepAliveTimeout = 45 * time.Second - return cfg +type acceptEvent struct { + stream *stream + payload []byte } type Session struct { - sess *smux.Session + conn net.Conn + + writeMu sync.Mutex + + streamsMu sync.Mutex + streams map[uint32]*stream + nextID uint32 + + acceptCh chan acceptEvent + + closed chan struct{} + closeOnce sync.Once + closeErr error } func NewClientSession(conn net.Conn) (*Session, error) { if conn == nil { return nil, fmt.Errorf("nil conn") } - s, err := smux.Client(conn, defaultSmuxConfig()) - if err != nil { - _ = conn.Close() - return nil, err + s := &Session{ + conn: conn, + streams: make(map[uint32]*stream), + closed: make(chan struct{}), } - return &Session{sess: s}, nil + go s.readLoop() + return s, nil } func NewServerSession(conn net.Conn) (*Session, error) { if conn == nil { return nil, fmt.Errorf("nil conn") } - s, err := smux.Server(conn, defaultSmuxConfig()) - if err != nil { - _ = conn.Close() - return nil, err + s := &Session{ + conn: conn, + streams: make(map[uint32]*stream), + acceptCh: make(chan acceptEvent, 256), + closed: make(chan struct{}), } - return &Session{sess: s}, nil -} - -func (s *Session) OpenStream() (net.Conn, error) { - if s == nil || s.sess == nil { - return nil, fmt.Errorf("nil session") - } - return s.sess.OpenStream() -} - -func (s *Session) AcceptStream() (net.Conn, error) { - if s == nil || s.sess == nil { - return nil, fmt.Errorf("nil session") - } - return s.sess.AcceptStream() -} - -func (s *Session) Close() error { - if s == nil || s.sess == nil { - return nil - } - return s.sess.Close() + go s.readLoop() + return s, nil } func (s *Session) IsClosed() bool { - if s == nil || s.sess == nil { + if s == nil { return true } - return s.sess.IsClosed() + select { + case <-s.closed: + return true + default: + return false + } } +func (s *Session) closedErr() error { + s.streamsMu.Lock() + err := s.closeErr + s.streamsMu.Unlock() + if err == nil { + return io.ErrClosedPipe + } + return err +} + +func (s *Session) closeWithError(err error) { + if err == nil { + err = io.ErrClosedPipe + } + s.closeOnce.Do(func() { + s.streamsMu.Lock() + if s.closeErr == nil { + s.closeErr = err + } + streams := make([]*stream, 0, len(s.streams)) + for _, st := range s.streams { + streams = append(streams, st) + } + s.streams = make(map[uint32]*stream) + s.streamsMu.Unlock() + + for _, st := range streams { + st.closeNoSend(err) + } + + close(s.closed) + _ = s.conn.Close() + }) +} + +func (s *Session) Close() error { + if s == nil { + return nil + } + s.closeWithError(io.ErrClosedPipe) + return nil +} + +func (s *Session) registerStream(st *stream) { + s.streamsMu.Lock() + s.streams[st.id] = st + s.streamsMu.Unlock() +} + +func (s *Session) getStream(id uint32) *stream { + s.streamsMu.Lock() + st := s.streams[id] + s.streamsMu.Unlock() + return st +} + +func (s *Session) removeStream(id uint32) { + s.streamsMu.Lock() + delete(s.streams, id) + s.streamsMu.Unlock() +} + +func (s *Session) nextStreamID() uint32 { + s.streamsMu.Lock() + s.nextID++ + id := s.nextID + if id == 0 { + s.nextID++ + id = s.nextID + } + s.streamsMu.Unlock() + return id +} + +func (s *Session) sendFrame(frameType byte, streamID uint32, payload []byte) error { + if s.IsClosed() { + return s.closedErr() + } + if len(payload) > maxFrameSize { + return fmt.Errorf("mux payload too large: %d", len(payload)) + } + + var header [headerSize]byte + header[0] = frameType + binary.BigEndian.PutUint32(header[1:5], streamID) + binary.BigEndian.PutUint32(header[5:9], uint32(len(payload))) + + s.writeMu.Lock() + defer s.writeMu.Unlock() + + if err := writeFull(s.conn, header[:]); err != nil { + s.closeWithError(err) + return err + } + if len(payload) > 0 { + if err := writeFull(s.conn, payload); err != nil { + s.closeWithError(err) + return err + } + } + return nil +} + +func (s *Session) sendReset(streamID uint32, msg string) { + if msg == "" { + msg = "reset" + } + _ = s.sendFrame(frameReset, streamID, []byte(msg)) + _ = s.sendFrame(frameClose, streamID, nil) +} + +func (s *Session) OpenStream(openPayload []byte) (net.Conn, error) { + if s == nil { + return nil, fmt.Errorf("nil session") + } + if s.IsClosed() { + return nil, s.closedErr() + } + + streamID := s.nextStreamID() + st := newStream(s, streamID) + s.registerStream(st) + + if err := s.sendFrame(frameOpen, streamID, openPayload); err != nil { + st.closeNoSend(err) + s.removeStream(streamID) + return nil, fmt.Errorf("mux open failed: %w", err) + } + return st, nil +} + +func (s *Session) AcceptStream() (net.Conn, []byte, error) { + if s == nil { + return nil, nil, fmt.Errorf("nil session") + } + if s.acceptCh == nil { + return nil, nil, fmt.Errorf("accept is not supported on client sessions") + } + select { + case ev := <-s.acceptCh: + return ev.stream, ev.payload, nil + case <-s.closed: + return nil, nil, s.closedErr() + } +} + +func (s *Session) readLoop() { + var header [headerSize]byte + for { + if _, err := io.ReadFull(s.conn, header[:]); err != nil { + s.closeWithError(err) + return + } + frameType := header[0] + streamID := binary.BigEndian.Uint32(header[1:5]) + n := int(binary.BigEndian.Uint32(header[5:9])) + if n < 0 || n > maxFrameSize { + s.closeWithError(fmt.Errorf("invalid mux frame length: %d", n)) + return + } + + var payload []byte + if n > 0 { + payload = make([]byte, n) + if _, err := io.ReadFull(s.conn, payload); err != nil { + s.closeWithError(err) + return + } + } + + switch frameType { + case frameOpen: + if s.acceptCh == nil { + s.sendReset(streamID, "unexpected open") + continue + } + if streamID == 0 { + s.sendReset(streamID, "invalid stream id") + continue + } + if existing := s.getStream(streamID); existing != nil { + s.sendReset(streamID, "stream already exists") + continue + } + st := newStream(s, streamID) + s.registerStream(st) + go func() { + select { + case s.acceptCh <- acceptEvent{stream: st, payload: payload}: + case <-s.closed: + st.closeNoSend(io.ErrClosedPipe) + s.removeStream(streamID) + } + }() + + case frameData: + st := s.getStream(streamID) + if st == nil { + continue + } + if len(payload) == 0 { + continue + } + st.enqueue(payload) + + case frameClose: + st := s.getStream(streamID) + if st == nil { + continue + } + st.closeNoSend(io.EOF) + s.removeStream(streamID) + + case frameReset: + st := s.getStream(streamID) + if st == nil { + continue + } + msg := trimASCII(payload) + if msg == "" { + msg = "reset" + } + st.closeNoSend(errors.New(msg)) + s.removeStream(streamID) + + default: + s.closeWithError(fmt.Errorf("unknown mux frame type: %d", frameType)) + return + } + } +} + +func writeFull(w io.Writer, b []byte) error { + for len(b) > 0 { + n, err := w.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil +} + +func trimASCII(b []byte) string { + i := 0 + j := len(b) + for i < j { + c := b[i] + if c != ' ' && c != '\n' && c != '\r' && c != '\t' { + break + } + i++ + } + for j > i { + c := b[j-1] + if c != ' ' && c != '\n' && c != '\r' && c != '\t' { + break + } + j-- + } + if i >= j { + return "" + } + out := make([]byte, j-i) + copy(out, b[i:j]) + return string(out) +} + +type stream struct { + session *Session + id uint32 + + mu sync.Mutex + cond *sync.Cond + closed bool + closeErr error + readBuf []byte + queue [][]byte + + localAddr net.Addr + remoteAddr net.Addr +} + +func newStream(session *Session, id uint32) *stream { + st := &stream{ + session: session, + id: id, + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + } + st.cond = sync.NewCond(&st.mu) + return st +} + +func (c *stream) enqueue(payload []byte) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + c.queue = append(c.queue, payload) + c.cond.Signal() + c.mu.Unlock() +} + +func (c *stream) closeNoSend(err error) { + if err == nil { + err = io.EOF + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + c.closed = true + if c.closeErr == nil { + c.closeErr = err + } + c.cond.Broadcast() + c.mu.Unlock() +} + +func (c *stream) closedErr() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closeErr == nil { + return io.ErrClosedPipe + } + return c.closeErr +} + +func (c *stream) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + c.mu.Lock() + defer c.mu.Unlock() + + for len(c.readBuf) == 0 && len(c.queue) == 0 && !c.closed { + c.cond.Wait() + } + if len(c.readBuf) == 0 && len(c.queue) > 0 { + c.readBuf = c.queue[0] + c.queue = c.queue[1:] + } + if len(c.readBuf) == 0 && c.closed { + if c.closeErr == nil { + return 0, io.ErrClosedPipe + } + return 0, c.closeErr + } + + n := copy(p, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil +} + +func (c *stream) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if c.session == nil || c.session.IsClosed() { + if c.session != nil { + return 0, c.session.closedErr() + } + return 0, io.ErrClosedPipe + } + + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return 0, c.closedErr() + } + + written := 0 + for len(p) > 0 { + chunk := p + if len(chunk) > maxDataPayload { + chunk = p[:maxDataPayload] + } + if err := c.session.sendFrame(frameData, c.id, chunk); err != nil { + return written, err + } + written += len(chunk) + p = p[len(chunk):] + } + return written, nil +} + +func (c *stream) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + if c.closeErr == nil { + c.closeErr = io.ErrClosedPipe + } + c.cond.Broadcast() + c.mu.Unlock() + + _ = c.session.sendFrame(frameClose, c.id, nil) + c.session.removeStream(c.id) + return nil +} + +func (c *stream) LocalAddr() net.Addr { return c.localAddr } +func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *stream) SetDeadline(t time.Time) error { + _ = c.SetReadDeadline(t) + _ = c.SetWriteDeadline(t) + return nil +} +func (c *stream) SetReadDeadline(time.Time) error { return nil } +func (c *stream) SetWriteDeadline(time.Time) error { return nil } + diff --git a/transport/sudoku/obfs/httpmask/tunnel.go b/transport/sudoku/obfs/httpmask/tunnel.go index e281f479..88bc1a3a 100644 --- a/transport/sudoku/obfs/httpmask/tunnel.go +++ b/transport/sudoku/obfs/httpmask/tunnel.go @@ -62,11 +62,84 @@ type TunnelDialOptions struct { Mode string TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference) HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443" + // Multiplex controls whether the caller should reuse underlying HTTP connections (HTTP/1.1 keep-alive / HTTP/2). + // To reuse across multiple dials, create a TunnelClient per proxy and reuse it. + // Values: "off" disables reuse; "auto"/"on" enables it. + Multiplex string // DialContext overrides how the HTTP tunnel dials raw TCP/TLS connections. // It must not be nil; passing nil is a programming error. DialContext func(ctx context.Context, network, addr string) (net.Conn, error) } +type TunnelClientOptions struct { + TLSEnabled bool + HostOverride string + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + MaxIdleConns int +} + +type TunnelClient struct { + client *http.Client + transport *http.Transport + target httpClientTarget +} + +func NewTunnelClient(serverAddress string, opts TunnelClientOptions) (*TunnelClient, error) { + maxIdle := opts.MaxIdleConns + if maxIdle <= 0 { + maxIdle = 32 + } + + transport, target, err := buildHTTPTransport(serverAddress, opts.TLSEnabled, opts.HostOverride, opts.DialContext, maxIdle) + if err != nil { + return nil, err + } + + return &TunnelClient{ + client: &http.Client{Transport: transport}, + transport: transport, + target: target, + }, nil +} + +func (c *TunnelClient) CloseIdleConnections() { + if c == nil || c.transport == nil { + return + } + c.transport.CloseIdleConnections() +} + +func (c *TunnelClient) DialTunnel(ctx context.Context, mode string) (net.Conn, error) { + if c == nil || c.client == nil { + return nil, fmt.Errorf("nil tunnel client") + } + tm := normalizeTunnelMode(mode) + if tm == TunnelModeLegacy { + return nil, fmt.Errorf("legacy mode does not use http tunnel") + } + + switch tm { + case TunnelModeStream: + return dialStreamWithClient(ctx, c.client, c.target) + case TunnelModePoll: + return dialPollWithClient(ctx, c.client, c.target) + case TunnelModeAuto: + streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second) + c1, errX := dialStreamWithClient(streamCtx, c.client, c.target) + cancelX() + if errX == nil { + return c1, nil + } + c2, errP := dialPollWithClient(ctx, c.client, c.target) + if errP == nil { + return c2, nil + } + return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP) + default: + return dialStreamWithClient(ctx, c.client, c.target) + } +} + // DialTunnel establishes a bidirectional stream over HTTP: // - stream: a single streaming POST (request body uplink, response body downlink) // - poll: authorize + push/pull polling tunnel (base64 framed) @@ -192,43 +265,154 @@ type httpClientTarget struct { headerHost string } -func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) { - if opts.DialContext == nil { +func buildHTTPTransport(serverAddress string, tlsEnabled bool, hostOverride string, dialContext func(ctx context.Context, network, addr string) (net.Conn, error), maxIdleConns int) (*http.Transport, httpClientTarget, error) { + if dialContext == nil { panic("httpmask: DialContext is nil") } - scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride) + scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, tlsEnabled, hostOverride) if err != nil { return nil, httpClientTarget{}, err } transport := &http.Transport{ - ForceAttemptHTTP2: scheme == "https", - DisableCompression: true, - MaxIdleConns: maxIdleConns, - IdleConnTimeout: 30 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: scheme == "https", + DisableCompression: true, + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConns, + IdleConnTimeout: 30 * time.Second, + ResponseHeaderTimeout: 20 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, DialContext: func(dialCtx context.Context, network, _ string) (net.Conn, error) { - return opts.DialContext(dialCtx, network, dialAddr) + return dialContext(dialCtx, network, dialAddr) }, } if scheme == "https" { - transport.TLSClientConfig, err = ca.GetTLSConfig(ca.Option{TLSConfig: &tls.Config{ + var tlsConf *tls.Config + tlsConf, err = ca.GetTLSConfig(ca.Option{TLSConfig: &tls.Config{ ServerName: serverName, MinVersion: tls.VersionTLS12, }}) if err != nil { return nil, httpClientTarget{}, err } + transport.TLSClientConfig = tlsConf } - return &http.Client{Transport: transport}, httpClientTarget{ + return transport, httpClientTarget{ scheme: scheme, urlHost: urlHost, headerHost: canonicalHeaderHost(urlHost, scheme), }, nil } +func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) { + transport, target, err := buildHTTPTransport(serverAddress, opts.TLSEnabled, opts.HostOverride, opts.DialContext, maxIdleConns) + if err != nil { + return nil, httpClientTarget{}, err + } + return &http.Client{Transport: transport}, target, nil +} + +type sessionDialInfo struct { + client *http.Client + pushURL string + pullURL string + closeURL string + headerHost string +} + +func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode) (*sessionDialInfo, error) { + if client == nil { + return nil, fmt.Errorf("nil http client") + } + + authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) + if err != nil { + return nil, err + } + req.Host = target.headerHost + applyTunnelHeaders(req.Header, target.headerHost, mode) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes))) + } + + token, err := parseTunnelToken(bodyBytes) + if err != nil { + return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes))) + } + if token == "" { + return nil, fmt.Errorf("%s authorize empty token", mode) + } + + pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String() + pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String() + closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() + + return &sessionDialInfo{ + client: client, + pushURL: pushURL, + pullURL: pullURL, + closeURL: closeURL, + headerHost: target.headerHost, + }, nil +} + +func dialSession(ctx context.Context, serverAddress string, opts TunnelDialOptions, mode TunnelMode) (*sessionDialInfo, error) { + client, target, err := newHTTPClient(serverAddress, opts, 32) + if err != nil { + return nil, err + } + return dialSessionWithClient(ctx, client, target, mode) +} + +func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode) { + if client == nil || closeURL == "" || headerHost == "" { + return + } + + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, closeURL, nil) + if err != nil { + return + } + req.Host = headerHost + applyTunnelHeaders(req.Header, headerHost, mode) + + resp, err := client.Do(req) + if err != nil || resp == nil { + return + } + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() +} + +func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { + // Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. + c, errSplit := dialStreamSplitWithClient(ctx, client, target) + if errSplit == nil { + return c, nil + } + c2, errOne := dialStreamOneWithClient(ctx, client, target) + if errOne == nil { + return c2, nil + } + return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne) +} + func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { // Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. c, errSplit := dialStreamSplit(ctx, serverAddress, opts) @@ -242,10 +426,9 @@ func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOption return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne) } -func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { - client, target, err := newHTTPClient(serverAddress, opts, 16) - if err != nil { - return nil, err +func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { + if client == nil { + return nil, fmt.Errorf("nil http client") } r := rngPool.Get().(*mrand.Rand) @@ -312,16 +495,15 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt } } -type streamSplitConn struct { - ctx context.Context - cancel context.CancelFunc - - client *http.Client - pushURL string - pullURL string - closeURL string - headerHost string +func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + client, target, err := newHTTPClient(serverAddress, opts, 32) + if err != nil { + return nil, err + } + return dialStreamOneWithClient(ctx, client, target) +} +type queuedConn struct { rxc chan []byte closed chan struct{} @@ -329,16 +511,46 @@ type streamSplitConn struct { mu sync.Mutex readBuf []byte + closeErr error localAddr net.Addr remoteAddr net.Addr } -func (c *streamSplitConn) Read(b []byte) (n int, err error) { +func (c *queuedConn) closeWithError(err error) error { + c.mu.Lock() + select { + case <-c.closed: + c.mu.Unlock() + return nil + default: + if err == nil { + err = io.ErrClosedPipe + } + if c.closeErr == nil { + c.closeErr = err + } + close(c.closed) + } + c.mu.Unlock() + return nil +} + +func (c *queuedConn) closedErr() error { + c.mu.Lock() + err := c.closeErr + c.mu.Unlock() + if err == nil { + return io.ErrClosedPipe + } + return err +} + +func (c *queuedConn) Read(b []byte) (n int, err error) { if len(c.readBuf) == 0 { select { case c.readBuf = <-c.rxc: case <-c.closed: - return 0, io.ErrClosedPipe + return 0, c.closedErr() } } n = copy(b, c.readBuf) @@ -346,7 +558,7 @@ func (c *streamSplitConn) Read(b []byte) (n int, err error) { return n, nil } -func (c *streamSplitConn) Write(b []byte) (n int, err error) { +func (c *queuedConn) Write(b []byte) (n int, err error) { if len(b) == 0 { return 0, nil } @@ -354,7 +566,7 @@ func (c *streamSplitConn) Write(b []byte) (n int, err error) { select { case <-c.closed: c.mu.Unlock() - return 0, io.ErrClosedPipe + return 0, c.closedErr() default: } c.mu.Unlock() @@ -365,111 +577,97 @@ func (c *streamSplitConn) Write(b []byte) (n int, err error) { case c.writeCh <- payload: return len(b), nil case <-c.closed: - return 0, io.ErrClosedPipe + return 0, c.closedErr() } } +func (c *queuedConn) LocalAddr() net.Addr { return c.localAddr } +func (c *queuedConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *queuedConn) SetDeadline(time.Time) error { return nil } +func (c *queuedConn) SetReadDeadline(time.Time) error { return nil } +func (c *queuedConn) SetWriteDeadline(time.Time) error { return nil } + +type streamSplitConn struct { + queuedConn + + ctx context.Context + cancel context.CancelFunc + + client *http.Client + pushURL string + pullURL string + closeURL string + headerHost string +} + func (c *streamSplitConn) Close() error { - c.mu.Lock() - select { - case <-c.closed: - c.mu.Unlock() - return nil - default: - close(c.closed) - } - c.mu.Unlock() + _ = c.closeWithError(io.ErrClosedPipe) if c.cancel != nil { c.cancel() } - - // Best-effort session close signal (avoid leaking server-side sessions). - closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, c.closeURL, nil) - if err == nil { - req.Host = c.headerHost - applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) - if resp, doErr := c.client.Do(req); doErr == nil && resp != nil { - _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) - _ = resp.Body.Close() - } - } - + bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream) return nil } -func (c *streamSplitConn) LocalAddr() net.Addr { return c.localAddr } -func (c *streamSplitConn) RemoteAddr() net.Addr { return c.remoteAddr } - -func (c *streamSplitConn) SetDeadline(time.Time) error { return nil } -func (c *streamSplitConn) SetReadDeadline(time.Time) error { return nil } -func (c *streamSplitConn) SetWriteDeadline(time.Time) error { return nil } - -func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { - client, target, err := newHTTPClient(serverAddress, opts, 32) - if err != nil { - return nil, err +func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn { + if info == nil { + return nil } - authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) - if err != nil { - return nil, err - } - req.Host = target.headerHost - applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream) - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) - _ = resp.Body.Close() - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("stream authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes))) - } - - token, err := parseTunnelToken(bodyBytes) - if err != nil { - return nil, fmt.Errorf("stream authorize failed: %q", strings.TrimSpace(string(bodyBytes))) - } - if token == "" { - return nil, fmt.Errorf("stream authorize empty token") - } - - pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String() - pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String() - closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() - connCtx, cancel := context.WithCancel(context.Background()) c := &streamSplitConn{ ctx: connCtx, cancel: cancel, - client: client, - pushURL: pushURL, - pullURL: pullURL, - closeURL: closeURL, - headerHost: target.headerHost, - rxc: make(chan []byte, 256), - closed: make(chan struct{}), - writeCh: make(chan []byte, 256), - localAddr: &net.TCPAddr{}, - remoteAddr: &net.TCPAddr{}, + client: info.client, + pushURL: info.pushURL, + pullURL: info.pullURL, + closeURL: info.closeURL, + headerHost: info.headerHost, + queuedConn: queuedConn{ + rxc: make(chan []byte, 256), + closed: make(chan struct{}), + writeCh: make(chan []byte, 256), + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + }, } go c.pullLoop() go c.pushLoop() + return c +} + +func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { + info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream) + if err != nil { + return nil, err + } + c := newStreamSplitConnFromInfo(info) + if c == nil { + return nil, fmt.Errorf("failed to build stream split conn") + } + return c, nil +} + +func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + info, err := dialSession(ctx, serverAddress, opts, TunnelModeStream) + if err != nil { + return nil, err + } + c := newStreamSplitConnFromInfo(info) + if c == nil { + return nil, fmt.Errorf("failed to build stream split conn") + } return c, nil } func (c *streamSplitConn) pullLoop() { const ( - requestTimeout = 30 * time.Second + // requestTimeout must be long enough for continuous high-throughput streams (e.g. mux + large downloads). + // If it is too short, the client cancels the response mid-body and corrupts the byte stream. + requestTimeout = 2 * time.Minute readChunkSize = 32 * 1024 idleBackoff = 25 * time.Millisecond maxDialRetry = 12 @@ -688,22 +886,16 @@ func (c *streamSplitConn) pushLoop() { } type pollConn struct { + queuedConn + + ctx context.Context + cancel context.CancelFunc + client *http.Client pushURL string pullURL string closeURL string headerHost string - - rxc chan []byte - closed chan struct{} - - writeCh chan []byte - - mu sync.Mutex - readBuf []byte - closeErr error - localAddr net.Addr - remoteAddr net.Addr } func isDialError(err error) bool { @@ -721,147 +913,67 @@ func isDialError(err error) bool { } func (c *pollConn) closeWithError(err error) error { - c.mu.Lock() - select { - case <-c.closed: - c.mu.Unlock() - return nil - default: - if err == nil { - err = io.ErrClosedPipe - } - if c.closeErr == nil { - c.closeErr = err - } - close(c.closed) - } - c.mu.Unlock() - - // Best-effort session close signal (avoid leaking server-side sessions). - req, reqErr := http.NewRequest(http.MethodPost, c.closeURL, nil) - if reqErr == nil { - req.Host = c.headerHost - req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll)) - req.Header.Set("X-Sudoku-Version", "1") - if resp, doErr := c.client.Do(req); doErr == nil && resp != nil { - _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) - _ = resp.Body.Close() - } + _ = c.queuedConn.closeWithError(err) + if c.cancel != nil { + c.cancel() } + bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll) return nil } -func (c *pollConn) closedErr() error { - c.mu.Lock() - err := c.closeErr - c.mu.Unlock() - if err == nil { - return io.ErrClosedPipe - } - return err -} - -func (c *pollConn) Read(b []byte) (n int, err error) { - if len(c.readBuf) == 0 { - select { - case c.readBuf = <-c.rxc: - case <-c.closed: - return 0, c.closedErr() - } - } - n = copy(b, c.readBuf) - c.readBuf = c.readBuf[n:] - return n, nil -} - -func (c *pollConn) Write(b []byte) (n int, err error) { - if len(b) == 0 { - return 0, nil - } - c.mu.Lock() - select { - case <-c.closed: - c.mu.Unlock() - return 0, c.closedErr() - default: - } - c.mu.Unlock() - - payload := make([]byte, len(b)) - copy(payload, b) - select { - case c.writeCh <- payload: - return len(b), nil - case <-c.closed: - return 0, c.closedErr() - } -} - func (c *pollConn) Close() error { return c.closeWithError(io.ErrClosedPipe) } -func (c *pollConn) LocalAddr() net.Addr { return c.localAddr } -func (c *pollConn) RemoteAddr() net.Addr { return c.remoteAddr } - -func (c *pollConn) SetDeadline(time.Time) error { return nil } -func (c *pollConn) SetReadDeadline(time.Time) error { return nil } -func (c *pollConn) SetWriteDeadline(time.Time) error { return nil } - -func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { - client, target, err := newHTTPClient(serverAddress, opts, 32) - if err != nil { - return nil, err +func newPollConnFromInfo(info *sessionDialInfo) *pollConn { + if info == nil { + return nil } - authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) - if err != nil { - return nil, err - } - req.Host = target.headerHost - applyTunnelHeaders(req.Header, target.headerHost, TunnelModePoll) - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) - _ = resp.Body.Close() - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("poll authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes))) - } - - token, err := parseTunnelToken(bodyBytes) - if err != nil { - return nil, fmt.Errorf("poll authorize failed: %q", strings.TrimSpace(string(bodyBytes))) - } - if token == "" { - return nil, fmt.Errorf("poll authorize empty token") - } - - pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String() - pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String() - closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() - + connCtx, cancel := context.WithCancel(context.Background()) c := &pollConn{ - client: client, - pushURL: pushURL, - pullURL: pullURL, - closeURL: closeURL, - headerHost: target.headerHost, - rxc: make(chan []byte, 128), - closed: make(chan struct{}), - writeCh: make(chan []byte, 256), - localAddr: &net.TCPAddr{}, - remoteAddr: &net.TCPAddr{}, + ctx: connCtx, + cancel: cancel, + client: info.client, + pushURL: info.pushURL, + pullURL: info.pullURL, + closeURL: info.closeURL, + headerHost: info.headerHost, + queuedConn: queuedConn{ + rxc: make(chan []byte, 128), + closed: make(chan struct{}), + writeCh: make(chan []byte, 256), + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + }, } go c.pullLoop() go c.pushLoop() + return c +} + +func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { + info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll) + if err != nil { + return nil, err + } + c := newPollConnFromInfo(info) + if c == nil { + return nil, fmt.Errorf("failed to build poll conn") + } + return c, nil +} + +func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + info, err := dialSession(ctx, serverAddress, opts, TunnelModePoll) + if err != nil { + return nil, err + } + c := newPollConnFromInfo(info) + if c == nil { + return nil, fmt.Errorf("failed to build poll conn") + } return c, nil }