diff --git a/adapter/outbound/sudoku.go b/adapter/outbound/sudoku.go index b1063f04..36cf4421 100644 --- a/adapter/outbound/sudoku.go +++ b/adapter/outbound/sudoku.go @@ -6,6 +6,8 @@ import ( "net" "strconv" "strings" + "sync" + "time" N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" @@ -16,6 +18,11 @@ type Sudoku struct { *Base option *SudokuOption baseConf sudoku.ProtocolConfig + + muxMu sync.Mutex + muxClient *sudoku.MultiplexClient + muxBackoffUntil time.Time + muxLastErr error } type SudokuOption struct { @@ -30,12 +37,12 @@ type SudokuOption struct { TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"` HTTPMask bool `proxy:"http-mask,omitempty"` - HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" - HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto - HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) - HTTPMaskStrategy string `proxy:"http-mask-strategy,omitempty"` // "random" (default), "post", "websocket" - CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv - CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty + HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" + HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto + HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) + HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto", "on" + CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv + CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty } // DialContext implements C.ProxyAdapter @@ -45,40 +52,26 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con return nil, err } - var c net.Conn - if !cfg.DisableHTTPMask { - switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { - case "stream", "poll", "auto": - c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) + muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) + if !cfg.DisableHTTPMask && muxMode != "off" { + shouldTry := muxMode == "on" || (muxMode == "auto" && httpTunnelModeEnabled(cfg.HTTPMaskMode)) + if shouldTry { + stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode) + if muxErr == nil { + return NewConn(stream, s), nil + } + if muxMode != "auto" { + return nil, muxErr + } + s.noteMuxFailure(muxMode, muxErr) } } - if c == nil && err == nil { - c, err = s.dialer.DialContext(ctx, "tcp", s.addr) - } - if err != nil { - return nil, fmt.Errorf("%s connect error: %w", s.addr, err) - } - defer func() { - safeConnClose(c, err) - }() - - if ctx.Done() != nil { - done := N.SetupContextForConn(ctx, c) - defer done(&err) - } - - handshakeCfg := *cfg - if !handshakeCfg.DisableHTTPMask { - switch strings.ToLower(strings.TrimSpace(handshakeCfg.HTTPMaskMode)) { - case "stream", "poll", "auto": - handshakeCfg.DisableHTTPMask = true - } - } - c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{HTTPMaskStrategy: s.option.HTTPMaskStrategy}) + c, err := s.dialAndHandshake(ctx, cfg) if err != nil { return nil, err } + defer func() { safeConnClose(c, err) }() addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress) if err != nil { @@ -86,7 +79,6 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con } if _, err = c.Write(addrBuf); err != nil { - _ = c.Close() return nil, fmt.Errorf("send target address failed: %w", err) } @@ -104,37 +96,7 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return nil, err } - var c net.Conn - if !cfg.DisableHTTPMask { - switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { - case "stream", "poll", "auto": - c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) - } - } - if c == nil && err == nil { - c, err = s.dialer.DialContext(ctx, "tcp", s.addr) - } - if err != nil { - return nil, fmt.Errorf("%s connect error: %w", s.addr, err) - } - - defer func() { - safeConnClose(c, err) - }() - - if ctx.Done() != nil { - done := N.SetupContextForConn(ctx, c) - defer done(&err) - } - - handshakeCfg := *cfg - if !handshakeCfg.DisableHTTPMask { - switch strings.ToLower(strings.TrimSpace(handshakeCfg.HTTPMaskMode)) { - case "stream", "poll", "auto": - handshakeCfg.DisableHTTPMask = true - } - } - c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{HTTPMaskStrategy: s.option.HTTPMaskStrategy}) + c, err := s.dialAndHandshake(ctx, cfg) if err != nil { return nil, err } @@ -224,10 +186,14 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) { HTTPMaskMode: defaultConf.HTTPMaskMode, HTTPMaskTLSEnabled: option.HTTPMaskTLS, HTTPMaskHost: option.HTTPMaskHost, + HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex, } if option.HTTPMaskMode != "" { baseConf.HTTPMaskMode = option.HTTPMaskMode } + if option.HTTPMaskMultiplex != "" { + baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex + } tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables) if err != nil { return nil, fmt.Errorf("build table(s) failed: %w", err) @@ -260,3 +226,157 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) { outbound.dialer = option.NewDialer(outbound.DialOptions()) return outbound, nil } + +func (s *Sudoku) Close() error { + s.resetMuxClient() + return s.Base.Close() +} + +func normalizeHTTPMaskMultiplex(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case "", "off": + return "off" + case "auto": + return "auto" + case "on": + return "on" + default: + return "off" + } +} + +func httpTunnelModeEnabled(mode string) bool { + switch strings.ToLower(strings.TrimSpace(mode)) { + case "stream", "poll", "auto": + return true + default: + return false + } +} + +func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (_ net.Conn, err error) { + if cfg == nil { + return nil, fmt.Errorf("config is required") + } + + var c net.Conn + if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { + c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) + } + if c == nil && err == nil { + c, err = s.dialer.DialContext(ctx, "tcp", s.addr) + } + if err != nil { + return nil, fmt.Errorf("%s connect error: %w", s.addr, err) + } + + defer func() { safeConnClose(c, err) }() + + if ctx.Done() != nil { + done := N.SetupContextForConn(ctx, c) + defer done(&err) + } + + handshakeCfg := *cfg + if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) { + handshakeCfg.DisableHTTPMask = true + } + + c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{}) + if err != nil { + return nil, err + } + + return c, nil +} + +func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode string) (net.Conn, error) { + for attempt := 0; attempt < 2; attempt++ { + client, err := s.getOrCreateMuxClient(ctx, mode) + if err != nil { + return nil, err + } + + stream, err := client.Dial(ctx, targetAddress) + if err != nil { + s.resetMuxClient() + continue + } + + return stream, nil + } + return nil, fmt.Errorf("multiplex open stream failed") +} + +func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku.MultiplexClient, error) { + if s == nil { + return nil, fmt.Errorf("nil adapter") + } + + if mode == "auto" { + s.muxMu.Lock() + backoffUntil := s.muxBackoffUntil + lastErr := s.muxLastErr + s.muxMu.Unlock() + if time.Now().Before(backoffUntil) { + return nil, fmt.Errorf("multiplex temporarily disabled: %v", lastErr) + } + } + + s.muxMu.Lock() + if s.muxClient != nil && !s.muxClient.IsClosed() { + client := s.muxClient + s.muxMu.Unlock() + return client, nil + } + s.muxMu.Unlock() + + s.muxMu.Lock() + defer s.muxMu.Unlock() + + if s.muxClient != nil && !s.muxClient.IsClosed() { + return s.muxClient, nil + } + + baseCfg := s.baseConf + baseConn, err := s.dialAndHandshake(ctx, &baseCfg) + if err != nil { + if mode == "auto" { + s.muxLastErr = err + s.muxBackoffUntil = time.Now().Add(45 * time.Second) + } + return nil, err + } + + client, err := sudoku.StartMultiplexClient(baseConn) + if err != nil { + _ = baseConn.Close() + if mode == "auto" { + s.muxLastErr = err + s.muxBackoffUntil = time.Now().Add(45 * time.Second) + } + return nil, err + } + + s.muxClient = client + return client, nil +} + +func (s *Sudoku) noteMuxFailure(mode string, err error) { + if mode != "auto" { + return + } + s.muxMu.Lock() + s.muxLastErr = err + s.muxBackoffUntil = time.Now().Add(45 * time.Second) + s.muxMu.Unlock() +} + +func (s *Sudoku) resetMuxClient() { + s.muxMu.Lock() + defer s.muxMu.Unlock() + if s.muxClient != nil { + _ = s.muxClient.Close() + s.muxClient = nil + } +} diff --git a/android_tz.go b/android_tz.go index 82fc38e3..37677730 100644 --- a/android_tz.go +++ b/android_tz.go @@ -1,6 +1,8 @@ // Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build android && cgo +// +build android,cgo // kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89 diff --git a/docs/config.yaml b/docs/config.yaml index 80e1c8af..b565fb03 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -1066,7 +1066,7 @@ proxies: # socks5 # http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代 # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断) # http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效 - # http-mask-strategy: random # 可选:random(默认)、post、websocket;仅 legacy 下生效 + # http-mask-multiplex: off # 可选:off(默认)、auto、on;复用单条隧道并在其内多路复用多个目标连接 enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none) # anytls diff --git a/listener/inbound/sudoku_test.go b/listener/inbound/sudoku_test.go index 5596bf91..1d6c4e59 100644 --- a/listener/inbound/sudoku_test.go +++ b/listener/inbound/sudoku_test.go @@ -2,7 +2,6 @@ package inbound_test import ( "net/netip" - "runtime" "testing" "github.com/metacubex/mihomo/adapter/outbound" @@ -167,10 +166,6 @@ func TestInboundSudoku_CustomTable(t *testing.T) { } func TestInboundSudoku_HTTPMaskMode(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("temporarily skipped on windows due to intermittent failures; tracked in PR") - } - key := "test_key_http_mask_mode" for _, mode := range []string{"legacy", "stream", "poll", "auto"} { diff --git a/listener/sudoku/server.go b/listener/sudoku/server.go index 7652783f..d0d3b404 100644 --- a/listener/sudoku/server.go +++ b/listener/sudoku/server.go @@ -78,6 +78,26 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou switch session.Type { case sudoku.SessionTypeUoT: l.handleUoTSession(session.Conn, tunnel, additions...) + case sudoku.SessionTypeMultiplex: + mux, err := sudoku.AcceptMultiplexServer(session.Conn) + if err != nil { + _ = session.Conn.Close() + return + } + defer mux.Close() + + for { + stream, target, err := mux.AcceptTCP() + if err != nil { + return + } + targetAddr := socks5.ParseAddr(target) + if targetAddr == nil { + _ = stream.Close() + continue + } + go l.handler.HandleSocket(targetAddr, stream, additions...) + } default: targetAddr := socks5.ParseAddr(session.Target) if targetAddr == nil { diff --git a/transport/sudoku/config.go b/transport/sudoku/config.go index 4fee6b67..8eee292d 100644 --- a/transport/sudoku/config.go +++ b/transport/sudoku/config.go @@ -57,6 +57,11 @@ type ProtocolConfig struct { // HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side). HTTPMaskHost string + + // HTTPMaskMultiplex controls whether the client reuses a single (HTTP-masked) tunnel connection and + // opens multiple logical target streams inside it (reduces RTT for subsequent connections). + // Values: "off" / "auto" / "on". + HTTPMaskMultiplex string } func (c *ProtocolConfig) Validate() error { @@ -103,6 +108,12 @@ func (c *ProtocolConfig) Validate() error { return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode) } + switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMultiplex)) { + case "", "off", "auto", "on": + default: + return fmt.Errorf("invalid http-mask-multiplex: %s, must be one of: off, auto, on", c.HTTPMaskMultiplex) + } + return nil } @@ -127,6 +138,7 @@ func DefaultConfig() *ProtocolConfig { EnablePureDownlink: true, HandshakeTimeoutSeconds: 5, HTTPMaskMode: "legacy", + HTTPMaskMultiplex: "off", } } diff --git a/transport/sudoku/handshake.go b/transport/sudoku/handshake.go index 2a0437d6..6963add5 100644 --- a/transport/sudoku/handshake.go +++ b/transport/sudoku/handshake.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/binary" + "encoding/hex" "fmt" "io" "net" @@ -23,12 +24,17 @@ type SessionType int const ( SessionTypeTCP SessionType = iota SessionTypeUoT + SessionTypeMultiplex ) type ServerSession struct { Conn net.Conn Type SessionType Target string + + // UserHash is a stable per-key identifier derived from the handshake payload. + // It is primarily useful for debugging / user attribution when table rotation is enabled. + UserHash string } type bufferedConn struct { @@ -147,7 +153,14 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table, func buildHandshakePayload(key string) [16]byte { var payload [16]byte binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix())) - hash := sha256.Sum256([]byte(key)) + // Hash the decoded HEX bytes of the key, not the HEX string itself. + // This ensures the user hash is computed on the actual key bytes. + keyBytes, err := hex.DecodeString(key) + if err != nil { + // Fallback: if key is not valid HEX (e.g., a UUID or plain string), hash the string bytes + keyBytes = []byte(key) + } + hash := sha256.Sum256(keyBytes) copy(payload[8:], hash[:8]) return payload } @@ -216,7 +229,7 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt Clien handshake := buildHandshakePayload(cfg.Key) if len(cfg.tableCandidates()) > 1 { - handshake[15] = tableID + handshake[8] = tableID } if _, err := cConn.Write(handshake[:]); err != nil { cConn.Close() @@ -280,6 +293,7 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err return nil, fmt.Errorf("timestamp skew detected") } + userHash := userHashFromHandshake(handshakeBuf[:]) sConn.StopRecording() modeBuf := []byte{0} @@ -298,6 +312,11 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err return nil, fmt.Errorf("read first byte failed: %w", err) } + if firstByte[0] == MultiplexMagicByte { + rawConn.SetReadDeadline(time.Time{}) + return &ServerSession{Conn: cConn, Type: SessionTypeMultiplex, UserHash: userHash}, nil + } + if firstByte[0] == UoTMagicByte { version := make([]byte, 1) if _, err := io.ReadFull(cConn, version); err != nil { @@ -309,7 +328,7 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err return nil, fmt.Errorf("unsupported uot version: %d", version[0]) } rawConn.SetReadDeadline(time.Time{}) - return &ServerSession{Conn: cConn, Type: SessionTypeUoT}, nil + return &ServerSession{Conn: cConn, Type: SessionTypeUoT, UserHash: userHash}, nil } prefixed := &preBufferedConn{Conn: cConn, buf: firstByte} @@ -322,9 +341,10 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err rawConn.SetReadDeadline(time.Time{}) log.Debugln("[Sudoku] incoming TCP session target: %s", target) return &ServerSession{ - Conn: prefixed, - Type: SessionTypeTCP, - Target: target, + Conn: prefixed, + Type: SessionTypeTCP, + Target: target, + UserHash: userHash, }, nil } @@ -364,3 +384,11 @@ func randomByte() byte { } return byte(time.Now().UnixNano()) } + +func userHashFromHandshake(handshakeBuf []byte) string { + if len(handshakeBuf) < 16 { + return "" + } + // handshake[8] may be a table ID when table rotation is enabled; use [9:16] as stable user hash bytes. + return hex.EncodeToString(handshakeBuf[9:16]) +} diff --git a/transport/sudoku/multiplex.go b/transport/sudoku/multiplex.go new file mode 100644 index 00000000..2834ef80 --- /dev/null +++ b/transport/sudoku/multiplex.go @@ -0,0 +1,150 @@ +package sudoku + +import ( + "context" + "fmt" + "net" + "strings" + "time" + + "github.com/metacubex/mihomo/transport/sudoku/multiplex" +) + +const ( + MultiplexMagicByte byte = multiplex.MagicByte + MultiplexVersion byte = multiplex.Version +) + +// StartMultiplexClient writes the multiplex preface and upgrades an already-handshaked Sudoku tunnel into a multiplex session. +func StartMultiplexClient(conn net.Conn) (*MultiplexClient, error) { + if conn == nil { + return nil, fmt.Errorf("nil conn") + } + + if err := multiplex.WritePreface(conn); err != nil { + return nil, fmt.Errorf("write multiplex preface failed: %w", err) + } + + sess, err := multiplex.NewClientSession(conn) + if err != nil { + return nil, fmt.Errorf("start multiplex session failed: %w", err) + } + + return &MultiplexClient{sess: sess}, nil +} + +type MultiplexClient struct { + sess *multiplex.Session +} + +// Dial opens a new logical stream, writes the target address, and returns the stream as net.Conn. +func (c *MultiplexClient) Dial(ctx context.Context, targetAddress string) (net.Conn, error) { + if c == nil || c.sess == nil || c.sess.IsClosed() { + return nil, fmt.Errorf("multiplex session is closed") + } + if strings.TrimSpace(targetAddress) == "" { + return nil, fmt.Errorf("target address cannot be empty") + } + + stream, err := c.sess.OpenStream() + if err != nil { + return nil, err + } + + if deadline, ok := ctx.Deadline(); ok { + _ = stream.SetWriteDeadline(deadline) + defer stream.SetWriteDeadline(time.Time{}) + } + + addrBuf, err := EncodeAddress(targetAddress) + if err != nil { + _ = stream.Close() + return nil, fmt.Errorf("encode target address failed: %w", err) + } + if _, err := stream.Write(addrBuf); err != nil { + _ = stream.Close() + return nil, fmt.Errorf("send target address failed: %w", err) + } + + return stream, nil +} + +func (c *MultiplexClient) Close() error { + if c == nil || c.sess == nil { + return nil + } + return c.sess.Close() +} + +func (c *MultiplexClient) IsClosed() bool { + if c == nil || c.sess == nil { + return true + } + return c.sess.IsClosed() +} + +// AcceptMultiplexServer upgrades a server-side, already-handshaked Sudoku connection into a multiplex session. +// +// The caller must have already consumed the multiplex magic byte (MultiplexMagicByte). This function consumes the +// multiplex version byte and starts the session. +func AcceptMultiplexServer(conn net.Conn) (*MultiplexServer, error) { + if conn == nil { + return nil, fmt.Errorf("nil conn") + } + v, err := multiplex.ReadVersion(conn) + if err != nil { + return nil, err + } + if err := multiplex.ValidateVersion(v); err != nil { + return nil, err + } + sess, err := multiplex.NewServerSession(conn) + if err != nil { + return nil, err + } + return &MultiplexServer{sess: sess}, nil +} + +// MultiplexServer wraps a multiplex session created from a handshaked Sudoku tunnel connection. +type MultiplexServer struct { + sess *multiplex.Session +} + +func (s *MultiplexServer) AcceptStream() (net.Conn, error) { + if s == nil || s.sess == nil { + return nil, fmt.Errorf("nil session") + } + return s.sess.AcceptStream() +} + +// AcceptTCP accepts a multiplex stream and reads the target address preface, returning the stream positioned at +// application data. +func (s *MultiplexServer) AcceptTCP() (net.Conn, string, error) { + stream, err := s.AcceptStream() + if err != nil { + return nil, "", err + } + + target, err := DecodeAddress(stream) + if err != nil { + _ = stream.Close() + return nil, "", err + } + + return stream, target, nil +} + +func (s *MultiplexServer) Close() error { + if s == nil || s.sess == nil { + return nil + } + return s.sess.Close() +} + +func (s *MultiplexServer) IsClosed() bool { + if s == nil || s.sess == nil { + return true + } + return s.sess.IsClosed() +} + diff --git a/transport/sudoku/multiplex/session.go b/transport/sudoku/multiplex/session.go new file mode 100644 index 00000000..9e24c6b7 --- /dev/null +++ b/transport/sudoku/multiplex/session.go @@ -0,0 +1,101 @@ +package multiplex + +import ( + "fmt" + "io" + "net" + "time" + + "github.com/metacubex/smux" +) + +const ( + // MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode. + // It is sent after the Sudoku handshake + downlink mode byte. + MagicByte byte = 0xEF + Version = 0x01 +) + +func WritePreface(w io.Writer) error { + _, err := w.Write([]byte{MagicByte, Version}) + return err +} + +func ReadVersion(r io.Reader) (byte, error) { + var b [1]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return 0, err + } + return b[0], nil +} + +func ValidateVersion(v byte) error { + if v != Version { + return fmt.Errorf("unsupported multiplex version: %d", v) + } + return nil +} + +func defaultSmuxConfig() *smux.Config { + cfg := smux.DefaultConfig() + cfg.KeepAliveInterval = 15 * time.Second + cfg.KeepAliveTimeout = 45 * time.Second + return cfg +} + +type Session struct { + sess *smux.Session +} + +func NewClientSession(conn net.Conn) (*Session, error) { + if conn == nil { + return nil, fmt.Errorf("nil conn") + } + s, err := smux.Client(conn, defaultSmuxConfig()) + if err != nil { + _ = conn.Close() + return nil, err + } + return &Session{sess: s}, nil +} + +func NewServerSession(conn net.Conn) (*Session, error) { + if conn == nil { + return nil, fmt.Errorf("nil conn") + } + s, err := smux.Server(conn, defaultSmuxConfig()) + if err != nil { + _ = conn.Close() + return nil, err + } + return &Session{sess: s}, nil +} + +func (s *Session) OpenStream() (net.Conn, error) { + if s == nil || s.sess == nil { + return nil, fmt.Errorf("nil session") + } + return s.sess.OpenStream() +} + +func (s *Session) AcceptStream() (net.Conn, error) { + if s == nil || s.sess == nil { + return nil, fmt.Errorf("nil session") + } + return s.sess.AcceptStream() +} + +func (s *Session) Close() error { + if s == nil || s.sess == nil { + return nil + } + return s.sess.Close() +} + +func (s *Session) IsClosed() bool { + if s == nil || s.sess == nil { + return true + } + return s.sess.IsClosed() +} + diff --git a/transport/sudoku/multiplex_test.go b/transport/sudoku/multiplex_test.go new file mode 100644 index 00000000..694b6daa --- /dev/null +++ b/transport/sudoku/multiplex_test.go @@ -0,0 +1,261 @@ +package sudoku + +import ( + "bytes" + "context" + "io" + "net" + "sync/atomic" + "testing" + "time" + + sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" +) + +func TestUserHash_StableAcrossTableRotation(t *testing.T) { + tables := []*sudokuobfs.Table{ + sudokuobfs.NewTable("seed-a", "prefer_ascii"), + sudokuobfs.NewTable("seed-b", "prefer_ascii"), + } + key := "userhash-stability-key" + target := "example.com:80" + + serverCfg := DefaultConfig() + serverCfg.Key = key + serverCfg.AEADMethod = "chacha20-poly1305" + serverCfg.Tables = tables + serverCfg.PaddingMin = 0 + serverCfg.PaddingMax = 0 + serverCfg.EnablePureDownlink = true + serverCfg.HandshakeTimeoutSeconds = 5 + serverCfg.DisableHTTPMask = true + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { _ = ln.Close() }) + + const attempts = 32 + hashCh := make(chan string, attempts) + errCh := make(chan error, attempts) + + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + session, err := ServerHandshake(conn, serverCfg) + if err != nil { + errCh <- err + return + } + defer session.Conn.Close() + hashCh <- session.UserHash + }(c) + } + }() + + clientCfg := DefaultConfig() + *clientCfg = *serverCfg + clientCfg.ServerAddress = ln.Addr().String() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for i := 0; i < attempts; i++ { + raw, err := (&net.Dialer{}).DialContext(ctx, "tcp", clientCfg.ServerAddress) + if err != nil { + t.Fatalf("dial %d: %v", i, err) + } + cConn, err := ClientHandshake(raw, clientCfg) + if err != nil { + _ = raw.Close() + t.Fatalf("handshake %d: %v", i, err) + } + + addrBuf, err := EncodeAddress(target) + if err != nil { + _ = cConn.Close() + t.Fatalf("encode addr %d: %v", i, err) + } + if _, err := cConn.Write(addrBuf); err != nil { + _ = cConn.Close() + t.Fatalf("write addr %d: %v", i, err) + } + _ = cConn.Close() + } + + unique := map[string]struct{}{} + deadline := time.After(10 * time.Second) + for i := 0; i < attempts; i++ { + select { + case err := <-errCh: + t.Fatalf("server handshake error: %v", err) + case h := <-hashCh: + if h == "" { + t.Fatalf("empty user hash") + } + if len(h) != 14 { + t.Fatalf("unexpected user hash length: %d", len(h)) + } + unique[h] = struct{}{} + case <-deadline: + t.Fatalf("timeout waiting for server handshakes") + } + } + if len(unique) != 1 { + t.Fatalf("user hash should be stable across table rotation; got %d distinct values", len(unique)) + } +} + +func TestMultiplex_TCP_Echo(t *testing.T) { + table := sudokuobfs.NewTable("seed", "prefer_ascii") + key := "test-key-mux" + target := "example.com:80" + + serverCfg := DefaultConfig() + serverCfg.Key = key + serverCfg.AEADMethod = "chacha20-poly1305" + serverCfg.Table = table + serverCfg.PaddingMin = 0 + serverCfg.PaddingMax = 0 + serverCfg.EnablePureDownlink = true + serverCfg.HandshakeTimeoutSeconds = 5 + serverCfg.DisableHTTPMask = true + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { _ = ln.Close() }) + + var handshakes int64 + var streams int64 + done := make(chan struct{}) + + go func() { + defer close(done) + raw, err := ln.Accept() + if err != nil { + return + } + defer raw.Close() + + session, err := ServerHandshake(raw, serverCfg) + if err != nil { + return + } + atomic.AddInt64(&handshakes, 1) + + if session.Type != SessionTypeMultiplex { + _ = session.Conn.Close() + return + } + + mux, err := AcceptMultiplexServer(session.Conn) + if err != nil { + return + } + defer mux.Close() + + for { + stream, dst, err := mux.AcceptTCP() + if err != nil { + return + } + if dst != target { + _ = stream.Close() + return + } + atomic.AddInt64(&streams, 1) + go func(c net.Conn) { + defer c.Close() + _, _ = io.Copy(c, c) + }(stream) + } + }() + + clientCfg := DefaultConfig() + *clientCfg = *serverCfg + clientCfg.ServerAddress = ln.Addr().String() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + raw, err := (&net.Dialer{}).DialContext(ctx, "tcp", clientCfg.ServerAddress) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { _ = raw.Close() }) + + cConn, err := ClientHandshake(raw, clientCfg) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + + mux, err := StartMultiplexClient(cConn) + if err != nil { + _ = cConn.Close() + t.Fatalf("start mux: %v", err) + } + defer mux.Close() + + for i := 0; i < 6; i++ { + s, err := mux.Dial(ctx, target) + if err != nil { + t.Fatalf("dial stream %d: %v", i, err) + } + + msg := []byte("hello-mux") + if _, err := s.Write(msg); err != nil { + _ = s.Close() + t.Fatalf("write: %v", err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(s, buf); err != nil { + _ = s.Close() + t.Fatalf("read: %v", err) + } + _ = s.Close() + if !bytes.Equal(buf, msg) { + t.Fatalf("echo mismatch: got %q", buf) + } + } + + _ = mux.Close() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("server did not exit") + } + + if got := atomic.LoadInt64(&handshakes); got != 1 { + t.Fatalf("unexpected handshake count: %d", got) + } + if got := atomic.LoadInt64(&streams); got < 6 { + t.Fatalf("unexpected stream count: %d", got) + } +} + +func TestMultiplex_Boundary_InvalidVersion(t *testing.T) { + client, server := net.Pipe() + t.Cleanup(func() { _ = client.Close() }) + t.Cleanup(func() { _ = server.Close() }) + + errCh := make(chan error, 1) + go func() { + _, err := AcceptMultiplexServer(server) + errCh <- err + }() + + // AcceptMultiplexServer expects the magic byte to have been consumed already; write a bad version byte. + _, _ = client.Write([]byte{0xFF}) + if err := <-errCh; err == nil { + t.Fatalf("expected error") + } +} + diff --git a/transport/sudoku/obfs/httpmask/tunnel.go b/transport/sudoku/obfs/httpmask/tunnel.go index b4c880bb..e281f479 100644 --- a/transport/sudoku/obfs/httpmask/tunnel.go +++ b/transport/sudoku/obfs/httpmask/tunnel.go @@ -203,7 +203,7 @@ func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns in } transport := &http.Transport{ - ForceAttemptHTTP2: true, + ForceAttemptHTTP2: scheme == "https", DisableCompression: true, MaxIdleConns: maxIdleConns, IdleConnTimeout: 30 * time.Second, @@ -261,10 +261,10 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt reqBodyR, reqBodyW := io.Pipe() - ctx, cancel := context.WithCancel(ctx) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), reqBodyR) + connCtx, connCancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(connCtx, http.MethodPost, u.String(), reqBodyR) if err != nil { - cancel() + connCancel() _ = reqBodyW.Close() return nil, err } @@ -273,27 +273,43 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream) req.Header.Set("Content-Type", ctype) - resp, err := client.Do(req) - if err != nil { - cancel() - _ = reqBodyW.Close() - return nil, err - } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) - cancel() - _ = reqBodyW.Close() - return nil, fmt.Errorf("stream bad status: %s (%s)", resp.Status, strings.TrimSpace(string(body))) + type doResult struct { + resp *http.Response + err error } + doCh := make(chan doResult, 1) + go func() { + resp, doErr := client.Do(req) + doCh <- doResult{resp: resp, err: doErr} + }() - return &httpStreamConn{ - reader: resp.Body, - writer: reqBodyW, - cancel: cancel, - localAddr: &net.TCPAddr{}, - remoteAddr: &net.TCPAddr{}, - }, nil + select { + case <-ctx.Done(): + connCancel() + _ = reqBodyW.Close() + return nil, ctx.Err() + case r := <-doCh: + if r.err != nil { + connCancel() + _ = reqBodyW.Close() + return nil, r.err + } + if r.resp.StatusCode != http.StatusOK { + defer r.resp.Body.Close() + body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024)) + connCancel() + _ = reqBodyW.Close() + return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body))) + } + + return &httpStreamConn{ + reader: r.resp.Body, + writer: reqBodyW, + cancel: connCancel, + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + }, nil + } } type streamSplitConn struct { @@ -456,8 +472,15 @@ func (c *streamSplitConn) pullLoop() { requestTimeout = 30 * time.Second readChunkSize = 32 * 1024 idleBackoff = 25 * time.Millisecond + maxDialRetry = 12 + minBackoff = 10 * time.Millisecond + maxBackoff = 250 * time.Millisecond ) + var ( + dialRetry int + backoff = minBackoff + ) buf := make([]byte, readChunkSize) for { select { @@ -479,9 +502,24 @@ func (c *streamSplitConn) pullLoop() { resp, err := c.client.Do(req) if err != nil { cancel() + if isDialError(err) && dialRetry < maxDialRetry { + dialRetry++ + select { + case <-time.After(backoff): + case <-c.closed: + return + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } _ = c.Close() return } + dialRetry = 0 + backoff = minBackoff if resp.StatusCode != http.StatusOK { _ = resp.Body.Close() @@ -533,6 +571,9 @@ func (c *streamSplitConn) pushLoop() { maxBatchBytes = 256 * 1024 flushInterval = 5 * time.Millisecond requestTimeout = 20 * time.Second + maxDialRetry = 12 + minBackoff = 10 * time.Millisecond + maxBackoff = 250 * time.Millisecond ) var ( @@ -541,16 +582,16 @@ func (c *streamSplitConn) pushLoop() { ) defer timer.Stop() - flush := func() bool { + flush := func() error { if buf.Len() == 0 { - return true + return nil } reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout) req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) if err != nil { cancel() - return false + return err } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) @@ -559,17 +600,41 @@ func (c *streamSplitConn) pushLoop() { resp, err := c.client.Do(req) if err != nil { cancel() - return false + return err } _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) _ = resp.Body.Close() cancel() if resp.StatusCode != http.StatusOK { - return false + return fmt.Errorf("bad status: %s", resp.Status) } buf.Reset() - return true + return nil + } + + flushWithRetry := func() error { + dialRetry := 0 + backoff := minBackoff + for { + if err := flush(); err == nil { + return nil + } else if isDialError(err) && dialRetry < maxDialRetry { + dialRetry++ + select { + case <-time.After(backoff): + case <-c.closed: + return io.ErrClosedPipe + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } else { + return err + } + } } resetTimer := func() { @@ -588,14 +653,14 @@ func (c *streamSplitConn) pushLoop() { select { case b, ok := <-c.writeCh: if !ok { - _ = flush() + _ = flushWithRetry() return } if len(b) == 0 { continue } if buf.Len()+len(b) > maxBatchBytes { - if !flush() { + if err := flushWithRetry(); err != nil { _ = c.Close() return } @@ -603,20 +668,20 @@ func (c *streamSplitConn) pushLoop() { } _, _ = buf.Write(b) if buf.Len() >= maxBatchBytes { - if !flush() { + if err := flushWithRetry(); err != nil { _ = c.Close() return } resetTimer() } case <-timer.C: - if !flush() { + if err := flushWithRetry(); err != nil { _ = c.Close() return } resetTimer() case <-c.closed: - _ = flush() + _ = flushWithRetry() return } } @@ -636,16 +701,72 @@ type pollConn struct { mu sync.Mutex readBuf []byte + closeErr error localAddr net.Addr remoteAddr net.Addr } +func isDialError(err error) bool { + var urlErr *url.Error + if errors.As(err, &urlErr) { + return isDialError(urlErr.Err) + } + var opErr *net.OpError + if errors.As(err, &opErr) { + if opErr.Op == "dial" || opErr.Op == "connect" { + return true + } + } + return false +} + +func (c *pollConn) closeWithError(err error) error { + c.mu.Lock() + select { + case <-c.closed: + c.mu.Unlock() + return nil + default: + if err == nil { + err = io.ErrClosedPipe + } + if c.closeErr == nil { + c.closeErr = err + } + close(c.closed) + } + c.mu.Unlock() + + // Best-effort session close signal (avoid leaking server-side sessions). + req, reqErr := http.NewRequest(http.MethodPost, c.closeURL, nil) + if reqErr == nil { + req.Host = c.headerHost + req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll)) + req.Header.Set("X-Sudoku-Version", "1") + if resp, doErr := c.client.Do(req); doErr == nil && resp != nil { + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + } + } + return nil +} + +func (c *pollConn) closedErr() error { + c.mu.Lock() + err := c.closeErr + c.mu.Unlock() + if err == nil { + return io.ErrClosedPipe + } + return err +} + func (c *pollConn) Read(b []byte) (n int, err error) { if len(c.readBuf) == 0 { select { case c.readBuf = <-c.rxc: case <-c.closed: - return 0, io.ErrClosedPipe + return 0, c.closedErr() } } n = copy(b, c.readBuf) @@ -661,7 +782,7 @@ func (c *pollConn) Write(b []byte) (n int, err error) { select { case <-c.closed: c.mu.Unlock() - return 0, io.ErrClosedPipe + return 0, c.closedErr() default: } c.mu.Unlock() @@ -672,33 +793,12 @@ func (c *pollConn) Write(b []byte) (n int, err error) { case c.writeCh <- payload: return len(b), nil case <-c.closed: - return 0, io.ErrClosedPipe + return 0, c.closedErr() } } func (c *pollConn) Close() error { - c.mu.Lock() - select { - case <-c.closed: - c.mu.Unlock() - return nil - default: - close(c.closed) - } - c.mu.Unlock() - - close(c.writeCh) - - // Best-effort session close signal (avoid leaking server-side sessions). - req, err := http.NewRequest(http.MethodPost, c.closeURL, nil) - if err == nil { - req.Host = c.headerHost - req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll)) - req.Header.Set("X-Sudoku-Version", "1") - _, _ = c.client.Do(req) - } - - return nil + return c.closeWithError(io.ErrClosedPipe) } func (c *pollConn) LocalAddr() net.Addr { return c.localAddr } @@ -766,6 +866,15 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) } func (c *pollConn) pullLoop() { + const ( + maxDialRetry = 12 + minBackoff = 10 * time.Millisecond + maxBackoff = 250 * time.Millisecond + ) + var ( + dialRetry int + backoff = minBackoff + ) for { select { case <-c.closed: @@ -783,13 +892,28 @@ func (c *pollConn) pullLoop() { resp, err := c.client.Do(req) if err != nil { - _ = c.Close() + if isDialError(err) && dialRetry < maxDialRetry { + dialRetry++ + select { + case <-time.After(backoff): + case <-c.closed: + return + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } + _ = c.closeWithError(fmt.Errorf("poll pull request failed: %w", err)) return } + dialRetry = 0 + backoff = minBackoff if resp.StatusCode != http.StatusOK { _ = resp.Body.Close() - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("poll pull bad status: %s", resp.Status)) return } @@ -802,7 +926,7 @@ func (c *pollConn) pullLoop() { payload, err := base64.StdEncoding.DecodeString(line) if err != nil { _ = resp.Body.Close() - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("poll pull decode failed: %w", err)) return } select { @@ -814,7 +938,7 @@ func (c *pollConn) pullLoop() { } _ = resp.Body.Close() if err := scanner.Err(); err != nil { - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("poll pull scan failed: %w", err)) return } } @@ -825,6 +949,9 @@ func (c *pollConn) pushLoop() { maxBatchBytes = 64 * 1024 flushInterval = 5 * time.Millisecond maxLineRawBytes = 16 * 1024 + maxDialRetry = 12 + minBackoff = 10 * time.Millisecond + maxBackoff = 250 * time.Millisecond ) var ( @@ -834,14 +961,14 @@ func (c *pollConn) pushLoop() { ) defer timer.Stop() - flush := func() bool { + flush := func() error { if buf.Len() == 0 { - return true + return nil } req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) if err != nil { - return false + return err } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) @@ -849,17 +976,41 @@ func (c *pollConn) pushLoop() { resp, err := c.client.Do(req) if err != nil { - return false + return err } _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) _ = resp.Body.Close() if resp.StatusCode != http.StatusOK { - return false + return fmt.Errorf("bad status: %s", resp.Status) } buf.Reset() pendingRaw = 0 - return true + return nil + } + + flushWithRetry := func() error { + dialRetry := 0 + backoff := minBackoff + for { + if err := flush(); err == nil { + return nil + } else if isDialError(err) && dialRetry < maxDialRetry { + dialRetry++ + select { + case <-time.After(backoff): + case <-c.closed: + return c.closedErr() + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } else { + return err + } + } } resetTimer := func() { @@ -878,7 +1029,7 @@ func (c *pollConn) pushLoop() { select { case b, ok := <-c.writeCh: if !ok { - _ = flush() + _ = flushWithRetry() return } if len(b) == 0 { @@ -895,8 +1046,8 @@ func (c *pollConn) pushLoop() { encLen := base64.StdEncoding.EncodedLen(len(chunk)) if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 { - if !flush() { - _ = c.Close() + if err := flushWithRetry(); err != nil { + _ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err)) return } } @@ -909,20 +1060,20 @@ func (c *pollConn) pushLoop() { } if pendingRaw >= maxBatchBytes { - if !flush() { - _ = c.Close() + if err := flushWithRetry(); err != nil { + _ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err)) return } resetTimer() } case <-timer.C: - if !flush() { - _ = c.Close() + if err := flushWithRetry(); err != nil { + _ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err)) return } resetTimer() case <-c.closed: - _ = flush() + _ = flushWithRetry() return } }