chore: simplify tuic client

This commit is contained in:
wwqgtxx 2025-12-02 21:07:51 +08:00
parent ac90543548
commit bc719eb96d
6 changed files with 114 additions and 156 deletions

View File

@ -70,12 +70,7 @@ type TuicOption struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return t.DialContextWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata) conn, err := t.client.DialContext(ctx, metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
conn, err := t.client.DialContextWithDialer(ctx, metadata, dialer, t.dialWithDialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -84,11 +79,6 @@ func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = t.ResolveUDP(ctx, metadata); err != nil { if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
@ -98,7 +88,7 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta
uotMetadata := *metadata uotMetadata := *metadata
uotMetadata.Host = uotDestination.Fqdn uotMetadata.Host = uotDestination.Fqdn
uotMetadata.DstPort = uotDestination.Port uotMetadata.DstPort = uotDestination.Port
c, err := t.DialContextWithDialer(ctx, dialer, &uotMetadata) c, err := t.DialContext(ctx, &uotMetadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,21 +102,17 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta
return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil
} }
} }
pc, err := t.client.ListenPacketWithDialer(ctx, metadata, dialer, t.dialWithDialer) pc, err := t.client.ListenPacket(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newPacketConn(pc, t), nil return newPacketConn(pc, t), nil
} }
// SupportWithDialer implements C.ProxyAdapter func (t *Tuic) dial(ctx context.Context) (transport *quic.Transport, addr net.Addr, err error) {
func (t *Tuic) SupportWithDialer() C.NetWork { var cDialer C.Dialer = dialer.NewDialer(t.DialOptions()...)
return C.ALLNet
}
func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) {
if len(t.option.DialerProxy) > 0 { if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer) cDialer, err = proxydialer.NewByName(t.option.DialerProxy, cDialer)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -141,7 +127,7 @@ func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *
} }
addr = udpAddr addr = udpAddr
var pc net.PacketConn var pc net.PacketConn
pc, err = dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort()) pc, err = cDialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -313,7 +299,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND, CWND: option.CWND,
} }
t.client = tuic.NewPoolClientV4(clientOption) t.client = tuic.NewPoolClientV4(clientOption, t.dial)
} else { } else {
maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize
if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 { if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 {
@ -332,7 +318,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND, CWND: option.CWND,
} }
t.client = tuic.NewPoolClientV5(clientOption) t.client = tuic.NewPoolClientV5(clientOption, t.dial)
} }
return t, nil return t, nil

View File

