diff --git a/transport/gun/gun.go b/transport/gun/gun.go index ce9d0279..08598b25 100644 --- a/transport/gun/gun.go +++ b/transport/gun/gun.go @@ -18,7 +18,6 @@ import ( "sync" "time" - "github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/buf" "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/component/ech" @@ -42,16 +41,19 @@ type DialFn = func(ctx context.Context, network, addr string) (net.Conn, error) type Conn struct { initFn func() (io.ReadCloser, netAddr, error) - writer io.Writer + writer io.Writer // writer must not nil closer io.Closer netAddr - reader io.ReadCloser - once sync.Once - closed atomic.Bool - err error - remain int - br *bufio.Reader + initOnce sync.Once + initErr error + reader io.ReadCloser + br *bufio.Reader + remain int + + closeMutex sync.Mutex + closed bool + // deadlines deadline *time.Timer } @@ -65,7 +67,7 @@ type Config struct { func (g *Conn) initReader() { reader, addr, err := g.initFn() if err != nil { - g.err = err + g.initErr = err if closer, ok := g.writer.(io.Closer); ok { closer.Close() } @@ -73,17 +75,21 @@ func (g *Conn) initReader() { } g.netAddr = addr - if !g.closed.Load() { - g.reader = reader - g.br = bufio.NewReader(reader) - } else { - reader.Close() + g.closeMutex.Lock() + defer g.closeMutex.Unlock() + if g.closed { // if g.Close() be called between g.initFn(), direct close the initFn returned reader + _ = reader.Close() + g.initErr = net.ErrClosed + return } + + g.reader = reader + g.br = bufio.NewReader(reader) } func (g *Conn) Init() error { - g.once.Do(g.initReader) - return g.err + g.initOnce.Do(g.initReader) + return g.initErr } func (g *Conn) Read(b []byte) (n int, err error) { @@ -100,8 +106,6 @@ func (g *Conn) Read(b []byte) (n int, err error) { n, err = io.ReadFull(g.br, b[:size]) g.remain -= n return - } else if g.reader == nil { - return 0, net.ErrClosed } // 0x00 grpclength(uint32) 0x0A uleb128 payload @@ -147,8 +151,8 @@ func (g *Conn) Write(b []byte) (n int, err error) { buf.Write(b) _, err = g.writer.Write(buf.Bytes()) - if err == io.ErrClosedPipe && g.err != nil { - err = g.err + if err == io.ErrClosedPipe && g.initErr != nil { + err = g.initErr } if flusher, ok := g.writer.(http.Flusher); ok { @@ -170,8 +174,8 @@ func (g *Conn) WriteBuffer(buffer *buf.Buffer) error { binary.PutUvarint(header[6:], uint64(dataLen)) _, err := g.writer.Write(buffer.Bytes()) - if err == io.ErrClosedPipe && g.err != nil { - err = g.err + if err == io.ErrClosedPipe && g.initErr != nil { + err = g.initErr } if flusher, ok := g.writer.(http.Flusher); ok { @@ -186,7 +190,17 @@ func (g *Conn) FrontHeadroom() int { } func (g *Conn) Close() error { - g.closed.Store(true) + g.initOnce.Do(func() { // if initReader not called, it should not be run anymore + g.initErr = net.ErrClosed + }) + + g.closeMutex.Lock() + defer g.closeMutex.Unlock() + if g.closed { + return nil + } + g.closed = true + var errorArr []error if reader := g.reader; reader != nil {