mihomo/transport/kcptun/client.go

174 lines
4.2 KiB
Go

package kcptun
import (
"context"
"crypto/rand"
"encoding/binary"
"net"
"sync"
"time"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/kcp-go"
"github.com/metacubex/smux"
)
const Mode = "kcptun"
type DialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)
type Client struct {
once sync.Once
config Config
block kcp.BlockCrypt
ctx context.Context
cancel context.CancelFunc
numconn uint16
muxes []timedSession
rr uint16
connMu sync.Mutex
chScavenger chan timedSession
}
func NewClient(config Config) *Client {
config.FillDefaults()
block := config.NewBlock()
ctx, cancel := context.WithCancel(context.Background())
return &Client{
config: config,
block: block,
ctx: ctx,
cancel: cancel,
}
}
func (c *Client) Close() error {
c.cancel()
return nil
}
func (c *Client) createConn(ctx context.Context, dial DialFn) (*smux.Session, error) {
conn, addr, err := dial(ctx)
if err != nil {
return nil, err
}
config := c.config
var convid uint32
binary.Read(rand.Reader, binary.LittleEndian, &convid)
kcpconn, err := kcp.NewConn4(convid, addr, c.block, config.DataShard, config.ParityShard, true, conn)
if err != nil {
return nil, err
}
kcpconn.SetStreamMode(true)
kcpconn.SetWriteDelay(false)
kcpconn.SetNoDelay(config.NoDelay, config.Interval, config.Resend, config.NoCongestion)
kcpconn.SetWindowSize(config.SndWnd, config.RcvWnd)
kcpconn.SetMtu(config.MTU)
kcpconn.SetACKNoDelay(config.AckNodelay)
_ = kcpconn.SetDSCP(config.DSCP)
_ = kcpconn.SetReadBuffer(config.SockBuf)
_ = kcpconn.SetWriteBuffer(config.SockBuf)
smuxConfig := smux.DefaultConfig()
smuxConfig.Version = config.SmuxVer
smuxConfig.MaxReceiveBuffer = config.SmuxBuf
smuxConfig.MaxStreamBuffer = config.StreamBuf
smuxConfig.KeepAliveInterval = time.Duration(config.KeepAlive) * time.Second
if smuxConfig.KeepAliveInterval >= smuxConfig.KeepAliveTimeout {
smuxConfig.KeepAliveTimeout = 3 * smuxConfig.KeepAliveInterval
}
if err := smux.VerifyConfig(smuxConfig); err != nil {
return nil, err
}
var netConn net.Conn = kcpconn
if !config.NoComp {
netConn = NewCompStream(netConn)
}
// stream multiplex
return smux.Client(netConn, smuxConfig)
}
func (c *Client) OpenStream(ctx context.Context, dial DialFn) (*smux.Stream, error) {
c.once.Do(func() {
// start scavenger if autoexpire is set
c.chScavenger = make(chan timedSession, 128)
if c.config.AutoExpire > 0 {
go scavenger(c.ctx, c.chScavenger, &c.config)
}
c.numconn = uint16(c.config.Conn)
c.muxes = make([]timedSession, c.config.Conn)
c.rr = uint16(0)
})
c.connMu.Lock()
idx := c.rr % c.numconn
// do auto expiration && reconnection
if c.muxes[idx].session == nil || c.muxes[idx].session.IsClosed() ||
(c.config.AutoExpire > 0 && time.Now().After(c.muxes[idx].expiryDate)) {
var err error
c.muxes[idx].session, err = c.createConn(ctx, dial)
if err != nil {
c.connMu.Unlock()
return nil, err
}
c.muxes[idx].expiryDate = time.Now().Add(time.Duration(c.config.AutoExpire) * time.Second)
if c.config.AutoExpire > 0 { // only when autoexpire set
c.chScavenger <- c.muxes[idx]
}
}
c.rr++
session := c.muxes[idx].session
c.connMu.Unlock()
return session.OpenStream()
}
// timedSession is a wrapper for smux.Session with expiry date
type timedSession struct {
session *smux.Session
expiryDate time.Time
}
// scavenger goroutine is used to close expired sessions
func scavenger(ctx context.Context, ch chan timedSession, config *Config) {
ticker := time.NewTicker(scavengePeriod * time.Second)
defer ticker.Stop()
var sessionList []timedSession
for {
select {
case item := <-ch:
sessionList = append(sessionList, timedSession{
item.session,
item.expiryDate.Add(time.Duration(config.ScavengeTTL) * time.Second)})
case <-ticker.C:
var newList []timedSession
for k := range sessionList {
s := sessionList[k]
if s.session.IsClosed() {
log.Debugln("scavenger: session normally closed: %s", s.session.LocalAddr())
} else if time.Now().After(s.expiryDate) {
s.session.Close()
log.Debugln("scavenger: session closed due to ttl: %s", s.session.LocalAddr())
} else {
newList = append(newList, sessionList[k])
}
}
sessionList = newList
case <-ctx.Done():
return
}
}
}