chore: sharing sudoku internal code

This commit is contained in:
wwqgtxx 2025-12-03 22:23:37 +08:00
parent e4cdb9b600
commit 30891f8781
5 changed files with 108 additions and 133 deletions

View File

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -19,7 +18,7 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
ts "github.com/metacubex/mihomo/transport/sudoku" "github.com/metacubex/mihomo/transport/sudoku"
) )
type Sudoku struct { type Sudoku struct {
@ -101,12 +100,12 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
return nil, err return nil, err
} }
if err = ts.WritePreface(c); err != nil { if err = sudoku.WritePreface(c); err != nil {
_ = c.Close() _ = c.Close()
return nil, fmt.Errorf("send uot preface failed: %w", err) return nil, fmt.Errorf("send uot preface failed: %w", err)
} }
return newPacketConn(N.NewThreadSafePacketConn(ts.NewUoTPacketConn(c)), s), nil return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil
} }
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
@ -163,7 +162,12 @@ func (s *Sudoku) streamConn(rawConn net.Conn, cfg *apis.ProtocolConfig) (_ net.C
return nil, err return nil, err
} }
if err = writeTargetAddress(cConn, cfg.TargetAddress); err != nil { addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
if err != nil {
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if _, err = cConn.Write(addrBuf); err != nil {
cConn.Close() cConn.Close()
return nil, fmt.Errorf("send target address failed: %w", err) return nil, fmt.Errorf("send target address failed: %w", err)
} }
@ -256,40 +260,3 @@ func buildSudokuHandshakePayload(key string) [16]byte {
copy(payload[8:], hash[:8]) copy(payload[8:], hash[:8])
return payload return payload
} }
func writeTargetAddress(w io.Writer, rawAddr string) error {
host, portStr, err := net.SplitHostPort(rawAddr)
if err != nil {
return err
}
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return err
}
var buf []byte
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, 0x01) // IPv4
buf = append(buf, ip4...)
} else {
buf = append(buf, 0x04) // IPv6
buf = append(buf, ip...)
}
} else {
if len(host) > 255 {
return fmt.Errorf("domain too long")
}
buf = append(buf, 0x03) // domain
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], uint16(portInt))
buf = append(buf, portBytes[:]...)
_, err = w.Write(buf)
return err
}

View File

