chore: refactored the implementation of suduko mux (#2486)
Some checks failed
Test / test (1.20, macos-15-intel) (push) Has been cancelled
Test / test (1.20, macos-latest) (push) Has been cancelled
Test / test (1.20, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.20, ubuntu-latest) (push) Has been cancelled
Test / test (1.20, windows-latest) (push) Has been cancelled
Test / test (1.21, macos-15-intel) (push) Has been cancelled
Test / test (1.21, macos-latest) (push) Has been cancelled
Test / test (1.21, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.21, ubuntu-latest) (push) Has been cancelled
Test / test (1.21, windows-latest) (push) Has been cancelled
Test / test (1.22, macos-15-intel) (push) Has been cancelled
Test / test (1.22, macos-latest) (push) Has been cancelled
Test / test (1.22, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.22, ubuntu-latest) (push) Has been cancelled
Test / test (1.22, windows-latest) (push) Has been cancelled
Test / test (1.23, macos-15-intel) (push) Has been cancelled
Test / test (1.23, macos-latest) (push) Has been cancelled
Test / test (1.23, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.23, ubuntu-latest) (push) Has been cancelled
Test / test (1.23, windows-latest) (push) Has been cancelled
Test / test (1.24, macos-15-intel) (push) Has been cancelled
Test / test (1.24, macos-latest) (push) Has been cancelled
Test / test (1.24, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.24, ubuntu-latest) (push) Has been cancelled
Test / test (1.24, windows-latest) (push) Has been cancelled
Test / test (1.25, macos-15-intel) (push) Has been cancelled
Test / test (1.25, macos-latest) (push) Has been cancelled
Test / test (1.25, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.25, ubuntu-latest) (push) Has been cancelled
Test / test (1.25, windows-latest) (push) Has been cancelled
Test / test (1.26.0-rc.1, macos-15-intel) (push) Has been cancelled
Test / test (1.26.0-rc.1, macos-latest) (push) Has been cancelled
Test / test (1.26.0-rc.1, ubuntu-24.04-arm) (push) Has been cancelled
Test / test (1.26.0-rc.1, ubuntu-latest) (push) Has been cancelled
Test / test (1.26.0-rc.1, windows-latest) (push) Has been cancelled
Trigger CMFA Update / trigger-CMFA-update (push) Has been cancelled

This commit is contained in:
saba-futai 2026-01-07 00:25:33 +08:00 committed by GitHub
parent b18a33552c
commit 0f2baca2de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 992 additions and 349 deletions

View File

@ -19,6 +19,9 @@ type Sudoku struct {
option *SudokuOption
baseConf sudoku.ProtocolConfig
httpMaskMu sync.Mutex
httpMaskClient *sudoku.HTTPMaskTunnelClient
muxMu sync.Mutex
muxClient *sudoku.MultiplexClient
muxBackoffUntil time.Time
@ -40,7 +43,7 @@ type SudokuOption struct {
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto", "on"
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
}
@ -53,18 +56,12 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
}
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if !cfg.DisableHTTPMask && muxMode != "off" {
shouldTry := muxMode == "on" || (muxMode == "auto" && httpTunnelModeEnabled(cfg.HTTPMaskMode))
if shouldTry {
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
if muxErr == nil {
return NewConn(stream, s), nil
}
if muxMode != "auto" {
return nil, muxErr
}
s.noteMuxFailure(muxMode, muxErr)
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
if muxErr == nil {
return NewConn(stream, s), nil
}
return nil, muxErr
}
c, err := s.dialAndHandshake(ctx, cfg)
@ -229,6 +226,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
func (s *Sudoku) Close() error {
s.resetMuxClient()
s.resetHTTPMaskClient()
return s.Base.Close()
}
@ -261,7 +259,17 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
var c net.Conn
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode {
case "auto", "on":
client, errX := s.getOrCreateHTTPMaskClient(cfg)
if errX != nil {
return nil, errX
}
c, err = client.Dial(ctx)
default:
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
}
}
if c == nil && err == nil {
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
@ -380,3 +388,35 @@ func (s *Sudoku) resetMuxClient() {
s.muxClient = nil
}
}
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
return s.httpMaskClient, nil
}
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
if err != nil {
return nil, err
}
s.httpMaskClient = c
return c, nil
}
func (s *Sudoku) resetHTTPMaskClient() {
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
s.httpMaskClient.CloseIdleConnections()
s.httpMaskClient = nil
}
}

View File

@ -1068,7 +1068,7 @@ proxies: # socks5
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断
# http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效
# http-mask-multiplex: off # 可选off默认、auto、on复用单条隧道并在其内多路复用多个目标连接
# http-mask-multiplex: off # 可选off默认、auto(复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT、on单条隧道内多路复用多个目标连接仅在 http-mask-mode=stream/poll/auto 生效)
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与服务端端保持相同(如果此处为false则要求aead不可为none)
# anytls