@ -18,13 +18,12 @@ var (
TooManyOpenStreams = errors.New("tuic: too many open streams") TooManyOpenStreams = errors.New("tuic: too many open streams")
) )
type DialFunc func(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) type DialFunc func(ctx context.Context) (transport *quic.Transport, addr net.Addr, err error)
type Client interface { type Client interface {
DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error)
ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) ListenPacket(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error)
OpenStreams() int64 OpenStreams() int64
DialerRef() C.Dialer
LastVisited() time.Time LastVisited() time.Time
SetLastVisited(last time.Time) SetLastVisited(last time.Time)
Close() Close()

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"runtime" "runtime"
"sync" "sync"
"sync/atomic"
"time" "time"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
@ -17,30 +18,21 @@ import (
list "github.com/bahlo/generic-list-go" list "github.com/bahlo/generic-list-go"
) )
type dialResult struct {
transport *quic.Transport
addr net.Addr
err error
}
type PoolClient struct { type PoolClient struct {
newClientOptionV4 *ClientOptionV4 newClientOptionV4 *ClientOptionV4
newClientOptionV5 *ClientOptionV5 newClientOptionV5 *ClientOptionV5
dialResultMap map[C.Dialer]dialResult
dialResultMutex *sync.Mutex dialHelper *poolDialHelper
tcpClients *list.List[Client] tcpClients list.List[Client]
tcpClientsMutex *sync.Mutex tcpClientsMutex sync.Mutex
udpClients *list.List[Client] udpClients list.List[Client]
udpClientsMutex *sync.Mutex udpClientsMutex sync.Mutex
} }
func (t *PoolClient) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
newDialFn := func(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) { conn, err := t.getClient(false).DialContext(ctx, metadata)
return t.dial(ctx, dialer, dialFn)
}
conn, err := t.getClient(false, dialer).DialContextWithDialer(ctx, metadata, dialer, newDialFn)
if errors.Is(err, TooManyOpenStreams) { if errors.Is(err, TooManyOpenStreams) {
conn, err = t.newClient(false, dialer).DialContextWithDialer(ctx, metadata, dialer, newDialFn) conn, err = t.newClient(false).DialContext(ctx, metadata)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -48,13 +40,10 @@ func (t *PoolClient) DialContextWithDialer(ctx context.Context, metadata *C.Meta
return N.NewRefConn(conn, t), err return N.NewRefConn(conn, t), err
} }
func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { func (t *PoolClient) ListenPacket(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
newDialFn := func(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) { pc, err := t.getClient(true).ListenPacket(ctx, metadata)
return t.dial(ctx, dialer, dialFn)
}
pc, err := t.getClient(true, dialer).ListenPacketWithDialer(ctx, metadata, dialer, newDialFn)
if errors.Is(err, TooManyOpenStreams) { if errors.Is(err, TooManyOpenStreams) {
pc, err = t.newClient(true, dialer).ListenPacketWithDialer(ctx, metadata, dialer, newDialFn) pc, err = t.newClient(true).ListenPacket(ctx, metadata)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -62,58 +51,63 @@ func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, metadata *C.Met
return N.NewRefPacketConn(pc, t), nil return N.NewRefPacketConn(pc, t), nil
} }
func (t *PoolClient) dial(ctx context.Context, dialer C.Dialer, dialFn DialFunc) (transport *quic.Transport, addr net.Addr, err error) { // poolDialHelper is a helper for dialFn
t.dialResultMutex.Lock() // using a standalone struct to let finalizer working
dr, ok := t.dialResultMap[dialer] type poolDialHelper struct {
t.dialResultMutex.Unlock() dialFn DialFunc
if ok { dialResult atomic.Pointer[dialResult]
return dr.transport, dr.addr, dr.err
} }
transport, addr, err = dialFn(ctx, dialer) type dialResult struct {
transport *quic.Transport
addr net.Addr
}
func (t *poolDialHelper) dial(ctx context.Context) (transport *quic.Transport, addr net.Addr, err error) {
if dr := t.dialResult.Load(); dr != nil {
return dr.transport, dr.addr, nil
}
transport, addr, err = t.dialFn(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if _, ok := transport.Conn.(*net.UDPConn); ok { // only cache the system's UDPConn if _, ok := transport.Conn.(*net.UDPConn); ok { // only cache the system's UDPConn
transport.SetSingleUse(false) // don't close transport in each dial transport.SetSingleUse(false) // don't close transport in each dial
dr.transport, dr.addr, dr.err = transport, addr, err
t.dialResultMutex.Lock() dr := &dialResult{transport: transport, addr: addr}
t.dialResultMap[dialer] = dr t.dialResult.Store(dr)
t.dialResultMutex.Unlock()
} }
return transport, addr, err return transport, addr, err
} }
func (t *PoolClient) forceClose() { func (t *poolDialHelper) forceClose() {
t.dialResultMutex.Lock() if dr := t.dialResult.Swap(nil); dr != nil {
defer t.dialResultMutex.Unlock() transport := dr.transport
for key := range t.dialResultMap {
transport := t.dialResultMap[key].transport
if transport != nil { if transport != nil {
_ = transport.Close() _ = transport.Close()
} }
delete(t.dialResultMap, key)
} }
} }
func (t *PoolClient) newClient(udp bool, dialer C.Dialer) (client Client) { func (t *PoolClient) newClient(udp bool) (client Client) {
clients := t.tcpClients clients := &t.tcpClients
clientsMutex := t.tcpClientsMutex clientsMutex := &t.tcpClientsMutex
if udp { if udp {
clients = t.udpClients clients = &t.udpClients
clientsMutex = t.udpClientsMutex clientsMutex = &t.udpClientsMutex
} }
clientsMutex.Lock() clientsMutex.Lock()
defer clientsMutex.Unlock() defer clientsMutex.Unlock()
dialHelper := t.dialHelper
if t.newClientOptionV4 != nil { if t.newClientOptionV4 != nil {
client = NewClientV4(t.newClientOptionV4, udp, dialer) client = NewClientV4(t.newClientOptionV4, udp, dialHelper.dial)
} else { } else {
client = NewClientV5(t.newClientOptionV5, udp, dialer) client = NewClientV5(t.newClientOptionV5, udp, dialHelper.dial)
} }
client.SetLastVisited(time.Now()) client.SetLastVisited(time.Now())
@ -122,12 +116,12 @@ func (t *PoolClient) newClient(udp bool, dialer C.Dialer) (client Client) {
return client return client
} }
func (t *PoolClient) getClient(udp bool, dialer C.Dialer) Client { func (t *PoolClient) getClient(udp bool) Client {
clients := t.tcpClients clients := &t.tcpClients
clientsMutex := t.tcpClientsMutex clientsMutex := &t.tcpClientsMutex
if udp { if udp {
clients = t.udpClients clients = &t.udpClients
clientsMutex = t.udpClientsMutex clientsMutex = &t.udpClientsMutex
} }
var bestClient Client var bestClient Client
@ -142,7 +136,6 @@ func (t *PoolClient) getClient(udp bool, dialer C.Dialer) Client {
it = next it = next
continue continue
} }
if client.DialerRef() == dialer {
if bestClient == nil { if bestClient == nil {
bestClient = client bestClient = client
} else { } else {
@ -150,10 +143,8 @@ func (t *PoolClient) getClient(udp bool, dialer C.Dialer) Client {
bestClient = client bestClient = client
} }
} }
}
it = it.Next() it = it.Next()
} }
}()
for it := clients.Front(); it != nil; { for it := clients.Front(); it != nil; {
client := it.Value client := it.Value
if client != bestClient && client.OpenStreams() == 0 && time.Now().Sub(client.LastVisited()) > 30*time.Minute { if client != bestClient && client.OpenStreams() == 0 && time.Now().Sub(client.LastVisited()) > 30*time.Minute {
@ -165,23 +156,19 @@ func (t *PoolClient) getClient(udp bool, dialer C.Dialer) Client {
} }
it = it.Next() it = it.Next()
} }
}()
if bestClient == nil { if bestClient == nil {
return t.newClient(udp, dialer) return t.newClient(udp)
} else { } else {
bestClient.SetLastVisited(time.Now()) bestClient.SetLastVisited(time.Now())
return bestClient return bestClient
} }
} }
func NewPoolClientV4(clientOption *ClientOptionV4) *PoolClient { func NewPoolClientV4(clientOption *ClientOptionV4, dialFn DialFunc) *PoolClient {
p := &PoolClient{ p := &PoolClient{
dialResultMap: make(map[C.Dialer]dialResult), dialHelper: &poolDialHelper{dialFn: dialFn},
dialResultMutex: &sync.Mutex{},
tcpClients: list.New[Client](),
tcpClientsMutex: &sync.Mutex{},
udpClients: list.New[Client](),
udpClientsMutex: &sync.Mutex{},
} }
newClientOption := *clientOption newClientOption := *clientOption
p.newClientOptionV4 = &newClientOption p.newClientOptionV4 = &newClientOption
@ -190,14 +177,9 @@ func NewPoolClientV4(clientOption *ClientOptionV4) *PoolClient {
return p return p
} }
func NewPoolClientV5(clientOption *ClientOptionV5) *PoolClient { func NewPoolClientV5(clientOption *ClientOptionV5, dialFn DialFunc) *PoolClient {
p := &PoolClient{ p := &PoolClient{
dialResultMap: make(map[C.Dialer]dialResult), dialHelper: &poolDialHelper{dialFn: dialFn},
dialResultMutex: &sync.Mutex{},
tcpClients: list.New[Client](),
tcpClientsMutex: &sync.Mutex{},
udpClients: list.New[Client](),
udpClientsMutex: &sync.Mutex{},
} }
newClientOption := *clientOption newClientOption := *clientOption
p.newClientOptionV5 = &newClientOption p.newClientOptionV5 = &newClientOption
@ -208,5 +190,5 @@ func NewPoolClientV5(clientOption *ClientOptionV5) *PoolClient {
func closeClientPool(client *PoolClient) { func closeClientPool(client *PoolClient) {
log.Debugln("Close Tuic PoolClient at %p", client) log.Debugln("Close Tuic PoolClient at %p", client)
client.forceClose() client.dialHelper.forceClose()
} }

View File

@ -1,7 +1,6 @@
package tuic package tuic
import ( import (
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/tuic/common" "github.com/metacubex/mihomo/transport/tuic/common"
v4 "github.com/metacubex/mihomo/transport/tuic/v4" v4 "github.com/metacubex/mihomo/transport/tuic/v4"
v5 "github.com/metacubex/mihomo/transport/tuic/v5" v5 "github.com/metacubex/mihomo/transport/tuic/v5"
@ -12,12 +11,12 @@ type ClientOptionV5 = v5.ClientOption
type Client = common.Client type Client = common.Client
func NewClientV4(clientOption *ClientOptionV4, udp bool, dialerRef C.Dialer) Client { func NewClientV4(clientOption *ClientOptionV4, udp bool, dialFn DialFunc) Client {
return v4.NewClient(clientOption, udp, dialerRef) return v4.NewClient(clientOption, udp, dialFn)
} }
func NewClientV5(clientOption *ClientOptionV5, udp bool, dialerRef C.Dialer) Client { func NewClientV5(clientOption *ClientOptionV5, udp bool, dialFn DialFunc) Client {
return v5.NewClient(clientOption, udp, dialerRef) return v5.NewClient(clientOption, udp, dialFn)
} }
type DialFunc = common.DialFunc type DialFunc = common.DialFunc

View File

@ -40,6 +40,7 @@ type ClientOption struct {
type clientImpl struct { type clientImpl struct {
*ClientOption *ClientOption
dialFn common.DialFunc
udp bool udp bool
quicConn *quic.Conn quicConn *quic.Conn
@ -51,7 +52,6 @@ type clientImpl struct {
udpInputMap xsync.Map[uint32, net.Conn] udpInputMap xsync.Map[uint32, net.Conn]
// only ready for PoolClient // only ready for PoolClient
dialerRef C.Dialer
lastVisited atomic2.TypedValue[time.Time] lastVisited atomic2.TypedValue[time.Time]
} }
@ -59,10 +59,6 @@ func (t *clientImpl) OpenStreams() int64 {
return t.openStreams.Load() return t.openStreams.Load()
} }
func (t *clientImpl) DialerRef() C.Dialer {
return t.dialerRef
}
func (t *clientImpl) LastVisited() time.Time { func (t *clientImpl) LastVisited() time.Time {
return t.lastVisited.Load() return t.lastVisited.Load()
} }
@ -71,13 +67,13 @@ func (t *clientImpl) SetLastVisited(last time.Time) {
t.lastVisited.Store(last) t.lastVisited.Store(last)
} }
func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (*quic.Conn, error) { func (t *clientImpl) getQuicConn(ctx context.Context) (*quic.Conn, error) {
t.connMutex.Lock() t.connMutex.Lock()
defer t.connMutex.Unlock() defer t.connMutex.Unlock()
if t.quicConn != nil { if t.quicConn != nil {
return t.quicConn, nil return t.quicConn, nil
} }
transport, addr, err := dialFn(ctx, dialer) transport, addr, err := t.dialFn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -262,7 +258,7 @@ func (t *clientImpl) forceClose(quicConn *quic.Conn, err error) {
if quicConn != nil { if quicConn != nil {
_ = quicConn.CloseWithError(ProtocolError, errStr) _ = quicConn.CloseWithError(ProtocolError, errStr)
} }
udpInputMap := t.udpInputMap udpInputMap := &t.udpInputMap
udpInputMap.Range(func(key uint32, value net.Conn) bool { udpInputMap.Range(func(key uint32, value net.Conn) bool {
conn := value conn := value
_ = conn.Close() _ = conn.Close()
@ -278,8 +274,8 @@ func (t *clientImpl) Close() {
} }
} }
func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { func (t *clientImpl) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialer, dialFn) quicConn, err := t.getQuicConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -353,8 +349,8 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta
return bufConn, nil return bufConn, nil
} }
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { func (t *clientImpl) ListenPacket(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx, dialer, dialFn) quicConn, err := t.getQuicConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -397,16 +393,16 @@ type Client struct {
*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
} }
func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn) conn, err := t.clientImpl.DialContext(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return N.NewRefConn(conn, t), err return N.NewRefConn(conn, t), err
} }
func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { func (t *Client) ListenPacket(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn) pc, err := t.clientImpl.ListenPacket(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -417,11 +413,11 @@ func (t *Client) forceClose() {
t.clientImpl.forceClose(nil, common.ClientClosed) t.clientImpl.forceClose(nil, common.ClientClosed)
} }
func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client { func NewClient(clientOption *ClientOption, udp bool, dialFn common.DialFunc) *Client {
ci := &clientImpl{ ci := &clientImpl{
ClientOption: clientOption, ClientOption: clientOption,
dialFn: dialFn,
udp: udp, udp: udp,
dialerRef: dialerRef,
} }
c := &Client{ci} c := &Client{ci}
runtime.SetFinalizer(c, closeClient) runtime.SetFinalizer(c, closeClient)

View File

@ -39,6 +39,7 @@ type ClientOption struct {
type clientImpl struct { type clientImpl struct {
*ClientOption *ClientOption
dialFn common.DialFunc
udp bool udp bool
quicConn *quic.Conn quicConn *quic.Conn
@ -50,7 +51,6 @@ type clientImpl struct {
udpInputMap xsync.Map[uint16, net.Conn] udpInputMap xsync.Map[uint16, net.Conn]
// only ready for PoolClient // only ready for PoolClient
dialerRef C.Dialer
lastVisited atomic2.TypedValue[time.Time] lastVisited atomic2.TypedValue[time.Time]
} }
@ -58,10 +58,6 @@ func (t *clientImpl) OpenStreams() int64 {
return t.openStreams.Load() return t.openStreams.Load()
} }
func (t *clientImpl) DialerRef() C.Dialer {
return t.dialerRef
}
func (t *clientImpl) LastVisited() time.Time { func (t *clientImpl) LastVisited() time.Time {
return t.lastVisited.Load() return t.lastVisited.Load()
} }
@ -70,13 +66,13 @@ func (t *clientImpl) SetLastVisited(last time.Time) {
t.lastVisited.Store(last) t.lastVisited.Store(last)
} }
func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (*quic.Conn, error) { func (t *clientImpl) getQuicConn(ctx context.Context) (*quic.Conn, error) {
t.connMutex.Lock() t.connMutex.Lock()
defer t.connMutex.Unlock() defer t.connMutex.Unlock()
if t.quicConn != nil { if t.quicConn != nil {
return t.quicConn, nil return t.quicConn, nil
} }
transport, addr, err := dialFn(ctx, dialer) transport, addr, err := t.dialFn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -270,7 +266,7 @@ func (t *clientImpl) forceClose(quicConn *quic.Conn, err error) {
if quicConn != nil { if quicConn != nil {
_ = quicConn.CloseWithError(ProtocolError, errStr) _ = quicConn.CloseWithError(ProtocolError, errStr)
} }
udpInputMap := t.udpInputMap udpInputMap := &t.udpInputMap
udpInputMap.Range(func(key uint16, value net.Conn) bool { udpInputMap.Range(func(key uint16, value net.Conn) bool {
conn := value conn := value
_ = conn.Close() _ = conn.Close()
@ -286,8 +282,8 @@ func (t *clientImpl) Close() {
} }
} }
func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { func (t *clientImpl) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialer, dialFn) quicConn, err := t.getQuicConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -337,8 +333,8 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta
return stream, nil return stream, nil
} }
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { func (t *clientImpl) ListenPacket(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx, dialer, dialFn) quicConn, err := t.getQuicConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -381,16 +377,16 @@ type Client struct {
*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
} }
func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn) conn, err := t.clientImpl.DialContext(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return N.NewRefConn(conn, t), err return N.NewRefConn(conn, t), err
} }
func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { func (t *Client) ListenPacket(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn) pc, err := t.clientImpl.ListenPacket(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -401,11 +397,11 @@ func (t *Client) forceClose() {
t.clientImpl.forceClose(nil, common.ClientClosed) t.clientImpl.forceClose(nil, common.ClientClosed)
} }
func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client { func NewClient(clientOption *ClientOption, udp bool, dialFn common.DialFunc) *Client {
ci := &clientImpl{ ci := &clientImpl{
ClientOption: clientOption, ClientOption: clientOption,
dialFn: dialFn,
udp: udp, udp: udp,
dialerRef: dialerRef,
} }
c := &Client{ci} c := &Client{ci}
runtime.SetFinalizer(c, closeClient) runtime.SetFinalizer(c, closeClient)