diff --git a/go.mod b/go.mod index d28ac4e8..66c1556f 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/gofrs/uuid/v5 v5.3.2 github.com/insomniacslk/dhcp v0.0.0-20250109001534-8abf58130905 github.com/klauspost/compress v1.17.9 // lastest version compatible with golang1.20 - github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/mdlayher/netlink v1.7.2 github.com/metacubex/amneziawg-go v0.0.0-20240922133038-fdf3a4d5a4ab github.com/metacubex/bart v0.20.5 diff --git a/go.sum b/go.sum index 374463a9..3c2d3e3f 100644 --- a/go.sum +++ b/go.sum @@ -84,8 +84,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= -github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc= -github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= diff --git a/transport/hysteria/core/client.go b/transport/hysteria/core/client.go index 89f1eff6..13c8db75 100644 --- a/transport/hysteria/core/client.go +++ b/transport/hysteria/core/client.go @@ -1,7 +1,6 @@ package core import ( - "bytes" "context" "errors" "fmt" @@ -16,7 +15,6 @@ import ( "github.com/metacubex/mihomo/transport/hysteria/transport" "github.com/metacubex/mihomo/transport/hysteria/utils" - "github.com/lunixbochs/struc" "github.com/metacubex/quic-go" "github.com/metacubex/quic-go/congestion" "github.com/metacubex/randv2" @@ -104,31 +102,23 @@ func (c *Client) connectToServer(dialer utils.PacketDialer) error { } func (c *Client) handleControlStream(qs *quic.Conn, stream *quic.Stream) (bool, string, error) { - // Send protocol version - _, err := stream.Write([]byte{protocolVersion}) - if err != nil { - return false, "", err - } // Send client hello - err = struc.Pack(stream, &clientHello{ - Rate: transmissionRate{ - SendBPS: c.sendBPS, - RecvBPS: c.recvBPS, - }, - Auth: c.auth, + err := WriteClientHello(stream, ClientHello{ + SendBPS: c.sendBPS, + RecvBPS: c.recvBPS, + Auth: c.auth, }) if err != nil { return false, "", err } // Receive server hello - var sh serverHello - err = struc.Unpack(stream, &sh) + sh, err := ReadServerHello(stream) if err != nil { return false, "", err } // Set the congestion accordingly if sh.OK && c.congestionFactory != nil { - qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS)) + qs.SetCongestionControl(c.congestionFactory(sh.RecvBPS)) } return sh.OK, sh.Message, nil } @@ -140,7 +130,7 @@ func (c *Client) handleMessage(qs *quic.Conn) { break } var udpMsg udpMessage - err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg) + err = udpMsg.Unpack(msg) if err != nil { continue } @@ -200,7 +190,7 @@ func (c *Client) DialTCP(host string, port uint16, dialer utils.PacketDialer) (n return nil, err } // Send request - err = struc.Pack(stream, &clientRequest{ + err = WriteClientRequest(stream, ClientRequest{ UDP: false, Host: host, Port: port, @@ -213,8 +203,8 @@ func (c *Client) DialTCP(host string, port uint16, dialer utils.PacketDialer) (n // and defer the response handling to the first Read() call if !c.fastOpen { // Read response - var sr serverResponse - err = struc.Unpack(stream, &sr) + var sr *ServerResponse + sr, err = ReadServerResponse(stream) if err != nil { _ = stream.Close() return nil, err @@ -239,16 +229,16 @@ func (c *Client) DialUDP(dialer utils.PacketDialer) (UDPConn, error) { return nil, err } // Send request - err = struc.Pack(stream, &clientRequest{ - UDP: true, + err = WriteClientRequest(stream, ClientRequest{ + UDP: false, }) if err != nil { _ = stream.Close() return nil, err } // Read response - var sr serverResponse - err = struc.Unpack(stream, &sr) + var sr *ServerResponse + sr, err = ReadServerResponse(stream) if err != nil { _ = stream.Close() return nil, err @@ -306,8 +296,8 @@ type quicConn struct { func (w *quicConn) Read(b []byte) (n int, err error) { if !w.Established { - var sr serverResponse - err := struc.Unpack(w.Orig, &sr) + var sr *ServerResponse + sr, err = ReadServerResponse(w.Orig) if err != nil { _ = w.Close() return 0, err @@ -401,9 +391,7 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error { Data: p, } // try no frag first - var msgBuf bytes.Buffer - _ = struc.Pack(&msgBuf, &msg) - err = c.Session.SendDatagram(msgBuf.Bytes()) + err = c.Session.SendDatagram(msg.Pack()) if err != nil { var errSize *quic.DatagramTooLargeError if errors.As(err, &errSize) { @@ -411,9 +399,7 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error { msg.MsgID = uint16(randv2.IntN(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 fragMsgs := fragUDPMessage(msg, int(errSize.MaxDatagramPayloadSize)) for _, fragMsg := range fragMsgs { - msgBuf.Reset() - _ = struc.Pack(&msgBuf, &fragMsg) - err = c.Session.SendDatagram(msgBuf.Bytes()) + err = c.Session.SendDatagram(fragMsg.Pack()) if err != nil { return err } diff --git a/transport/hysteria/core/frag.go b/transport/hysteria/core/frag.go index 7a387747..fdcb1fd4 100644 --- a/transport/hysteria/core/frag.go +++ b/transport/hysteria/core/frag.go @@ -18,7 +18,6 @@ func fragUDPMessage(m udpMessage, maxSize int) []udpMessage { frag := m frag.FragID = fragID frag.FragCount = fragCount - frag.DataLen = uint16(payloadSize) frag.Data = fullPayload[off : off+payloadSize] frags = append(frags, frag) off += payloadSize @@ -56,7 +55,6 @@ func (d *defragger) Feed(m udpMessage) *udpMessage { for _, frag := range d.frags { data = append(data, frag.Data...) } - m.DataLen = uint16(len(data)) m.Data = data m.FragID = 0 m.FragCount = 1 diff --git a/transport/hysteria/core/frag_test.go b/transport/hysteria/core/frag_test.go index f2f24625..f5c05cf3 100644 --- a/transport/hysteria/core/frag_test.go +++ b/transport/hysteria/core/frag_test.go @@ -20,13 +20,11 @@ func Test_fragUDPMessage(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 1, - DataLen: 5, Data: []byte("hello"), }, 100, @@ -34,13 +32,11 @@ func Test_fragUDPMessage(t *testing.T) { []udpMessage{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 1, - DataLen: 5, Data: []byte("hello"), }, }, @@ -50,13 +46,11 @@ func Test_fragUDPMessage(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 1, - DataLen: 5, Data: []byte("hello"), }, 22, @@ -64,24 +58,20 @@ func Test_fragUDPMessage(t *testing.T) { []udpMessage{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 2, - DataLen: 4, Data: []byte("hell"), }, udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 1, FragCount: 2, - DataLen: 1, Data: []byte("o"), }, }, @@ -91,13 +81,11 @@ func Test_fragUDPMessage(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 1, - DataLen: 20, Data: []byte("wow wow wow lol lmao"), }, 23, @@ -105,46 +93,38 @@ func Test_fragUDPMessage(t *testing.T) { []udpMessage{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 4, - DataLen: 5, Data: []byte("wow w"), }, udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 1, FragCount: 4, - DataLen: 5, Data: []byte("ow wo"), }, udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 2, FragCount: 4, - DataLen: 5, Data: []byte("w lol"), }, udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 3, FragCount: 4, - DataLen: 5, Data: []byte(" lmao"), }, }, @@ -174,25 +154,21 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 1, - DataLen: 5, Data: []byte("hello"), }, }, &udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 123, FragID: 0, FragCount: 1, - DataLen: 5, Data: []byte("hello"), }, }, @@ -201,13 +177,11 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 666, FragID: 0, FragCount: 3, - DataLen: 5, Data: []byte("hello"), }, }, @@ -218,13 +192,11 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 666, FragID: 1, FragCount: 3, - DataLen: 8, Data: []byte(" shitty "), }, }, @@ -235,25 +207,21 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 666, FragID: 2, FragCount: 3, - DataLen: 7, Data: []byte("world!!"), }, }, &udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 666, FragID: 0, FragCount: 1, - DataLen: 20, Data: []byte("hello shitty world!!"), }, }, @@ -262,13 +230,11 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 777, FragID: 0, FragCount: 2, - DataLen: 5, Data: []byte("hello"), }, }, @@ -279,13 +245,11 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 778, FragID: 1, FragCount: 2, - DataLen: 5, Data: []byte(" moto"), }, }, @@ -296,13 +260,11 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 777, FragID: 1, FragCount: 2, - DataLen: 5, Data: []byte(" moto"), }, }, @@ -313,25 +275,21 @@ func Test_defragger_Feed(t *testing.T) { args{ udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 777, FragID: 0, FragCount: 2, - DataLen: 5, Data: []byte("hello"), }, }, &udpMessage{ SessionID: 123, - HostLen: 4, Host: "test", Port: 123, MsgID: 777, FragID: 0, FragCount: 1, - DataLen: 10, Data: []byte("hello moto"), }, }, diff --git a/transport/hysteria/core/protocol.go b/transport/hysteria/core/protocol.go index 7fd64d7c..77dfc2e5 100644 --- a/transport/hysteria/core/protocol.go +++ b/transport/hysteria/core/protocol.go @@ -1,60 +1,241 @@ package core import ( + "bytes" + "encoding/binary" + "errors" + "io" "time" ) const ( - protocolVersion = uint8(3) - protocolVersionV2 = uint8(2) - protocolTimeout = 10 * time.Second + protocolVersion = uint8(3) + protocolTimeout = 10 * time.Second closeErrorCodeGeneric = 0 closeErrorCodeProtocol = 1 closeErrorCodeAuth = 2 ) -type transmissionRate struct { +type ClientHello struct { SendBPS uint64 RecvBPS uint64 -} - -type clientHello struct { - Rate transmissionRate - AuthLen uint16 `struc:"sizeof=Auth"` Auth []byte } -type serverHello struct { - OK bool - Rate transmissionRate - MessageLen uint16 `struc:"sizeof=Message"` - Message string +func WriteClientHello(stream io.Writer, hello ClientHello) error { + var requestLen int + requestLen += 1 // version + requestLen += 8 // sendBPS + requestLen += 8 // recvBPS + requestLen += 2 // auth len + requestLen += len(hello.Auth) + request := make([]byte, requestLen) + request[0] = protocolVersion + binary.BigEndian.PutUint64(request[1:9], hello.SendBPS) + binary.BigEndian.PutUint64(request[9:17], hello.RecvBPS) + binary.BigEndian.PutUint16(request[17:19], uint16(len(hello.Auth))) + copy(request[19:], hello.Auth) + _, err := stream.Write(request) + return err } -type clientRequest struct { - UDP bool - HostLen uint16 `struc:"sizeof=Host"` - Host string - Port uint16 +func ReadClientHello(stream io.Reader) (*ClientHello, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // auth len + response := make([]byte, responseLen) + _, err := io.ReadFull(stream, response) + if err != nil { + return nil, err + } + + if response[0] != protocolVersion { + return nil, errors.New("unsupported client version") + } + var clientHello ClientHello + clientHello.SendBPS = binary.BigEndian.Uint64(response[1:9]) + clientHello.RecvBPS = binary.BigEndian.Uint64(response[9:17]) + authLen := binary.BigEndian.Uint16(response[17:19]) + + if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 { + return nil, errors.New("invalid rate from client") + } + + authBytes := make([]byte, authLen) + _, err = io.ReadFull(stream, authBytes) + if err != nil { + return nil, err + } + clientHello.Auth = authBytes + return &clientHello, nil } -type serverResponse struct { +type ServerHello struct { + OK bool + SendBPS uint64 + RecvBPS uint64 + Message string +} + +func ReadServerHello(stream io.Reader) (*ServerHello, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // message len + response := make([]byte, responseLen) + _, err := io.ReadFull(stream, response) + if err != nil { + return nil, err + } + var serverHello ServerHello + serverHello.OK = response[0] == 1 + serverHello.SendBPS = binary.BigEndian.Uint64(response[1:9]) + serverHello.RecvBPS = binary.BigEndian.Uint64(response[9:17]) + messageLen := binary.BigEndian.Uint16(response[17:19]) + if messageLen == 0 { + return &serverHello, nil + } + message := make([]byte, messageLen) + _, err = io.ReadFull(stream, message) + if err != nil { + return nil, err + } + serverHello.Message = string(message) + return &serverHello, nil +} + +func WriteServerHello(stream io.Writer, hello ServerHello) error { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // message len + responseLen += len(hello.Message) + response := make([]byte, responseLen) + if hello.OK { + response[0] = 1 + } else { + response[0] = 0 + } + binary.BigEndian.PutUint64(response[1:9], hello.SendBPS) + binary.BigEndian.PutUint64(response[9:17], hello.RecvBPS) + binary.BigEndian.PutUint16(response[17:19], uint16(len(hello.Message))) + copy(response[19:], hello.Message) + _, err := stream.Write(response) + return err +} + +type ClientRequest struct { + UDP bool + Host string + Port uint16 +} + +func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { + var clientRequest ClientRequest + err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP) + if err != nil { + return nil, err + } + var hostLen uint16 + err = binary.Read(stream, binary.BigEndian, &hostLen) + if err != nil { + return nil, err + } + host := make([]byte, hostLen) + _, err = io.ReadFull(stream, host) + if err != nil { + return nil, err + } + clientRequest.Host = string(host) + err = binary.Read(stream, binary.BigEndian, &clientRequest.Port) + if err != nil { + return nil, err + } + return &clientRequest, nil +} + +func WriteClientRequest(stream io.Writer, request ClientRequest) error { + var requestLen int + requestLen += 1 // udp + requestLen += 2 // host len + requestLen += len(request.Host) + requestLen += 2 // port + buffer := make([]byte, requestLen) + if request.UDP { + buffer[0] = 1 + } else { + buffer[0] = 0 + } + binary.BigEndian.PutUint16(buffer[1:3], uint16(len(request.Host))) + n := copy(buffer[3:], request.Host) + binary.BigEndian.PutUint16(buffer[3+n:3+n+2], request.Port) + _, err := stream.Write(buffer) + return err +} + +type ServerResponse struct { OK bool UDPSessionID uint32 - MessageLen uint16 `struc:"sizeof=Message"` Message string } +func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 4 // udp session id + responseLen += 2 // message len + response := make([]byte, responseLen) + _, err := io.ReadFull(stream, response) + if err != nil { + return nil, err + } + var serverResponse ServerResponse + serverResponse.OK = response[0] == 1 + serverResponse.UDPSessionID = binary.BigEndian.Uint32(response[1:5]) + messageLen := binary.BigEndian.Uint16(response[5:7]) + if messageLen == 0 { + return &serverResponse, nil + } + message := make([]byte, messageLen) + _, err = io.ReadFull(stream, message) + if err != nil { + return nil, err + } + serverResponse.Message = string(message) + return &serverResponse, nil +} + +func WriteServerResponse(stream io.Writer, response ServerResponse) error { + var responseLen int + responseLen += 1 // ok + responseLen += 4 // udp session id + responseLen += 2 // message len + responseLen += len(response.Message) + buffer := make([]byte, responseLen) + if response.OK { + buffer[0] = 1 + } else { + buffer[0] = 0 + } + binary.BigEndian.PutUint32(buffer[1:5], response.UDPSessionID) + binary.BigEndian.PutUint16(buffer[5:7], uint16(len(response.Message))) + copy(buffer[7:], response.Message) + _, err := stream.Write(buffer) + return err +} + type udpMessage struct { SessionID uint32 - HostLen uint16 `struc:"sizeof=Host"` Host string Port uint16 MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented FragCount uint8 // must be 1 when not fragmented - DataLen uint16 `struc:"sizeof=Data"` Data []byte } @@ -66,11 +247,62 @@ func (m udpMessage) Size() int { return m.HeaderSize() + len(m.Data) } -type udpMessageV2 struct { - SessionID uint32 - HostLen uint16 `struc:"sizeof=Host"` - Host string - Port uint16 - DataLen uint16 `struc:"sizeof=Data"` - Data []byte +func (m udpMessage) Pack() []byte { + data := make([]byte, m.Size()) + buffer := bytes.NewBuffer(data) + _ = binary.Write(buffer, binary.BigEndian, m.SessionID) + _ = binary.Write(buffer, binary.BigEndian, uint16(len(m.Host))) + buffer.WriteString(m.Host) + _ = binary.Write(buffer, binary.BigEndian, m.Port) + _ = binary.Write(buffer, binary.BigEndian, m.MsgID) + _ = binary.Write(buffer, binary.BigEndian, m.FragID) + _ = binary.Write(buffer, binary.BigEndian, m.FragCount) + _ = binary.Write(buffer, binary.BigEndian, uint16(len(m.Data))) + buffer.Write(m.Data) + return buffer.Bytes() +} + +func (m *udpMessage) Unpack(data []byte) error { + reader := bytes.NewReader(data) + err := binary.Read(reader, binary.BigEndian, &m.SessionID) + if err != nil { + return err + } + var hostLen uint16 + err = binary.Read(reader, binary.BigEndian, &hostLen) + if err != nil { + return err + } + hostBytes := make([]byte, hostLen) + _, err = io.ReadFull(reader, hostBytes) + if err != nil { + return err + } + m.Host = string(hostBytes) + err = binary.Read(reader, binary.BigEndian, &m.Port) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &m.MsgID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &m.FragID) + if err != nil { + return err + } + err = binary.Read(reader, binary.BigEndian, &m.FragCount) + if err != nil { + return err + } + var dataLen uint16 + err = binary.Read(reader, binary.BigEndian, &dataLen) + if err != nil { + return err + } + if reader.Len() != int(dataLen) { + return errors.New("invalid data length") + } + m.Data = data[len(data)-reader.Len():] + return nil }