diff --git a/adapter/outbound/anytls.go b/adapter/outbound/anytls.go index 1dea5579..02541f23 100644 --- a/adapter/outbound/anytls.go +++ b/adapter/outbound/anytls.go @@ -2,7 +2,6 @@ package outbound import ( "context" - "errors" "net" "strconv" "time" @@ -10,7 +9,6 @@ import ( CN "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/proxydialer" - "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/anytls" "github.com/metacubex/mihomo/transport/vmess" @@ -53,6 +51,10 @@ func (t *AnyTLS) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con } func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { + if err = t.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } + // create tcp c, err := t.client.CreateProxy(ctx, uot.RequestDestination(2)) if err != nil { @@ -60,13 +62,6 @@ func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata) } // create uot on tcp - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip - } destination := M.SocksaddrFromNet(metadata.UDPAddr()) return newPacketConn(CN.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil } diff --git a/adapter/outbound/base.go b/adapter/outbound/base.go index a526ae12..2fdf9b49 100644 --- a/adapter/outbound/base.go +++ b/adapter/outbound/base.go @@ -3,6 +3,7 @@ package outbound import ( "context" "encoding/json" + "fmt" "net" "runtime" "sync" @@ -11,6 +12,7 @@ import ( N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/component/dialer" + "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" ) @@ -18,6 +20,7 @@ import ( type ProxyAdapter interface { C.ProxyAdapter DialOptions() []dialer.Option + ResolveUDP(ctx context.Context, metadata *C.Metadata) error } type Base struct { @@ -159,6 +162,17 @@ func (b *Base) DialOptions() (opts []dialer.Option) { return opts } +func (b *Base) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + if !metadata.Resolved() { + ip, err := resolver.ResolveIP(ctx, metadata.Host) + if err != nil { + return fmt.Errorf("can't resolve ip: %w", err) + } + metadata.DstIP = ip + } + return nil +} + func (b *Base) Close() error { return nil } @@ -258,6 +272,11 @@ type packetConn struct { adapterName string connID string adapterAddr string + resolveUDP func(ctx context.Context, metadata *C.Metadata) error +} + +func (c *packetConn) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + return c.resolveUDP(ctx, metadata) } func (c *packetConn) RemoteDestination() string { @@ -296,12 +315,12 @@ func (c *packetConn) AddRef(ref any) { c.EnhancePacketConn = N.NewRefPacketConn(c.EnhancePacketConn, ref) // add ref for autoCloseProxyAdapter } -func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn { +func newPacketConn(pc net.PacketConn, a ProxyAdapter) C.PacketConn { epc := N.NewEnhancePacketConn(pc) if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly } - return &packetConn{epc, []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), a.Addr()} + return &packetConn{epc, []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), a.Addr(), a.ResolveUDP} } type AddRef interface { diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 788a4b93..be8367ba 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -2,7 +2,8 @@ package outbound import ( "context" - "errors" + "fmt" + "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/loopback" "github.com/metacubex/mihomo/component/resolver" @@ -38,13 +39,8 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata) if err := d.loopBack.CheckPacketConn(metadata); err != nil { return nil, err } - // net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DirectHostResolver) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err := d.ResolveUDP(ctx, metadata); err != nil { + return nil, err } pc, err := dialer.NewDialer(d.DialOptions()...).ListenPacket(ctx, "udp", "", metadata.AddrPort()) if err != nil { @@ -53,6 +49,17 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return d.loopBack.NewPacketConn(newPacketConn(pc, d)), nil } +func (d *Direct) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + if (!metadata.Resolved() || resolver.DirectHostResolver != resolver.DefaultResolver) && metadata.Host != "" { + ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DirectHostResolver) + if err != nil { + return fmt.Errorf("can't resolve ip: %w", err) + } + metadata.DstIP = ip + } + return nil +} + func (d *Direct) IsL3Protocol(metadata *C.Metadata) bool { return true // tell DNSDialer don't send domain to DialContext, avoid lookback to DefaultResolver } diff --git a/adapter/outbound/dns.go b/adapter/outbound/dns.go index 40e70f25..25228502 100644 --- a/adapter/outbound/dns.go +++ b/adapter/outbound/dns.go @@ -3,6 +3,7 @@ package outbound import ( "context" "net" + "net/netip" "time" N "github.com/metacubex/mihomo/common/net" @@ -31,6 +32,9 @@ func (d *Dns) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, er // ListenPacketContext implements C.ProxyAdapter func (d *Dns) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { log.Debugln("[DNS] hijack udp:%s from %s", metadata.RemoteAddress(), metadata.SourceAddrPort()) + if err := d.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } ctx, cancel := context.WithCancel(context.Background()) @@ -41,6 +45,13 @@ func (d *Dns) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C. }, d), nil } +func (d *Dns) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + if !metadata.Resolved() { + metadata.DstIP = netip.AddrFrom4([4]byte{127, 0, 0, 2}) + } + return nil +} + type dnsPacket struct { data []byte put func() diff --git a/adapter/outbound/hysteria.go b/adapter/outbound/hysteria.go index bd92260a..966fa745 100644 --- a/adapter/outbound/hysteria.go +++ b/adapter/outbound/hysteria.go @@ -60,6 +60,9 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con } func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + if err := h.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } udpConn, err := h.client.DialUDP(h.genHdc(ctx)) if err != nil { return nil, err diff --git a/adapter/outbound/hysteria2.go b/adapter/outbound/hysteria2.go index b9e41c46..cb844e1f 100644 --- a/adapter/outbound/hysteria2.go +++ b/adapter/outbound/hysteria2.go @@ -77,6 +77,9 @@ func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata) (_ C. } func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { + if err = h.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } pc, err := h.client.ListenPacket(ctx) if err != nil { return nil, err diff --git a/adapter/outbound/mieru.go b/adapter/outbound/mieru.go index 1d8c78f7..fca7ede4 100644 --- a/adapter/outbound/mieru.go +++ b/adapter/outbound/mieru.go @@ -54,6 +54,9 @@ func (m *Mieru) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, // ListenPacketContext implements C.ProxyAdapter func (m *Mieru) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { + if err = m.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } if err := m.ensureClientIsRunning(); err != nil { return nil, err } diff --git a/adapter/outbound/reject.go b/adapter/outbound/reject.go index da3a4e3c..9e0bb022 100644 --- a/adapter/outbound/reject.go +++ b/adapter/outbound/reject.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "net/netip" "time" "github.com/metacubex/mihomo/common/buf" @@ -29,9 +30,19 @@ func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, // ListenPacketContext implements C.ProxyAdapter func (r *Reject) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { + if err := r.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } return newPacketConn(&nopPacketConn{}, r), nil } +func (r *Reject) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + if !metadata.Resolved() { + metadata.DstIP = netip.IPv4Unspecified() + } + return nil +} + func NewRejectWithOption(option RejectOption) *Reject { return &Reject{ Base: &Base{ diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index fc2d5c4d..595ef0e5 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -2,7 +2,6 @@ package outbound import ( "context" - "errors" "fmt" "net" "strconv" @@ -11,7 +10,6 @@ import ( "github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/proxydialer" - "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" gost "github.com/metacubex/mihomo/transport/gost-plugin" "github.com/metacubex/mihomo/transport/restls" @@ -202,6 +200,9 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial return nil, err } } + if err = ss.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer) if err != nil { return nil, err @@ -230,15 +231,9 @@ func (ss *ShadowSocks) ProxyInfo() C.ProxyInfo { // ListenPacketOnStreamConn implements C.ProxyAdapter func (ss *ShadowSocks) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { if ss.option.UDPOverTCP { - // ss uot use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = ss.ResolveUDP(ctx, metadata); err != nil { + return nil, err } - destination := M.SocksaddrFromNet(metadata.UDPAddr()) if ss.option.UDPOverTCPVersion == uot.LegacyVersion { return newPacketConn(N.NewThreadSafePacketConn(uot.NewConn(c, uot.Request{Destination: destination})), ss), nil diff --git a/adapter/outbound/shadowsocksr.go b/adapter/outbound/shadowsocksr.go index d7b932e6..a9975107 100644 --- a/adapter/outbound/shadowsocksr.go +++ b/adapter/outbound/shadowsocksr.go @@ -105,6 +105,9 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di return nil, err } } + if err = ssr.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } addr, err := resolveUDPAddr(ctx, "udp", ssr.addr, ssr.prefer) if err != nil { return nil, err diff --git a/adapter/outbound/singmux.go b/adapter/outbound/singmux.go index 819f2387..cd9ec904 100644 --- a/adapter/outbound/singmux.go +++ b/adapter/outbound/singmux.go @@ -2,12 +2,10 @@ package outbound import ( "context" - "errors" CN "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/proxydialer" - "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" @@ -53,16 +51,9 @@ func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata) if s.onlyTcp { return s.ProxyAdapter.ListenPacketContext(ctx, metadata) } - - // sing-mux use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = s.ProxyAdapter.ResolveUDP(ctx, metadata); err != nil { + return nil, err } - pc, err := s.client.ListenPacket(ctx, M.SocksaddrFromNet(metadata.UDPAddr())) if err != nil { return nil, err diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index 275f7263..ef3603e6 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -127,6 +127,9 @@ func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met return nil, err } } + if err = s.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } c, err := dialer.DialContext(ctx, "tcp", s.addr) if err != nil { return nil, err diff --git a/adapter/outbound/socks5.go b/adapter/outbound/socks5.go index 4bb774b9..26c64dce 100644 --- a/adapter/outbound/socks5.go +++ b/adapter/outbound/socks5.go @@ -109,6 +109,9 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return nil, err } } + if err = ss.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } c, err := cDialer.DialContext(ctx, "tcp", ss.addr) if err != nil { err = fmt.Errorf("%s connect error: %w", ss.addr, err) diff --git a/adapter/outbound/trojan.go b/adapter/outbound/trojan.go index 0397126a..9795ed78 100644 --- a/adapter/outbound/trojan.go +++ b/adapter/outbound/trojan.go @@ -219,6 +219,10 @@ func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, met // ListenPacketContext implements C.ProxyAdapter func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { + if err = t.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } + var c net.Conn // grpc transport @@ -250,6 +254,9 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me return nil, err } } + if err = t.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } c, err := dialer.DialContext(ctx, "tcp", t.addr) if err != nil { return nil, fmt.Errorf("%s connect error: %w", t.addr, err) @@ -271,12 +278,6 @@ func (t *Trojan) SupportWithDialer() C.NetWork { return C.ALLNet } -// ListenPacketOnStreamConn implements C.ProxyAdapter -func (t *Trojan) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { - pc := trojan.NewPacketConn(c) - return newPacketConn(pc, t), err -} - // SupportUOT implements C.ProxyAdapter func (t *Trojan) SupportUOT() bool { return true diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index 525f9ec6..7a913fbf 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -3,7 +3,6 @@ package outbound import ( "context" "crypto/tls" - "errors" "fmt" "math" "net" @@ -14,7 +13,6 @@ import ( "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/proxydialer" - "github.com/metacubex/mihomo/component/resolver" tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/tuic" @@ -91,6 +89,10 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ // ListenPacketWithDialer implements C.ProxyAdapter func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) { + if err = t.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } + if t.option.UDPOverStream { uotDestination := uot.RequestDestination(uint8(t.option.UDPOverStreamVersion)) uotMetadata := *metadata @@ -102,13 +104,6 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta } // tuic uos use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip - } destination := M.SocksaddrFromNet(metadata.UDPAddr()) if t.option.UDPOverStreamVersion == uot.LegacyVersion { diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index 6412d4da..9b5c65f9 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -19,7 +19,6 @@ import ( "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/proxydialer" - "github.com/metacubex/mihomo/component/resolver" tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/gun" @@ -277,13 +276,8 @@ func (v *Vless) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta // ListenPacketContext implements C.ProxyAdapter func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { - // vless use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = v.ResolveUDP(ctx, metadata); err != nil { + return nil, err } var c net.Conn // gun transport @@ -315,13 +309,8 @@ func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met } } - // vless use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = v.ResolveUDP(ctx, metadata); err != nil { + return nil, err } c, err := dialer.DialContext(ctx, "tcp", v.addr) @@ -347,13 +336,8 @@ func (v *Vless) SupportWithDialer() C.NetWork { // ListenPacketOnStreamConn implements C.ProxyAdapter func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { - // vless use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = v.ResolveUDP(ctx, metadata); err != nil { + return nil, err } if v.option.XUDP { diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 8496efc1..fbf8c266 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -17,7 +17,6 @@ import ( "github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/proxydialer" - "github.com/metacubex/mihomo/component/resolver" tlsC "github.com/metacubex/mihomo/component/tls" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/ntp" @@ -330,13 +329,8 @@ func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta // ListenPacketContext implements C.ProxyAdapter func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { - // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = v.ResolveUDP(ctx, metadata); err != nil { + return nil, err } var c net.Conn // gun transport @@ -367,13 +361,8 @@ func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met } } - // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = v.ResolveUDP(ctx, metadata); err != nil { + return nil, err } c, err := dialer.DialContext(ctx, "tcp", v.addr) @@ -413,13 +402,8 @@ func (v *Vmess) Close() error { // ListenPacketOnStreamConn implements C.ProxyAdapter func (v *Vmess) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { - // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr - if !metadata.Resolved() { - ip, err := resolver.ResolveIP(ctx, metadata.Host) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = v.ResolveUDP(ctx, metadata); err != nil { + return nil, err } if pc, ok := c.(net.PacketConn); ok { diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index db0ef95a..84ba1dbb 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/hex" - "errors" "fmt" "net" "net/netip" @@ -520,16 +519,8 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat if err = w.init(ctx); err != nil { return nil, err } - if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { - r := resolver.DefaultResolver - if w.resolver != nil { - r = w.resolver - } - ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) - if err != nil { - return nil, errors.New("can't resolve ip") - } - metadata.DstIP = ip + if err = w.ResolveUDP(ctx, metadata); err != nil { + return nil, err } pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) if err != nil { @@ -541,6 +532,21 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat return newPacketConn(pc, w), nil } +func (w *WireGuard) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { + r := resolver.DefaultResolver + if w.resolver != nil { + r = w.resolver + } + ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) + if err != nil { + return fmt.Errorf("can't resolve ip: %w", err) + } + metadata.DstIP = ip + } + return nil +} + // IsL3Protocol implements C.ProxyAdapter func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool { return true diff --git a/component/sniffer/quic_sniffer.go b/component/sniffer/quic_sniffer.go index a1b39f92..0eac36fa 100644 --- a/component/sniffer/quic_sniffer.go +++ b/component/sniffer/quic_sniffer.go @@ -76,9 +76,9 @@ func (sniffer *QuicSniffer) SniffData(b []byte) (string, error) { func (sniffer *QuicSniffer) WrapperSender(packetSender constant.PacketSender, override bool) constant.PacketSender { return &quicPacketSender{ - sender: packetSender, - chClose: make(chan struct{}), - override: override, + PacketSender: packetSender, + chClose: make(chan struct{}), + override: override, } } @@ -91,7 +91,7 @@ type quicPacketSender struct { result string override bool - sender constant.PacketSender + constant.PacketSender chClose chan struct{} closed bool @@ -100,7 +100,7 @@ type quicPacketSender struct { // Send will send PacketAdapter nonblocking // the implement must call UDPPacket.Drop() inside Send func (q *quicPacketSender) Send(current constant.PacketAdapter) { - defer q.sender.Send(current) + defer q.PacketSender.Send(current) q.lock.RLock() if q.closed { @@ -116,29 +116,24 @@ func (q *quicPacketSender) Send(current constant.PacketAdapter) { } } -// Process is a blocking loop to send PacketAdapter to PacketConn and update the WriteBackProxy -func (q *quicPacketSender) Process(conn constant.PacketConn, proxy constant.WriteBackProxy) { - q.sender.Process(conn, proxy) -} - -// ResolveUDP wait sniffer recv all fragments and update the domain -func (q *quicPacketSender) ResolveUDP(data *constant.Metadata) error { +// DoSniff wait sniffer recv all fragments and update the domain +func (q *quicPacketSender) DoSniff(metadata *constant.Metadata) error { select { case <-q.chClose: q.lock.RLock() - replaceDomain(data, q.result, q.override) + replaceDomain(metadata, q.result, q.override) q.lock.RUnlock() break case <-time.After(quicWaitConn): q.close() } - return q.sender.ResolveUDP(data) + return q.PacketSender.DoSniff(metadata) } // Close stop the Process loop func (q *quicPacketSender) Close() { - q.sender.Close() + q.PacketSender.Close() q.close() } diff --git a/component/sniffer/sniff_test.go b/component/sniffer/sniff_test.go index bc250d59..f8a5583a 100644 --- a/component/sniffer/sniff_test.go +++ b/component/sniffer/sniff_test.go @@ -12,7 +12,7 @@ import ( ) type fakeSender struct { - resultCh chan *constant.Metadata + constant.PacketSender } var _ constant.PacketSender = (*fakeSender)(nil) @@ -22,18 +22,7 @@ func (e *fakeSender) Send(packet constant.PacketAdapter) { packet.Drop() } -func (e *fakeSender) Process(constant.PacketConn, constant.WriteBackProxy) { - panic("not implemented") -} - -func (e *fakeSender) ResolveUDP(metadata *constant.Metadata) error { - e.resultCh <- metadata - return nil -} - -func (e *fakeSender) Close() { - panic("not implemented") -} +func (e *fakeSender) DoSniff(metadata *constant.Metadata) error { return nil } type fakeUDPPacket struct { data []byte @@ -85,16 +74,17 @@ func testQuicSniffer(data []string, async bool) (string, error) { } resultCh := make(chan *constant.Metadata, 1) - emptySender := &fakeSender{resultCh: resultCh} + emptySender := &fakeSender{} sender := q.WrapperSender(emptySender, true) go func() { meta := constant.Metadata{} - err = sender.ResolveUDP(&meta) + err := sender.DoSniff(&meta) if err != nil { panic(err) } + resultCh <- &meta }() for _, d := range data { diff --git a/constant/adapters.go b/constant/adapters.go index 280552ec..ab48c47d 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -92,8 +92,7 @@ type Conn interface { type PacketConn interface { N.EnhancePacketConn Connection - // Deprecate WriteWithMetadata because of remote resolve DNS cause TURN failed - // WriteWithMetadata(p []byte, metadata *Metadata) (n int, err error) + ResolveUDP(ctx context.Context, metadata *Metadata) error } type Dialer interface { @@ -319,10 +318,15 @@ type PacketSender interface { Send(PacketAdapter) // Process is a blocking loop to send PacketAdapter to PacketConn and update the WriteBackProxy Process(PacketConn, WriteBackProxy) - // ResolveUDP do a local resolve UDP dns blocking if metadata is not resolved - ResolveUDP(*Metadata) error // Close stop the Process loop Close() + // DoSniff will blocking after sniffer work done + DoSniff(*Metadata) error + // AddMapping add a destination NAT record + AddMapping(originMetadata *Metadata, metadata *Metadata) + // RestoreReadFrom restore destination NAT for ReadFrom + // the implement must ensure returned netip.Add is valid (or just return input addr) + RestoreReadFrom(addr netip.Addr) netip.Addr } type NatTable interface { diff --git a/constant/metadata.go b/constant/metadata.go index 3a8f4c79..72b9995b 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -261,6 +261,11 @@ func (m *Metadata) Pure() *Metadata { return m } +func (m *Metadata) Clone() *Metadata { + copyM := *m + return ©M +} + func (m *Metadata) AddrPort() netip.AddrPort { return netip.AddrPortFrom(m.DstIP.Unmap(), m.DstPort) } diff --git a/tunnel/connection.go b/tunnel/connection.go index 4884297e..3a66a76e 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -5,11 +5,10 @@ import ( "errors" "net" "net/netip" + "sync" "time" - "github.com/metacubex/mihomo/common/lru" N "github.com/metacubex/mihomo/common/net" - "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" ) @@ -18,7 +17,11 @@ type packetSender struct { ctx context.Context cancel context.CancelFunc ch chan C.PacketAdapter - cache *lru.LruCache[string, netip.Addr] + + // destination NAT mapping + originToTarget map[string]netip.Addr + targetToOrigin map[netip.Addr]netip.Addr + mappingMutex sync.RWMutex } // newPacketSender return a chan based C.PacketSender @@ -30,10 +33,74 @@ func newPacketSender() C.PacketSender { ctx: ctx, cancel: cancel, ch: ch, - cache: lru.New[string, netip.Addr](lru.WithSize[string, netip.Addr](senderCapacity)), + + originToTarget: make(map[string]netip.Addr), + targetToOrigin: make(map[netip.Addr]netip.Addr), } } +func (s *packetSender) AddMapping(originMetadata *C.Metadata, metadata *C.Metadata) { + s.mappingMutex.Lock() + defer s.mappingMutex.Unlock() + originKey := originMetadata.String() + originAddr := originMetadata.DstIP + targetAddr := metadata.DstIP + if addr := s.originToTarget[originKey]; !addr.IsValid() { // overwrite only if the record is illegal + s.originToTarget[originKey] = targetAddr + } + if addr := s.targetToOrigin[targetAddr]; !addr.IsValid() { // overwrite only if the record is illegal + s.targetToOrigin[targetAddr] = originAddr + } +} + +func (s *packetSender) RestoreReadFrom(addr netip.Addr) netip.Addr { + s.mappingMutex.RLock() + defer s.mappingMutex.RUnlock() + if originAddr := s.targetToOrigin[addr]; originAddr.IsValid() { + return originAddr + } + return addr +} + +func (s *packetSender) processPacket(pc C.PacketConn, packet C.PacketAdapter) { + defer packet.Drop() + metadata := packet.Metadata() + + var addr *net.UDPAddr + + s.mappingMutex.RLock() + targetAddr := s.originToTarget[metadata.String()] + s.mappingMutex.RUnlock() + + if targetAddr.IsValid() { + addr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(targetAddr, metadata.DstPort)) + } + + if addr == nil { + originMetadata := metadata // save origin metadata + metadata = metadata.Clone() // don't modify PacketAdapter's metadata + + _ = preHandleMetadata(metadata) // error was pre-checked + metadata = metadata.Pure() + if metadata.Host != "" { + // TODO: ResolveUDP may take a long time to block the Process loop + // but we want keep sequence sending so can't open a new goroutine + if err := pc.ResolveUDP(s.ctx, metadata); err != nil { + log.Warnln("[UDP] Resolve Ip error: %s", err) + return + } + } + + if !metadata.DstIP.IsValid() { + log.Warnln("[UDP] Destination ip not valid: %#v", metadata) + return + } + s.AddMapping(originMetadata, metadata) + addr = metadata.UDPAddr() + } + _ = handleUDPToRemote(packet, pc, addr) +} + func (s *packetSender) Process(pc C.PacketConn, proxy C.WriteBackProxy) { for { select { @@ -43,12 +110,7 @@ func (s *packetSender) Process(pc C.PacketConn, proxy C.WriteBackProxy) { if proxy != nil { proxy.UpdateWriteBack(packet) } - if err := s.ResolveUDP(packet.Metadata()); err != nil { - log.Warnln("[UDP] Resolve Ip error: %s", err) - } else { - _ = handleUDPToRemote(packet, pc, packet.Metadata()) - } - packet.Drop() + s.processPacket(pc, packet) } } } @@ -87,25 +149,9 @@ func (s *packetSender) Close() { s.dropAll() } -func (s *packetSender) ResolveUDP(metadata *C.Metadata) (err error) { - // local resolve UDP dns - if !metadata.Resolved() { - ip, ok := s.cache.Get(metadata.Host) - if !ok { - ip, err = resolver.ResolveIP(s.ctx, metadata.Host) - if err != nil { - return err - } - s.cache.Set(metadata.Host, ip) - } +func (s *packetSender) DoSniff(metadata *C.Metadata) error { return nil } - metadata.DstIP = ip - } - return nil -} - -func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error { - addr := metadata.UDPAddr() +func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, addr *net.UDPAddr) error { if addr == nil { return errors.New("udp addr invalid") } @@ -119,7 +165,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return nil } -func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, sender C.PacketSender, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { +func handleUDPToLocal(writeBack C.WriteBack, pc C.PacketConn, sender C.PacketSender, key string, oAddrPort netip.AddrPort) { defer func() { sender.Close() _ = pc.Close() @@ -146,10 +192,8 @@ func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, sender C.Pa fromAddrPort := fromUDPAddr.AddrPort() fromAddr := fromAddrPort.Addr().Unmap() - // restore fakeip - if fAddr.IsValid() && (oAddrPort.Addr() == fromAddr) { // oAddrPort was Unmapped - fromAddr = fAddr.Unmap() - } + // restore DestinationNAT + fromAddr = sender.RestoreReadFrom(fromAddr).Unmap() fromAddrPort = netip.AddrPortFrom(fromAddr, fromAddrPort.Port()) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index add12a05..9c38110d 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -366,15 +366,9 @@ func handleUDPConn(packet C.PacketAdapter) { return } - // make a fAddr if request ip is fakeip - var fAddr netip.Addr - if resolver.IsExistFakeIP(metadata.DstIP) { - fAddr = metadata.DstIP - } - - if err := preHandleMetadata(metadata); err != nil { + if err := preHandleMetadata(metadata.Clone()); err != nil { // precheck without modify metadata packet.Drop() - log.Debugln("[Metadata PreHandle] error: %s", err) + log.Warnln("[Metadata PreHandle] error: %s", err) return } @@ -388,21 +382,27 @@ func handleUDPConn(packet C.PacketAdapter) { }) if !loaded { dial := func() (C.PacketConn, C.WriteBackProxy, error) { - if err := sender.ResolveUDP(metadata); err != nil { - log.Warnln("[UDP] Resolve Ip error: %s", err) + originMetadata := metadata // save origin metadata + metadata = metadata.Clone() // don't modify PacketAdapter's metadata + + if err := sender.DoSniff(metadata); err != nil { + log.Warnln("[UDP] DoSniff error: %s", err.Error()) return nil, nil, err } + _ = preHandleMetadata(metadata) // error was pre-checked + proxy, rule, err := resolveMetadata(metadata) if err != nil { log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) return nil, nil, err } + dialMetadata := metadata.Pure() ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout) defer cancel() rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) { - return proxy.ListenPacketContext(ctx, metadata.Pure()) + return proxy.ListenPacketContext(ctx, dialMetadata) }, func(err error) { logMetadataErr(metadata, rule, proxy, err) }) @@ -413,10 +413,11 @@ func handleUDPConn(packet C.PacketAdapter) { pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true) - oAddrPort := metadata.AddrPort() + sender.AddMapping(originMetadata, dialMetadata) + oAddrPort := dialMetadata.AddrPort() writeBackProxy := nat.NewWriteBackProxy(packet) - go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort, fAddr) + go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort) return pc, writeBackProxy, nil }