View File

@ -58,9 +58,10 @@ type ProtocolConfig struct {
// HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side).
HTTPMaskHost string
// HTTPMaskMultiplex controls whether the client reuses a single (HTTP-masked) tunnel connection and
// opens multiple logical target streams inside it (reduces RTT for subsequent connections).
// Values: "off" / "auto" / "on".
// HTTPMaskMultiplex controls multiplex behavior when HTTPMask tunnel modes are enabled:
// - "off": disable reuse; each Dial establishes its own HTTPMask tunnel
// - "auto": reuse underlying HTTP connections across multiple tunnel dials (HTTP/1.1 keep-alive / HTTP/2)
// - "on": enable "single tunnel, multi-target" mux (Sudoku-level multiplex; Dial behaves like "auto" otherwise)
HTTPMaskMultiplex string
}

View File

@ -83,6 +83,59 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
Mode: cfg.HTTPMaskMode,
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
Multiplex: cfg.HTTPMaskMultiplex,
DialContext: dial,
})
}
type HTTPMaskTunnelClient struct {
mode string
client *httpmask.TunnelClient
}
func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
if cfg.DisableHTTPMask {
return nil, fmt.Errorf("http mask is disabled")
}
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
case "stream", "poll", "auto":
default:
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
}
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMultiplex)) {
case "auto", "on":
default:
return nil, fmt.Errorf("http-mask-multiplex=%q does not enable reuse", cfg.HTTPMaskMultiplex)
}
c, err := httpmask.NewTunnelClient(serverAddress, httpmask.TunnelClientOptions{
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
DialContext: dial,
})
if err != nil {
return nil, err
}
return &HTTPMaskTunnelClient{
mode: cfg.HTTPMaskMode,
client: c,
}, nil
}
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context) (net.Conn, error) {
if c == nil || c.client == nil {
return nil, fmt.Errorf("nil httpmask tunnel client")
}
return c.client.DialTunnel(ctx, c.mode)
}
func (c *HTTPMaskTunnelClient) CloseIdleConnections() {
if c == nil || c.client == nil {
return
}
c.client.CloseIdleConnections()
}

View File

