mihomo/transport/sudoku/handshake_test.go
2025-12-10 17:27:45 +08:00

250 lines
5.6 KiB
Go

package sudoku
import (
"bytes"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/saba-futai/sudoku/apis"
sudokuobfs "github.com/saba-futai/sudoku/pkg/obfs/sudoku"
)
func TestPackedConnRoundTrip_WithPadding(t *testing.T) {
payload := []byte{0x3a, 0x1f, 0x71, 0x00, 0xff, 0x10, 0x22}
tableTypes := []string{"prefer_ascii", "prefer_entropy"}
for _, tt := range tableTypes {
t.Run(tt, func(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
table := sudokuobfs.NewTable("roundtrip-seed", tt)
writer := sudokuobfs.NewPackedConn(serverConn, table, 30, 80)
reader := sudokuobfs.NewPackedConn(clientConn, table, 30, 80)
writeErr := make(chan error, 1)
go func() {
if _, err := writer.Write(payload); err != nil {
writeErr <- err
return
}
if err := writer.Flush(); err != nil {
writeErr <- err
return
}
writeErr <- serverConn.Close()
}()
done := make(chan struct{})
var got []byte
var readErr error
go func() {
got, readErr = io.ReadAll(reader)
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("read timeout")
}
if err := <-writeErr; err != nil && err != io.EOF {
t.Fatalf("write side error: %v", err)
}
if readErr != nil && readErr != io.EOF {
t.Fatalf("read side error: %v", readErr)
}
if !bytes.Equal(got, payload) {
t.Fatalf("payload mismatch, want %x got %x", payload, got)
}
})
}
}
func newPackedConfig(table *sudokuobfs.Table) *apis.ProtocolConfig {
cfg := apis.DefaultConfig()
cfg.Key = "sudoku-test-key"
cfg.Table = table
cfg.PaddingMin = 10
cfg.PaddingMax = 30
cfg.EnablePureDownlink = false
cfg.ServerAddress = "example.com:443"
cfg.DisableHTTPMask = true
return cfg
}
func TestPackedDownlinkSoak(t *testing.T) {
const sessions = 16
table := sudokuobfs.NewTable("soak-seed", "prefer_ascii")
cfg := newPackedConfig(table)
var wg sync.WaitGroup
errCh := make(chan error, sessions*2)
for i := 0; i < sessions; i++ {
wg.Add(2)
go func(id int) {
defer wg.Done()
runPackedTCPSession(id, cfg, errCh)
}(i)
go func(id int) {
defer wg.Done()
runPackedUoTSession(id, cfg, errCh)
}(i)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(10 * time.Second):
t.Fatal("soak test timeout")
}
close(errCh)
for err := range errCh {
t.Fatalf("soak error: %v", err)
}
}
func runPackedTCPSession(id int, cfg *apis.ProtocolConfig, errCh chan<- error) {
serverConn, clientConn := net.Pipe()
target := fmt.Sprintf("1.1.1.%d:80", (id%200)+1)
payload := []byte{0x42, byte(id)}
// Server side
go func() {
session, err := ServerHandshake(serverConn, cfg)
if err != nil {
errCh <- fmt.Errorf("server handshake tcp: %w", err)
return
}
defer session.Conn.Close()
if session.Type != SessionTypeTCP {
errCh <- fmt.Errorf("unexpected session type: %v", session.Type)
return
}
if session.Target != target {
errCh <- fmt.Errorf("target mismatch want %s got %s", target, session.Target)
return
}
if _, err := session.Conn.Write(payload); err != nil {
errCh <- fmt.Errorf("server write: %w", err)
return
}
}()
// Client side
clientCfg := *cfg
cConn, err := ClientHandshake(clientConn, &clientCfg)
if err != nil {
errCh <- fmt.Errorf("client handshake tcp: %w", err)
return
}
defer cConn.Close()
addrBuf, err := EncodeAddress(target)
if err != nil {
errCh <- fmt.Errorf("encode address: %w", err)
return
}
if _, err := cConn.Write(addrBuf); err != nil {
errCh <- fmt.Errorf("client send addr: %w", err)
return
}
buf := make([]byte, len(payload))
if _, err := io.ReadFull(cConn, buf); err != nil {
errCh <- fmt.Errorf("client read: %w", err)
return
}
if !bytes.Equal(buf, payload) {
errCh <- fmt.Errorf("payload mismatch want %x got %x", payload, buf)
return
}
}
func runPackedUoTSession(id int, cfg *apis.ProtocolConfig, errCh chan<- error) {
serverConn, clientConn := net.Pipe()
target := "8.8.8.8:53"
payload := []byte{0xaa, byte(id)}
// Server side
go func() {
session, err := ServerHandshake(serverConn, cfg)
if err != nil {
errCh <- fmt.Errorf("server handshake uot: %w", err)
return
}
defer session.Conn.Close()
if session.Type != SessionTypeUoT {
errCh <- fmt.Errorf("unexpected session type: %v", session.Type)
return
}
if err := WriteDatagram(session.Conn, target, payload); err != nil {
errCh <- fmt.Errorf("server write datagram: %w", err)
return
}
}()
// Client side
clientCfg := *cfg
cConn, err := ClientHandshake(clientConn, &clientCfg)
if err != nil {
errCh <- fmt.Errorf("client handshake uot: %w", err)
return
}
defer cConn.Close()
if err := WritePreface(cConn); err != nil {
errCh <- fmt.Errorf("client write preface: %w", err)
return
}
addr, data, err := ReadDatagram(cConn)
if err != nil {
errCh <- fmt.Errorf("client read datagram: %w", err)
return
}
if addr != target {
errCh <- fmt.Errorf("uot target mismatch want %s got %s", target, addr)
return
}
if !bytes.Equal(data, payload) {
errCh <- fmt.Errorf("uot payload mismatch want %x got %x", payload, data)
return
}
}
func TestCustomTableHandshake(t *testing.T) {
table, err := sudokuobfs.NewTableWithCustom("custom-seed", "prefer_entropy", "xpxvvpvv")
if err != nil {
t.Fatalf("build custom table: %v", err)
}
cfg := newPackedConfig(table)
errCh := make(chan error, 2)
runPackedTCPSession(42, cfg, errCh)
runPackedUoTSession(43, cfg, errCh)
close(errCh)
for err := range errCh {
if err != nil {
t.Fatalf("custom table handshake failed: %v", err)
}
}
}