mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-01-10 00:29:01 +08:00
feat: support http-mask-multiplex for suduko (#2482)
Some checks are pending
Test / test (1.20, macos-15-intel) (push) Waiting to run
Test / test (1.20, macos-latest) (push) Waiting to run
Test / test (1.20, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.20, ubuntu-latest) (push) Waiting to run
Test / test (1.20, windows-latest) (push) Waiting to run
Test / test (1.21, macos-15-intel) (push) Waiting to run
Test / test (1.21, macos-latest) (push) Waiting to run
Test / test (1.21, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.21, ubuntu-latest) (push) Waiting to run
Test / test (1.21, windows-latest) (push) Waiting to run
Test / test (1.22, macos-15-intel) (push) Waiting to run
Test / test (1.22, macos-latest) (push) Waiting to run
Test / test (1.22, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.22, ubuntu-latest) (push) Waiting to run
Test / test (1.22, windows-latest) (push) Waiting to run
Test / test (1.23, macos-15-intel) (push) Waiting to run
Test / test (1.23, macos-latest) (push) Waiting to run
Test / test (1.23, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.23, ubuntu-latest) (push) Waiting to run
Test / test (1.23, windows-latest) (push) Waiting to run
Test / test (1.24, macos-15-intel) (push) Waiting to run
Test / test (1.24, macos-latest) (push) Waiting to run
Test / test (1.24, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.24, ubuntu-latest) (push) Waiting to run
Test / test (1.24, windows-latest) (push) Waiting to run
Test / test (1.25, macos-15-intel) (push) Waiting to run
Test / test (1.25, macos-latest) (push) Waiting to run
Test / test (1.25, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.25, ubuntu-latest) (push) Waiting to run
Test / test (1.25, windows-latest) (push) Waiting to run
Test / test (1.26.0-rc.1, macos-15-intel) (push) Waiting to run
Test / test (1.26.0-rc.1, macos-latest) (push) Waiting to run
Test / test (1.26.0-rc.1, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.26.0-rc.1, ubuntu-latest) (push) Waiting to run
Test / test (1.26.0-rc.1, windows-latest) (push) Waiting to run
Trigger CMFA Update / trigger-CMFA-update (push) Waiting to run
Some checks are pending
Test / test (1.20, macos-15-intel) (push) Waiting to run
Test / test (1.20, macos-latest) (push) Waiting to run
Test / test (1.20, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.20, ubuntu-latest) (push) Waiting to run
Test / test (1.20, windows-latest) (push) Waiting to run
Test / test (1.21, macos-15-intel) (push) Waiting to run
Test / test (1.21, macos-latest) (push) Waiting to run
Test / test (1.21, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.21, ubuntu-latest) (push) Waiting to run
Test / test (1.21, windows-latest) (push) Waiting to run
Test / test (1.22, macos-15-intel) (push) Waiting to run
Test / test (1.22, macos-latest) (push) Waiting to run
Test / test (1.22, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.22, ubuntu-latest) (push) Waiting to run
Test / test (1.22, windows-latest) (push) Waiting to run
Test / test (1.23, macos-15-intel) (push) Waiting to run
Test / test (1.23, macos-latest) (push) Waiting to run
Test / test (1.23, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.23, ubuntu-latest) (push) Waiting to run
Test / test (1.23, windows-latest) (push) Waiting to run
Test / test (1.24, macos-15-intel) (push) Waiting to run
Test / test (1.24, macos-latest) (push) Waiting to run
Test / test (1.24, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.24, ubuntu-latest) (push) Waiting to run
Test / test (1.24, windows-latest) (push) Waiting to run
Test / test (1.25, macos-15-intel) (push) Waiting to run
Test / test (1.25, macos-latest) (push) Waiting to run
Test / test (1.25, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.25, ubuntu-latest) (push) Waiting to run
Test / test (1.25, windows-latest) (push) Waiting to run
Test / test (1.26.0-rc.1, macos-15-intel) (push) Waiting to run
Test / test (1.26.0-rc.1, macos-latest) (push) Waiting to run
Test / test (1.26.0-rc.1, ubuntu-24.04-arm) (push) Waiting to run
Test / test (1.26.0-rc.1, ubuntu-latest) (push) Waiting to run
Test / test (1.26.0-rc.1, windows-latest) (push) Waiting to run
Trigger CMFA Update / trigger-CMFA-update (push) Waiting to run
This commit is contained in:
parent
4d7670339b
commit
d6b1263236
@ -6,6 +6,8 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
N "github.com/metacubex/mihomo/common/net"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
@ -16,6 +18,11 @@ type Sudoku struct {
|
||||
*Base
|
||||
option *SudokuOption
|
||||
baseConf sudoku.ProtocolConfig
|
||||
|
||||
muxMu sync.Mutex
|
||||
muxClient *sudoku.MultiplexClient
|
||||
muxBackoffUntil time.Time
|
||||
muxLastErr error
|
||||
}
|
||||
|
||||
type SudokuOption struct {
|
||||
@ -30,12 +37,12 @@ type SudokuOption struct {
|
||||
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
|
||||
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
|
||||
HTTPMask bool `proxy:"http-mask,omitempty"`
|
||||
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
||||
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
|
||||
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
|
||||
HTTPMaskStrategy string `proxy:"http-mask-strategy,omitempty"` // "random" (default), "post", "websocket"
|
||||
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
|
||||
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
||||
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
|
||||
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
|
||||
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto", "on"
|
||||
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
|
||||
}
|
||||
|
||||
// DialContext implements C.ProxyAdapter
|
||||
@ -45,40 +52,26 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var c net.Conn
|
||||
if !cfg.DisableHTTPMask {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
|
||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||
if !cfg.DisableHTTPMask && muxMode != "off" {
|
||||
shouldTry := muxMode == "on" || (muxMode == "auto" && httpTunnelModeEnabled(cfg.HTTPMaskMode))
|
||||
if shouldTry {
|
||||
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
|
||||
if muxErr == nil {
|
||||
return NewConn(stream, s), nil
|
||||
}
|
||||
if muxMode != "auto" {
|
||||
return nil, muxErr
|
||||
}
|
||||
s.noteMuxFailure(muxMode, muxErr)
|
||||
}
|
||||
}
|
||||
if c == nil && err == nil {
|
||||
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
safeConnClose(c, err)
|
||||
}()
|
||||
|
||||
if ctx.Done() != nil {
|
||||
done := N.SetupContextForConn(ctx, c)
|
||||
defer done(&err)
|
||||
}
|
||||
|
||||
handshakeCfg := *cfg
|
||||
if !handshakeCfg.DisableHTTPMask {
|
||||
switch strings.ToLower(strings.TrimSpace(handshakeCfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
handshakeCfg.DisableHTTPMask = true
|
||||
}
|
||||
}
|
||||
c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{HTTPMaskStrategy: s.option.HTTPMaskStrategy})
|
||||
c, err := s.dialAndHandshake(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { safeConnClose(c, err) }()
|
||||
|
||||
addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
|
||||
if err != nil {
|
||||
@ -86,7 +79,6 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
|
||||
}
|
||||
|
||||
if _, err = c.Write(addrBuf); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, fmt.Errorf("send target address failed: %w", err)
|
||||
}
|
||||
|
||||
@ -104,37 +96,7 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var c net.Conn
|
||||
if !cfg.DisableHTTPMask {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
|
||||
}
|
||||
}
|
||||
if c == nil && err == nil {
|
||||
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
safeConnClose(c, err)
|
||||
}()
|
||||
|
||||
if ctx.Done() != nil {
|
||||
done := N.SetupContextForConn(ctx, c)
|
||||
defer done(&err)
|
||||
}
|
||||
|
||||
handshakeCfg := *cfg
|
||||
if !handshakeCfg.DisableHTTPMask {
|
||||
switch strings.ToLower(strings.TrimSpace(handshakeCfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
handshakeCfg.DisableHTTPMask = true
|
||||
}
|
||||
}
|
||||
c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{HTTPMaskStrategy: s.option.HTTPMaskStrategy})
|
||||
c, err := s.dialAndHandshake(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -224,10 +186,14 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
|
||||
HTTPMaskMode: defaultConf.HTTPMaskMode,
|
||||
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
|
||||
HTTPMaskHost: option.HTTPMaskHost,
|
||||
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
|
||||
}
|
||||
if option.HTTPMaskMode != "" {
|
||||
baseConf.HTTPMaskMode = option.HTTPMaskMode
|
||||
}
|
||||
if option.HTTPMaskMultiplex != "" {
|
||||
baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex
|
||||
}
|
||||
tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build table(s) failed: %w", err)
|
||||
@ -260,3 +226,157 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
|
||||
outbound.dialer = option.NewDialer(outbound.DialOptions())
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func (s *Sudoku) Close() error {
|
||||
s.resetMuxClient()
|
||||
return s.Base.Close()
|
||||
}
|
||||
|
||||
func normalizeHTTPMaskMultiplex(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "", "off":
|
||||
return "off"
|
||||
case "auto":
|
||||
return "auto"
|
||||
case "on":
|
||||
return "on"
|
||||
default:
|
||||
return "off"
|
||||
}
|
||||
}
|
||||
|
||||
func httpTunnelModeEnabled(mode string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "stream", "poll", "auto":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (_ net.Conn, err error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
var c net.Conn
|
||||
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
|
||||
}
|
||||
if c == nil && err == nil {
|
||||
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
|
||||
}
|
||||
|
||||
defer func() { safeConnClose(c, err) }()
|
||||
|
||||
if ctx.Done() != nil {
|
||||
done := N.SetupContextForConn(ctx, c)
|
||||
defer done(&err)
|
||||
}
|
||||
|
||||
handshakeCfg := *cfg
|
||||
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
|
||||
handshakeCfg.DisableHTTPMask = true
|
||||
}
|
||||
|
||||
c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode string) (net.Conn, error) {
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
client, err := s.getOrCreateMuxClient(ctx, mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream, err := client.Dial(ctx, targetAddress)
|
||||
if err != nil {
|
||||
s.resetMuxClient()
|
||||
continue
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
return nil, fmt.Errorf("multiplex open stream failed")
|
||||
}
|
||||
|
||||
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku.MultiplexClient, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("nil adapter")
|
||||
}
|
||||
|
||||
if mode == "auto" {
|
||||
s.muxMu.Lock()
|
||||
backoffUntil := s.muxBackoffUntil
|
||||
lastErr := s.muxLastErr
|
||||
s.muxMu.Unlock()
|
||||
if time.Now().Before(backoffUntil) {
|
||||
return nil, fmt.Errorf("multiplex temporarily disabled: %v", lastErr)
|
||||
}
|
||||
}
|
||||
|
||||
s.muxMu.Lock()
|
||||
if s.muxClient != nil && !s.muxClient.IsClosed() {
|
||||
client := s.muxClient
|
||||
s.muxMu.Unlock()
|
||||
return client, nil
|
||||
}
|
||||
s.muxMu.Unlock()
|
||||
|
||||
s.muxMu.Lock()
|
||||
defer s.muxMu.Unlock()
|
||||
|
||||
if s.muxClient != nil && !s.muxClient.IsClosed() {
|
||||
return s.muxClient, nil
|
||||
}
|
||||
|
||||
baseCfg := s.baseConf
|
||||
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
|
||||
if err != nil {
|
||||
if mode == "auto" {
|
||||
s.muxLastErr = err
|
||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := sudoku.StartMultiplexClient(baseConn)
|
||||
if err != nil {
|
||||
_ = baseConn.Close()
|
||||
if mode == "auto" {
|
||||
s.muxLastErr = err
|
||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.muxClient = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *Sudoku) noteMuxFailure(mode string, err error) {
|
||||
if mode != "auto" {
|
||||
return
|
||||
}
|
||||
s.muxMu.Lock()
|
||||
s.muxLastErr = err
|
||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
||||
s.muxMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Sudoku) resetMuxClient() {
|
||||
s.muxMu.Lock()
|
||||
defer s.muxMu.Unlock()
|
||||
if s.muxClient != nil {
|
||||
_ = s.muxClient.Close()
|
||||
s.muxClient = nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//go:build android && cgo
|
||||
// +build android,cgo
|
||||
|
||||
// kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89
|
||||
|
||||
|
||||
@ -1066,7 +1066,7 @@ proxies: # socks5
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
||||
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
||||
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
||||
# http-mask-strategy: random # 可选:random(默认)、post、websocket;仅 legacy 下生效
|
||||
# http-mask-multiplex: off # 可选:off(默认)、auto、on;复用单条隧道并在其内多路复用多个目标连接
|
||||
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none)
|
||||
|
||||
# anytls
|
||||
|
||||
@ -2,7 +2,6 @@ package inbound_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/metacubex/mihomo/adapter/outbound"
|
||||
@ -167,10 +166,6 @@ func TestInboundSudoku_CustomTable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInboundSudoku_HTTPMaskMode(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("temporarily skipped on windows due to intermittent failures; tracked in PR")
|
||||
}
|
||||
|
||||
key := "test_key_http_mask_mode"
|
||||
|
||||
for _, mode := range []string{"legacy", "stream", "poll", "auto"} {
|
||||
|
||||
@ -78,6 +78,26 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou
|
||||
switch session.Type {
|
||||
case sudoku.SessionTypeUoT:
|
||||
l.handleUoTSession(session.Conn, tunnel, additions...)
|
||||
case sudoku.SessionTypeMultiplex:
|
||||
mux, err := sudoku.AcceptMultiplexServer(session.Conn)
|
||||
if err != nil {
|
||||
_ = session.Conn.Close()
|
||||
return
|
||||
}
|
||||
defer mux.Close()
|
||||
|
||||
for {
|
||||
stream, target, err := mux.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
targetAddr := socks5.ParseAddr(target)
|
||||
if targetAddr == nil {
|
||||
_ = stream.Close()
|
||||
continue
|
||||
}
|
||||
go l.handler.HandleSocket(targetAddr, stream, additions...)
|
||||
}
|
||||
default:
|
||||
targetAddr := socks5.ParseAddr(session.Target)
|
||||
if targetAddr == nil {
|
||||
|
||||
@ -57,6 +57,11 @@ type ProtocolConfig struct {
|
||||
|
||||
// HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side).
|
||||
HTTPMaskHost string
|
||||
|
||||
// HTTPMaskMultiplex controls whether the client reuses a single (HTTP-masked) tunnel connection and
|
||||
// opens multiple logical target streams inside it (reduces RTT for subsequent connections).
|
||||
// Values: "off" / "auto" / "on".
|
||||
HTTPMaskMultiplex string
|
||||
}
|
||||
|
||||
func (c *ProtocolConfig) Validate() error {
|
||||
@ -103,6 +108,12 @@ func (c *ProtocolConfig) Validate() error {
|
||||
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode)
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMultiplex)) {
|
||||
case "", "off", "auto", "on":
|
||||
default:
|
||||
return fmt.Errorf("invalid http-mask-multiplex: %s, must be one of: off, auto, on", c.HTTPMaskMultiplex)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -127,6 +138,7 @@ func DefaultConfig() *ProtocolConfig {
|
||||
EnablePureDownlink: true,
|
||||
HandshakeTimeoutSeconds: 5,
|
||||
HTTPMaskMode: "legacy",
|
||||
HTTPMaskMultiplex: "off",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -23,12 +24,17 @@ type SessionType int
|
||||
const (
|
||||
SessionTypeTCP SessionType = iota
|
||||
SessionTypeUoT
|
||||
SessionTypeMultiplex
|
||||
)
|
||||
|
||||
type ServerSession struct {
|
||||
Conn net.Conn
|
||||
Type SessionType
|
||||
Target string
|
||||
|
||||
// UserHash is a stable per-key identifier derived from the handshake payload.
|
||||
// It is primarily useful for debugging / user attribution when table rotation is enabled.
|
||||
UserHash string
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
@ -147,7 +153,14 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table,
|
||||
func buildHandshakePayload(key string) [16]byte {
|
||||
var payload [16]byte
|
||||
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix()))
|
||||
hash := sha256.Sum256([]byte(key))
|
||||
// Hash the decoded HEX bytes of the key, not the HEX string itself.
|
||||
// This ensures the user hash is computed on the actual key bytes.
|
||||
keyBytes, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
// Fallback: if key is not valid HEX (e.g., a UUID or plain string), hash the string bytes
|
||||
keyBytes = []byte(key)
|
||||
}
|
||||
hash := sha256.Sum256(keyBytes)
|
||||
copy(payload[8:], hash[:8])
|
||||
return payload
|
||||
}
|
||||
@ -216,7 +229,7 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt Clien
|
||||
|
||||
handshake := buildHandshakePayload(cfg.Key)
|
||||
if len(cfg.tableCandidates()) > 1 {
|
||||
handshake[15] = tableID
|
||||
handshake[8] = tableID
|
||||
}
|
||||
if _, err := cConn.Write(handshake[:]); err != nil {
|
||||
cConn.Close()
|
||||
@ -280,6 +293,7 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err
|
||||
return nil, fmt.Errorf("timestamp skew detected")
|
||||
}
|
||||
|
||||
userHash := userHashFromHandshake(handshakeBuf[:])
|
||||
sConn.StopRecording()
|
||||
|
||||
modeBuf := []byte{0}
|
||||
@ -298,6 +312,11 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err
|
||||
return nil, fmt.Errorf("read first byte failed: %w", err)
|
||||
}
|
||||
|
||||
if firstByte[0] == MultiplexMagicByte {
|
||||
rawConn.SetReadDeadline(time.Time{})
|
||||
return &ServerSession{Conn: cConn, Type: SessionTypeMultiplex, UserHash: userHash}, nil
|
||||
}
|
||||
|
||||
if firstByte[0] == UoTMagicByte {
|
||||
version := make([]byte, 1)
|
||||
if _, err := io.ReadFull(cConn, version); err != nil {
|
||||
@ -309,7 +328,7 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err
|
||||
return nil, fmt.Errorf("unsupported uot version: %d", version[0])
|
||||
}
|
||||
rawConn.SetReadDeadline(time.Time{})
|
||||
return &ServerSession{Conn: cConn, Type: SessionTypeUoT}, nil
|
||||
return &ServerSession{Conn: cConn, Type: SessionTypeUoT, UserHash: userHash}, nil
|
||||
}
|
||||
|
||||
prefixed := &preBufferedConn{Conn: cConn, buf: firstByte}
|
||||
@ -322,9 +341,10 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err
|
||||
rawConn.SetReadDeadline(time.Time{})
|
||||
log.Debugln("[Sudoku] incoming TCP session target: %s", target)
|
||||
return &ServerSession{
|
||||
Conn: prefixed,
|
||||
Type: SessionTypeTCP,
|
||||
Target: target,
|
||||
Conn: prefixed,
|
||||
Type: SessionTypeTCP,
|
||||
Target: target,
|
||||
UserHash: userHash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -364,3 +384,11 @@ func randomByte() byte {
|
||||
}
|
||||
return byte(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func userHashFromHandshake(handshakeBuf []byte) string {
|
||||
if len(handshakeBuf) < 16 {
|
||||
return ""
|
||||
}
|
||||
// handshake[8] may be a table ID when table rotation is enabled; use [9:16] as stable user hash bytes.
|
||||
return hex.EncodeToString(handshakeBuf[9:16])
|
||||
}
|
||||
|
||||
150
transport/sudoku/multiplex.go
Normal file
150
transport/sudoku/multiplex.go
Normal file
@ -0,0 +1,150 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/sudoku/multiplex"
|
||||
)
|
||||
|
||||
const (
|
||||
MultiplexMagicByte byte = multiplex.MagicByte
|
||||
MultiplexVersion byte = multiplex.Version
|
||||
)
|
||||
|
||||
// StartMultiplexClient writes the multiplex preface and upgrades an already-handshaked Sudoku tunnel into a multiplex session.
|
||||
func StartMultiplexClient(conn net.Conn) (*MultiplexClient, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
|
||||
if err := multiplex.WritePreface(conn); err != nil {
|
||||
return nil, fmt.Errorf("write multiplex preface failed: %w", err)
|
||||
}
|
||||
|
||||
sess, err := multiplex.NewClientSession(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start multiplex session failed: %w", err)
|
||||
}
|
||||
|
||||
return &MultiplexClient{sess: sess}, nil
|
||||
}
|
||||
|
||||
type MultiplexClient struct {
|
||||
sess *multiplex.Session
|
||||
}
|
||||
|
||||
// Dial opens a new logical stream, writes the target address, and returns the stream as net.Conn.
|
||||
func (c *MultiplexClient) Dial(ctx context.Context, targetAddress string) (net.Conn, error) {
|
||||
if c == nil || c.sess == nil || c.sess.IsClosed() {
|
||||
return nil, fmt.Errorf("multiplex session is closed")
|
||||
}
|
||||
if strings.TrimSpace(targetAddress) == "" {
|
||||
return nil, fmt.Errorf("target address cannot be empty")
|
||||
}
|
||||
|
||||
stream, err := c.sess.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = stream.SetWriteDeadline(deadline)
|
||||
defer stream.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
addrBuf, err := EncodeAddress(targetAddress)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("encode target address failed: %w", err)
|
||||
}
|
||||
if _, err := stream.Write(addrBuf); err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("send target address failed: %w", err)
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *MultiplexClient) Close() error {
|
||||
if c == nil || c.sess == nil {
|
||||
return nil
|
||||
}
|
||||
return c.sess.Close()
|
||||
}
|
||||
|
||||
func (c *MultiplexClient) IsClosed() bool {
|
||||
if c == nil || c.sess == nil {
|
||||
return true
|
||||
}
|
||||
return c.sess.IsClosed()
|
||||
}
|
||||
|
||||
// AcceptMultiplexServer upgrades a server-side, already-handshaked Sudoku connection into a multiplex session.
|
||||
//
|
||||
// The caller must have already consumed the multiplex magic byte (MultiplexMagicByte). This function consumes the
|
||||
// multiplex version byte and starts the session.
|
||||
func AcceptMultiplexServer(conn net.Conn) (*MultiplexServer, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
v, err := multiplex.ReadVersion(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := multiplex.ValidateVersion(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sess, err := multiplex.NewServerSession(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MultiplexServer{sess: sess}, nil
|
||||
}
|
||||
|
||||
// MultiplexServer wraps a multiplex session created from a handshaked Sudoku tunnel connection.
|
||||
type MultiplexServer struct {
|
||||
sess *multiplex.Session
|
||||
}
|
||||
|
||||
func (s *MultiplexServer) AcceptStream() (net.Conn, error) {
|
||||
if s == nil || s.sess == nil {
|
||||
return nil, fmt.Errorf("nil session")
|
||||
}
|
||||
return s.sess.AcceptStream()
|
||||
}
|
||||
|
||||
// AcceptTCP accepts a multiplex stream and reads the target address preface, returning the stream positioned at
|
||||
// application data.
|
||||
func (s *MultiplexServer) AcceptTCP() (net.Conn, string, error) {
|
||||
stream, err := s.AcceptStream()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
target, err := DecodeAddress(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return stream, target, nil
|
||||
}
|
||||
|
||||
func (s *MultiplexServer) Close() error {
|
||||
if s == nil || s.sess == nil {
|
||||
return nil
|
||||
}
|
||||
return s.sess.Close()
|
||||
}
|
||||
|
||||
func (s *MultiplexServer) IsClosed() bool {
|
||||
if s == nil || s.sess == nil {
|
||||
return true
|
||||
}
|
||||
return s.sess.IsClosed()
|
||||
}
|
||||
|
||||
101
transport/sudoku/multiplex/session.go
Normal file
101
transport/sudoku/multiplex/session.go
Normal file
@ -0,0 +1,101 @@
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/smux"
|
||||
)
|
||||
|
||||
const (
|
||||
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
|
||||
// It is sent after the Sudoku handshake + downlink mode byte.
|
||||
MagicByte byte = 0xEF
|
||||
Version = 0x01
|
||||
)
|
||||
|
||||
func WritePreface(w io.Writer) error {
|
||||
_, err := w.Write([]byte{MagicByte, Version})
|
||||
return err
|
||||
}
|
||||
|
||||
func ReadVersion(r io.Reader) (byte, error) {
|
||||
var b [1]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
func ValidateVersion(v byte) error {
|
||||
if v != Version {
|
||||
return fmt.Errorf("unsupported multiplex version: %d", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultSmuxConfig() *smux.Config {
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.KeepAliveInterval = 15 * time.Second
|
||||
cfg.KeepAliveTimeout = 45 * time.Second
|
||||
return cfg
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
sess *smux.Session
|
||||
}
|
||||
|
||||
func NewClientSession(conn net.Conn) (*Session, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
s, err := smux.Client(conn, defaultSmuxConfig())
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &Session{sess: s}, nil
|
||||
}
|
||||
|
||||
func NewServerSession(conn net.Conn) (*Session, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
s, err := smux.Server(conn, defaultSmuxConfig())
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &Session{sess: s}, nil
|
||||
}
|
||||
|
||||
func (s *Session) OpenStream() (net.Conn, error) {
|
||||
if s == nil || s.sess == nil {
|
||||
return nil, fmt.Errorf("nil session")
|
||||
}
|
||||
return s.sess.OpenStream()
|
||||
}
|
||||
|
||||
func (s *Session) AcceptStream() (net.Conn, error) {
|
||||
if s == nil || s.sess == nil {
|
||||
return nil, fmt.Errorf("nil session")
|
||||
}
|
||||
return s.sess.AcceptStream()
|
||||
}
|
||||
|
||||
func (s *Session) Close() error {
|
||||
if s == nil || s.sess == nil {
|
||||
return nil
|
||||
}
|
||||
return s.sess.Close()
|
||||
}
|
||||
|
||||
func (s *Session) IsClosed() bool {
|
||||
if s == nil || s.sess == nil {
|
||||
return true
|
||||
}
|
||||
return s.sess.IsClosed()
|
||||
}
|
||||
|
||||
261
transport/sudoku/multiplex_test.go
Normal file
261
transport/sudoku/multiplex_test.go
Normal file
@ -0,0 +1,261 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
|
||||
)
|
||||
|
||||
func TestUserHash_StableAcrossTableRotation(t *testing.T) {
|
||||
tables := []*sudokuobfs.Table{
|
||||
sudokuobfs.NewTable("seed-a", "prefer_ascii"),
|
||||
sudokuobfs.NewTable("seed-b", "prefer_ascii"),
|
||||
}
|
||||
key := "userhash-stability-key"
|
||||
target := "example.com:80"
|
||||
|
||||
serverCfg := DefaultConfig()
|
||||
serverCfg.Key = key
|
||||
serverCfg.AEADMethod = "chacha20-poly1305"
|
||||
serverCfg.Tables = tables
|
||||
serverCfg.PaddingMin = 0
|
||||
serverCfg.PaddingMax = 0
|
||||
serverCfg.EnablePureDownlink = true
|
||||
serverCfg.HandshakeTimeoutSeconds = 5
|
||||
serverCfg.DisableHTTPMask = true
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = ln.Close() })
|
||||
|
||||
const attempts = 32
|
||||
hashCh := make(chan string, attempts)
|
||||
errCh := make(chan error, attempts)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
session, err := ServerHandshake(conn, serverCfg)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
defer session.Conn.Close()
|
||||
hashCh <- session.UserHash
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
clientCfg := DefaultConfig()
|
||||
*clientCfg = *serverCfg
|
||||
clientCfg.ServerAddress = ln.Addr().String()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < attempts; i++ {
|
||||
raw, err := (&net.Dialer{}).DialContext(ctx, "tcp", clientCfg.ServerAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("dial %d: %v", i, err)
|
||||
}
|
||||
cConn, err := ClientHandshake(raw, clientCfg)
|
||||
if err != nil {
|
||||
_ = raw.Close()
|
||||
t.Fatalf("handshake %d: %v", i, err)
|
||||
}
|
||||
|
||||
addrBuf, err := EncodeAddress(target)
|
||||
if err != nil {
|
||||
_ = cConn.Close()
|
||||
t.Fatalf("encode addr %d: %v", i, err)
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
_ = cConn.Close()
|
||||
t.Fatalf("write addr %d: %v", i, err)
|
||||
}
|
||||
_ = cConn.Close()
|
||||
}
|
||||
|
||||
unique := map[string]struct{}{}
|
||||
deadline := time.After(10 * time.Second)
|
||||
for i := 0; i < attempts; i++ {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("server handshake error: %v", err)
|
||||
case h := <-hashCh:
|
||||
if h == "" {
|
||||
t.Fatalf("empty user hash")
|
||||
}
|
||||
if len(h) != 14 {
|
||||
t.Fatalf("unexpected user hash length: %d", len(h))
|
||||
}
|
||||
unique[h] = struct{}{}
|
||||
case <-deadline:
|
||||
t.Fatalf("timeout waiting for server handshakes")
|
||||
}
|
||||
}
|
||||
if len(unique) != 1 {
|
||||
t.Fatalf("user hash should be stable across table rotation; got %d distinct values", len(unique))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplex_TCP_Echo(t *testing.T) {
|
||||
table := sudokuobfs.NewTable("seed", "prefer_ascii")
|
||||
key := "test-key-mux"
|
||||
target := "example.com:80"
|
||||
|
||||
serverCfg := DefaultConfig()
|
||||
serverCfg.Key = key
|
||||
serverCfg.AEADMethod = "chacha20-poly1305"
|
||||
serverCfg.Table = table
|
||||
serverCfg.PaddingMin = 0
|
||||
serverCfg.PaddingMax = 0
|
||||
serverCfg.EnablePureDownlink = true
|
||||
serverCfg.HandshakeTimeoutSeconds = 5
|
||||
serverCfg.DisableHTTPMask = true
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = ln.Close() })
|
||||
|
||||
var handshakes int64
|
||||
var streams int64
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
raw, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer raw.Close()
|
||||
|
||||
session, err := ServerHandshake(raw, serverCfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&handshakes, 1)
|
||||
|
||||
if session.Type != SessionTypeMultiplex {
|
||||
_ = session.Conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
mux, err := AcceptMultiplexServer(session.Conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer mux.Close()
|
||||
|
||||
for {
|
||||
stream, dst, err := mux.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if dst != target {
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&streams, 1)
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}(stream)
|
||||
}
|
||||
}()
|
||||
|
||||
clientCfg := DefaultConfig()
|
||||
*clientCfg = *serverCfg
|
||||
clientCfg.ServerAddress = ln.Addr().String()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
raw, err := (&net.Dialer{}).DialContext(ctx, "tcp", clientCfg.ServerAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = raw.Close() })
|
||||
|
||||
cConn, err := ClientHandshake(raw, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("client handshake: %v", err)
|
||||
}
|
||||
|
||||
mux, err := StartMultiplexClient(cConn)
|
||||
if err != nil {
|
||||
_ = cConn.Close()
|
||||
t.Fatalf("start mux: %v", err)
|
||||
}
|
||||
defer mux.Close()
|
||||
|
||||
for i := 0; i < 6; i++ {
|
||||
s, err := mux.Dial(ctx, target)
|
||||
if err != nil {
|
||||
t.Fatalf("dial stream %d: %v", i, err)
|
||||
}
|
||||
|
||||
msg := []byte("hello-mux")
|
||||
if _, err := s.Write(msg); err != nil {
|
||||
_ = s.Close()
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
buf := make([]byte, len(msg))
|
||||
if _, err := io.ReadFull(s, buf); err != nil {
|
||||
_ = s.Close()
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
_ = s.Close()
|
||||
if !bytes.Equal(buf, msg) {
|
||||
t.Fatalf("echo mismatch: got %q", buf)
|
||||
}
|
||||
}
|
||||
|
||||
_ = mux.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("server did not exit")
|
||||
}
|
||||
|
||||
if got := atomic.LoadInt64(&handshakes); got != 1 {
|
||||
t.Fatalf("unexpected handshake count: %d", got)
|
||||
}
|
||||
if got := atomic.LoadInt64(&streams); got < 6 {
|
||||
t.Fatalf("unexpected stream count: %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplex_Boundary_InvalidVersion(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
t.Cleanup(func() { _ = server.Close() })
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := AcceptMultiplexServer(server)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
// AcceptMultiplexServer expects the magic byte to have been consumed already; write a bad version byte.
|
||||
_, _ = client.Write([]byte{0xFF})
|
||||
if err := <-errCh; err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns in
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
ForceAttemptHTTP2: scheme == "https",
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
@ -261,10 +261,10 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt
|
||||
|
||||
reqBodyR, reqBodyW := io.Pipe()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), reqBodyR)
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(connCtx, http.MethodPost, u.String(), reqBodyR)
|
||||
if err != nil {
|
||||
cancel()
|
||||
connCancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, err
|
||||
}
|
||||
@ -273,27 +273,43 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt
|
||||
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
|
||||
cancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, fmt.Errorf("stream bad status: %s (%s)", resp.Status, strings.TrimSpace(string(body)))
|
||||
type doResult struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
doCh := make(chan doResult, 1)
|
||||
go func() {
|
||||
resp, doErr := client.Do(req)
|
||||
doCh <- doResult{resp: resp, err: doErr}
|
||||
}()
|
||||
|
||||
return &httpStreamConn{
|
||||
reader: resp.Body,
|
||||
writer: reqBodyW,
|
||||
cancel: cancel,
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
}, nil
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
connCancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, ctx.Err()
|
||||
case r := <-doCh:
|
||||
if r.err != nil {
|
||||
connCancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, r.err
|
||||
}
|
||||
if r.resp.StatusCode != http.StatusOK {
|
||||
defer r.resp.Body.Close()
|
||||
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
|
||||
connCancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
return &httpStreamConn{
|
||||
reader: r.resp.Body,
|
||||
writer: reqBodyW,
|
||||
cancel: connCancel,
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type streamSplitConn struct {
|
||||
@ -456,8 +472,15 @@ func (c *streamSplitConn) pullLoop() {
|
||||
requestTimeout = 30 * time.Second
|
||||
readChunkSize = 32 * 1024
|
||||
idleBackoff = 25 * time.Millisecond
|
||||
maxDialRetry = 12
|
||||
minBackoff = 10 * time.Millisecond
|
||||
maxBackoff = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
dialRetry int
|
||||
backoff = minBackoff
|
||||
)
|
||||
buf := make([]byte, readChunkSize)
|
||||
for {
|
||||
select {
|
||||
@ -479,9 +502,24 @@ func (c *streamSplitConn) pullLoop() {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if isDialError(err) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
dialRetry = 0
|
||||
backoff = minBackoff
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
@ -533,6 +571,9 @@ func (c *streamSplitConn) pushLoop() {
|
||||
maxBatchBytes = 256 * 1024
|
||||
flushInterval = 5 * time.Millisecond
|
||||
requestTimeout = 20 * time.Second
|
||||
maxDialRetry = 12
|
||||
minBackoff = 10 * time.Millisecond
|
||||
maxBackoff = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
@ -541,16 +582,16 @@ func (c *streamSplitConn) pushLoop() {
|
||||
)
|
||||
defer timer.Stop()
|
||||
|
||||
flush := func() bool {
|
||||
flush := func() error {
|
||||
if buf.Len() == 0 {
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return false
|
||||
return err
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
||||
@ -559,17 +600,41 @@ func (c *streamSplitConn) pushLoop() {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return false
|
||||
return err
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
flushWithRetry := func() error {
|
||||
dialRetry := 0
|
||||
backoff := minBackoff
|
||||
for {
|
||||
if err := flush(); err == nil {
|
||||
return nil
|
||||
} else if isDialError(err) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resetTimer := func() {
|
||||
@ -588,14 +653,14 @@ func (c *streamSplitConn) pushLoop() {
|
||||
select {
|
||||
case b, ok := <-c.writeCh:
|
||||
if !ok {
|
||||
_ = flush()
|
||||
_ = flushWithRetry()
|
||||
return
|
||||
}
|
||||
if len(b) == 0 {
|
||||
continue
|
||||
}
|
||||
if buf.Len()+len(b) > maxBatchBytes {
|
||||
if !flush() {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
@ -603,20 +668,20 @@ func (c *streamSplitConn) pushLoop() {
|
||||
}
|
||||
_, _ = buf.Write(b)
|
||||
if buf.Len() >= maxBatchBytes {
|
||||
if !flush() {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
}
|
||||
case <-timer.C:
|
||||
if !flush() {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
case <-c.closed:
|
||||
_ = flush()
|
||||
_ = flushWithRetry()
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -636,16 +701,72 @@ type pollConn struct {
|
||||
|
||||
mu sync.Mutex
|
||||
readBuf []byte
|
||||
closeErr error
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func isDialError(err error) bool {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
return isDialError(urlErr.Err)
|
||||
}
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
if opErr.Op == "dial" || opErr.Op == "connect" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *pollConn) closeWithError(err error) error {
|
||||
c.mu.Lock()
|
||||
select {
|
||||
case <-c.closed:
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
default:
|
||||
if err == nil {
|
||||
err = io.ErrClosedPipe
|
||||
}
|
||||
if c.closeErr == nil {
|
||||
c.closeErr = err
|
||||
}
|
||||
close(c.closed)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
// Best-effort session close signal (avoid leaking server-side sessions).
|
||||
req, reqErr := http.NewRequest(http.MethodPost, c.closeURL, nil)
|
||||
if reqErr == nil {
|
||||
req.Host = c.headerHost
|
||||
req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll))
|
||||
req.Header.Set("X-Sudoku-Version", "1")
|
||||
if resp, doErr := c.client.Do(req); doErr == nil && resp != nil {
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *pollConn) closedErr() error {
|
||||
c.mu.Lock()
|
||||
err := c.closeErr
|
||||
c.mu.Unlock()
|
||||
if err == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pollConn) Read(b []byte) (n int, err error) {
|
||||
if len(c.readBuf) == 0 {
|
||||
select {
|
||||
case c.readBuf = <-c.rxc:
|
||||
case <-c.closed:
|
||||
return 0, io.ErrClosedPipe
|
||||
return 0, c.closedErr()
|
||||
}
|
||||
}
|
||||
n = copy(b, c.readBuf)
|
||||
@ -661,7 +782,7 @@ func (c *pollConn) Write(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
c.mu.Unlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
return 0, c.closedErr()
|
||||
default:
|
||||
}
|
||||
c.mu.Unlock()
|
||||
@ -672,33 +793,12 @@ func (c *pollConn) Write(b []byte) (n int, err error) {
|
||||
case c.writeCh <- payload:
|
||||
return len(b), nil
|
||||
case <-c.closed:
|
||||
return 0, io.ErrClosedPipe
|
||||
return 0, c.closedErr()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pollConn) Close() error {
|
||||
c.mu.Lock()
|
||||
select {
|
||||
case <-c.closed:
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
close(c.writeCh)
|
||||
|
||||
// Best-effort session close signal (avoid leaking server-side sessions).
|
||||
req, err := http.NewRequest(http.MethodPost, c.closeURL, nil)
|
||||
if err == nil {
|
||||
req.Host = c.headerHost
|
||||
req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll))
|
||||
req.Header.Set("X-Sudoku-Version", "1")
|
||||
_, _ = c.client.Do(req)
|
||||
}
|
||||
|
||||
return nil
|
||||
return c.closeWithError(io.ErrClosedPipe)
|
||||
}
|
||||
|
||||
func (c *pollConn) LocalAddr() net.Addr { return c.localAddr }
|
||||
@ -766,6 +866,15 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions)
|
||||
}
|
||||
|
||||
func (c *pollConn) pullLoop() {
|
||||
const (
|
||||
maxDialRetry = 12
|
||||
minBackoff = 10 * time.Millisecond
|
||||
maxBackoff = 250 * time.Millisecond
|
||||
)
|
||||
var (
|
||||
dialRetry int
|
||||
backoff = minBackoff
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-c.closed:
|
||||
@ -783,13 +892,28 @@ func (c *pollConn) pullLoop() {
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
if isDialError(err) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = c.closeWithError(fmt.Errorf("poll pull request failed: %w", err))
|
||||
return
|
||||
}
|
||||
dialRetry = 0
|
||||
backoff = minBackoff
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("poll pull bad status: %s", resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
@ -802,7 +926,7 @@ func (c *pollConn) pullLoop() {
|
||||
payload, err := base64.StdEncoding.DecodeString(line)
|
||||
if err != nil {
|
||||
_ = resp.Body.Close()
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("poll pull decode failed: %w", err))
|
||||
return
|
||||
}
|
||||
select {
|
||||
@ -814,7 +938,7 @@ func (c *pollConn) pullLoop() {
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("poll pull scan failed: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -825,6 +949,9 @@ func (c *pollConn) pushLoop() {
|
||||
maxBatchBytes = 64 * 1024
|
||||
flushInterval = 5 * time.Millisecond
|
||||
maxLineRawBytes = 16 * 1024
|
||||
maxDialRetry = 12
|
||||
minBackoff = 10 * time.Millisecond
|
||||
maxBackoff = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
@ -834,14 +961,14 @@ func (c *pollConn) pushLoop() {
|
||||
)
|
||||
defer timer.Stop()
|
||||
|
||||
flush := func() bool {
|
||||
flush := func() error {
|
||||
if buf.Len() == 0 {
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
||||
@ -849,17 +976,41 @@ func (c *pollConn) pushLoop() {
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
pendingRaw = 0
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
flushWithRetry := func() error {
|
||||
dialRetry := 0
|
||||
backoff := minBackoff
|
||||
for {
|
||||
if err := flush(); err == nil {
|
||||
return nil
|
||||
} else if isDialError(err) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return c.closedErr()
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resetTimer := func() {
|
||||
@ -878,7 +1029,7 @@ func (c *pollConn) pushLoop() {
|
||||
select {
|
||||
case b, ok := <-c.writeCh:
|
||||
if !ok {
|
||||
_ = flush()
|
||||
_ = flushWithRetry()
|
||||
return
|
||||
}
|
||||
if len(b) == 0 {
|
||||
@ -895,8 +1046,8 @@ func (c *pollConn) pushLoop() {
|
||||
|
||||
encLen := base64.StdEncoding.EncodedLen(len(chunk))
|
||||
if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 {
|
||||
if !flush() {
|
||||
_ = c.Close()
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -909,20 +1060,20 @@ func (c *pollConn) pushLoop() {
|
||||
}
|
||||
|
||||
if pendingRaw >= maxBatchBytes {
|
||||
if !flush() {
|
||||
_ = c.Close()
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
}
|
||||
case <-timer.C:
|
||||
if !flush() {
|
||||
_ = c.Close()
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
case <-c.closed:
|
||||
_ = flush()
|
||||
_ = flushWithRetry()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user