@ -1,11 +1,11 @@
package sudoku
import (
"bytes"
"context"
"fmt"
"net"
"strings"
"time"
"github.com/metacubex/mihomo/transport/sudoku/multiplex"
)
@ -46,26 +46,19 @@ func (c *MultiplexClient) Dial(ctx context.Context, targetAddress string) (net.C
return nil, fmt.Errorf("target address cannot be empty")
}
stream, err := c.sess.OpenStream()
addrBuf, err := EncodeAddress(targetAddress)
if err != nil {
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if ctx != nil && ctx.Err() != nil {
return nil, ctx.Err()
}
stream, err := c.sess.OpenStream(addrBuf)
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok {
_ = stream.SetWriteDeadline(deadline)
defer stream.SetWriteDeadline(time.Time{})
}
addrBuf, err := EncodeAddress(targetAddress)
if err != nil {
_ = stream.Close()
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if _, err := stream.Write(addrBuf); err != nil {
_ = stream.Close()
return nil, fmt.Errorf("send target address failed: %w", err)
}
return stream, nil
}
@ -114,18 +107,21 @@ func (s *MultiplexServer) AcceptStream() (net.Conn, error) {
if s == nil || s.sess == nil {
return nil, fmt.Errorf("nil session")
}
return s.sess.AcceptStream()
c, _, err := s.sess.AcceptStream()
return c, err
}
// AcceptTCP accepts a multiplex stream and reads the target address preface, returning the stream positioned at
// application data.
// AcceptTCP accepts a multiplex stream and returns the target address declared in the open frame.
func (s *MultiplexServer) AcceptTCP() (net.Conn, string, error) {
stream, err := s.AcceptStream()
if s == nil || s.sess == nil {
return nil, "", fmt.Errorf("nil session")
}
stream, payload, err := s.sess.AcceptStream()
if err != nil {
return nil, "", err
}
target, err := DecodeAddress(stream)
target, err := DecodeAddress(bytes.NewReader(payload))
if err != nil {
_ = stream.Close()
return nil, "", err
@ -147,4 +143,3 @@ func (s *MultiplexServer) IsClosed() bool {
}
return s.sess.IsClosed()
}

View File

@ -0,0 +1,39 @@
package multiplex
import (
"fmt"
"io"
)
const (
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
// It is sent after the Sudoku handshake + downlink mode byte.
//
// Keep it distinct from UoTMagicByte and address type bytes.
MagicByte byte = 0xED
Version byte = 0x01
)
func WritePreface(w io.Writer) error {
if w == nil {
return fmt.Errorf("nil writer")
}
_, err := w.Write([]byte{MagicByte, Version})
return err
}
func ReadVersion(r io.Reader) (byte, error) {
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return 0, err
}
return b[0], nil
}
func ValidateVersion(v byte) error {
if v != Version {
return fmt.Errorf("unsupported multiplex version: %d", v)
}
return nil
}

View File

@ -1,101 +1,504 @@
package multiplex
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/metacubex/smux"
)
const (
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
// It is sent after the Sudoku handshake + downlink mode byte.
MagicByte byte = 0xEF
Version = 0x01
frameOpen byte = 0x01
frameData byte = 0x02
frameClose byte = 0x03
frameReset byte = 0x04
)
func WritePreface(w io.Writer) error {
_, err := w.Write([]byte{MagicByte, Version})
return err
}
const (
headerSize = 1 + 4 + 4
maxFrameSize = 256 * 1024
maxDataPayload = 32 * 1024
)
func ReadVersion(r io.Reader) (byte, error) {
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return 0, err
}
return b[0], nil
}
func ValidateVersion(v byte) error {
if v != Version {
return fmt.Errorf("unsupported multiplex version: %d", v)
}
return nil
}
func defaultSmuxConfig() *smux.Config {
cfg := smux.DefaultConfig()
cfg.KeepAliveInterval = 15 * time.Second
cfg.KeepAliveTimeout = 45 * time.Second
return cfg
type acceptEvent struct {
stream *stream
payload []byte
}
type Session struct {
sess *smux.Session
conn net.Conn
writeMu sync.Mutex
streamsMu sync.Mutex
streams map[uint32]*stream
nextID uint32
acceptCh chan acceptEvent
closed chan struct{}
closeOnce sync.Once
closeErr error
}
func NewClientSession(conn net.Conn) (*Session, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
s, err := smux.Client(conn, defaultSmuxConfig())
if err != nil {
_ = conn.Close()
return nil, err
s := &Session{
conn: conn,
streams: make(map[uint32]*stream),
closed: make(chan struct{}),
}
return &Session{sess: s}, nil
go s.readLoop()
return s, nil
}
func NewServerSession(conn net.Conn) (*Session, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
s, err := smux.Server(conn, defaultSmuxConfig())
if err != nil {
_ = conn.Close()
return nil, err
s := &Session{
conn: conn,
streams: make(map[uint32]*stream),
acceptCh: make(chan acceptEvent, 256),
closed: make(chan struct{}),
}
return &Session{sess: s}, nil
}
func (s *Session) OpenStream() (net.Conn, error) {
if s == nil || s.sess == nil {
return nil, fmt.Errorf("nil session")
}
return s.sess.OpenStream()
}
func (s *Session) AcceptStream() (net.Conn, error) {
if s == nil || s.sess == nil {
return nil, fmt.Errorf("nil session")
}
return s.sess.AcceptStream()
}
func (s *Session) Close() error {
if s == nil || s.sess == nil {
return nil
}
return s.sess.Close()
go s.readLoop()
return s, nil
}
func (s *Session) IsClosed() bool {
if s == nil || s.sess == nil {
if s == nil {
return true
}
return s.sess.IsClosed()
select {
case <-s.closed:
return true
default:
return false
}
}
func (s *Session) closedErr() error {
s.streamsMu.Lock()
err := s.closeErr
s.streamsMu.Unlock()
if err == nil {
return io.ErrClosedPipe
}
return err
}
func (s *Session) closeWithError(err error) {
if err == nil {
err = io.ErrClosedPipe
}
s.closeOnce.Do(func() {
s.streamsMu.Lock()
if s.closeErr == nil {
s.closeErr = err
}
streams := make([]*stream, 0, len(s.streams))
for _, st := range s.streams {
streams = append(streams, st)
}
s.streams = make(map[uint32]*stream)
s.streamsMu.Unlock()
for _, st := range streams {
st.closeNoSend(err)
}
close(s.closed)
_ = s.conn.Close()
})
}
func (s *Session) Close() error {
if s == nil {
return nil
}
s.closeWithError(io.ErrClosedPipe)
return nil
}
func (s *Session) registerStream(st *stream) {
s.streamsMu.Lock()
s.streams[st.id] = st
s.streamsMu.Unlock()
}
func (s *Session) getStream(id uint32) *stream {
s.streamsMu.Lock()
st := s.streams[id]
s.streamsMu.Unlock()
return st
}
func (s *Session) removeStream(id uint32) {
s.streamsMu.Lock()
delete(s.streams, id)
s.streamsMu.Unlock()
}
func (s *Session) nextStreamID() uint32 {
s.streamsMu.Lock()
s.nextID++
id := s.nextID
if id == 0 {
s.nextID++
id = s.nextID
}
s.streamsMu.Unlock()
return id
}
func (s *Session) sendFrame(frameType byte, streamID uint32, payload []byte) error {
if s.IsClosed() {
return s.closedErr()
}
if len(payload) > maxFrameSize {
return fmt.Errorf("mux payload too large: %d", len(payload))
}
var header [headerSize]byte
header[0] = frameType
binary.BigEndian.PutUint32(header[1:5], streamID)
binary.BigEndian.PutUint32(header[5:9], uint32(len(payload)))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if err := writeFull(s.conn, header[:]); err != nil {
s.closeWithError(err)
return err
}
if len(payload) > 0 {
if err := writeFull(s.conn, payload); err != nil {
s.closeWithError(err)
return err
}
}
return nil
}
func (s *Session) sendReset(streamID uint32, msg string) {
if msg == "" {
msg = "reset"
}
_ = s.sendFrame(frameReset, streamID, []byte(msg))
_ = s.sendFrame(frameClose, streamID, nil)
}
func (s *Session) OpenStream(openPayload []byte) (net.Conn, error) {
if s == nil {
return nil, fmt.Errorf("nil session")
}
if s.IsClosed() {
return nil, s.closedErr()
}
streamID := s.nextStreamID()
st := newStream(s, streamID)
s.registerStream(st)
if err := s.sendFrame(frameOpen, streamID, openPayload); err != nil {
st.closeNoSend(err)
s.removeStream(streamID)
return nil, fmt.Errorf("mux open failed: %w", err)
}
return st, nil
}
func (s *Session) AcceptStream() (net.Conn, []byte, error) {
if s == nil {
return nil, nil, fmt.Errorf("nil session")
}
if s.acceptCh == nil {
return nil, nil, fmt.Errorf("accept is not supported on client sessions")
}
select {
case ev := <-s.acceptCh:
return ev.stream, ev.payload, nil
case <-s.closed:
return nil, nil, s.closedErr()
}
}
func (s *Session) readLoop() {
var header [headerSize]byte
for {
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
s.closeWithError(err)
return
}
frameType := header[0]
streamID := binary.BigEndian.Uint32(header[1:5])
n := int(binary.BigEndian.Uint32(header[5:9]))
if n < 0 || n > maxFrameSize {
s.closeWithError(fmt.Errorf("invalid mux frame length: %d", n))
return
}
var payload []byte
if n > 0 {
payload = make([]byte, n)
if _, err := io.ReadFull(s.conn, payload); err != nil {
s.closeWithError(err)
return
}
}
switch frameType {
case frameOpen:
if s.acceptCh == nil {
s.sendReset(streamID, "unexpected open")
continue
}
if streamID == 0 {
s.sendReset(streamID, "invalid stream id")
continue
}
if existing := s.getStream(streamID); existing != nil {
s.sendReset(streamID, "stream already exists")
continue
}
st := newStream(s, streamID)
s.registerStream(st)
go func() {
select {
case s.acceptCh <- acceptEvent{stream: st, payload: payload}:
case <-s.closed:
st.closeNoSend(io.ErrClosedPipe)
s.removeStream(streamID)
}
}()
case frameData:
st := s.getStream(streamID)
if st == nil {
continue
}
if len(payload) == 0 {
continue
}
st.enqueue(payload)
case frameClose:
st := s.getStream(streamID)
if st == nil {
continue
}
st.closeNoSend(io.EOF)
s.removeStream(streamID)
case frameReset:
st := s.getStream(streamID)
if st == nil {
continue
}
msg := trimASCII(payload)
if msg == "" {
msg = "reset"
}
st.closeNoSend(errors.New(msg))
s.removeStream(streamID)
default:
s.closeWithError(fmt.Errorf("unknown mux frame type: %d", frameType))
return
}
}
}
func writeFull(w io.Writer, b []byte) error {
for len(b) > 0 {
n, err := w.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}
func trimASCII(b []byte) string {
i := 0
j := len(b)
for i < j {
c := b[i]
if c != ' ' && c != '\n' && c != '\r' && c != '\t' {
break
}
i++
}
for j > i {
c := b[j-1]
if c != ' ' && c != '\n' && c != '\r' && c != '\t' {
break
}
j--
}
if i >= j {
return ""
}
out := make([]byte, j-i)
copy(out, b[i:j])
return string(out)
}
type stream struct {
session *Session
id uint32
mu sync.Mutex
cond *sync.Cond
closed bool
closeErr error
readBuf []byte
queue [][]byte
localAddr net.Addr
remoteAddr net.Addr
}
func newStream(session *Session, id uint32) *stream {
st := &stream{
session: session,
id: id,
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}
st.cond = sync.NewCond(&st.mu)
return st
}
func (c *stream) enqueue(payload []byte) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.queue = append(c.queue, payload)
c.cond.Signal()
c.mu.Unlock()
}
func (c *stream) closeNoSend(err error) {
if err == nil {
err = io.EOF
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.closed = true
if c.closeErr == nil {
c.closeErr = err
}
c.cond.Broadcast()
c.mu.Unlock()
}
func (c *stream) closedErr() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closeErr == nil {
return io.ErrClosedPipe
}
return c.closeErr
}
func (c *stream) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
c.mu.Lock()
defer c.mu.Unlock()
for len(c.readBuf) == 0 && len(c.queue) == 0 && !c.closed {
c.cond.Wait()
}
if len(c.readBuf) == 0 && len(c.queue) > 0 {
c.readBuf = c.queue[0]
c.queue = c.queue[1:]
}
if len(c.readBuf) == 0 && c.closed {
if c.closeErr == nil {
return 0, io.ErrClosedPipe
}
return 0, c.closeErr
}
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *stream) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if c.session == nil || c.session.IsClosed() {
if c.session != nil {
return 0, c.session.closedErr()
}
return 0, io.ErrClosedPipe
}
c.mu.Lock()
closed := c.closed
c.mu.Unlock()
if closed {
return 0, c.closedErr()
}
written := 0
for len(p) > 0 {
chunk := p
if len(chunk) > maxDataPayload {
chunk = p[:maxDataPayload]
}
if err := c.session.sendFrame(frameData, c.id, chunk); err != nil {
return written, err
}
written += len(chunk)
p = p[len(chunk):]
}
return written, nil
}
func (c *stream) Close() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil
}
c.closed = true
if c.closeErr == nil {
c.closeErr = io.ErrClosedPipe
}
c.cond.Broadcast()
c.mu.Unlock()
_ = c.session.sendFrame(frameClose, c.id, nil)
c.session.removeStream(c.id)
return nil
}
func (c *stream) LocalAddr() net.Addr { return c.localAddr }
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *stream) SetDeadline(t time.Time) error {
_ = c.SetReadDeadline(t)
_ = c.SetWriteDeadline(t)
return nil
}
func (c *stream) SetReadDeadline(time.Time) error { return nil }
func (c *stream) SetWriteDeadline(time.Time) error { return nil }