@ -14,7 +14,7 @@ import (
LC "github.com/metacubex/mihomo/listener/config" LC "github.com/metacubex/mihomo/listener/config"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
ts "github.com/metacubex/mihomo/transport/sudoku" "github.com/metacubex/mihomo/transport/sudoku"
) )
type Listener struct { type Listener struct {
@ -47,14 +47,14 @@ func (l *Listener) Close() error {
} }
func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) { func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) {
session, err := ts.ServerHandshake(conn, &l.protoConf) session, err := sudoku.ServerHandshake(conn, &l.protoConf)
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return return
} }
switch session.Type { switch session.Type {
case ts.SessionTypeUoT: case sudoku.SessionTypeUoT:
l.handleUoTSession(session.Conn, tunnel, additions...) l.handleUoTSession(session.Conn, tunnel, additions...)
default: default:
targetAddr := socks5.ParseAddr(session.Target) targetAddr := socks5.ParseAddr(session.Target)
@ -67,11 +67,11 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou
} }
func (l *Listener) handleUoTSession(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) { func (l *Listener) handleUoTSession(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) {
writer := ts.NewUoTPacketConn(conn) writer := sudoku.NewUoTPacketConn(conn)
remoteAddr := conn.RemoteAddr() remoteAddr := conn.RemoteAddr()
for { for {
addrStr, payload, err := ts.ReadDatagram(conn) addrStr, payload, err := sudoku.ReadDatagram(conn)
if err != nil { if err != nil {
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Debugln("[Sudoku][UoT] session closed: %v", err) log.Debugln("[Sudoku][UoT] session closed: %v", err)
@ -97,7 +97,7 @@ func (l *Listener) handleUoTSession(conn net.Conn, tunnel C.Tunnel, additions ..
type uotPacket struct { type uotPacket struct {
payload []byte payload []byte
writer *ts.UoTPacketConn writer *sudoku.UoTPacketConn
rAddr net.Addr rAddr net.Addr
} }

View File

@ -0,0 +1,90 @@
package sudoku
import (
"encoding/binary"
"fmt"
"io"
"net"
)
func EncodeAddress(rawAddr string) ([]byte, error) {
host, portStr, err := net.SplitHostPort(rawAddr)
if err != nil {
return nil, err
}
portInt, err := net.LookupPort("udp", portStr)
if err != nil {
return nil, err
}
var buf []byte
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, 0x01) // IPv4
buf = append(buf, ip4...)
} else {
buf = append(buf, 0x04) // IPv6
buf = append(buf, ip...)
}
} else {
if len(host) > 255 {
return nil, fmt.Errorf("domain too long")
}
buf = append(buf, 0x03) // domain
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], uint16(portInt))
buf = append(buf, portBytes[:]...)
return buf, nil
}
func DecodeAddress(r io.Reader) (string, error) {
var atyp [1]byte
if _, err := io.ReadFull(r, atyp[:]); err != nil {
return "", err
}
switch atyp[0] {
case 0x01: // IPv4
var ipBuf [net.IPv4len]byte
if _, err := io.ReadFull(r, ipBuf[:]); err != nil {
return "", err
}
var portBuf [2]byte
if _, err := io.ReadFull(r, portBuf[:]); err != nil {
return "", err
}
return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil
case 0x04: // IPv6
var ipBuf [net.IPv6len]byte
if _, err := io.ReadFull(r, ipBuf[:]); err != nil {
return "", err
}
var portBuf [2]byte
if _, err := io.ReadFull(r, portBuf[:]); err != nil {
return "", err
}
return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil
case 0x03: // domain
var lengthBuf [1]byte
if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
return "", err
}
l := int(lengthBuf[0])
hostBuf := make([]byte, l)
if _, err := io.ReadFull(r, hostBuf); err != nil {
return "", err
}
var portBuf [2]byte
if _, err := io.ReadFull(r, portBuf[:]); err != nil {
return "", err
}
return net.JoinHostPort(string(hostBuf), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil
default:
return "", fmt.Errorf("unknown address type: %d", atyp[0])
}
}

View File

@ -131,7 +131,7 @@ func ServerHandshake(rawConn net.Conn, cfg *apis.ProtocolConfig) (*ServerSession
} }
prefixed := &preBufferedConn{Conn: cConn, buf: firstByte} prefixed := &preBufferedConn{Conn: cConn, buf: firstByte}
target, err := decodeAddress(prefixed) target, err := DecodeAddress(prefixed)
if err != nil { if err != nil {
cConn.Close() cConn.Close()
return nil, fmt.Errorf("read target address failed: %w", err) return nil, fmt.Errorf("read target address failed: %w", err)

View File

@ -25,91 +25,9 @@ func WritePreface(w io.Writer) error {
return err return err
} }
func encodeAddress(rawAddr string) ([]byte, error) {
host, portStr, err := net.SplitHostPort(rawAddr)
if err != nil {
return nil, err
}
portInt, err := net.LookupPort("udp", portStr)
if err != nil {
return nil, err
}
var buf []byte
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, 0x01) // IPv4
buf = append(buf, ip4...)
} else {
buf = append(buf, 0x04) // IPv6
buf = append(buf, ip...)
}
} else {
if len(host) > 255 {
return nil, fmt.Errorf("domain too long")
}
buf = append(buf, 0x03) // domain
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], uint16(portInt))
buf = append(buf, portBytes[:]...)
return buf, nil
}
func decodeAddress(r io.Reader) (string, error) {
var atyp [1]byte
if _, err := io.ReadFull(r, atyp[:]); err != nil {
return "", err
}
switch atyp[0] {
case 0x01: // IPv4
var ipBuf [net.IPv4len]byte
if _, err := io.ReadFull(r, ipBuf[:]); err != nil {
return "", err
}
var portBuf [2]byte
if _, err := io.ReadFull(r, portBuf[:]); err != nil {
return "", err
}
return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil
case 0x04: // IPv6
var ipBuf [net.IPv6len]byte
if _, err := io.ReadFull(r, ipBuf[:]); err != nil {
return "", err
}
var portBuf [2]byte
if _, err := io.ReadFull(r, portBuf[:]); err != nil {
return "", err
}
return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil
case 0x03: // domain
var lengthBuf [1]byte
if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
return "", err
}
l := int(lengthBuf[0])
hostBuf := make([]byte, l)
if _, err := io.ReadFull(r, hostBuf); err != nil {
return "", err
}
var portBuf [2]byte
if _, err := io.ReadFull(r, portBuf[:]); err != nil {
return "", err
}
return net.JoinHostPort(string(hostBuf), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil
default:
return "", fmt.Errorf("unknown address type: %d", atyp[0])
}
}
// WriteDatagram sends a single UDP datagram frame over a reliable stream. // WriteDatagram sends a single UDP datagram frame over a reliable stream.
func WriteDatagram(w io.Writer, addr string, payload []byte) error { func WriteDatagram(w io.Writer, addr string, payload []byte) error {
addrBuf, err := encodeAddress(addr) addrBuf, err := EncodeAddress(addr)
if err != nil { if err != nil {
return fmt.Errorf("encode address: %w", err) return fmt.Errorf("encode address: %w", err)
} }
@ -157,7 +75,7 @@ func ReadDatagram(r io.Reader) (string, []byte, error) {
return "", nil, err return "", nil, err
} }
addr, err := decodeAddress(bytes.NewReader(addrBuf)) addr, err := DecodeAddress(bytes.NewReader(addrBuf))
if err != nil { if err != nil {
return "", nil, fmt.Errorf("decode address: %w", err) return "", nil, fmt.Errorf("decode address: %w", err)
} }