From 7daf37bc15b0a52117ecf92090a4815a796a0876 Mon Sep 17 00:00:00 2001 From: saba-futai <120904569+saba-futai@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:08:38 +0800 Subject: [PATCH] feat: support `http-mask-mode`, `http-mask-tls` and `http-mask-host` for sudoku (#2456) --- adapter/outbound/sudoku.go | 53 +- docs/config.yaml | 9 +- go.mod | 4 +- go.sum | 2 - listener/config/sudoku.go | 2 + listener/inbound/sudoku.go | 4 + listener/inbound/sudoku_test.go | 25 + listener/sudoku/server.go | 30 +- transport/sudoku/config.go | 144 ++ transport/sudoku/crypto/aead.go | 130 ++ transport/sudoku/crypto/ed25519.go | 116 ++ transport/sudoku/features_test.go | 2 +- transport/sudoku/handshake.go | 27 +- transport/sudoku/handshake_test.go | 11 +- transport/sudoku/httpmask_strategy.go | 2 +- transport/sudoku/httpmask_tunnel.go | 88 ++ transport/sudoku/httpmask_tunnel_test.go | 445 ++++++ transport/sudoku/obfs/httpmask/masker.go | 246 +++ transport/sudoku/obfs/httpmask/tunnel.go | 1684 +++++++++++++++++++++ transport/sudoku/obfs/sudoku/conn.go | 212 +++ transport/sudoku/obfs/sudoku/grid.go | 46 + transport/sudoku/obfs/sudoku/layout.go | 204 +++ transport/sudoku/obfs/sudoku/packed.go | 332 ++++ transport/sudoku/obfs/sudoku/table.go | 153 ++ transport/sudoku/obfs/sudoku/table_set.go | 38 + transport/sudoku/obfs_writer.go | 2 +- transport/sudoku/table_probe.go | 26 +- transport/sudoku/tables.go | 2 +- 28 files changed, 3977 insertions(+), 62 deletions(-) create mode 100644 transport/sudoku/config.go create mode 100644 transport/sudoku/crypto/aead.go create mode 100644 transport/sudoku/crypto/ed25519.go create mode 100644 transport/sudoku/httpmask_tunnel.go create mode 100644 transport/sudoku/httpmask_tunnel_test.go create mode 100644 transport/sudoku/obfs/httpmask/masker.go create mode 100644 transport/sudoku/obfs/httpmask/tunnel.go create mode 100644 transport/sudoku/obfs/sudoku/conn.go create mode 100644 transport/sudoku/obfs/sudoku/grid.go create mode 100644 transport/sudoku/obfs/sudoku/layout.go create mode 100644 transport/sudoku/obfs/sudoku/packed.go create mode 100644 transport/sudoku/obfs/sudoku/table.go create mode 100644 transport/sudoku/obfs/sudoku/table_set.go diff --git a/adapter/outbound/sudoku.go b/adapter/outbound/sudoku.go index bd393ec6..b1063f04 100644 --- a/adapter/outbound/sudoku.go +++ b/adapter/outbound/sudoku.go @@ -30,6 +30,9 @@ type SudokuOption struct { TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"` HTTPMask bool `proxy:"http-mask,omitempty"` + HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" + HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto + HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) HTTPMaskStrategy string `proxy:"http-mask-strategy,omitempty"` // "random" (default), "post", "websocket" CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty @@ -42,7 +45,16 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con return nil, err } - c, err := s.dialer.DialContext(ctx, "tcp", s.addr) + var c net.Conn + if !cfg.DisableHTTPMask { + switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) + } + } + if c == nil && err == nil { + c, err = s.dialer.DialContext(ctx, "tcp", s.addr) + } if err != nil { return nil, fmt.Errorf("%s connect error: %w", s.addr, err) } @@ -56,9 +68,14 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con defer done(&err) } - c, err = sudoku.ClientHandshakeWithOptions(c, cfg, sudoku.ClientHandshakeOptions{ - HTTPMaskStrategy: s.option.HTTPMaskStrategy, - }) + handshakeCfg := *cfg + if !handshakeCfg.DisableHTTPMask { + switch strings.ToLower(strings.TrimSpace(handshakeCfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + handshakeCfg.DisableHTTPMask = true + } + } + c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{HTTPMaskStrategy: s.option.HTTPMaskStrategy}) if err != nil { return nil, err } @@ -87,7 +104,16 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return nil, err } - c, err := s.dialer.DialContext(ctx, "tcp", s.addr) + var c net.Conn + if !cfg.DisableHTTPMask { + switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) + } + } + if c == nil && err == nil { + c, err = s.dialer.DialContext(ctx, "tcp", s.addr) + } if err != nil { return nil, fmt.Errorf("%s connect error: %w", s.addr, err) } @@ -101,9 +127,14 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) defer done(&err) } - c, err = sudoku.ClientHandshakeWithOptions(c, cfg, sudoku.ClientHandshakeOptions{ - HTTPMaskStrategy: s.option.HTTPMaskStrategy, - }) + handshakeCfg := *cfg + if !handshakeCfg.DisableHTTPMask { + switch strings.ToLower(strings.TrimSpace(handshakeCfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + handshakeCfg.DisableHTTPMask = true + } + } + c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{HTTPMaskStrategy: s.option.HTTPMaskStrategy}) if err != nil { return nil, err } @@ -190,6 +221,12 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) { EnablePureDownlink: enablePureDownlink, HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds, DisableHTTPMask: !option.HTTPMask, + HTTPMaskMode: defaultConf.HTTPMaskMode, + HTTPMaskTLSEnabled: option.HTTPMaskTLS, + HTTPMaskHost: option.HTTPMaskHost, + } + if option.HTTPMaskMode != "" { + baseConf.HTTPMaskMode = option.HTTPMaskMode } tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables) if err != nil { diff --git a/docs/config.yaml b/docs/config.yaml index 04d15bd2..a350d1af 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -1041,7 +1041,7 @@ proxies: # socks5 # sudoku - name: sudoku type: sudoku - server: serverip # 1.2.3.4 + server: server_ip/domain # 1.2.3.4 or domain port: 443 key: "" # 如果你使用sudoku生成的ED25519密钥对,请填写密钥对中的私钥,否则填入和服务端相同的uuid aead-method: chacha20-poly1305 # 可选值:chacha20-poly1305、aes-128-gcm、none 我们保证在none的情况下sudoku混淆层仍然确保安全 @@ -1051,7 +1051,10 @@ proxies: # socks5 # custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy` # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table http-mask: true # 是否启用http掩码 - # http-mask-strategy: random # 可选:random(默认)、post、websocket;仅在 http-mask=true 时生效 + # http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代 + # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断) + # http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效 + # http-mask-strategy: random # 可选:random(默认)、post、websocket;仅 legacy 下生效 enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none) # anytls @@ -1596,6 +1599,8 @@ listeners: # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table handshake-timeout: 5 # optional enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与客户端保持相同(如果此处为false,则要求aead不可为none) + disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false) + # http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代 diff --git a/go.mod b/go.mod index bb558122..77d59a9a 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,6 @@ require ( github.com/mroth/weightedrand/v2 v2.1.0 github.com/openacid/low v0.1.21 github.com/oschwald/maxminddb-golang v1.12.0 // lastest version compatible with golang1.20 - github.com/saba-futai/sudoku v0.0.2-d github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/samber/lo v1.52.0 github.com/sirupsen/logrus v1.9.3 @@ -65,8 +64,9 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require filippo.io/edwards25519 v1.1.0 + require ( - filippo.io/edwards25519 v1.1.0 // indirect github.com/RyuaNerin/go-krypto v1.3.0 // indirect github.com/Yawning/aez v0.0.0-20211027044916-e49e68abd344 // indirect github.com/ajg/form v1.5.1 // indirect diff --git a/go.sum b/go.sum index 4756c83b..ac72c1dc 100644 --- a/go.sum +++ b/go.sum @@ -170,8 +170,6 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= -github.com/saba-futai/sudoku v0.0.2-d h1:HW/gIyNUFcDchpMN+ZhluM86U/HGkWkkRV+9Km6WZM8= -github.com/saba-futai/sudoku v0.0.2-d/go.mod h1:Rvggsoprp7HQM7bMIZUd1M27bPj8THRsZdY1dGbIAvo= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= diff --git a/listener/config/sudoku.go b/listener/config/sudoku.go index 848db875..118e252c 100644 --- a/listener/config/sudoku.go +++ b/listener/config/sudoku.go @@ -20,6 +20,8 @@ type SudokuServer struct { EnablePureDownlink *bool `json:"enable-pure-downlink,omitempty"` CustomTable string `json:"custom-table,omitempty"` CustomTables []string `json:"custom-tables,omitempty"` + DisableHTTPMask bool `json:"disable-http-mask,omitempty"` + HTTPMaskMode string `json:"http-mask-mode,omitempty"` // mihomo private extension (not the part of standard Sudoku protocol) MuxOption sing.MuxOption `json:"mux-option,omitempty"` diff --git a/listener/inbound/sudoku.go b/listener/inbound/sudoku.go index 43397602..fc37cb79 100644 --- a/listener/inbound/sudoku.go +++ b/listener/inbound/sudoku.go @@ -22,6 +22,8 @@ type SudokuOption struct { EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"` CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv CustomTables []string `inbound:"custom-tables,omitempty"` + DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"` + HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" // mihomo private extension (not the part of standard Sudoku protocol) MuxOption MuxOption `inbound:"mux-option,omitempty"` @@ -59,6 +61,8 @@ func NewSudoku(options *SudokuOption) (*Sudoku, error) { EnablePureDownlink: options.EnablePureDownlink, CustomTable: options.CustomTable, CustomTables: options.CustomTables, + DisableHTTPMask: options.DisableHTTPMask, + HTTPMaskMode: options.HTTPMaskMode, } serverConf.MuxOption = options.MuxOption.Build() diff --git a/listener/inbound/sudoku_test.go b/listener/inbound/sudoku_test.go index 6ba9e63b..5596bf91 100644 --- a/listener/inbound/sudoku_test.go +++ b/listener/inbound/sudoku_test.go @@ -2,6 +2,7 @@ package inbound_test import ( "net/netip" + "runtime" "testing" "github.com/metacubex/mihomo/adapter/outbound" @@ -164,3 +165,27 @@ func TestInboundSudoku_CustomTable(t *testing.T) { testInboundSudoku(t, inboundOptions, outboundOptions) }) } + +func TestInboundSudoku_HTTPMaskMode(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("temporarily skipped on windows due to intermittent failures; tracked in PR") + } + + key := "test_key_http_mask_mode" + + for _, mode := range []string{"legacy", "stream", "poll", "auto"} { + mode := mode + t.Run(mode, func(t *testing.T) { + inboundOptions := inbound.SudokuOption{ + Key: key, + HTTPMaskMode: mode, + } + outboundOptions := outbound.SudokuOption{ + Key: key, + HTTPMask: true, + HTTPMaskMode: mode, + } + testInboundSudoku(t, inboundOptions, outboundOptions) + }) + } +} diff --git a/listener/sudoku/server.go b/listener/sudoku/server.go index e90e231c..7652783f 100644 --- a/listener/sudoku/server.go +++ b/listener/sudoku/server.go @@ -20,6 +20,7 @@ type Listener struct { addr string closed bool protoConf sudoku.ProtocolConfig + tunnelSrv *sudoku.HTTPMaskTunnelServer handler *sing.ListenerHandler } @@ -46,9 +47,31 @@ func (l *Listener) Close() error { } func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) { - session, err := sudoku.ServerHandshake(conn, &l.protoConf) + handshakeConn := conn + handshakeCfg := &l.protoConf + if l.tunnelSrv != nil { + c, cfg, done, err := l.tunnelSrv.WrapConn(conn) + if err != nil { + _ = conn.Close() + return + } + if done { + return + } + if c != nil { + handshakeConn = c + } + if cfg != nil { + handshakeCfg = cfg + } + } + + session, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg) if err != nil { - _ = conn.Close() + _ = handshakeConn.Close() + if handshakeConn != conn { + _ = conn.Close() + } return } @@ -184,6 +207,8 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) PaddingMax: paddingMax, EnablePureDownlink: enablePureDownlink, HandshakeTimeoutSeconds: handshakeTimeout, + DisableHTTPMask: config.DisableHTTPMask, + HTTPMaskMode: config.HTTPMaskMode, } if len(tables) == 1 { protoConf.Table = tables[0] @@ -200,6 +225,7 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) protoConf: protoConf, handler: h, } + sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf) go func() { for { diff --git a/transport/sudoku/config.go b/transport/sudoku/config.go new file mode 100644 index 00000000..4fee6b67 --- /dev/null +++ b/transport/sudoku/config.go @@ -0,0 +1,144 @@ +package sudoku + +import ( + "fmt" + "strings" + + "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" +) + +// ProtocolConfig defines the configuration for the Sudoku protocol stack. +// It is intentionally kept close to the upstream Sudoku project to ensure wire compatibility. +type ProtocolConfig struct { + // Client-only: "host:port". + ServerAddress string + + // Pre-shared key (or ED25519 key material) used to derive crypto and tables. + Key string + + // "aes-128-gcm", "chacha20-poly1305", or "none". + AEADMethod string + + // Table is the single obfuscation table to use when table rotation is disabled. + Table *sudoku.Table + + // Tables is an optional candidate set for table rotation. + // If provided (len>0), the client will pick one table per connection and the server will + // probe the handshake to detect which one was used, keeping the handshake format unchanged. + // When Tables is set, Table may be nil. + Tables []*sudoku.Table + + // Padding insertion ratio (0-100). Must satisfy PaddingMax >= PaddingMin. + PaddingMin int + PaddingMax int + + // EnablePureDownlink toggles the bandwidth-optimized downlink mode. + EnablePureDownlink bool + + // Client-only: final target "host:port". + TargetAddress string + + // Server-side handshake timeout (seconds). + HandshakeTimeoutSeconds int + + // DisableHTTPMask disables all HTTP camouflage layers. + DisableHTTPMask bool + + // HTTPMaskMode controls how the HTTP layer behaves: + // - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible) + // - "stream": real HTTP tunnel (stream-one or split), CDN-compatible + // - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through + // - "auto": try stream then fall back to poll + HTTPMaskMode string + + // HTTPMaskTLSEnabled enables HTTPS for HTTP tunnel modes (client-side). + // If false, the tunnel uses HTTP (no port-based inference). + HTTPMaskTLSEnabled bool + + // HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side). + HTTPMaskHost string +} + +func (c *ProtocolConfig) Validate() error { + if c.Table == nil && len(c.Tables) == 0 { + return fmt.Errorf("table cannot be nil (or provide tables)") + } + for i, t := range c.Tables { + if t == nil { + return fmt.Errorf("tables[%d] cannot be nil", i) + } + } + + if c.Key == "" { + return fmt.Errorf("key cannot be empty") + } + + switch c.AEADMethod { + case "aes-128-gcm", "chacha20-poly1305", "none": + default: + return fmt.Errorf("invalid aead-method: %s, must be one of: aes-128-gcm, chacha20-poly1305, none", c.AEADMethod) + } + + if c.PaddingMin < 0 || c.PaddingMin > 100 { + return fmt.Errorf("padding-min must be between 0 and 100, got %d", c.PaddingMin) + } + if c.PaddingMax < 0 || c.PaddingMax > 100 { + return fmt.Errorf("padding-max must be between 0 and 100, got %d", c.PaddingMax) + } + if c.PaddingMax < c.PaddingMin { + return fmt.Errorf("padding-max (%d) must be >= padding-min (%d)", c.PaddingMax, c.PaddingMin) + } + + if !c.EnablePureDownlink && c.AEADMethod == "none" { + return fmt.Errorf("bandwidth optimized downlink requires AEAD") + } + + if c.HandshakeTimeoutSeconds < 0 { + return fmt.Errorf("handshake-timeout must be >= 0, got %d", c.HandshakeTimeoutSeconds) + } + + switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMode)) { + case "", "legacy", "stream", "poll", "auto": + default: + return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode) + } + + return nil +} + +func (c *ProtocolConfig) ValidateClient() error { + if err := c.Validate(); err != nil { + return err + } + if c.ServerAddress == "" { + return fmt.Errorf("server address cannot be empty") + } + if c.TargetAddress == "" { + return fmt.Errorf("target address cannot be empty") + } + return nil +} + +func DefaultConfig() *ProtocolConfig { + return &ProtocolConfig{ + AEADMethod: "chacha20-poly1305", + PaddingMin: 10, + PaddingMax: 30, + EnablePureDownlink: true, + HandshakeTimeoutSeconds: 5, + HTTPMaskMode: "legacy", + } +} + +func (c *ProtocolConfig) tableCandidates() []*sudoku.Table { + if c == nil { + return nil + } + if len(c.Tables) > 0 { + return c.Tables + } + if c.Table != nil { + return []*sudoku.Table{c.Table} + } + return nil +} diff --git a/transport/sudoku/crypto/aead.go b/transport/sudoku/crypto/aead.go new file mode 100644 index 00000000..b5f574d9 --- /dev/null +++ b/transport/sudoku/crypto/aead.go @@ -0,0 +1,130 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + + "golang.org/x/crypto/chacha20poly1305" +) + +type AEADConn struct { + net.Conn + aead cipher.AEAD + readBuf bytes.Buffer + nonceSize int +} + +func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) { + if method == "none" { + return &AEADConn{Conn: c, aead: nil}, nil + } + + h := sha256.New() + h.Write([]byte(key)) + keyBytes := h.Sum(nil) + + var aead cipher.AEAD + var err error + + switch method { + case "aes-128-gcm": + block, _ := aes.NewCipher(keyBytes[:16]) + aead, err = cipher.NewGCM(block) + case "chacha20-poly1305": + aead, err = chacha20poly1305.New(keyBytes) + default: + return nil, fmt.Errorf("unsupported cipher: %s", method) + } + if err != nil { + return nil, err + } + + return &AEADConn{ + Conn: c, + aead: aead, + nonceSize: aead.NonceSize(), + }, nil +} + +func (cc *AEADConn) Write(p []byte) (int, error) { + if cc.aead == nil { + return cc.Conn.Write(p) + } + + maxPayload := 65535 - cc.nonceSize - cc.aead.Overhead() + totalWritten := 0 + var frameBuf bytes.Buffer + header := make([]byte, 2) + nonce := make([]byte, cc.nonceSize) + + for len(p) > 0 { + chunkSize := len(p) + if chunkSize > maxPayload { + chunkSize = maxPayload + } + chunk := p[:chunkSize] + p = p[chunkSize:] + + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return totalWritten, err + } + + ciphertext := cc.aead.Seal(nil, nonce, chunk, nil) + frameLen := len(nonce) + len(ciphertext) + binary.BigEndian.PutUint16(header, uint16(frameLen)) + + frameBuf.Reset() + frameBuf.Write(header) + frameBuf.Write(nonce) + frameBuf.Write(ciphertext) + + if _, err := cc.Conn.Write(frameBuf.Bytes()); err != nil { + return totalWritten, err + } + totalWritten += chunkSize + } + return totalWritten, nil +} + +func (cc *AEADConn) Read(p []byte) (int, error) { + if cc.aead == nil { + return cc.Conn.Read(p) + } + + if cc.readBuf.Len() > 0 { + return cc.readBuf.Read(p) + } + + header := make([]byte, 2) + if _, err := io.ReadFull(cc.Conn, header); err != nil { + return 0, err + } + frameLen := int(binary.BigEndian.Uint16(header)) + + body := make([]byte, frameLen) + if _, err := io.ReadFull(cc.Conn, body); err != nil { + return 0, err + } + + if len(body) < cc.nonceSize { + return 0, errors.New("frame too short") + } + nonce := body[:cc.nonceSize] + ciphertext := body[cc.nonceSize:] + + plaintext, err := cc.aead.Open(nil, nonce, ciphertext, nil) + if err != nil { + return 0, errors.New("decryption failed") + } + + cc.readBuf.Write(plaintext) + return cc.readBuf.Read(p) +} diff --git a/transport/sudoku/crypto/ed25519.go b/transport/sudoku/crypto/ed25519.go new file mode 100644 index 00000000..7a2d0a12 --- /dev/null +++ b/transport/sudoku/crypto/ed25519.go @@ -0,0 +1,116 @@ +package crypto + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + + "filippo.io/edwards25519" +) + +// KeyPair holds the scalar private key and point public key +type KeyPair struct { + Private *edwards25519.Scalar + Public *edwards25519.Point +} + +// GenerateMasterKey generates a random master private key (scalar) and its public key (point) +func GenerateMasterKey() (*KeyPair, error) { + // 1. Generate random scalar x (32 bytes) + var seed [64]byte + if _, err := rand.Read(seed[:]); err != nil { + return nil, err + } + + x, err := edwards25519.NewScalar().SetUniformBytes(seed[:]) + if err != nil { + return nil, err + } + + // 2. Calculate Public Key P = x * G + P := new(edwards25519.Point).ScalarBaseMult(x) + + return &KeyPair{Private: x, Public: P}, nil +} + +// SplitPrivateKey takes a master private key x and returns a new random split key (r, k) +// such that x = r + k (mod L). +// Returns hex encoded string of r || k (64 bytes) +func SplitPrivateKey(x *edwards25519.Scalar) (string, error) { + // 1. Generate random r (32 bytes) + var seed [64]byte + if _, err := rand.Read(seed[:]); err != nil { + return "", err + } + r, err := edwards25519.NewScalar().SetUniformBytes(seed[:]) + if err != nil { + return "", err + } + + // 2. Calculate k = x - r (mod L) + k := new(edwards25519.Scalar).Subtract(x, r) + + // 3. Encode r and k + rBytes := r.Bytes() + kBytes := k.Bytes() + + full := make([]byte, 64) + copy(full[:32], rBytes) + copy(full[32:], kBytes) + + return hex.EncodeToString(full), nil +} + +// RecoverPublicKey takes a split private key (r, k) or a master private key (x) +// and returns the public key P. +// Input can be: +// - 32 bytes hex (Master Scalar x) +// - 64 bytes hex (Split Key r || k) +func RecoverPublicKey(keyHex string) (*edwards25519.Point, error) { + keyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return nil, fmt.Errorf("invalid hex: %w", err) + } + + if len(keyBytes) == 32 { + // Master Key x + x, err := edwards25519.NewScalar().SetCanonicalBytes(keyBytes) + if err != nil { + return nil, fmt.Errorf("invalid scalar: %w", err) + } + return new(edwards25519.Point).ScalarBaseMult(x), nil + + } else if len(keyBytes) == 64 { + // Split Key r || k + rBytes := keyBytes[:32] + kBytes := keyBytes[32:] + + r, err := edwards25519.NewScalar().SetCanonicalBytes(rBytes) + if err != nil { + return nil, fmt.Errorf("invalid scalar r: %w", err) + } + k, err := edwards25519.NewScalar().SetCanonicalBytes(kBytes) + if err != nil { + return nil, fmt.Errorf("invalid scalar k: %w", err) + } + + // sum = r + k + sum := new(edwards25519.Scalar).Add(r, k) + + // P = sum * G + return new(edwards25519.Point).ScalarBaseMult(sum), nil + } + + return nil, errors.New("invalid key length: must be 32 bytes (Master) or 64 bytes (Split)") +} + +// EncodePoint returns the hex string of the compressed point +func EncodePoint(p *edwards25519.Point) string { + return hex.EncodeToString(p.Bytes()) +} + +// EncodeScalar returns the hex string of the scalar +func EncodeScalar(s *edwards25519.Scalar) string { + return hex.EncodeToString(s.Bytes()) +} diff --git a/transport/sudoku/features_test.go b/transport/sudoku/features_test.go index 8eb3aedd..470598ce 100644 --- a/transport/sudoku/features_test.go +++ b/transport/sudoku/features_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - sudokuobfs "github.com/saba-futai/sudoku/pkg/obfs/sudoku" + sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) type discardConn struct{} diff --git a/transport/sudoku/handshake.go b/transport/sudoku/handshake.go index 989d2813..2a0437d6 100644 --- a/transport/sudoku/handshake.go +++ b/transport/sudoku/handshake.go @@ -11,18 +11,13 @@ import ( "strings" "time" - "github.com/saba-futai/sudoku/apis" - "github.com/saba-futai/sudoku/pkg/crypto" - "github.com/saba-futai/sudoku/pkg/obfs/httpmask" - "github.com/saba-futai/sudoku/pkg/obfs/sudoku" + "github.com/metacubex/mihomo/transport/sudoku/crypto" + "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask" + "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" "github.com/metacubex/mihomo/log" ) -type ProtocolConfig = apis.ProtocolConfig - -func DefaultConfig() *ProtocolConfig { return apis.DefaultConfig() } - type SessionType int const ( @@ -105,14 +100,14 @@ const ( downlinkModePacked byte = 0x02 ) -func downlinkMode(cfg *apis.ProtocolConfig) byte { +func downlinkMode(cfg *ProtocolConfig) byte { if cfg.EnablePureDownlink { return downlinkModePure } return downlinkModePacked } -func buildClientObfsConn(raw net.Conn, cfg *apis.ProtocolConfig, table *sudoku.Table) net.Conn { +func buildClientObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table) net.Conn { baseReader := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false) baseWriter := newSudokuObfsWriter(raw, table, cfg.PaddingMin, cfg.PaddingMax) if cfg.EnablePureDownlink { @@ -130,7 +125,7 @@ func buildClientObfsConn(raw net.Conn, cfg *apis.ProtocolConfig, table *sudoku.T } } -func buildServerObfsConn(raw net.Conn, cfg *apis.ProtocolConfig, table *sudoku.Table, record bool) (*sudoku.Conn, net.Conn) { +func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table, record bool) (*sudoku.Conn, net.Conn) { uplink := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, record) if cfg.EnablePureDownlink { downlink := &directionalConn{ @@ -189,12 +184,12 @@ type ClientHandshakeOptions struct { } // ClientHandshake performs the client-side Sudoku handshake (without sending target address). -func ClientHandshake(rawConn net.Conn, cfg *apis.ProtocolConfig) (net.Conn, error) { +func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) { return ClientHandshakeWithOptions(rawConn, cfg, ClientHandshakeOptions{}) } // ClientHandshakeWithOptions performs the client-side Sudoku handshake (without sending target address). -func ClientHandshakeWithOptions(rawConn net.Conn, cfg *apis.ProtocolConfig, opt ClientHandshakeOptions) (net.Conn, error) { +func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt ClientHandshakeOptions) (net.Conn, error) { if cfg == nil { return nil, fmt.Errorf("config is required") } @@ -220,7 +215,7 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *apis.ProtocolConfig, opt } handshake := buildHandshakePayload(cfg.Key) - if len(tableCandidates(cfg)) > 1 { + if len(cfg.tableCandidates()) > 1 { handshake[15] = tableID } if _, err := cConn.Write(handshake[:]); err != nil { @@ -236,7 +231,7 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *apis.ProtocolConfig, opt } // ServerHandshake performs Sudoku server-side handshake and detects UoT preface. -func ServerHandshake(rawConn net.Conn, cfg *apis.ProtocolConfig) (*ServerSession, error) { +func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, error) { if cfg == nil { return nil, fmt.Errorf("config is required") } @@ -260,7 +255,7 @@ func ServerHandshake(rawConn net.Conn, cfg *apis.ProtocolConfig) (*ServerSession } } - selectedTable, preRead, err := selectTableByProbe(bufReader, cfg, tableCandidates(cfg)) + selectedTable, preRead, err := selectTableByProbe(bufReader, cfg, cfg.tableCandidates()) if err != nil { return nil, err } diff --git a/transport/sudoku/handshake_test.go b/transport/sudoku/handshake_test.go index 5d9443df..b2f0999e 100644 --- a/transport/sudoku/handshake_test.go +++ b/transport/sudoku/handshake_test.go @@ -9,8 +9,7 @@ import ( "testing" "time" - "github.com/saba-futai/sudoku/apis" - sudokuobfs "github.com/saba-futai/sudoku/pkg/obfs/sudoku" + sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) func TestPackedConnRoundTrip_WithPadding(t *testing.T) { @@ -67,8 +66,8 @@ func TestPackedConnRoundTrip_WithPadding(t *testing.T) { } } -func newPackedConfig(table *sudokuobfs.Table) *apis.ProtocolConfig { - cfg := apis.DefaultConfig() +func newPackedConfig(table *sudokuobfs.Table) *ProtocolConfig { + cfg := DefaultConfig() cfg.Key = "sudoku-test-key" cfg.Table = table cfg.PaddingMin = 10 @@ -118,7 +117,7 @@ func TestPackedDownlinkSoak(t *testing.T) { } } -func runPackedTCPSession(id int, cfg *apis.ProtocolConfig, errCh chan<- error) { +func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) { serverConn, clientConn := net.Pipe() target := fmt.Sprintf("1.1.1.%d:80", (id%200)+1) payload := []byte{0x42, byte(id)} @@ -176,7 +175,7 @@ func runPackedTCPSession(id int, cfg *apis.ProtocolConfig, errCh chan<- error) { } } -func runPackedUoTSession(id int, cfg *apis.ProtocolConfig, errCh chan<- error) { +func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) { serverConn, clientConn := net.Pipe() target := "8.8.8.8:53" payload := []byte{0xaa, byte(id)} diff --git a/transport/sudoku/httpmask_strategy.go b/transport/sudoku/httpmask_strategy.go index dc90991d..fa11b249 100644 --- a/transport/sudoku/httpmask_strategy.go +++ b/transport/sudoku/httpmask_strategy.go @@ -10,7 +10,7 @@ import ( "sync" "time" - "github.com/saba-futai/sudoku/pkg/obfs/httpmask" + "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask" ) var ( diff --git a/transport/sudoku/httpmask_tunnel.go b/transport/sudoku/httpmask_tunnel.go new file mode 100644 index 00000000..aeedfe15 --- /dev/null +++ b/transport/sudoku/httpmask_tunnel.go @@ -0,0 +1,88 @@ +package sudoku + +import ( + "context" + "fmt" + "net" + "strings" + + "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask" +) + +type HTTPMaskTunnelServer struct { + cfg *ProtocolConfig + ts *httpmask.TunnelServer +} + +func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer { + if cfg == nil { + return &HTTPMaskTunnelServer{} + } + + var ts *httpmask.TunnelServer + if !cfg.DisableHTTPMask { + switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{Mode: cfg.HTTPMaskMode}) + } + } + return &HTTPMaskTunnelServer{cfg: cfg, ts: ts} +} + +// WrapConn inspects an accepted TCP connection and upgrades it to an HTTP tunnel stream when needed. +// +// Returns: +// - done=true: this TCP connection has been fully handled (e.g., stream/poll control request), caller should return +// - done=false: handshakeConn+cfg are ready for ServerHandshake +func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Conn, cfg *ProtocolConfig, done bool, err error) { + if rawConn == nil { + return nil, nil, true, fmt.Errorf("nil conn") + } + if s == nil { + return rawConn, nil, false, nil + } + if s.ts == nil { + return rawConn, s.cfg, false, nil + } + + res, c, err := s.ts.HandleConn(rawConn) + if err != nil { + return nil, nil, true, err + } + + switch res { + case httpmask.HandleDone: + return nil, nil, true, nil + case httpmask.HandlePassThrough: + return c, s.cfg, false, nil + case httpmask.HandleStartTunnel: + inner := *s.cfg + inner.DisableHTTPMask = true + return c, &inner, false, nil + default: + return nil, nil, true, nil + } +} + +type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error) + +// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes. +func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (net.Conn, error) { + if cfg == nil { + return nil, fmt.Errorf("config is required") + } + if cfg.DisableHTTPMask { + return nil, fmt.Errorf("http mask is disabled") + } + switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { + case "stream", "poll", "auto": + default: + return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode) + } + return httpmask.DialTunnel(ctx, serverAddress, httpmask.TunnelDialOptions{ + Mode: cfg.HTTPMaskMode, + TLSEnabled: cfg.HTTPMaskTLSEnabled, + HostOverride: cfg.HTTPMaskHost, + DialContext: dial, + }) +} diff --git a/transport/sudoku/httpmask_tunnel_test.go b/transport/sudoku/httpmask_tunnel_test.go new file mode 100644 index 00000000..eab310f9 --- /dev/null +++ b/transport/sudoku/httpmask_tunnel_test.go @@ -0,0 +1,445 @@ +package sudoku + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "strings" + "sync" + "testing" + "time" +) + +func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSession) error) (addr string, stop func(), errCh <-chan error) { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + + errC := make(chan error, 128) + done := make(chan struct{}) + + tunnelSrv := NewHTTPMaskTunnelServer(cfg) + var wg sync.WaitGroup + var stopOnce sync.Once + + wg.Add(1) + go func() { + defer wg.Done() + for { + c, err := ln.Accept() + if err != nil { + close(done) + return + } + wg.Add(1) + go func(conn net.Conn) { + defer wg.Done() + + handshakeConn, handshakeCfg, handled, err := tunnelSrv.WrapConn(conn) + if err != nil { + _ = conn.Close() + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return + } + if err == io.EOF { + return + } + errC <- err + return + } + if handled { + return + } + if handshakeConn == nil || handshakeCfg == nil { + _ = conn.Close() + errC <- fmt.Errorf("wrap conn returned nil") + return + } + + session, err := ServerHandshake(handshakeConn, handshakeCfg) + if err != nil { + _ = handshakeConn.Close() + if handshakeConn != conn { + _ = conn.Close() + } + errC <- err + return + } + defer session.Conn.Close() + + if handleErr := handle(session); handleErr != nil { + errC <- handleErr + } + }(c) + } + }() + + stop = func() { + stopOnce.Do(func() { + _ = ln.Close() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("server did not stop") + } + + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + select { + case <-ch: + case <-time.After(10 * time.Second): + t.Fatalf("server goroutines did not exit") + } + close(errC) + }) + } + + return ln.Addr().String(), stop, errC +} + +func newTunnelTestTable(t *testing.T, key string) *ProtocolConfig { + t.Helper() + + tables, err := NewTablesWithCustomPatterns(ClientAEADSeed(key), "prefer_ascii", "", nil) + if err != nil { + t.Fatalf("build tables: %v", err) + } + if len(tables) != 1 { + t.Fatalf("unexpected tables: %d", len(tables)) + } + + cfg := DefaultConfig() + cfg.Key = key + cfg.AEADMethod = "chacha20-poly1305" + cfg.Table = tables[0] + cfg.PaddingMin = 0 + cfg.PaddingMax = 0 + cfg.HandshakeTimeoutSeconds = 5 + cfg.EnablePureDownlink = true + cfg.DisableHTTPMask = false + return cfg +} + +func TestHTTPMaskTunnel_Stream_TCPRoundTrip(t *testing.T) { + key := "tunnel-stream-key" + target := "1.1.1.1:80" + + serverCfg := newTunnelTestTable(t, key) + serverCfg.HTTPMaskMode = "stream" + + addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error { + if s.Type != SessionTypeTCP { + return fmt.Errorf("unexpected session type: %v", s.Type) + } + if s.Target != target { + return fmt.Errorf("target mismatch: %s", s.Target) + } + _, _ = s.Conn.Write([]byte("ok")) + return nil + }) + defer stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + clientCfg := *serverCfg + clientCfg.ServerAddress = addr + clientCfg.HTTPMaskHost = "example.com" + + tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) + if err != nil { + t.Fatalf("dial tunnel: %v", err) + } + defer tunnelConn.Close() + + handshakeCfg := clientCfg + handshakeCfg.DisableHTTPMask = true + cConn, err := ClientHandshake(tunnelConn, &handshakeCfg) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + defer cConn.Close() + + addrBuf, err := EncodeAddress(target) + if err != nil { + t.Fatalf("encode addr: %v", err) + } + if _, err := cConn.Write(addrBuf); err != nil { + t.Fatalf("write addr: %v", err) + } + + buf := make([]byte, 2) + if _, err := io.ReadFull(cConn, buf); err != nil { + t.Fatalf("read: %v", err) + } + if string(buf) != "ok" { + t.Fatalf("unexpected payload: %q", buf) + } + + stop() + for err := range errCh { + t.Fatalf("server error: %v", err) + } +} + +func TestHTTPMaskTunnel_Poll_UoTRoundTrip(t *testing.T) { + key := "tunnel-poll-key" + target := "8.8.8.8:53" + payload := []byte{0xaa, 0xbb, 0xcc, 0xdd} + + serverCfg := newTunnelTestTable(t, key) + serverCfg.HTTPMaskMode = "poll" + + addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error { + if s.Type != SessionTypeUoT { + return fmt.Errorf("unexpected session type: %v", s.Type) + } + gotAddr, gotPayload, err := ReadDatagram(s.Conn) + if err != nil { + return fmt.Errorf("server read datagram: %w", err) + } + if gotAddr != target { + return fmt.Errorf("uot target mismatch: %s", gotAddr) + } + if !bytes.Equal(gotPayload, payload) { + return fmt.Errorf("uot payload mismatch: %x", gotPayload) + } + if err := WriteDatagram(s.Conn, gotAddr, gotPayload); err != nil { + return fmt.Errorf("server write datagram: %w", err) + } + return nil + }) + defer stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + clientCfg := *serverCfg + clientCfg.ServerAddress = addr + + tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) + if err != nil { + t.Fatalf("dial tunnel: %v", err) + } + defer tunnelConn.Close() + + handshakeCfg := clientCfg + handshakeCfg.DisableHTTPMask = true + cConn, err := ClientHandshake(tunnelConn, &handshakeCfg) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + defer cConn.Close() + + if err := WritePreface(cConn); err != nil { + t.Fatalf("write preface: %v", err) + } + if err := WriteDatagram(cConn, target, payload); err != nil { + t.Fatalf("write datagram: %v", err) + } + gotAddr, gotPayload, err := ReadDatagram(cConn) + if err != nil { + t.Fatalf("read datagram: %v", err) + } + if gotAddr != target { + t.Fatalf("uot target mismatch: %s", gotAddr) + } + if !bytes.Equal(gotPayload, payload) { + t.Fatalf("uot payload mismatch: %x", gotPayload) + } + + stop() + for err := range errCh { + t.Fatalf("server error: %v", err) + } +} + +func TestHTTPMaskTunnel_Auto_TCPRoundTrip(t *testing.T) { + key := "tunnel-auto-key" + target := "9.9.9.9:443" + + serverCfg := newTunnelTestTable(t, key) + serverCfg.HTTPMaskMode = "auto" + + addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error { + if s.Type != SessionTypeTCP { + return fmt.Errorf("unexpected session type: %v", s.Type) + } + if s.Target != target { + return fmt.Errorf("target mismatch: %s", s.Target) + } + _, _ = s.Conn.Write([]byte("ok")) + return nil + }) + defer stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + clientCfg := *serverCfg + clientCfg.ServerAddress = addr + + tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) + if err != nil { + t.Fatalf("dial tunnel: %v", err) + } + defer tunnelConn.Close() + + handshakeCfg := clientCfg + handshakeCfg.DisableHTTPMask = true + cConn, err := ClientHandshake(tunnelConn, &handshakeCfg) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + defer cConn.Close() + + addrBuf, err := EncodeAddress(target) + if err != nil { + t.Fatalf("encode addr: %v", err) + } + if _, err := cConn.Write(addrBuf); err != nil { + t.Fatalf("write addr: %v", err) + } + + buf := make([]byte, 2) + if _, err := io.ReadFull(cConn, buf); err != nil { + t.Fatalf("read: %v", err) + } + if string(buf) != "ok" { + t.Fatalf("unexpected payload: %q", buf) + } + + stop() + for err := range errCh { + t.Fatalf("server error: %v", err) + } +} + +func TestHTTPMaskTunnel_Validation(t *testing.T) { + cfg := DefaultConfig() + cfg.Key = "k" + cfg.Table = NewTable("seed", "prefer_ascii") + cfg.ServerAddress = "127.0.0.1:1" + + cfg.DisableHTTPMask = true + cfg.HTTPMaskMode = "stream" + if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext); err == nil { + t.Fatalf("expected error for disabled http mask") + } + + cfg.DisableHTTPMask = false + cfg.HTTPMaskMode = "legacy" + if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext); err == nil { + t.Fatalf("expected error for legacy mode") + } +} + +func TestHTTPMaskTunnel_Soak_Concurrent(t *testing.T) { + key := "tunnel-soak-key" + target := "1.0.0.1:80" + + serverCfg := newTunnelTestTable(t, key) + serverCfg.HTTPMaskMode = "stream" + serverCfg.EnablePureDownlink = false + + const ( + sessions = 8 + payloadLen = 64 * 1024 + ) + + addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error { + if s.Type != SessionTypeTCP { + return fmt.Errorf("unexpected session type: %v", s.Type) + } + if s.Target != target { + return fmt.Errorf("target mismatch: %s", s.Target) + } + buf := make([]byte, payloadLen) + if _, err := io.ReadFull(s.Conn, buf); err != nil { + return fmt.Errorf("server read payload: %w", err) + } + _, err := s.Conn.Write(buf) + return err + }) + defer stop() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var wg sync.WaitGroup + runErr := make(chan error, sessions) + + for i := 0; i < sessions; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + clientCfg := *serverCfg + clientCfg.ServerAddress = addr + clientCfg.HTTPMaskHost = strings.TrimSpace(clientCfg.HTTPMaskHost) + + tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) + if err != nil { + runErr <- fmt.Errorf("dial: %w", err) + return + } + defer tunnelConn.Close() + + handshakeCfg := clientCfg + handshakeCfg.DisableHTTPMask = true + cConn, err := ClientHandshake(tunnelConn, &handshakeCfg) + if err != nil { + runErr <- fmt.Errorf("handshake: %w", err) + return + } + defer cConn.Close() + + addrBuf, err := EncodeAddress(target) + if err != nil { + runErr <- fmt.Errorf("encode addr: %w", err) + return + } + if _, err := cConn.Write(addrBuf); err != nil { + runErr <- fmt.Errorf("write addr: %w", err) + return + } + + payload := bytes.Repeat([]byte{byte(id)}, payloadLen) + if _, err := cConn.Write(payload); err != nil { + runErr <- fmt.Errorf("write payload: %w", err) + return + } + echo := make([]byte, payloadLen) + if _, err := io.ReadFull(cConn, echo); err != nil { + runErr <- fmt.Errorf("read echo: %w", err) + return + } + if !bytes.Equal(echo, payload) { + runErr <- fmt.Errorf("echo mismatch") + return + } + runErr <- nil + }(i) + } + + wg.Wait() + close(runErr) + + for err := range runErr { + if err != nil { + t.Fatalf("soak: %v", err) + } + } + + stop() + for err := range errCh { + t.Fatalf("server error: %v", err) + } +} diff --git a/transport/sudoku/obfs/httpmask/masker.go b/transport/sudoku/obfs/httpmask/masker.go new file mode 100644 index 00000000..540a8911 --- /dev/null +++ b/transport/sudoku/obfs/httpmask/masker.go @@ -0,0 +1,246 @@ +package httpmask + +import ( + "bufio" + "bytes" + "encoding/base64" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +var ( + userAgents = []string{ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:121.0) Gecko/20100101 Firefox/121.0", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.2 Safari/605.1.15", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_2_1) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.2 Safari/605.1.15", + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (iPhone; CPU iPhone OS 17_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.2 Mobile/15E148 Safari/604.1", + "Mozilla/5.0 (Linux; Android 14; Pixel 7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Mobile Safari/537.36", + } + accepts = []string{ + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + "application/json, text/plain, */*", + "application/octet-stream", + "*/*", + } + acceptLanguages = []string{ + "en-US,en;q=0.9", + "en-GB,en;q=0.9", + "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7", + "ja-JP,ja;q=0.9,en-US;q=0.8,en;q=0.7", + "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7", + } + acceptEncodings = []string{ + "gzip, deflate, br", + "gzip, deflate", + "br, gzip, deflate", + } + paths = []string{ + "/api/v1/upload", + "/data/sync", + "/uploads/raw", + "/api/report", + "/feed/update", + "/v2/events", + "/v1/telemetry", + "/session", + "/stream", + "/ws", + } + contentTypes = []string{ + "application/octet-stream", + "application/x-protobuf", + "application/json", + } +) + +var ( + rngPool = sync.Pool{ + New: func() interface{} { + return rand.New(rand.NewSource(time.Now().UnixNano())) + }, + } + headerBufPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 0, 1024) + return &b + }, + } +) + +// LooksLikeHTTPRequestStart reports whether peek4 looks like a supported HTTP/1.x request method prefix. +func LooksLikeHTTPRequestStart(peek4 []byte) bool { + if len(peek4) < 4 { + return false + } + // Common methods: "GET ", "POST", "HEAD", "PUT ", "OPTI" (OPTIONS), "PATC" (PATCH), "DELE" (DELETE) + return bytes.Equal(peek4, []byte("GET ")) || + bytes.Equal(peek4, []byte("POST")) || + bytes.Equal(peek4, []byte("HEAD")) || + bytes.Equal(peek4, []byte("PUT ")) || + bytes.Equal(peek4, []byte("OPTI")) || + bytes.Equal(peek4, []byte("PATC")) || + bytes.Equal(peek4, []byte("DELE")) +} + +func trimPortForHost(host string) string { + if host == "" { + return host + } + // Accept "example.com:443" / "1.2.3.4:443" / "[::1]:443" + h, _, err := net.SplitHostPort(host) + if err == nil && h != "" { + return h + } + // If it's not in host:port form, keep as-is. + return host +} + +func appendCommonHeaders(buf []byte, host string, r *rand.Rand) []byte { + ua := userAgents[r.Intn(len(userAgents))] + accept := accepts[r.Intn(len(accepts))] + lang := acceptLanguages[r.Intn(len(acceptLanguages))] + enc := acceptEncodings[r.Intn(len(acceptEncodings))] + + buf = append(buf, "Host: "...) + buf = append(buf, host...) + buf = append(buf, "\r\nUser-Agent: "...) + buf = append(buf, ua...) + buf = append(buf, "\r\nAccept: "...) + buf = append(buf, accept...) + buf = append(buf, "\r\nAccept-Language: "...) + buf = append(buf, lang...) + buf = append(buf, "\r\nAccept-Encoding: "...) + buf = append(buf, enc...) + buf = append(buf, "\r\nConnection: keep-alive\r\n"...) + + // A couple of common cache headers; keep them static for simplicity. + buf = append(buf, "Cache-Control: no-cache\r\nPragma: no-cache\r\n"...) + return buf +} + +// WriteRandomRequestHeader writes a plausible HTTP/1.1 request header as a mask. +func WriteRandomRequestHeader(w io.Writer, host string) error { + // Get RNG from pool + r := rngPool.Get().(*rand.Rand) + defer rngPool.Put(r) + + path := paths[r.Intn(len(paths))] + ctype := contentTypes[r.Intn(len(contentTypes))] + + // Use buffer pool + bufPtr := headerBufPool.Get().(*[]byte) + buf := *bufPtr + buf = buf[:0] + defer func() { + if cap(buf) <= 4096 { + *bufPtr = buf + headerBufPool.Put(bufPtr) + } + }() + + // Weighted template selection. Keep a conservative default (POST w/ Content-Length), + // but occasionally rotate to other realistic templates (e.g. WebSocket upgrade). + switch r.Intn(10) { + case 0, 1: // ~20% WebSocket-like upgrade + hostNoPort := trimPortForHost(host) + var keyBytes [16]byte + for i := 0; i < len(keyBytes); i++ { + keyBytes[i] = byte(r.Intn(256)) + } + wsKey := base64.StdEncoding.EncodeToString(keyBytes[:]) + + buf = append(buf, "GET "...) + buf = append(buf, path...) + buf = append(buf, " HTTP/1.1\r\n"...) + buf = appendCommonHeaders(buf, host, r) + buf = append(buf, "Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...) + buf = append(buf, wsKey...) + buf = append(buf, "\r\nOrigin: https://"...) + buf = append(buf, hostNoPort...) + buf = append(buf, "\r\n\r\n"...) + default: // ~80% POST upload + // Random Content-Length: 4KB–10MB. Small enough to look plausible, large enough + // to justify long-lived writes on keep-alive connections. + const minCL = int64(4 * 1024) + const maxCL = int64(10 * 1024 * 1024) + contentLength := minCL + r.Int63n(maxCL-minCL+1) + + buf = append(buf, "POST "...) + buf = append(buf, path...) + buf = append(buf, " HTTP/1.1\r\n"...) + buf = appendCommonHeaders(buf, host, r) + buf = append(buf, "Content-Type: "...) + buf = append(buf, ctype...) + buf = append(buf, "\r\nContent-Length: "...) + buf = strconv.AppendInt(buf, contentLength, 10) + // A couple of extra headers seen in real clients. + if r.Intn(2) == 0 { + buf = append(buf, "\r\nX-Requested-With: XMLHttpRequest"...) + } + if r.Intn(3) == 0 { + buf = append(buf, "\r\nReferer: https://"...) + buf = append(buf, trimPortForHost(host)...) + buf = append(buf, "/"...) + } + buf = append(buf, "\r\n\r\n"...) + } + + _, err := w.Write(buf) + return err +} + +// ConsumeHeader 读取并消耗 HTTP 头部,返回消耗的数据和剩余的 reader 数据 +// 如果不是 POST 请求或格式严重错误,返回 error +func ConsumeHeader(r *bufio.Reader) ([]byte, error) { + var consumed bytes.Buffer + + // 1. 读取请求行 + // Use ReadSlice to avoid allocation if line fits in buffer + line, err := r.ReadSlice('\n') + if err != nil { + return nil, err + } + consumed.Write(line) + + // Basic method validation: accept common HTTP/1.x methods used by our masker. + // Keep it strict enough to reject obvious garbage. + switch { + case bytes.HasPrefix(line, []byte("POST ")), + bytes.HasPrefix(line, []byte("GET ")), + bytes.HasPrefix(line, []byte("HEAD ")), + bytes.HasPrefix(line, []byte("PUT ")), + bytes.HasPrefix(line, []byte("DELETE ")), + bytes.HasPrefix(line, []byte("OPTIONS ")), + bytes.HasPrefix(line, []byte("PATCH ")): + default: + return consumed.Bytes(), fmt.Errorf("invalid method or garbage: %s", strings.TrimSpace(string(line))) + } + + // 2. 循环读取头部,直到遇到空行 + for { + line, err = r.ReadSlice('\n') + if err != nil { + return consumed.Bytes(), err + } + consumed.Write(line) + + // Check for empty line (\r\n or \n) + // ReadSlice includes the delimiter + n := len(line) + if n == 2 && line[0] == '\r' && line[1] == '\n' { + return consumed.Bytes(), nil + } + if n == 1 && line[0] == '\n' { + return consumed.Bytes(), nil + } + } +} diff --git a/transport/sudoku/obfs/httpmask/tunnel.go b/transport/sudoku/obfs/httpmask/tunnel.go new file mode 100644 index 00000000..3e0b6aa1 --- /dev/null +++ b/transport/sudoku/obfs/httpmask/tunnel.go @@ -0,0 +1,1684 @@ +package httpmask + +import ( + "bufio" + "bytes" + "context" + crand "crypto/rand" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + mrand "math/rand" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" +) + +type TunnelMode string + +const ( + TunnelModeLegacy TunnelMode = "legacy" + TunnelModeStream TunnelMode = "stream" + TunnelModePoll TunnelMode = "poll" + TunnelModeAuto TunnelMode = "auto" +) + +func normalizeTunnelMode(mode string) TunnelMode { + switch strings.ToLower(strings.TrimSpace(mode)) { + case "", string(TunnelModeLegacy): + return TunnelModeLegacy + case string(TunnelModeStream): + return TunnelModeStream + case string(TunnelModePoll): + return TunnelModePoll + case string(TunnelModeAuto): + return TunnelModeAuto + default: + // Be conservative: unknown => legacy + return TunnelModeLegacy + } +} + +type HandleResult int + +const ( + HandlePassThrough HandleResult = iota + HandleStartTunnel + HandleDone +) + +type TunnelDialOptions struct { + Mode string + TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference) + HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443" + // DialContext overrides how the HTTP tunnel dials raw TCP/TLS connections. + // It must not be nil; passing nil is a programming error. + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// DialTunnel establishes a bidirectional stream over HTTP: +// - stream: a single streaming POST (request body uplink, response body downlink) +// - poll: authorize + push/pull polling tunnel (base64 framed) +// - auto: try stream then fall back to poll +// +// The returned net.Conn carries the raw Sudoku stream (no HTTP headers). +func DialTunnel(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + mode := normalizeTunnelMode(opts.Mode) + if mode == TunnelModeLegacy { + return nil, fmt.Errorf("legacy mode does not use http tunnel") + } + + switch mode { + case TunnelModeStream: + return dialStreamFn(ctx, serverAddress, opts) + case TunnelModePoll: + return dialPollFn(ctx, serverAddress, opts) + case TunnelModeAuto: + // "stream" can hang on some CDNs that buffer uploads until request body completes. + // Keep it on a short leash so we can fall back to poll within the caller's deadline. + streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second) + c, errX := dialStreamFn(streamCtx, serverAddress, opts) + cancelX() + if errX == nil { + return c, nil + } + c, errP := dialPollFn(ctx, serverAddress, opts) + if errP == nil { + return c, nil + } + return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP) + default: + return dialStreamFn(ctx, serverAddress, opts) + } +} + +var ( + dialStreamFn = dialStream + dialPollFn = dialPoll +) + +func canonicalHeaderHost(urlHost, scheme string) string { + host, port, err := net.SplitHostPort(urlHost) + if err != nil { + return urlHost + } + + defaultPort := "" + switch scheme { + case "https": + defaultPort = "443" + case "http": + defaultPort = "80" + } + if defaultPort == "" || port != defaultPort { + return urlHost + } + + // If we strip the port from an IPv6 literal, re-add brackets to keep the Host header valid. + if strings.Contains(host, ":") { + return "[" + host + "]" + } + return host +} + +func parseTunnelToken(body []byte) (string, error) { + s := strings.TrimSpace(string(body)) + idx := strings.Index(s, "token=") + if idx < 0 { + return "", errors.New("missing token") + } + s = s[idx+len("token="):] + if s == "" { + return "", errors.New("empty token") + } + // Token is base64.RawURLEncoding (A-Z a-z 0-9 - _). Strip any trailing bytes (e.g. from CDN compression). + var b strings.Builder + for i := 0; i < len(s); i++ { + c := s[i] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { + b.WriteByte(c) + continue + } + break + } + token := b.String() + if token == "" { + return "", errors.New("empty token") + } + return token, nil +} + +type httpStreamConn struct { + reader io.ReadCloser + writer *io.PipeWriter + cancel context.CancelFunc + + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *httpStreamConn) Read(p []byte) (int, error) { return c.reader.Read(p) } +func (c *httpStreamConn) Write(p []byte) (int, error) { return c.writer.Write(p) } + +func (c *httpStreamConn) Close() error { + if c.cancel != nil { + c.cancel() + } + _ = c.writer.CloseWithError(io.ErrClosedPipe) + return c.reader.Close() +} + +func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr } +func (c *httpStreamConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *httpStreamConn) SetDeadline(time.Time) error { return nil } +func (c *httpStreamConn) SetReadDeadline(time.Time) error { return nil } +func (c *httpStreamConn) SetWriteDeadline(time.Time) error { return nil } + +type httpClientTarget struct { + scheme string + urlHost string + headerHost string +} + +func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) { + if opts.DialContext == nil { + panic("httpmask: DialContext is nil") + } + + scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride) + if err != nil { + return nil, httpClientTarget{}, err + } + + transport := &http.Transport{ + ForceAttemptHTTP2: true, + DisableCompression: true, + MaxIdleConns: maxIdleConns, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + DialContext: func(dialCtx context.Context, network, _ string) (net.Conn, error) { + return opts.DialContext(dialCtx, network, dialAddr) + }, + } + if scheme == "https" { + transport.TLSClientConfig = &tls.Config{ + ServerName: serverName, + MinVersion: tls.VersionTLS12, + } + } + + return &http.Client{Transport: transport}, httpClientTarget{ + scheme: scheme, + urlHost: urlHost, + headerHost: canonicalHeaderHost(urlHost, scheme), + }, nil +} + +func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + // Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. + c, errSplit := dialStreamSplit(ctx, serverAddress, opts) + if errSplit == nil { + return c, nil + } + c2, errOne := dialStreamOne(ctx, serverAddress, opts) + if errOne == nil { + return c2, nil + } + return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne) +} + +func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + client, target, err := newHTTPClient(serverAddress, opts, 16) + if err != nil { + return nil, err + } + + r := rngPool.Get().(*mrand.Rand) + path := paths[r.Intn(len(paths))] + ctype := contentTypes[r.Intn(len(contentTypes))] + rngPool.Put(r) + + u := url.URL{ + Scheme: target.scheme, + Host: target.urlHost, + Path: path, + } + + reqBodyR, reqBodyW := io.Pipe() + + ctx, cancel := context.WithCancel(ctx) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), reqBodyR) + if err != nil { + cancel() + _ = reqBodyW.Close() + return nil, err + } + req.Host = target.headerHost + + applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream) + req.Header.Set("Content-Type", ctype) + + resp, err := client.Do(req) + if err != nil { + cancel() + _ = reqBodyW.Close() + return nil, err + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + cancel() + _ = reqBodyW.Close() + return nil, fmt.Errorf("stream bad status: %s (%s)", resp.Status, strings.TrimSpace(string(body))) + } + + return &httpStreamConn{ + reader: resp.Body, + writer: reqBodyW, + cancel: cancel, + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + }, nil +} + +type streamSplitConn struct { + ctx context.Context + cancel context.CancelFunc + + client *http.Client + pushURL string + pullURL string + closeURL string + headerHost string + + rxc chan []byte + closed chan struct{} + + writeCh chan []byte + + mu sync.Mutex + readBuf []byte + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *streamSplitConn) Read(b []byte) (n int, err error) { + if len(c.readBuf) == 0 { + select { + case c.readBuf = <-c.rxc: + case <-c.closed: + return 0, io.ErrClosedPipe + } + } + n = copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil +} + +func (c *streamSplitConn) Write(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + c.mu.Lock() + select { + case <-c.closed: + c.mu.Unlock() + return 0, io.ErrClosedPipe + default: + } + c.mu.Unlock() + + payload := make([]byte, len(b)) + copy(payload, b) + select { + case c.writeCh <- payload: + return len(b), nil + case <-c.closed: + return 0, io.ErrClosedPipe + } +} + +func (c *streamSplitConn) Close() error { + c.mu.Lock() + select { + case <-c.closed: + c.mu.Unlock() + return nil + default: + close(c.closed) + } + c.mu.Unlock() + + if c.cancel != nil { + c.cancel() + } + + // Best-effort session close signal (avoid leaking server-side sessions). + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, c.closeURL, nil) + if err == nil { + req.Host = c.headerHost + applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) + if resp, doErr := c.client.Do(req); doErr == nil && resp != nil { + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + } + } + + return nil +} + +func (c *streamSplitConn) LocalAddr() net.Addr { return c.localAddr } +func (c *streamSplitConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *streamSplitConn) SetDeadline(time.Time) error { return nil } +func (c *streamSplitConn) SetReadDeadline(time.Time) error { return nil } +func (c *streamSplitConn) SetWriteDeadline(time.Time) error { return nil } + +func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + client, target, err := newHTTPClient(serverAddress, opts, 32) + if err != nil { + return nil, err + } + + authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) + if err != nil { + return nil, err + } + req.Host = target.headerHost + applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("stream authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes))) + } + + token, err := parseTunnelToken(bodyBytes) + if err != nil { + return nil, fmt.Errorf("stream authorize failed: %q", strings.TrimSpace(string(bodyBytes))) + } + if token == "" { + return nil, fmt.Errorf("stream authorize empty token") + } + + pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String() + pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String() + closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() + + connCtx, cancel := context.WithCancel(context.Background()) + c := &streamSplitConn{ + ctx: connCtx, + cancel: cancel, + client: client, + pushURL: pushURL, + pullURL: pullURL, + closeURL: closeURL, + headerHost: target.headerHost, + rxc: make(chan []byte, 256), + closed: make(chan struct{}), + writeCh: make(chan []byte, 256), + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + } + + go c.pullLoop() + go c.pushLoop() + return c, nil +} + +func (c *streamSplitConn) pullLoop() { + const ( + requestTimeout = 30 * time.Second + readChunkSize = 32 * 1024 + idleBackoff = 25 * time.Millisecond + ) + + buf := make([]byte, readChunkSize) + for { + select { + case <-c.closed: + return + default: + } + + reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout) + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil) + if err != nil { + cancel() + _ = c.Close() + return + } + req.Host = c.headerHost + applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) + + resp, err := c.client.Do(req) + if err != nil { + cancel() + _ = c.Close() + return + } + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + cancel() + _ = c.Close() + return + } + + readAny := false + for { + n, rerr := resp.Body.Read(buf) + if n > 0 { + readAny = true + payload := make([]byte, n) + copy(payload, buf[:n]) + select { + case c.rxc <- payload: + case <-c.closed: + _ = resp.Body.Close() + return + } + } + if rerr != nil { + _ = resp.Body.Close() + cancel() + if errors.Is(rerr, io.EOF) { + // Long-poll ended; retry. + break + } + _ = c.Close() + return + } + } + cancel() + if !readAny { + // Avoid tight loop if the server replied quickly with an empty body. + select { + case <-time.After(idleBackoff): + case <-c.closed: + return + } + } + } +} + +func (c *streamSplitConn) pushLoop() { + const ( + maxBatchBytes = 256 * 1024 + flushInterval = 5 * time.Millisecond + requestTimeout = 20 * time.Second + ) + + var ( + buf bytes.Buffer + timer = time.NewTimer(flushInterval) + ) + defer timer.Stop() + + flush := func() bool { + if buf.Len() == 0 { + return true + } + + reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout) + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) + if err != nil { + cancel() + return false + } + req.Host = c.headerHost + applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := c.client.Do(req) + if err != nil { + cancel() + return false + } + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + cancel() + if resp.StatusCode != http.StatusOK { + return false + } + + buf.Reset() + return true + } + + resetTimer := func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(flushInterval) + } + + resetTimer() + + for { + select { + case b, ok := <-c.writeCh: + if !ok { + _ = flush() + return + } + if len(b) == 0 { + continue + } + if buf.Len()+len(b) > maxBatchBytes { + if !flush() { + _ = c.Close() + return + } + resetTimer() + } + _, _ = buf.Write(b) + if buf.Len() >= maxBatchBytes { + if !flush() { + _ = c.Close() + return + } + resetTimer() + } + case <-timer.C: + if !flush() { + _ = c.Close() + return + } + resetTimer() + case <-c.closed: + _ = flush() + return + } + } +} + +type pollConn struct { + client *http.Client + pushURL string + pullURL string + closeURL string + headerHost string + + rxc chan []byte + closed chan struct{} + + writeCh chan []byte + + mu sync.Mutex + readBuf []byte + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *pollConn) Read(b []byte) (n int, err error) { + if len(c.readBuf) == 0 { + select { + case c.readBuf = <-c.rxc: + case <-c.closed: + return 0, io.ErrClosedPipe + } + } + n = copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil +} + +func (c *pollConn) Write(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + c.mu.Lock() + select { + case <-c.closed: + c.mu.Unlock() + return 0, io.ErrClosedPipe + default: + } + c.mu.Unlock() + + payload := make([]byte, len(b)) + copy(payload, b) + select { + case c.writeCh <- payload: + return len(b), nil + case <-c.closed: + return 0, io.ErrClosedPipe + } +} + +func (c *pollConn) Close() error { + c.mu.Lock() + select { + case <-c.closed: + c.mu.Unlock() + return nil + default: + close(c.closed) + } + c.mu.Unlock() + + close(c.writeCh) + + // Best-effort session close signal (avoid leaking server-side sessions). + req, err := http.NewRequest(http.MethodPost, c.closeURL, nil) + if err == nil { + req.Host = c.headerHost + req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll)) + req.Header.Set("X-Sudoku-Version", "1") + _, _ = c.client.Do(req) + } + + return nil +} + +func (c *pollConn) LocalAddr() net.Addr { return c.localAddr } +func (c *pollConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *pollConn) SetDeadline(time.Time) error { return nil } +func (c *pollConn) SetReadDeadline(time.Time) error { return nil } +func (c *pollConn) SetWriteDeadline(time.Time) error { return nil } + +func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + client, target, err := newHTTPClient(serverAddress, opts, 32) + if err != nil { + return nil, err + } + + authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) + if err != nil { + return nil, err + } + req.Host = target.headerHost + applyTunnelHeaders(req.Header, target.headerHost, TunnelModePoll) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("poll authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes))) + } + + token, err := parseTunnelToken(bodyBytes) + if err != nil { + return nil, fmt.Errorf("poll authorize failed: %q", strings.TrimSpace(string(bodyBytes))) + } + if token == "" { + return nil, fmt.Errorf("poll authorize empty token") + } + + pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String() + pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String() + closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() + + c := &pollConn{ + client: client, + pushURL: pushURL, + pullURL: pullURL, + closeURL: closeURL, + headerHost: target.headerHost, + rxc: make(chan []byte, 128), + closed: make(chan struct{}), + writeCh: make(chan []byte, 256), + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, + } + + go c.pullLoop() + go c.pushLoop() + return c, nil +} + +func (c *pollConn) pullLoop() { + for { + select { + case <-c.closed: + return + default: + } + + req, err := http.NewRequest(http.MethodGet, c.pullURL, nil) + if err != nil { + _ = c.Close() + return + } + req.Host = c.headerHost + applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) + + resp, err := c.client.Do(req) + if err != nil { + _ = c.Close() + return + } + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + _ = c.Close() + return + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + payload, err := base64.StdEncoding.DecodeString(line) + if err != nil { + _ = resp.Body.Close() + _ = c.Close() + return + } + select { + case c.rxc <- payload: + case <-c.closed: + _ = resp.Body.Close() + return + } + } + _ = resp.Body.Close() + if err := scanner.Err(); err != nil { + _ = c.Close() + return + } + } +} + +func (c *pollConn) pushLoop() { + const ( + maxBatchBytes = 64 * 1024 + flushInterval = 5 * time.Millisecond + maxLineRawBytes = 16 * 1024 + ) + + var ( + buf bytes.Buffer + pendingRaw int + timer = time.NewTimer(flushInterval) + ) + defer timer.Stop() + + flush := func() bool { + if buf.Len() == 0 { + return true + } + + req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) + if err != nil { + return false + } + req.Host = c.headerHost + applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) + req.Header.Set("Content-Type", "text/plain") + + resp, err := c.client.Do(req) + if err != nil { + return false + } + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return false + } + + buf.Reset() + pendingRaw = 0 + return true + } + + resetTimer := func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(flushInterval) + } + + resetTimer() + + for { + select { + case b, ok := <-c.writeCh: + if !ok { + _ = flush() + return + } + if len(b) == 0 { + continue + } + + // Split large writes into multiple base64 lines to cap per-line size. + for len(b) > 0 { + chunk := b + if len(chunk) > maxLineRawBytes { + chunk = b[:maxLineRawBytes] + } + b = b[len(chunk):] + + encLen := base64.StdEncoding.EncodedLen(len(chunk)) + if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 { + if !flush() { + _ = c.Close() + return + } + } + + tmp := make([]byte, base64.StdEncoding.EncodedLen(len(chunk))) + base64.StdEncoding.Encode(tmp, chunk) + buf.Write(tmp) + buf.WriteByte('\n') + pendingRaw += len(chunk) + } + + if pendingRaw >= maxBatchBytes { + if !flush() { + _ = c.Close() + return + } + resetTimer() + } + case <-timer.C: + if !flush() { + _ = c.Close() + return + } + resetTimer() + case <-c.closed: + _ = flush() + return + } + } +} + +func normalizeHTTPDialTarget(serverAddress string, tlsEnabled bool, hostOverride string) (scheme, urlHost, dialAddr, serverName string, err error) { + host, port, err := net.SplitHostPort(serverAddress) + if err != nil { + return "", "", "", "", fmt.Errorf("invalid server address %q: %w", serverAddress, err) + } + + if hostOverride != "" { + // Allow "example.com" or "example.com:443" + if h, p, splitErr := net.SplitHostPort(hostOverride); splitErr == nil { + if h != "" { + hostOverride = h + } + if p != "" { + port = p + } + } + serverName = hostOverride + urlHost = net.JoinHostPort(hostOverride, port) + } else { + serverName = host + urlHost = net.JoinHostPort(host, port) + } + + if tlsEnabled { + scheme = "https" + } else { + scheme = "http" + } + + dialAddr = net.JoinHostPort(host, port) + return scheme, urlHost, dialAddr, trimPortForHost(serverName), nil +} + +func applyTunnelHeaders(h http.Header, host string, mode TunnelMode) { + r := rngPool.Get().(*mrand.Rand) + ua := userAgents[r.Intn(len(userAgents))] + accept := accepts[r.Intn(len(accepts))] + lang := acceptLanguages[r.Intn(len(acceptLanguages))] + enc := acceptEncodings[r.Intn(len(acceptEncodings))] + rngPool.Put(r) + + h.Set("User-Agent", ua) + h.Set("Accept", accept) + h.Set("Accept-Language", lang) + h.Set("Accept-Encoding", enc) + h.Set("Cache-Control", "no-cache") + h.Set("Pragma", "no-cache") + h.Set("Connection", "keep-alive") + h.Set("Host", host) + h.Set("X-Sudoku-Tunnel", string(mode)) + h.Set("X-Sudoku-Version", "1") +} + +type TunnelServerOptions struct { + Mode string + // PullReadTimeout controls how long the server long-poll waits for tunnel downlink data before replying with a keepalive newline. + PullReadTimeout time.Duration + // SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default. + SessionTTL time.Duration +} + +type TunnelServer struct { + mode TunnelMode + + pullReadTimeout time.Duration + sessionTTL time.Duration + + mu sync.Mutex + sessions map[string]*tunnelSession +} + +type tunnelSession struct { + conn net.Conn + lastActive time.Time +} + +func NewTunnelServer(opts TunnelServerOptions) *TunnelServer { + mode := normalizeTunnelMode(opts.Mode) + if mode == TunnelModeLegacy { + // Server-side "legacy" means: don't accept stream/poll tunnels; only passthrough. + } + timeout := opts.PullReadTimeout + if timeout <= 0 { + timeout = 10 * time.Second + } + ttl := opts.SessionTTL + if ttl <= 0 { + ttl = 2 * time.Minute + } + return &TunnelServer{ + mode: mode, + pullReadTimeout: timeout, + sessionTTL: ttl, + sessions: make(map[string]*tunnelSession), + } +} + +// HandleConn inspects rawConn. If it is an HTTP tunnel request (X-Sudoku-Tunnel header), it is handled here and: +// - returns HandleStartTunnel + a net.Conn that carries the raw Sudoku stream (stream mode or poll session pipe) +// - or returns HandleDone if the HTTP request is a poll control request (push/pull) and no Sudoku handshake should run on this TCP conn +// +// If it is not an HTTP tunnel request (or server mode is legacy), it returns HandlePassThrough with a conn that replays any pre-read bytes. +func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, error) { + if rawConn == nil { + return HandleDone, nil, errors.New("nil conn") + } + + // Small header read deadline to avoid stalling Accept loops. The actual Sudoku handshake has its own deadlines. + _ = rawConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + var first [4]byte + n, err := io.ReadFull(rawConn, first[:]) + if err != nil { + _ = rawConn.SetReadDeadline(time.Time{}) + // Even if short-read, preserve bytes for downstream handlers. + if n > 0 { + return HandlePassThrough, newPreBufferedConn(rawConn, first[:n]), nil + } + return HandleDone, nil, err + } + pc := newPreBufferedConn(rawConn, first[:]) + br := bufio.NewReader(pc) + + if !LooksLikeHTTPRequestStart(first[:]) { + _ = rawConn.SetReadDeadline(time.Time{}) + return HandlePassThrough, pc, nil + } + + req, headerBytes, buffered, err := readHTTPHeader(br) + _ = rawConn.SetReadDeadline(time.Time{}) + if err != nil { + // Not a valid HTTP request; hand it back to the legacy path with replay. + prefix := make([]byte, 0, len(first)+len(headerBytes)+len(buffered)) + if len(headerBytes) == 0 || !bytes.HasPrefix(headerBytes, first[:]) { + prefix = append(prefix, first[:]...) + } + prefix = append(prefix, headerBytes...) + prefix = append(prefix, buffered...) + return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + } + + tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"])) + if tunnelHeader == "" { + // Not our tunnel; replay full bytes to legacy handler. + prefix := make([]byte, 0, len(headerBytes)+len(buffered)) + prefix = append(prefix, headerBytes...) + prefix = append(prefix, buffered...) + return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + } + if s.mode == TunnelModeLegacy { + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + switch TunnelMode(tunnelHeader) { + case TunnelModeStream: + if s.mode != TunnelModeStream && s.mode != TunnelModeAuto { + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + return s.handleStream(rawConn, req, buffered) + case TunnelModePoll: + if s.mode != TunnelModePoll && s.mode != TunnelModeAuto { + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + return s.handlePoll(rawConn, req, buffered) + default: + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } +} + +type httpRequestHeader struct { + method string + target string // path + query + proto string + headers map[string]string // lower-case keys +} + +func readHTTPHeader(r *bufio.Reader) (*httpRequestHeader, []byte, []byte, error) { + const maxHeaderBytes = 32 * 1024 + + var consumed bytes.Buffer + readLine := func() ([]byte, error) { + line, err := r.ReadSlice('\n') + if len(line) > 0 { + if consumed.Len()+len(line) > maxHeaderBytes { + return line, fmt.Errorf("http header too large") + } + consumed.Write(line) + } + return line, err + } + + // Request line + line, err := readLine() + if err != nil { + return nil, consumed.Bytes(), readAllBuffered(r), err + } + lineStr := strings.TrimRight(string(line), "\r\n") + parts := strings.SplitN(lineStr, " ", 3) + if len(parts) != 3 { + return nil, consumed.Bytes(), readAllBuffered(r), fmt.Errorf("invalid request line") + } + req := &httpRequestHeader{ + method: parts[0], + target: parts[1], + proto: parts[2], + headers: make(map[string]string), + } + + // Headers + for { + line, err = readLine() + if err != nil { + return nil, consumed.Bytes(), readAllBuffered(r), err + } + trimmed := strings.TrimRight(string(line), "\r\n") + if trimmed == "" { + break + } + k, v, ok := strings.Cut(trimmed, ":") + if !ok { + continue + } + k = strings.ToLower(strings.TrimSpace(k)) + v = strings.TrimSpace(v) + if k == "" { + continue + } + // Keep the first value; we only care about a small set. + if _, exists := req.headers[k]; !exists { + req.headers[k] = v + } + } + + return req, consumed.Bytes(), readAllBuffered(r), nil +} + +func readAllBuffered(r *bufio.Reader) []byte { + n := r.Buffered() + if n <= 0 { + return nil + } + b, err := r.Peek(n) + if err != nil { + return nil + } + out := make([]byte, n) + copy(out, b) + return out +} + +type preBufferedConn struct { + net.Conn + buf []byte +} + +func newPreBufferedConn(conn net.Conn, pre []byte) net.Conn { + cpy := make([]byte, len(pre)) + copy(cpy, pre) + return &preBufferedConn{Conn: conn, buf: cpy} +} + +func (p *preBufferedConn) Read(b []byte) (int, error) { + if len(p.buf) > 0 { + n := copy(b, p.buf) + p.buf = p.buf[n:] + return n, nil + } + return p.Conn.Read(b) +} + +type bodyConn struct { + net.Conn + reader io.Reader + writer io.WriteCloser + tail io.Writer + flush func() error +} + +func (c *bodyConn) Read(p []byte) (int, error) { return c.reader.Read(p) } +func (c *bodyConn) Write(p []byte) (int, error) { + n, err := c.writer.Write(p) + if c.flush != nil { + _ = c.flush() + } + return n, err +} + +func (c *bodyConn) Close() error { + var firstErr error + if c.writer != nil { + if err := c.writer.Close(); err != nil && firstErr == nil { + firstErr = err + } + // NewChunkedWriter does not write the final CRLF. Ensure a clean terminator. + if c.tail != nil { + _, _ = c.tail.Write([]byte("\r\n")) + } else { + _, _ = c.Conn.Write([]byte("\r\n")) + } + if c.flush != nil { + _ = c.flush() + } + } + if err := c.Conn.Close(); err != nil && firstErr == nil { + firstErr = err + } + return firstErr +} + +func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, buffered []byte) (HandleResult, net.Conn, error) { + u, err := url.ParseRequestURI(req.target) + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + // Only accept plausible paths to reduce accidental exposure. + if !isAllowedPath(req.target) { + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + token := u.Query().Get("token") + closeFlag := u.Query().Get("close") == "1" + + switch strings.ToUpper(req.method) { + case http.MethodGet: + // Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe. + if token == "" && u.Path == "/session" { + return s.authorizeSession(rawConn) + } + // Stream split-session: GET /stream?token=... => downlink poll. + if token != "" && u.Path == "/stream" { + return s.streamPull(rawConn, token) + } + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + + case http.MethodPost: + // Stream split-session: POST /api/v1/upload?token=... => uplink push. + if token != "" && u.Path == "/api/v1/upload" { + if closeFlag { + s.closeSession(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil + } + bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + return s.streamPush(rawConn, token, bodyReader) + } + + // Stream-one: single full-duplex POST. + if err := writeTunnelResponseHeader(rawConn); err != nil { + _ = rawConn.Close() + return HandleDone, nil, err + } + + bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) + if err != nil { + _ = rawConn.Close() + return HandleDone, nil, err + } + + bw := bufio.NewWriterSize(rawConn, 32*1024) + chunked := httputil.NewChunkedWriter(bw) + stream := &bodyConn{ + Conn: rawConn, + reader: bodyReader, + writer: chunked, + tail: bw, + flush: bw.Flush, + } + return HandleStartTunnel, stream, nil + + default: + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } +} + +func isAllowedPath(target string) bool { + u, err := url.ParseRequestURI(target) + if err != nil { + return false + } + for _, p := range paths { + if u.Path == p { + return true + } + } + return false +} + +func newRequestBodyReader(conn net.Conn, headers map[string]string) (io.Reader, error) { + br := bufio.NewReaderSize(conn, 32*1024) + + te := strings.ToLower(headers["transfer-encoding"]) + if strings.Contains(te, "chunked") { + return httputil.NewChunkedReader(br), nil + } + if clStr := headers["content-length"]; clStr != "" { + n, err := strconv.ParseInt(strings.TrimSpace(clStr), 10, 64) + if err != nil || n < 0 { + return nil, fmt.Errorf("invalid content-length") + } + return io.LimitReader(br, n), nil + } + return br, nil +} + +func writeTunnelResponseHeader(w io.Writer) error { + _, err := io.WriteString(w, + "HTTP/1.1 200 OK\r\n"+ + "Content-Type: application/octet-stream\r\n"+ + "Transfer-Encoding: chunked\r\n"+ + "Cache-Control: no-store\r\n"+ + "Pragma: no-cache\r\n"+ + "Connection: keep-alive\r\n"+ + "X-Accel-Buffering: no\r\n"+ + "\r\n") + return err +} + +func writeSimpleHTTPResponse(w io.Writer, code int, body string) error { + if body == "" { + body = http.StatusText(code) + } + body = strings.TrimRight(body, "\r\n") + _, err := io.WriteString(w, + fmt.Sprintf("HTTP/1.1 %d %s\r\nContent-Type: text/plain\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s", + code, http.StatusText(code), len(body), body)) + return err +} + +func writeTokenHTTPResponse(w io.Writer, token string) error { + token = strings.TrimRight(token, "\r\n") + // Use application/octet-stream to avoid CDN auto-compression (e.g. brotli) breaking clients that expect a plain token string. + _, err := io.WriteString(w, + fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\ntoken=%s", + len("token=")+len(token), token)) + return err +} + +func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, buffered []byte) (HandleResult, net.Conn, error) { + u, err := url.ParseRequestURI(req.target) + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + if !isAllowedPath(req.target) { + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + token := u.Query().Get("token") + closeFlag := u.Query().Get("close") == "1" + switch strings.ToUpper(req.method) { + case http.MethodGet: + if token == "" { + return s.authorizeSession(rawConn) + } + return s.pollPull(rawConn, token) + case http.MethodPost: + if token == "" { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "missing token") + _ = rawConn.Close() + return HandleDone, nil, nil + } + if closeFlag { + s.closeSession(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil + } + bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + return s.pollPush(rawConn, token, bodyReader) + default: + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } +} + +func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Conn, error) { + token, err := newSessionToken() + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + c1, c2 := net.Pipe() + + s.mu.Lock() + s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()} + s.mu.Unlock() + + go s.reapSessionLater(token) + + _ = writeTokenHTTPResponse(rawConn, token) + _ = rawConn.Close() + return HandleStartTunnel, c1, nil +} + +func newSessionToken() (string, error) { + var b [16]byte + if _, err := crand.Read(b[:]); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} + +func (s *TunnelServer) reapSessionLater(token string) { + ttl := s.sessionTTL + if ttl <= 0 { + return + } + timer := time.NewTimer(ttl) + defer timer.Stop() + <-timer.C + + s.mu.Lock() + sess, ok := s.sessions[token] + if !ok { + s.mu.Unlock() + return + } + if time.Since(sess.lastActive) < ttl { + s.mu.Unlock() + return + } + delete(s.sessions, token) + s.mu.Unlock() + _ = sess.conn.Close() +} + +func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[token] + if !ok { + return nil, false + } + sess.lastActive = time.Now() + return sess, true +} + +func (s *TunnelServer) closeSession(token string) { + s.mu.Lock() + sess, ok := s.sessions[token] + if ok { + delete(s.sessions, token) + } + s.mu.Unlock() + if ok { + _ = sess.conn.Close() + } +} + +func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) { + sess, ok := s.getSession(token) + if !ok { + _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + payload, err := io.ReadAll(io.LimitReader(body, 1<<20)) // 1MiB per request cap + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + lines := bytes.Split(payload, []byte{'\n'}) + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + decoded := make([]byte, base64.StdEncoding.DecodedLen(len(line))) + n, decErr := base64.StdEncoding.Decode(decoded, line) + if decErr != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + if n == 0 { + continue + } + _ = sess.conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + _, werr := sess.conn.Write(decoded[:n]) + _ = sess.conn.SetWriteDeadline(time.Time{}) + if werr != nil { + s.closeSession(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone") + _ = rawConn.Close() + return HandleDone, nil, nil + } + } + + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil +} + +func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) { + sess, ok := s.getSession(token) + if !ok { + _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + const maxUploadBytes = 1 << 20 + payload, err := io.ReadAll(io.LimitReader(body, maxUploadBytes+1)) + if err != nil { + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil + } + if len(payload) > maxUploadBytes { + _ = writeSimpleHTTPResponse(rawConn, http.StatusRequestEntityTooLarge, "too large") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + if len(payload) > 0 { + _ = sess.conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + _, werr := sess.conn.Write(payload) + _ = sess.conn.SetWriteDeadline(time.Time{}) + if werr != nil { + s.closeSession(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone") + _ = rawConn.Close() + return HandleDone, nil, nil + } + } + + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil +} + +func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) { + sess, ok := s.getSession(token) + if !ok { + _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + // Streaming response (chunked) with raw bytes (no base64 framing). + if err := writeTunnelResponseHeader(rawConn); err != nil { + _ = rawConn.Close() + return HandleDone, nil, err + } + + bw := bufio.NewWriterSize(rawConn, 32*1024) + cw := httputil.NewChunkedWriter(bw) + defer func() { + _ = cw.Close() + _, _ = bw.WriteString("\r\n") + _ = bw.Flush() + _ = rawConn.Close() + }() + + buf := make([]byte, 32*1024) + for { + _ = sess.conn.SetReadDeadline(time.Now().Add(s.pullReadTimeout)) + n, err := sess.conn.Read(buf) + if n > 0 { + _, _ = cw.Write(buf[:n]) + _ = bw.Flush() + } + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + // End this long-poll response; client will re-issue. + return HandleDone, nil, nil + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { + return HandleDone, nil, nil + } + s.closeSession(token) + return HandleDone, nil, nil + } + } +} + +func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) { + sess, ok := s.getSession(token) + if !ok { + _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") + _ = rawConn.Close() + return HandleDone, nil, nil + } + + // Streaming response (chunked) with base64 lines. + if err := writeTunnelResponseHeader(rawConn); err != nil { + _ = rawConn.Close() + return HandleDone, nil, err + } + + bw := bufio.NewWriterSize(rawConn, 32*1024) + cw := httputil.NewChunkedWriter(bw) + defer func() { + _ = cw.Close() + _, _ = bw.WriteString("\r\n") + _ = bw.Flush() + _ = rawConn.Close() + }() + + buf := make([]byte, 32*1024) + for { + _ = sess.conn.SetReadDeadline(time.Now().Add(s.pullReadTimeout)) + n, err := sess.conn.Read(buf) + if n > 0 { + line := make([]byte, base64.StdEncoding.EncodedLen(n)) + base64.StdEncoding.Encode(line, buf[:n]) + _, _ = cw.Write(append(line, '\n')) + _ = bw.Flush() + } + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + // Keepalive: send an empty line then end this long-poll response. + _, _ = cw.Write([]byte("\n")) + _ = bw.Flush() + return HandleDone, nil, nil + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { + return HandleDone, nil, nil + } + s.closeSession(token) + return HandleDone, nil, nil + } + } +} diff --git a/transport/sudoku/obfs/sudoku/conn.go b/transport/sudoku/obfs/sudoku/conn.go new file mode 100644 index 00000000..d09c8a68 --- /dev/null +++ b/transport/sudoku/obfs/sudoku/conn.go @@ -0,0 +1,212 @@ +package sudoku + +import ( + "bufio" + "bytes" + crypto_rand "crypto/rand" + "encoding/binary" + "errors" + "math/rand" + "net" + "sync" +) + +const IOBufferSize = 32 * 1024 + +var perm4 = [24][4]byte{ + {0, 1, 2, 3}, + {0, 1, 3, 2}, + {0, 2, 1, 3}, + {0, 2, 3, 1}, + {0, 3, 1, 2}, + {0, 3, 2, 1}, + {1, 0, 2, 3}, + {1, 0, 3, 2}, + {1, 2, 0, 3}, + {1, 2, 3, 0}, + {1, 3, 0, 2}, + {1, 3, 2, 0}, + {2, 0, 1, 3}, + {2, 0, 3, 1}, + {2, 1, 0, 3}, + {2, 1, 3, 0}, + {2, 3, 0, 1}, + {2, 3, 1, 0}, + {3, 0, 1, 2}, + {3, 0, 2, 1}, + {3, 1, 0, 2}, + {3, 1, 2, 0}, + {3, 2, 0, 1}, + {3, 2, 1, 0}, +} + +type Conn struct { + net.Conn + table *Table + reader *bufio.Reader + recorder *bytes.Buffer + recording bool + recordLock sync.Mutex + + rawBuf []byte + pendingData []byte + hintBuf []byte + + rng *rand.Rand + paddingRate float32 +} + +func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn { + var seedBytes [8]byte + if _, err := crypto_rand.Read(seedBytes[:]); err != nil { + binary.BigEndian.PutUint64(seedBytes[:], uint64(rand.Int63())) + } + seed := int64(binary.BigEndian.Uint64(seedBytes[:])) + localRng := rand.New(rand.NewSource(seed)) + + min := float32(pMin) / 100.0 + rng := float32(pMax-pMin) / 100.0 + rate := min + localRng.Float32()*rng + + sc := &Conn{ + Conn: c, + table: table, + reader: bufio.NewReaderSize(c, IOBufferSize), + rawBuf: make([]byte, IOBufferSize), + pendingData: make([]byte, 0, 4096), + hintBuf: make([]byte, 0, 4), + rng: localRng, + paddingRate: rate, + } + if record { + sc.recorder = new(bytes.Buffer) + sc.recording = true + } + return sc +} + +func (sc *Conn) StopRecording() { + sc.recordLock.Lock() + sc.recording = false + sc.recorder = nil + sc.recordLock.Unlock() +} + +func (sc *Conn) GetBufferedAndRecorded() []byte { + if sc == nil { + return nil + } + + sc.recordLock.Lock() + defer sc.recordLock.Unlock() + + var recorded []byte + if sc.recorder != nil { + recorded = sc.recorder.Bytes() + } + + buffered := sc.reader.Buffered() + if buffered > 0 { + peeked, _ := sc.reader.Peek(buffered) + full := make([]byte, len(recorded)+len(peeked)) + copy(full, recorded) + copy(full[len(recorded):], peeked) + return full + } + return recorded +} + +func (sc *Conn) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + + outCapacity := len(p) * 6 + out := make([]byte, 0, outCapacity) + pads := sc.table.PaddingPool + padLen := len(pads) + + for _, b := range p { + if sc.rng.Float32() < sc.paddingRate { + out = append(out, pads[sc.rng.Intn(padLen)]) + } + + puzzles := sc.table.EncodeTable[b] + puzzle := puzzles[sc.rng.Intn(len(puzzles))] + + perm := perm4[sc.rng.Intn(len(perm4))] + for _, idx := range perm { + if sc.rng.Float32() < sc.paddingRate { + out = append(out, pads[sc.rng.Intn(padLen)]) + } + out = append(out, puzzle[idx]) + } + } + + if sc.rng.Float32() < sc.paddingRate { + out = append(out, pads[sc.rng.Intn(padLen)]) + } + + _, err = sc.Conn.Write(out) + return len(p), err +} + +func (sc *Conn) Read(p []byte) (n int, err error) { + if len(sc.pendingData) > 0 { + n = copy(p, sc.pendingData) + if n == len(sc.pendingData) { + sc.pendingData = sc.pendingData[:0] + } else { + sc.pendingData = sc.pendingData[n:] + } + return n, nil + } + + for { + if len(sc.pendingData) > 0 { + break + } + + nr, rErr := sc.reader.Read(sc.rawBuf) + if nr > 0 { + chunk := sc.rawBuf[:nr] + sc.recordLock.Lock() + if sc.recording { + sc.recorder.Write(chunk) + } + sc.recordLock.Unlock() + + for _, b := range chunk { + if !sc.table.layout.isHint(b) { + continue + } + + sc.hintBuf = append(sc.hintBuf, b) + if len(sc.hintBuf) == 4 { + key := packHintsToKey([4]byte{sc.hintBuf[0], sc.hintBuf[1], sc.hintBuf[2], sc.hintBuf[3]}) + val, ok := sc.table.DecodeMap[key] + if !ok { + return 0, errors.New("INVALID_SUDOKU_MAP_MISS") + } + sc.pendingData = append(sc.pendingData, val) + sc.hintBuf = sc.hintBuf[:0] + } + } + } + + if rErr != nil { + return 0, rErr + } + if len(sc.pendingData) > 0 { + break + } + } + + n = copy(p, sc.pendingData) + if n == len(sc.pendingData) { + sc.pendingData = sc.pendingData[:0] + } else { + sc.pendingData = sc.pendingData[n:] + } + return n, nil +} diff --git a/transport/sudoku/obfs/sudoku/grid.go b/transport/sudoku/obfs/sudoku/grid.go new file mode 100644 index 00000000..3e802989 --- /dev/null +++ b/transport/sudoku/obfs/sudoku/grid.go @@ -0,0 +1,46 @@ +package sudoku + +// Grid represents a 4x4 sudoku grid +type Grid [16]uint8 + +// GenerateAllGrids generates all valid 4x4 Sudoku grids +func GenerateAllGrids() []Grid { + var grids []Grid + var g Grid + var backtrack func(int) + + backtrack = func(idx int) { + if idx == 16 { + grids = append(grids, g) + return + } + row, col := idx/4, idx%4 + br, bc := (row/2)*2, (col/2)*2 + for num := uint8(1); num <= 4; num++ { + valid := true + for i := 0; i < 4; i++ { + if g[row*4+i] == num || g[i*4+col] == num { + valid = false + break + } + } + if valid { + for r := 0; r < 2; r++ { + for c := 0; c < 2; c++ { + if g[(br+r)*4+(bc+c)] == num { + valid = false + break + } + } + } + } + if valid { + g[idx] = num + backtrack(idx + 1) + g[idx] = 0 + } + } + } + backtrack(0) + return grids +} diff --git a/transport/sudoku/obfs/sudoku/layout.go b/transport/sudoku/obfs/sudoku/layout.go new file mode 100644 index 00000000..72c569f5 --- /dev/null +++ b/transport/sudoku/obfs/sudoku/layout.go @@ -0,0 +1,204 @@ +package sudoku + +import ( + "fmt" + "math/bits" + "sort" + "strings" +) + +type byteLayout struct { + name string + hintMask byte + hintValue byte + padMarker byte + paddingPool []byte + + encodeHint func(val, pos byte) byte + encodeGroup func(group byte) byte + decodeGroup func(b byte) (byte, bool) +} + +func (l *byteLayout) isHint(b byte) bool { + return (b & l.hintMask) == l.hintValue +} + +// resolveLayout picks the byte layout based on ASCII preference and optional custom pattern. +// ASCII always wins if requested. Custom patterns are ignored when ASCII is preferred. +func resolveLayout(mode string, customPattern string) (*byteLayout, error) { + switch strings.ToLower(mode) { + case "ascii", "prefer_ascii": + return newASCIILayout(), nil + case "entropy", "prefer_entropy", "": + // fallback to entropy unless a custom pattern is provided + default: + return nil, fmt.Errorf("invalid ascii mode: %s", mode) + } + + if strings.TrimSpace(customPattern) != "" { + return newCustomLayout(customPattern) + } + return newEntropyLayout(), nil +} + +func newASCIILayout() *byteLayout { + padding := make([]byte, 0, 32) + for i := 0; i < 32; i++ { + padding = append(padding, byte(0x20+i)) + } + return &byteLayout{ + name: "ascii", + hintMask: 0x40, + hintValue: 0x40, + padMarker: 0x3F, + paddingPool: padding, + encodeHint: func(val, pos byte) byte { + return 0x40 | ((val & 0x03) << 4) | (pos & 0x0F) + }, + encodeGroup: func(group byte) byte { + return 0x40 | (group & 0x3F) + }, + decodeGroup: func(b byte) (byte, bool) { + if (b & 0x40) == 0 { + return 0, false + } + return b & 0x3F, true + }, + } +} + +func newEntropyLayout() *byteLayout { + padding := make([]byte, 0, 16) + for i := 0; i < 8; i++ { + padding = append(padding, byte(0x80+i)) + padding = append(padding, byte(0x10+i)) + } + return &byteLayout{ + name: "entropy", + hintMask: 0x90, + hintValue: 0x00, + padMarker: 0x80, + paddingPool: padding, + encodeHint: func(val, pos byte) byte { + return ((val & 0x03) << 5) | (pos & 0x0F) + }, + encodeGroup: func(group byte) byte { + v := group & 0x3F + return ((v & 0x30) << 1) | (v & 0x0F) + }, + decodeGroup: func(b byte) (byte, bool) { + if (b & 0x90) != 0 { + return 0, false + } + return ((b >> 1) & 0x30) | (b & 0x0F), true + }, + } +} + +func newCustomLayout(pattern string) (*byteLayout, error) { + cleaned := strings.ToLower(strings.ReplaceAll(strings.TrimSpace(pattern), " ", "")) + if len(cleaned) != 8 { + return nil, fmt.Errorf("custom table must have 8 symbols, got %d", len(cleaned)) + } + + var xBits, pBits, vBits []uint8 + for i, c := range cleaned { + bit := uint8(7 - i) + switch c { + case 'x': + xBits = append(xBits, bit) + case 'p': + pBits = append(pBits, bit) + case 'v': + vBits = append(vBits, bit) + default: + return nil, fmt.Errorf("invalid char %q in custom table", c) + } + } + + if len(xBits) != 2 || len(pBits) != 2 || len(vBits) != 4 { + return nil, fmt.Errorf("custom table must contain exactly 2 x, 2 p, 4 v") + } + + xMask := byte(0) + for _, b := range xBits { + xMask |= 1 << b + } + + encodeBits := func(val, pos byte, dropX int) byte { + var out byte + out |= xMask + if dropX >= 0 { + out &^= 1 << xBits[dropX] + } + if (val & 0x02) != 0 { + out |= 1 << pBits[0] + } + if (val & 0x01) != 0 { + out |= 1 << pBits[1] + } + for i, bit := range vBits { + if (pos>>(3-uint8(i)))&0x01 == 1 { + out |= 1 << bit + } + } + return out + } + + decodeGroup := func(b byte) (byte, bool) { + if (b & xMask) != xMask { + return 0, false + } + var val, pos byte + if b&(1<= 5 { + if _, ok := paddingSet[b]; !ok { + paddingSet[b] = struct{}{} + padding = append(padding, b) + } + } + } + } + } + sort.Slice(padding, func(i, j int) bool { return padding[i] < padding[j] }) + if len(padding) == 0 { + return nil, fmt.Errorf("custom table produced empty padding pool") + } + + return &byteLayout{ + name: fmt.Sprintf("custom(%s)", cleaned), + hintMask: xMask, + hintValue: xMask, + padMarker: padding[0], + paddingPool: padding, + encodeHint: func(val, pos byte) byte { + return encodeBits(val, pos, -1) + }, + encodeGroup: func(group byte) byte { + val := (group >> 4) & 0x03 + pos := group & 0x0F + return encodeBits(val, pos, -1) + }, + decodeGroup: decodeGroup, + }, nil +} diff --git a/transport/sudoku/obfs/sudoku/packed.go b/transport/sudoku/obfs/sudoku/packed.go new file mode 100644 index 00000000..567afe73 --- /dev/null +++ b/transport/sudoku/obfs/sudoku/packed.go @@ -0,0 +1,332 @@ +package sudoku + +import ( + "bufio" + crypto_rand "crypto/rand" + "encoding/binary" + "io" + "math/rand" + "net" + "sync" +) + +const ( + // 每次从 RNG 获取批量随机数的缓存大小,减少 RNG 函数调用开销 + RngBatchSize = 128 +) + +// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销) +// 2. 使用浮点随机概率判断 Padding,与纯 Sudoku 保持流量特征一致 +// 3. Read 使用 copy 移动避免底层数组泄漏 +type PackedConn struct { + net.Conn + table *Table + reader *bufio.Reader + + // 读缓冲 + rawBuf []byte + pendingData []byte // 解码后尚未被 Read 取走的字节 + + // 写缓冲与状态 + writeMu sync.Mutex + writeBuf []byte + bitBuf uint64 // 暂存的位数据 + bitCount int // 暂存的位数 + + // 读状态 + readBitBuf uint64 + readBits int + + // 随机数与填充控制 - 使用浮点随机,与 Conn 一致 + rng *rand.Rand + paddingRate float32 // 与 Conn 保持一致的随机概率模型 + padMarker byte + padPool []byte +} + +func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { + var seedBytes [8]byte + if _, err := crypto_rand.Read(seedBytes[:]); err != nil { + binary.BigEndian.PutUint64(seedBytes[:], uint64(rand.Int63())) + } + seed := int64(binary.BigEndian.Uint64(seedBytes[:])) + localRng := rand.New(rand.NewSource(seed)) + + // 与 Conn 保持一致的 padding 概率计算 + min := float32(pMin) / 100.0 + rng := float32(pMax-pMin) / 100.0 + rate := min + localRng.Float32()*rng + + pc := &PackedConn{ + Conn: c, + table: table, + reader: bufio.NewReaderSize(c, IOBufferSize), + rawBuf: make([]byte, IOBufferSize), + pendingData: make([]byte, 0, 4096), + writeBuf: make([]byte, 0, 4096), + rng: localRng, + paddingRate: rate, + } + + pc.padMarker = table.layout.padMarker + for _, b := range table.PaddingPool { + if b != pc.padMarker { + pc.padPool = append(pc.padPool, b) + } + } + if len(pc.padPool) == 0 { + pc.padPool = append(pc.padPool, pc.padMarker) + } + return pc +} + +// maybeAddPadding 内联辅助:根据浮点概率插入 padding +func (pc *PackedConn) maybeAddPadding(out []byte) []byte { + if pc.rng.Float32() < pc.paddingRate { + out = append(out, pc.getPaddingByte()) + } + return out +} + +// Write 极致优化版 - 批量处理 12 字节 +func (pc *PackedConn) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + + // 1. 预分配内存,避免 append 导致的多次扩容 + // 预估:原数据 * 1.5 (4/3 + padding 余量) + needed := len(p)*3/2 + 32 + if cap(pc.writeBuf) < needed { + pc.writeBuf = make([]byte, 0, needed) + } + out := pc.writeBuf[:0] + + i := 0 + n := len(p) + + // 2. 头部对齐处理 (Slow Path) + for pc.bitCount > 0 && i < n { + out = pc.maybeAddPadding(out) + b := p[i] + i++ + pc.bitBuf = (pc.bitBuf << 8) | uint64(b) + pc.bitCount += 8 + for pc.bitCount >= 6 { + pc.bitCount -= 6 + group := byte(pc.bitBuf >> pc.bitCount) + if pc.bitCount == 0 { + pc.bitBuf = 0 + } else { + pc.bitBuf &= (1 << pc.bitCount) - 1 + } + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(group&0x3F)) + } + } + + // 3. 极速批量处理 (Fast Path) - 每次处理 12 字节 → 生成 16 个编码组 + for i+11 < n { + // 处理 4 组,每组 3 字节 + for batch := 0; batch < 4; batch++ { + b1, b2, b3 := p[i], p[i+1], p[i+2] + i += 3 + + g1 := (b1 >> 2) & 0x3F + g2 := ((b1 & 0x03) << 4) | ((b2 >> 4) & 0x0F) + g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) + g4 := b3 & 0x3F + + // 每个组之前都有概率插入 padding + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g1)) + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g2)) + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g3)) + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g4)) + } + } + + // 4. 处理剩余的 3 字节块 + for i+2 < n { + b1, b2, b3 := p[i], p[i+1], p[i+2] + i += 3 + + g1 := (b1 >> 2) & 0x3F + g2 := ((b1 & 0x03) << 4) | ((b2 >> 4) & 0x0F) + g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) + g4 := b3 & 0x3F + + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g1)) + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g2)) + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g3)) + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(g4)) + } + + // 5. 尾部处理 (Tail Path) - 处理剩余的 1 或 2 个字节 + for ; i < n; i++ { + b := p[i] + pc.bitBuf = (pc.bitBuf << 8) | uint64(b) + pc.bitCount += 8 + for pc.bitCount >= 6 { + pc.bitCount -= 6 + group := byte(pc.bitBuf >> pc.bitCount) + if pc.bitCount == 0 { + pc.bitBuf = 0 + } else { + pc.bitBuf &= (1 << pc.bitCount) - 1 + } + out = pc.maybeAddPadding(out) + out = append(out, pc.encodeGroup(group&0x3F)) + } + } + + // 6. 处理残留位 + if pc.bitCount > 0 { + out = pc.maybeAddPadding(out) + group := byte(pc.bitBuf << (6 - pc.bitCount)) + pc.bitBuf = 0 + pc.bitCount = 0 + out = append(out, pc.encodeGroup(group&0x3F)) + out = append(out, pc.padMarker) + } + + // 尾部可能添加 padding + out = pc.maybeAddPadding(out) + + // 发送数据 + if len(out) > 0 { + _, err := pc.Conn.Write(out) + pc.writeBuf = out[:0] + return len(p), err + } + pc.writeBuf = out[:0] + return len(p), nil +} + +// Flush 处理最后不足 6 bit 的情况 +func (pc *PackedConn) Flush() error { + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + + out := pc.writeBuf[:0] + if pc.bitCount > 0 { + group := byte(pc.bitBuf << (6 - pc.bitCount)) + pc.bitBuf = 0 + pc.bitCount = 0 + + out = append(out, pc.encodeGroup(group&0x3F)) + out = append(out, pc.padMarker) + } + + // 尾部随机添加 padding + out = pc.maybeAddPadding(out) + + if len(out) > 0 { + _, err := pc.Conn.Write(out) + pc.writeBuf = out[:0] + return err + } + return nil +} + +// Read 优化版:减少切片操作,避免内存泄漏 +func (pc *PackedConn) Read(p []byte) (int, error) { + // 1. 优先返回待处理区的数据 + if len(pc.pendingData) > 0 { + n := copy(p, pc.pendingData) + if n == len(pc.pendingData) { + pc.pendingData = pc.pendingData[:0] + } else { + // 优化:移动剩余数据到数组头部,避免切片指向中间导致内存泄漏 + remaining := len(pc.pendingData) - n + copy(pc.pendingData, pc.pendingData[n:]) + pc.pendingData = pc.pendingData[:remaining] + } + return n, nil + } + + // 2. 循环读取直到解出数据或出错 + for { + nr, rErr := pc.reader.Read(pc.rawBuf) + if nr > 0 { + // 缓存频繁访问的变量 + rBuf := pc.readBitBuf + rBits := pc.readBits + padMarker := pc.padMarker + layout := pc.table.layout + + for _, b := range pc.rawBuf[:nr] { + if !layout.isHint(b) { + if b == padMarker { + rBuf = 0 + rBits = 0 + } + continue + } + + group, ok := layout.decodeGroup(b) + if !ok { + return 0, ErrInvalidSudokuMapMiss + } + + rBuf = (rBuf << 6) | uint64(group) + rBits += 6 + + if rBits >= 8 { + rBits -= 8 + val := byte(rBuf >> rBits) + pc.pendingData = append(pc.pendingData, val) + } + } + + pc.readBitBuf = rBuf + pc.readBits = rBits + } + + if rErr != nil { + if rErr == io.EOF { + pc.readBitBuf = 0 + pc.readBits = 0 + } + if len(pc.pendingData) > 0 { + break + } + return 0, rErr + } + + if len(pc.pendingData) > 0 { + break + } + } + + // 3. 返回解码后的数据 - 优化:避免底层数组泄漏 + n := copy(p, pc.pendingData) + if n == len(pc.pendingData) { + pc.pendingData = pc.pendingData[:0] + } else { + remaining := len(pc.pendingData) - n + copy(pc.pendingData, pc.pendingData[n:]) + pc.pendingData = pc.pendingData[:remaining] + } + return n, nil +} + +// getPaddingByte 从 Pool 中随机取 Padding 字节 +func (pc *PackedConn) getPaddingByte() byte { + return pc.padPool[pc.rng.Intn(len(pc.padPool))] +} + +// encodeGroup 编码 6-bit 组 +func (pc *PackedConn) encodeGroup(group byte) byte { + return pc.table.layout.encodeGroup(group) +} diff --git a/transport/sudoku/obfs/sudoku/table.go b/transport/sudoku/obfs/sudoku/table.go new file mode 100644 index 00000000..d86e642f --- /dev/null +++ b/transport/sudoku/obfs/sudoku/table.go @@ -0,0 +1,153 @@ +package sudoku + +import ( + "crypto/sha256" + "encoding/binary" + "errors" + "log" + "math/rand" + "time" +) + +var ( + ErrInvalidSudokuMapMiss = errors.New("INVALID_SUDOKU_MAP_MISS") +) + +type Table struct { + EncodeTable [256][][4]byte + DecodeMap map[uint32]byte + PaddingPool []byte + IsASCII bool // 标记当前模式 + layout *byteLayout +} + +// NewTable initializes the obfuscation tables with built-in layouts. +// Equivalent to calling NewTableWithCustom(key, mode, ""). +func NewTable(key string, mode string) *Table { + t, err := NewTableWithCustom(key, mode, "") + if err != nil { + log.Panicf("failed to build table: %v", err) + } + return t +} + +// NewTableWithCustom initializes obfuscation tables using either predefined or custom layouts. +// mode: "prefer_ascii" or "prefer_entropy". If a custom pattern is provided, ASCII mode still takes precedence. +// The customPattern must contain 8 characters with exactly 2 x, 2 p, and 4 v (case-insensitive). +func NewTableWithCustom(key string, mode string, customPattern string) (*Table, error) { + start := time.Now() + + layout, err := resolveLayout(mode, customPattern) + if err != nil { + return nil, err + } + + t := &Table{ + DecodeMap: make(map[uint32]byte), + IsASCII: layout.name == "ascii", + layout: layout, + } + t.PaddingPool = append(t.PaddingPool, layout.paddingPool...) + + // 生成数独网格 (逻辑不变) + allGrids := GenerateAllGrids() + h := sha256.New() + h.Write([]byte(key)) + seed := int64(binary.BigEndian.Uint64(h.Sum(nil)[:8])) + rng := rand.New(rand.NewSource(seed)) + + shuffledGrids := make([]Grid, 288) + copy(shuffledGrids, allGrids) + rng.Shuffle(len(shuffledGrids), func(i, j int) { + shuffledGrids[i], shuffledGrids[j] = shuffledGrids[j], shuffledGrids[i] + }) + + // 预计算组合 + var combinations [][]int + var combine func(int, int, []int) + combine = func(s, k int, c []int) { + if k == 0 { + tmp := make([]int, len(c)) + copy(tmp, c) + combinations = append(combinations, tmp) + return + } + for i := s; i <= 16-k; i++ { + c = append(c, i) + combine(i+1, k-1, c) + c = c[:len(c)-1] + } + } + combine(0, 4, []int{}) + + // 构建映射表 + for byteVal := 0; byteVal < 256; byteVal++ { + targetGrid := shuffledGrids[byteVal] + for _, positions := range combinations { + var currentHints [4]byte + + // 1. 计算抽象提示 (Abstract Hints) + // 我们先计算出 val 和 pos,后面再根据模式编码成 byte + var rawParts [4]struct{ val, pos byte } + + for i, pos := range positions { + val := targetGrid[pos] // 1..4 + rawParts[i] = struct{ val, pos byte }{val, uint8(pos)} + } + + // 检查唯一性 (数独逻辑) + matchCount := 0 + for _, g := range allGrids { + match := true + for _, p := range rawParts { + if g[p.pos] != p.val { + match = false + break + } + } + if match { + matchCount++ + if matchCount > 1 { + break + } + } + } + + if matchCount == 1 { + // 唯一确定,生成最终编码字节 + for i, p := range rawParts { + currentHints[i] = t.layout.encodeHint(p.val-1, p.pos) + } + + t.EncodeTable[byteVal] = append(t.EncodeTable[byteVal], currentHints) + // 生成解码键 (需要对 Hints 进行排序以忽略传输顺序) + key := packHintsToKey(currentHints) + t.DecodeMap[key] = byte(byteVal) + } + } + } + log.Printf("[Init] Sudoku Tables initialized (%s) in %v", layout.name, time.Since(start)) + return t, nil +} + +func packHintsToKey(hints [4]byte) uint32 { + // Sorting network for 4 elements (Bubble sort unrolled) + // Swap if a > b + if hints[0] > hints[1] { + hints[0], hints[1] = hints[1], hints[0] + } + if hints[2] > hints[3] { + hints[2], hints[3] = hints[3], hints[2] + } + if hints[0] > hints[2] { + hints[0], hints[2] = hints[2], hints[0] + } + if hints[1] > hints[3] { + hints[1], hints[3] = hints[3], hints[1] + } + if hints[1] > hints[2] { + hints[1], hints[2] = hints[2], hints[1] + } + + return uint32(hints[0])<<24 | uint32(hints[1])<<16 | uint32(hints[2])<<8 | uint32(hints[3]) +} diff --git a/transport/sudoku/obfs/sudoku/table_set.go b/transport/sudoku/obfs/sudoku/table_set.go new file mode 100644 index 00000000..59d3c98f --- /dev/null +++ b/transport/sudoku/obfs/sudoku/table_set.go @@ -0,0 +1,38 @@ +package sudoku + +import "fmt" + +// TableSet is a small helper for managing multiple Sudoku tables (e.g. for per-connection rotation). +// It is intentionally decoupled from the tunnel/app layers. +type TableSet struct { + Tables []*Table +} + +// NewTableSet builds one or more tables from key/mode and a list of custom X/P/V patterns. +// If patterns is empty, it builds a single default table (customPattern=""). +func NewTableSet(key string, mode string, patterns []string) (*TableSet, error) { + if len(patterns) == 0 { + t, err := NewTableWithCustom(key, mode, "") + if err != nil { + return nil, err + } + return &TableSet{Tables: []*Table{t}}, nil + } + + tables := make([]*Table, 0, len(patterns)) + for i, pattern := range patterns { + t, err := NewTableWithCustom(key, mode, pattern) + if err != nil { + return nil, fmt.Errorf("build table[%d] (%q): %w", i, pattern, err) + } + tables = append(tables, t) + } + return &TableSet{Tables: tables}, nil +} + +func (ts *TableSet) Candidates() []*Table { + if ts == nil { + return nil + } + return ts.Tables +} diff --git a/transport/sudoku/obfs_writer.go b/transport/sudoku/obfs_writer.go index f9803591..3dc94b4e 100644 --- a/transport/sudoku/obfs_writer.go +++ b/transport/sudoku/obfs_writer.go @@ -6,7 +6,7 @@ import ( "math/rand" "net" - "github.com/saba-futai/sudoku/pkg/obfs/sudoku" + "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) // perm4 matches github.com/saba-futai/sudoku/pkg/obfs/sudoku perm4. diff --git a/transport/sudoku/table_probe.go b/transport/sudoku/table_probe.go index f12c1722..8def6fd4 100644 --- a/transport/sudoku/table_probe.go +++ b/transport/sudoku/table_probe.go @@ -10,26 +10,12 @@ import ( "net" "time" - "github.com/saba-futai/sudoku/apis" - "github.com/saba-futai/sudoku/pkg/crypto" - "github.com/saba-futai/sudoku/pkg/obfs/sudoku" + "github.com/metacubex/mihomo/transport/sudoku/crypto" + "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) -func tableCandidates(cfg *apis.ProtocolConfig) []*sudoku.Table { - if cfg == nil { - return nil - } - if len(cfg.Tables) > 0 { - return cfg.Tables - } - if cfg.Table != nil { - return []*sudoku.Table{cfg.Table} - } - return nil -} - -func pickClientTable(cfg *apis.ProtocolConfig) (*sudoku.Table, byte, error) { - candidates := tableCandidates(cfg) +func pickClientTable(cfg *ProtocolConfig) (*sudoku.Table, byte, error) { + candidates := cfg.tableCandidates() if len(candidates) == 0 { return nil, 0, fmt.Errorf("no table configured") } @@ -62,7 +48,7 @@ func drainBuffered(r *bufio.Reader) ([]byte, error) { return out, err } -func probeHandshakeBytes(probe []byte, cfg *apis.ProtocolConfig, table *sudoku.Table) error { +func probeHandshakeBytes(probe []byte, cfg *ProtocolConfig, table *sudoku.Table) error { rc := &readOnlyConn{Reader: bytes.NewReader(probe)} _, obfsConn := buildServerObfsConn(rc, cfg, table, false) cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod) @@ -90,7 +76,7 @@ func probeHandshakeBytes(probe []byte, cfg *apis.ProtocolConfig, table *sudoku.T return nil } -func selectTableByProbe(r *bufio.Reader, cfg *apis.ProtocolConfig, tables []*sudoku.Table) (*sudoku.Table, []byte, error) { +func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.Table) (*sudoku.Table, []byte, error) { const ( maxProbeBytes = 64 * 1024 readChunk = 4 * 1024 diff --git a/transport/sudoku/tables.go b/transport/sudoku/tables.go index 429a4ab3..2630ea52 100644 --- a/transport/sudoku/tables.go +++ b/transport/sudoku/tables.go @@ -3,7 +3,7 @@ package sudoku import ( "strings" - "github.com/saba-futai/sudoku/pkg/obfs/sudoku" + "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) // NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns.