diff --git a/transport/anytls/client.go b/transport/anytls/client.go index 4e74c927..dcb679ff 100644 --- a/transport/anytls/client.go +++ b/transport/anytls/client.go @@ -46,7 +46,7 @@ func NewClient(ctx context.Context, config ClientConfig) *Client { } // Initialize the padding state of this client padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding) - c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout, config.MinIdleSession) + c.sessionClient = session.NewClient(ctx, c.createOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout, config.MinIdleSession) return c } @@ -63,7 +63,7 @@ func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net. return conn, nil } -func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, error) { +func (c *Client) createOutboundTLSConnection(ctx context.Context) (net.Conn, error) { conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) if err != nil { return nil, err diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go index 0698bd50..b2d1cf7f 100644 --- a/transport/anytls/session/client.go +++ b/transport/anytls/session/client.go @@ -66,23 +66,21 @@ func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { var stream *Stream var err error - for i := 0; i < 3; i++ { - session, err = c.findSession(ctx) - if session == nil { - return nil, fmt.Errorf("failed to create session: %w", err) - } - stream, err = session.OpenStream() - if err != nil { - _ = session.Close() - continue - } - break + session = c.getIdleSession() + if session == nil { + session, err = c.createSession(ctx) } - if session == nil || stream == nil { - return nil, fmt.Errorf("too many closed session: %w", err) + if session == nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + stream, err = session.OpenStream() + if err != nil { + session.Close() + return nil, fmt.Errorf("failed to create stream: %w", err) } stream.dieHook = func() { + // If Session is not closed, put this Stream to pool if !session.IsClosed() { select { case <-c.die.Done(): @@ -100,9 +98,7 @@ func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { return stream, nil } -func (c *Client) findSession(ctx context.Context) (*Session, error) { - var idle *Session - +func (c *Client) getIdleSession() (idle *Session) { c.idleSessionLock.Lock() if !c.idleSession.IsEmpty() { it := c.idleSession.Iterate() @@ -110,12 +106,7 @@ func (c *Client) findSession(ctx context.Context) (*Session, error) { c.idleSession.Remove(it.Key()) } c.idleSessionLock.Unlock() - - if idle == nil { - s, err := c.createSession(ctx) - return s, err - } - return idle, nil + return } func (c *Client) createSession(ctx context.Context) (*Session, error) { @@ -127,7 +118,6 @@ func (c *Client) createSession(ctx context.Context) (*Session, error) { session := NewClientSession(underlying, c.padding) session.seq = c.sessionCounter.Add(1) session.dieHook = func() { - //logrus.Debugln("session died", session) c.idleSessionLock.Lock() c.idleSession.Remove(math.MaxUint64 - session.seq) c.idleSessionLock.Unlock() @@ -168,12 +158,11 @@ func (c *Client) idleCleanup() { } func (c *Client) idleCleanupExpTime(expTime time.Time) { - sessionToRemove := make([]*Session, 0, c.idleSession.Len()) + activeCount := 0 + sessionToClose := make([]*Session, 0, c.idleSession.Len()) c.idleSessionLock.Lock() it := c.idleSession.Iterate() - - activeCount := 0 for it.IsNotEnd() { session := it.Value() key := it.Key() @@ -190,12 +179,12 @@ func (c *Client) idleCleanupExpTime(expTime time.Time) { continue } - sessionToRemove = append(sessionToRemove, session) + sessionToClose = append(sessionToClose, session) c.idleSession.Remove(key) } c.idleSessionLock.Unlock() - for _, session := range sessionToRemove { + for _, session := range sessionToClose { session.Close() } } diff --git a/transport/anytls/session/session.go b/transport/anytls/session/session.go index 23c3a087..12c5b682 100644 --- a/transport/anytls/session/session.go +++ b/transport/anytls/session/session.go @@ -90,7 +90,7 @@ func (s *Session) Run() { f := newFrame(cmdSettings, 0) f.data = settings.ToBytes() s.buffering = true - s.writeFrame(f) + s.writeControlFrame(f) go s.recvLoop() } @@ -119,7 +119,7 @@ func (s *Session) Close() error { } s.streamLock.Lock() for _, stream := range s.streams { - stream.Close() + stream.closeLocally() } s.streams = make(map[uint32]*Stream) s.streamLock.Unlock() @@ -138,8 +138,6 @@ func (s *Session) OpenStream() (*Stream, error) { sid := s.streamId.Add(1) stream := newStream(sid, s) - //logrus.Debugln("stream open", sid, s.streams) - if sid >= 2 && s.peerVersion >= 2 { s.synDoneLock.Lock() if s.synDone != nil { @@ -151,7 +149,7 @@ func (s *Session) OpenStream() (*Stream, error) { s.synDoneLock.Unlock() } - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + if _, err := s.writeControlFrame(newFrame(cmdSYN, sid)); err != nil { return nil, err } @@ -207,7 +205,7 @@ func (s *Session) recvLoop() error { if !s.isClient && !receivedSettingsFromClient { f := newFrame(cmdAlert, 0) f.data = []byte("client did not send its settings") - s.writeFrame(f) + s.writeControlFrame(f) return nil } s.streamLock.Lock() @@ -241,18 +239,18 @@ func (s *Session) recvLoop() error { stream, ok := s.streams[sid] s.streamLock.RUnlock() if ok { - stream.CloseWithError(fmt.Errorf("remote: %s", string(buffer))) + stream.closeWithError(fmt.Errorf("remote: %s", string(buffer))) } pool.Put(buffer) } case cmdFIN: - s.streamLock.RLock() + s.streamLock.Lock() stream, ok := s.streams[sid] - s.streamLock.RUnlock() + delete(s.streams, sid) + s.streamLock.Unlock() if ok { - stream.Close() + stream.closeLocally() } - //logrus.Debugln("stream fin", sid, s.streams) case cmdWaste: if hdr.Length() > 0 { buffer := pool.Get(int(hdr.Length())) @@ -274,10 +272,9 @@ func (s *Session) recvLoop() error { m := util.StringMapFromBytes(buffer) paddingF := s.padding.Load() if m["padding-md5"] != paddingF.Md5 { - // logrus.Debugln("remote md5 is", m["padding-md5"]) f := newFrame(cmdUpdatePaddingScheme, 0) f.data = paddingF.RawScheme - _, err = s.writeFrame(f) + _, err = s.writeControlFrame(f) if err != nil { pool.Put(buffer) return err @@ -291,7 +288,7 @@ func (s *Session) recvLoop() error { f.data = util.StringMap{ "v": "2", }.ToBytes() - _, err = s.writeFrame(f) + _, err = s.writeControlFrame(f) if err != nil { pool.Put(buffer) return err @@ -329,7 +326,7 @@ func (s *Session) recvLoop() error { } } case cmdHeartRequest: - if _, err := s.writeFrame(newFrame(cmdHeartResponse, sid)); err != nil { + if _, err := s.writeControlFrame(newFrame(cmdHeartResponse, sid)); err != nil { return err } case cmdHeartResponse: @@ -364,14 +361,31 @@ func (s *Session) streamClosed(sid uint32) error { if s.IsClosed() { return io.ErrClosedPipe } - _, err := s.writeFrame(newFrame(cmdFIN, sid)) + _, err := s.writeControlFrame(newFrame(cmdFIN, sid)) s.streamLock.Lock() delete(s.streams, sid) s.streamLock.Unlock() return err } -func (s *Session) writeFrame(frame frame) (int, error) { +func (s *Session) writeDataFrame(sid uint32, data []byte) (int, error) { + dataLen := len(data) + + buffer := buf.NewSize(dataLen + headerOverHeadSize) + buffer.WriteByte(cmdPSH) + binary.BigEndian.PutUint32(buffer.Extend(4), sid) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen)) + buffer.Write(data) + _, err := s.writeConn(buffer.Bytes()) + buffer.Release() + if err != nil { + return 0, err + } + + return dataLen, nil +} + +func (s *Session) writeControlFrame(frame frame) (int, error) { dataLen := len(frame.data) buffer := buf.NewSize(dataLen + headerOverHeadSize) @@ -379,12 +393,18 @@ func (s *Session) writeFrame(frame frame) (int, error) { binary.BigEndian.PutUint32(buffer.Extend(4), frame.sid) binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen)) buffer.Write(frame.data) + + s.conn.SetWriteDeadline(time.Now().Add(time.Second * 5)) + _, err := s.writeConn(buffer.Bytes()) buffer.Release() if err != nil { + s.Close() return 0, err } + s.conn.SetWriteDeadline(time.Time{}) + return dataLen, nil } diff --git a/transport/anytls/session/stream.go b/transport/anytls/session/stream.go index f7e8de67..9827b0b3 100644 --- a/transport/anytls/session/stream.go +++ b/transport/anytls/session/stream.go @@ -53,21 +53,35 @@ func (s *Stream) Write(b []byte) (n int, err error) { return 0, os.ErrDeadlineExceeded default: } - f := newFrame(cmdPSH, s.id) - f.data = b - n, err = s.sess.writeFrame(f) + if s.dieErr != nil { + return 0, s.dieErr + } + n, err = s.sess.writeDataFrame(s.id, b) return } // Close implements net.Conn func (s *Stream) Close() error { - return s.CloseWithError(io.ErrClosedPipe) + return s.closeWithError(io.ErrClosedPipe) } -func (s *Stream) CloseWithError(err error) error { - // if err != io.ErrClosedPipe { - // logrus.Debugln(err) - // } +// closeLocally only closes Stream and don't notify remote peer +func (s *Stream) closeLocally() { + var once bool + s.dieOnce.Do(func() { + s.dieErr = net.ErrClosed + s.pipeR.Close() + once = true + }) + if once { + if s.dieHook != nil { + s.dieHook() + s.dieHook = nil + } + } +} + +func (s *Stream) closeWithError(err error) error { var once bool s.dieOnce.Do(func() { s.dieErr = err @@ -128,7 +142,7 @@ func (s *Stream) HandshakeFailure(err error) error { if once && err != nil && s.sess.peerVersion >= 2 { f := newFrame(cmdSYNACK, s.id) f.data = []byte(err.Error()) - if _, err := s.sess.writeFrame(f); err != nil { + if _, err := s.sess.writeControlFrame(f); err != nil { return err } } @@ -142,7 +156,7 @@ func (s *Stream) HandshakeSuccess() error { once = true }) if once && s.sess.peerVersion >= 2 { - if _, err := s.sess.writeFrame(newFrame(cmdSYNACK, s.id)); err != nil { + if _, err := s.sess.writeControlFrame(newFrame(cmdSYNACK, s.id)); err != nil { return err } }