diff --git a/adapter/outbound/sudoku.go b/adapter/outbound/sudoku.go index 71dac311..0e6c3bba 100644 --- a/adapter/outbound/sudoku.go +++ b/adapter/outbound/sudoku.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/binary" "fmt" - "io" "net" "strconv" "strings" @@ -19,7 +18,7 @@ import ( N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" - ts "github.com/metacubex/mihomo/transport/sudoku" + "github.com/metacubex/mihomo/transport/sudoku" ) type Sudoku struct { @@ -101,12 +100,12 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return nil, err } - if err = ts.WritePreface(c); err != nil { + if err = sudoku.WritePreface(c); err != nil { _ = c.Close() return nil, fmt.Errorf("send uot preface failed: %w", err) } - return newPacketConn(N.NewThreadSafePacketConn(ts.NewUoTPacketConn(c)), s), nil + return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil } // SupportUOT implements C.ProxyAdapter @@ -163,7 +162,12 @@ func (s *Sudoku) streamConn(rawConn net.Conn, cfg *apis.ProtocolConfig) (_ net.C return nil, err } - if err = writeTargetAddress(cConn, cfg.TargetAddress); err != nil { + addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress) + if err != nil { + return nil, fmt.Errorf("encode target address failed: %w", err) + } + + if _, err = cConn.Write(addrBuf); err != nil { cConn.Close() return nil, fmt.Errorf("send target address failed: %w", err) } @@ -256,40 +260,3 @@ func buildSudokuHandshakePayload(key string) [16]byte { copy(payload[8:], hash[:8]) return payload } - -func writeTargetAddress(w io.Writer, rawAddr string) error { - host, portStr, err := net.SplitHostPort(rawAddr) - if err != nil { - return err - } - - portInt, err := net.LookupPort("tcp", portStr) - if err != nil { - return err - } - - var buf []byte - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - buf = append(buf, 0x01) // IPv4 - buf = append(buf, ip4...) - } else { - buf = append(buf, 0x04) // IPv6 - buf = append(buf, ip...) - } - } else { - if len(host) > 255 { - return fmt.Errorf("domain too long") - } - buf = append(buf, 0x03) // domain - buf = append(buf, byte(len(host))) - buf = append(buf, host...) - } - - var portBytes [2]byte - binary.BigEndian.PutUint16(portBytes[:], uint16(portInt)) - buf = append(buf, portBytes[:]...) - - _, err = w.Write(buf) - return err -} diff --git a/listener/sudoku/server.go b/listener/sudoku/server.go index 1755a040..f3260eb3 100644 --- a/listener/sudoku/server.go +++ b/listener/sudoku/server.go @@ -14,7 +14,7 @@ import ( LC "github.com/metacubex/mihomo/listener/config" "github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/transport/socks5" - ts "github.com/metacubex/mihomo/transport/sudoku" + "github.com/metacubex/mihomo/transport/sudoku" ) type Listener struct { @@ -47,14 +47,14 @@ func (l *Listener) Close() error { } func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) { - session, err := ts.ServerHandshake(conn, &l.protoConf) + session, err := sudoku.ServerHandshake(conn, &l.protoConf) if err != nil { _ = conn.Close() return } switch session.Type { - case ts.SessionTypeUoT: + case sudoku.SessionTypeUoT: l.handleUoTSession(session.Conn, tunnel, additions...) default: targetAddr := socks5.ParseAddr(session.Target) @@ -67,11 +67,11 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou } func (l *Listener) handleUoTSession(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) { - writer := ts.NewUoTPacketConn(conn) + writer := sudoku.NewUoTPacketConn(conn) remoteAddr := conn.RemoteAddr() for { - addrStr, payload, err := ts.ReadDatagram(conn) + addrStr, payload, err := sudoku.ReadDatagram(conn) if err != nil { if !errors.Is(err, io.EOF) { log.Debugln("[Sudoku][UoT] session closed: %v", err) @@ -97,7 +97,7 @@ func (l *Listener) handleUoTSession(conn net.Conn, tunnel C.Tunnel, additions .. type uotPacket struct { payload []byte - writer *ts.UoTPacketConn + writer *sudoku.UoTPacketConn rAddr net.Addr } diff --git a/transport/sudoku/address.go b/transport/sudoku/address.go new file mode 100644 index 00000000..4a666a23 --- /dev/null +++ b/transport/sudoku/address.go @@ -0,0 +1,90 @@ +package sudoku + +import ( + "encoding/binary" + "fmt" + "io" + "net" +) + +func EncodeAddress(rawAddr string) ([]byte, error) { + host, portStr, err := net.SplitHostPort(rawAddr) + if err != nil { + return nil, err + } + + portInt, err := net.LookupPort("udp", portStr) + if err != nil { + return nil, err + } + + var buf []byte + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + buf = append(buf, 0x01) // IPv4 + buf = append(buf, ip4...) + } else { + buf = append(buf, 0x04) // IPv6 + buf = append(buf, ip...) + } + } else { + if len(host) > 255 { + return nil, fmt.Errorf("domain too long") + } + buf = append(buf, 0x03) // domain + buf = append(buf, byte(len(host))) + buf = append(buf, host...) + } + + var portBytes [2]byte + binary.BigEndian.PutUint16(portBytes[:], uint16(portInt)) + buf = append(buf, portBytes[:]...) + return buf, nil +} + +func DecodeAddress(r io.Reader) (string, error) { + var atyp [1]byte + if _, err := io.ReadFull(r, atyp[:]); err != nil { + return "", err + } + + switch atyp[0] { + case 0x01: // IPv4 + var ipBuf [net.IPv4len]byte + if _, err := io.ReadFull(r, ipBuf[:]); err != nil { + return "", err + } + var portBuf [2]byte + if _, err := io.ReadFull(r, portBuf[:]); err != nil { + return "", err + } + return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil + case 0x04: // IPv6 + var ipBuf [net.IPv6len]byte + if _, err := io.ReadFull(r, ipBuf[:]); err != nil { + return "", err + } + var portBuf [2]byte + if _, err := io.ReadFull(r, portBuf[:]); err != nil { + return "", err + } + return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil + case 0x03: // domain + var lengthBuf [1]byte + if _, err := io.ReadFull(r, lengthBuf[:]); err != nil { + return "", err + } + l := int(lengthBuf[0]) + hostBuf := make([]byte, l) + if _, err := io.ReadFull(r, hostBuf); err != nil { + return "", err + } + var portBuf [2]byte + if _, err := io.ReadFull(r, portBuf[:]); err != nil { + return "", err + } + return net.JoinHostPort(string(hostBuf), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil + default: + return "", fmt.Errorf("unknown address type: %d", atyp[0]) + } +} diff --git a/transport/sudoku/handshake.go b/transport/sudoku/handshake.go index 0c65ae25..1f822707 100644 --- a/transport/sudoku/handshake.go +++ b/transport/sudoku/handshake.go @@ -131,7 +131,7 @@ func ServerHandshake(rawConn net.Conn, cfg *apis.ProtocolConfig) (*ServerSession } prefixed := &preBufferedConn{Conn: cConn, buf: firstByte} - target, err := decodeAddress(prefixed) + target, err := DecodeAddress(prefixed) if err != nil { cConn.Close() return nil, fmt.Errorf("read target address failed: %w", err) diff --git a/transport/sudoku/uot.go b/transport/sudoku/uot.go index 97a5bb48..8b03b09b 100644 --- a/transport/sudoku/uot.go +++ b/transport/sudoku/uot.go @@ -25,91 +25,9 @@ func WritePreface(w io.Writer) error { return err } -func encodeAddress(rawAddr string) ([]byte, error) { - host, portStr, err := net.SplitHostPort(rawAddr) - if err != nil { - return nil, err - } - - portInt, err := net.LookupPort("udp", portStr) - if err != nil { - return nil, err - } - - var buf []byte - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - buf = append(buf, 0x01) // IPv4 - buf = append(buf, ip4...) - } else { - buf = append(buf, 0x04) // IPv6 - buf = append(buf, ip...) - } - } else { - if len(host) > 255 { - return nil, fmt.Errorf("domain too long") - } - buf = append(buf, 0x03) // domain - buf = append(buf, byte(len(host))) - buf = append(buf, host...) - } - - var portBytes [2]byte - binary.BigEndian.PutUint16(portBytes[:], uint16(portInt)) - buf = append(buf, portBytes[:]...) - return buf, nil -} - -func decodeAddress(r io.Reader) (string, error) { - var atyp [1]byte - if _, err := io.ReadFull(r, atyp[:]); err != nil { - return "", err - } - - switch atyp[0] { - case 0x01: // IPv4 - var ipBuf [net.IPv4len]byte - if _, err := io.ReadFull(r, ipBuf[:]); err != nil { - return "", err - } - var portBuf [2]byte - if _, err := io.ReadFull(r, portBuf[:]); err != nil { - return "", err - } - return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil - case 0x04: // IPv6 - var ipBuf [net.IPv6len]byte - if _, err := io.ReadFull(r, ipBuf[:]); err != nil { - return "", err - } - var portBuf [2]byte - if _, err := io.ReadFull(r, portBuf[:]); err != nil { - return "", err - } - return net.JoinHostPort(net.IP(ipBuf[:]).String(), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil - case 0x03: // domain - var lengthBuf [1]byte - if _, err := io.ReadFull(r, lengthBuf[:]); err != nil { - return "", err - } - l := int(lengthBuf[0]) - hostBuf := make([]byte, l) - if _, err := io.ReadFull(r, hostBuf); err != nil { - return "", err - } - var portBuf [2]byte - if _, err := io.ReadFull(r, portBuf[:]); err != nil { - return "", err - } - return net.JoinHostPort(string(hostBuf), fmt.Sprint(binary.BigEndian.Uint16(portBuf[:]))), nil - default: - return "", fmt.Errorf("unknown address type: %d", atyp[0]) - } -} - // WriteDatagram sends a single UDP datagram frame over a reliable stream. func WriteDatagram(w io.Writer, addr string, payload []byte) error { - addrBuf, err := encodeAddress(addr) + addrBuf, err := EncodeAddress(addr) if err != nil { return fmt.Errorf("encode address: %w", err) } @@ -157,7 +75,7 @@ func ReadDatagram(r io.Reader) (string, []byte, error) { return "", nil, err } - addr, err := decodeAddress(bytes.NewReader(addrBuf)) + addr, err := DecodeAddress(bytes.NewReader(addrBuf)) if err != nil { return "", nil, fmt.Errorf("decode address: %w", err) }