View File

@ -62,11 +62,84 @@ type TunnelDialOptions struct {
Mode string
TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference)
HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443"
// Multiplex controls whether the caller should reuse underlying HTTP connections (HTTP/1.1 keep-alive / HTTP/2).
// To reuse across multiple dials, create a TunnelClient per proxy and reuse it.
// Values: "off" disables reuse; "auto"/"on" enables it.
Multiplex string
// DialContext overrides how the HTTP tunnel dials raw TCP/TLS connections.
// It must not be nil; passing nil is a programming error.
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}
type TunnelClientOptions struct {
TLSEnabled bool
HostOverride string
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
MaxIdleConns int
}
type TunnelClient struct {
client *http.Client
transport *http.Transport
target httpClientTarget
}
func NewTunnelClient(serverAddress string, opts TunnelClientOptions) (*TunnelClient, error) {
maxIdle := opts.MaxIdleConns
if maxIdle <= 0 {
maxIdle = 32
}
transport, target, err := buildHTTPTransport(serverAddress, opts.TLSEnabled, opts.HostOverride, opts.DialContext, maxIdle)
if err != nil {
return nil, err
}
return &TunnelClient{
client: &http.Client{Transport: transport},
transport: transport,
target: target,
}, nil
}
func (c *TunnelClient) CloseIdleConnections() {
if c == nil || c.transport == nil {
return
}
c.transport.CloseIdleConnections()
}
func (c *TunnelClient) DialTunnel(ctx context.Context, mode string) (net.Conn, error) {
if c == nil || c.client == nil {
return nil, fmt.Errorf("nil tunnel client")
}
tm := normalizeTunnelMode(mode)
if tm == TunnelModeLegacy {
return nil, fmt.Errorf("legacy mode does not use http tunnel")
}
switch tm {
case TunnelModeStream:
return dialStreamWithClient(ctx, c.client, c.target)
case TunnelModePoll:
return dialPollWithClient(ctx, c.client, c.target)
case TunnelModeAuto:
streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second)
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target)
cancelX()
if errX == nil {
return c1, nil
}
c2, errP := dialPollWithClient(ctx, c.client, c.target)
if errP == nil {
return c2, nil
}
return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP)
default:
return dialStreamWithClient(ctx, c.client, c.target)
}
}
// DialTunnel establishes a bidirectional stream over HTTP:
// - stream: a single streaming POST (request body uplink, response body downlink)
// - poll: authorize + push/pull polling tunnel (base64 framed)
@ -192,43 +265,154 @@ type httpClientTarget struct {
headerHost string
}
func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) {
if opts.DialContext == nil {
func buildHTTPTransport(serverAddress string, tlsEnabled bool, hostOverride string, dialContext func(ctx context.Context, network, addr string) (net.Conn, error), maxIdleConns int) (*http.Transport, httpClientTarget, error) {
if dialContext == nil {
panic("httpmask: DialContext is nil")
}
scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride)
scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, tlsEnabled, hostOverride)
if err != nil {
return nil, httpClientTarget{}, err
}
transport := &http.Transport{
ForceAttemptHTTP2: scheme == "https",
DisableCompression: true,
MaxIdleConns: maxIdleConns,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ForceAttemptHTTP2: scheme == "https",
DisableCompression: true,
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConns,
IdleConnTimeout: 30 * time.Second,
ResponseHeaderTimeout: 20 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
DialContext: func(dialCtx context.Context, network, _ string) (net.Conn, error) {
return opts.DialContext(dialCtx, network, dialAddr)
return dialContext(dialCtx, network, dialAddr)
},
}
if scheme == "https" {
transport.TLSClientConfig, err = ca.GetTLSConfig(ca.Option{TLSConfig: &tls.Config{
var tlsConf *tls.Config
tlsConf, err = ca.GetTLSConfig(ca.Option{TLSConfig: &tls.Config{
ServerName: serverName,
MinVersion: tls.VersionTLS12,
}})
if err != nil {
return nil, httpClientTarget{}, err
}
transport.TLSClientConfig = tlsConf
}
return &http.Client{Transport: transport}, httpClientTarget{
return transport, httpClientTarget{
scheme: scheme,
urlHost: urlHost,
headerHost: canonicalHeaderHost(urlHost, scheme),
}, nil
}
func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) {
transport, target, err := buildHTTPTransport(serverAddress, opts.TLSEnabled, opts.HostOverride, opts.DialContext, maxIdleConns)
if err != nil {
return nil, httpClientTarget{}, err
}
return &http.Client{Transport: transport}, target, nil
}
type sessionDialInfo struct {
client *http.Client
pushURL string
pullURL string
closeURL string
headerHost string
}
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode) (*sessionDialInfo, error) {
if client == nil {
return nil, fmt.Errorf("nil http client")
}
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil {
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, mode)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
}
token, err := parseTunnelToken(bodyBytes)
if err != nil {
return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes)))
}
if token == "" {
return nil, fmt.Errorf("%s authorize empty token", mode)
}
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
return &sessionDialInfo{
client: client,
pushURL: pushURL,
pullURL: pullURL,
closeURL: closeURL,
headerHost: target.headerHost,
}, nil
}
func dialSession(ctx context.Context, serverAddress string, opts TunnelDialOptions, mode TunnelMode) (*sessionDialInfo, error) {
client, target, err := newHTTPClient(serverAddress, opts, 32)
if err != nil {
return nil, err
}
return dialSessionWithClient(ctx, client, target, mode)
}
func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode) {
if client == nil || closeURL == "" || headerHost == "" {
return
}
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, closeURL, nil)
if err != nil {
return
}
req.Host = headerHost
applyTunnelHeaders(req.Header, headerHost, mode)
resp, err := client.Do(req)
if err != nil || resp == nil {
return
}
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
}
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplitWithClient(ctx, client, target)
if errSplit == nil {
return c, nil
}
c2, errOne := dialStreamOneWithClient(ctx, client, target)
if errOne == nil {
return c2, nil
}
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
}
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
@ -242,10 +426,9 @@ func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOption
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
}
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
client, target, err := newHTTPClient(serverAddress, opts, 16)
if err != nil {
return nil, err
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
if client == nil {
return nil, fmt.Errorf("nil http client")
}
r := rngPool.Get().(*mrand.Rand)
@ -312,16 +495,15 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt
}
}
type streamSplitConn struct {
ctx context.Context
cancel context.CancelFunc
client *http.Client
pushURL string
pullURL string
closeURL string
headerHost string
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
client, target, err := newHTTPClient(serverAddress, opts, 32)
if err != nil {
return nil, err
}
return dialStreamOneWithClient(ctx, client, target)
}
type queuedConn struct {
rxc chan []byte
closed chan struct{}
@ -329,16 +511,46 @@ type streamSplitConn struct {
mu sync.Mutex
readBuf []byte
closeErr error
localAddr net.Addr
remoteAddr net.Addr
}
func (c *streamSplitConn) Read(b []byte) (n int, err error) {
func (c *queuedConn) closeWithError(err error) error {
c.mu.Lock()
select {
case <-c.closed:
c.mu.Unlock()
return nil
default:
if err == nil {
err = io.ErrClosedPipe
}
if c.closeErr == nil {
c.closeErr = err
}
close(c.closed)
}
c.mu.Unlock()
return nil
}
func (c *queuedConn) closedErr() error {
c.mu.Lock()
err := c.closeErr
c.mu.Unlock()
if err == nil {
return io.ErrClosedPipe
}
return err
}
func (c *queuedConn) Read(b []byte) (n int, err error) {
if len(c.readBuf) == 0 {
select {
case c.readBuf = <-c.rxc:
case <-c.closed:
return 0, io.ErrClosedPipe
return 0, c.closedErr()
}
}
n = copy(b, c.readBuf)
@ -346,7 +558,7 @@ func (c *streamSplitConn) Read(b []byte) (n int, err error) {
return n, nil
}
func (c *streamSplitConn) Write(b []byte) (n int, err error) {
func (c *queuedConn) Write(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
@ -354,7 +566,7 @@ func (c *streamSplitConn) Write(b []byte) (n int, err error) {
select {
case <-c.closed:
c.mu.Unlock()
return 0, io.ErrClosedPipe
return 0, c.closedErr()
default:
}
c.mu.Unlock()
@ -365,111 +577,97 @@ func (c *streamSplitConn) Write(b []byte) (n int, err error) {
case c.writeCh <- payload:
return len(b), nil
case <-c.closed:
return 0, io.ErrClosedPipe
return 0, c.closedErr()
}
}
func (c *queuedConn) LocalAddr() net.Addr { return c.localAddr }
func (c *queuedConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *queuedConn) SetDeadline(time.Time) error { return nil }
func (c *queuedConn) SetReadDeadline(time.Time) error { return nil }
func (c *queuedConn) SetWriteDeadline(time.Time) error { return nil }
type streamSplitConn struct {
queuedConn
ctx context.Context
cancel context.CancelFunc
client *http.Client
pushURL string
pullURL string
closeURL string
headerHost string
}
func (c *streamSplitConn) Close() error {
c.mu.Lock()
select {
case <-c.closed:
c.mu.Unlock()
return nil
default:
close(c.closed)
}
c.mu.Unlock()
_ = c.closeWithError(io.ErrClosedPipe)
if c.cancel != nil {
c.cancel()
}
// Best-effort session close signal (avoid leaking server-side sessions).
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, c.closeURL, nil)
if err == nil {
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
if resp, doErr := c.client.Do(req); doErr == nil && resp != nil {
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
}
}
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream)
return nil
}
func (c *streamSplitConn) LocalAddr() net.Addr { return c.localAddr }
func (c *streamSplitConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *streamSplitConn) SetDeadline(time.Time) error { return nil }
func (c *streamSplitConn) SetReadDeadline(time.Time) error { return nil }
func (c *streamSplitConn) SetWriteDeadline(time.Time) error { return nil }
func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
client, target, err := newHTTPClient(serverAddress, opts, 32)
if err != nil {
return nil, err
func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
if info == nil {
return nil
}
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil {
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("stream authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes)))
}
token, err := parseTunnelToken(bodyBytes)
if err != nil {
return nil, fmt.Errorf("stream authorize failed: %q", strings.TrimSpace(string(bodyBytes)))
}
if token == "" {
return nil, fmt.Errorf("stream authorize empty token")
}
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
connCtx, cancel := context.WithCancel(context.Background())
c := &streamSplitConn{
ctx: connCtx,
cancel: cancel,
client: client,
pushURL: pushURL,
pullURL: pullURL,
closeURL: closeURL,
headerHost: target.headerHost,
rxc: make(chan []byte, 256),
closed: make(chan struct{}),
writeCh: make(chan []byte, 256),
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
client: info.client,
pushURL: info.pushURL,
pullURL: info.pullURL,
closeURL: info.closeURL,
headerHost: info.headerHost,
queuedConn: queuedConn{
rxc: make(chan []byte, 256),
closed: make(chan struct{}),
writeCh: make(chan []byte, 256),
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
},
}
go c.pullLoop()
go c.pushLoop()
return c
}
func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream)
if err != nil {
return nil, err
}
c := newStreamSplitConnFromInfo(info)
if c == nil {
return nil, fmt.Errorf("failed to build stream split conn")
}
return c, nil
}
func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
info, err := dialSession(ctx, serverAddress, opts, TunnelModeStream)
if err != nil {
return nil, err
}
c := newStreamSplitConnFromInfo(info)
if c == nil {
return nil, fmt.Errorf("failed to build stream split conn")
}
return c, nil
}
func (c *streamSplitConn) pullLoop() {
const (
requestTimeout = 30 * time.Second
// requestTimeout must be long enough for continuous high-throughput streams (e.g. mux + large downloads).
// If it is too short, the client cancels the response mid-body and corrupts the byte stream.
requestTimeout = 2 * time.Minute
readChunkSize = 32 * 1024
idleBackoff = 25 * time.Millisecond
maxDialRetry = 12
@ -688,22 +886,16 @@ func (c *streamSplitConn) pushLoop() {
}
type pollConn struct {
queuedConn
ctx context.Context
cancel context.CancelFunc
client *http.Client
pushURL string
pullURL string
closeURL string
headerHost string
rxc chan []byte
closed chan struct{}
writeCh chan []byte
mu sync.Mutex
readBuf []byte
closeErr error
localAddr net.Addr
remoteAddr net.Addr
}
func isDialError(err error) bool {
@ -721,147 +913,67 @@ func isDialError(err error) bool {
}
func (c *pollConn) closeWithError(err error) error {
c.mu.Lock()
select {
case <-c.closed:
c.mu.Unlock()
return nil
default:
if err == nil {
err = io.ErrClosedPipe
}
if c.closeErr == nil {
c.closeErr = err
}
close(c.closed)
}
c.mu.Unlock()
// Best-effort session close signal (avoid leaking server-side sessions).
req, reqErr := http.NewRequest(http.MethodPost, c.closeURL, nil)
if reqErr == nil {
req.Host = c.headerHost
req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll))
req.Header.Set("X-Sudoku-Version", "1")
if resp, doErr := c.client.Do(req); doErr == nil && resp != nil {
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
}
_ = c.queuedConn.closeWithError(err)
if c.cancel != nil {
c.cancel()
}
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll)
return nil
}
func (c *pollConn) closedErr() error {
c.mu.Lock()
err := c.closeErr
c.mu.Unlock()
if err == nil {
return io.ErrClosedPipe
}
return err
}
func (c *pollConn) Read(b []byte) (n int, err error) {
if len(c.readBuf) == 0 {
select {
case c.readBuf = <-c.rxc:
case <-c.closed:
return 0, c.closedErr()
}
}
n = copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *pollConn) Write(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
c.mu.Lock()
select {
case <-c.closed:
c.mu.Unlock()
return 0, c.closedErr()
default:
}
c.mu.Unlock()
payload := make([]byte, len(b))
copy(payload, b)
select {
case c.writeCh <- payload:
return len(b), nil
case <-c.closed:
return 0, c.closedErr()
}
}
func (c *pollConn) Close() error {
return c.closeWithError(io.ErrClosedPipe)
}
func (c *pollConn) LocalAddr() net.Addr { return c.localAddr }
func (c *pollConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *pollConn) SetDeadline(time.Time) error { return nil }
func (c *pollConn) SetReadDeadline(time.Time) error { return nil }
func (c *pollConn) SetWriteDeadline(time.Time) error { return nil }
func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
client, target, err := newHTTPClient(serverAddress, opts, 32)
if err != nil {
return nil, err
func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
if info == nil {
return nil
}
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil {
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, TunnelModePoll)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("poll authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes)))
}
token, err := parseTunnelToken(bodyBytes)
if err != nil {
return nil, fmt.Errorf("poll authorize failed: %q", strings.TrimSpace(string(bodyBytes)))
}
if token == "" {
return nil, fmt.Errorf("poll authorize empty token")
}
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
connCtx, cancel := context.WithCancel(context.Background())
c := &pollConn{
client: client,
pushURL: pushURL,
pullURL: pullURL,
closeURL: closeURL,
headerHost: target.headerHost,
rxc: make(chan []byte, 128),
closed: make(chan struct{}),
writeCh: make(chan []byte, 256),
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
ctx: connCtx,
cancel: cancel,
client: info.client,
pushURL: info.pushURL,
pullURL: info.pullURL,
closeURL: info.closeURL,
headerHost: info.headerHost,
queuedConn: queuedConn{
rxc: make(chan []byte, 128),
closed: make(chan struct{}),
writeCh: make(chan []byte, 256),
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
},
}
go c.pullLoop()
go c.pushLoop()
return c
}
func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll)
if err != nil {
return nil, err
}
c := newPollConnFromInfo(info)
if c == nil {
return nil, fmt.Errorf("failed to build poll conn")
}
return c, nil
}
func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
info, err := dialSession(ctx, serverAddress, opts, TunnelModePoll)
if err != nil {
return nil, err
}
c := newPollConnFromInfo(info)
if c == nil {
return nil, fmt.Errorf("failed to build poll conn")
}
return c, nil
}