Chore: adjust mitm proxy

This commit is contained in:
yaling888 2022-04-15 00:29:21 +08:00
parent ca76e5cf0e
commit 6327cf7434
4 changed files with 58 additions and 58 deletions

View File

@ -13,8 +13,6 @@ import (
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
) )
var ErrCertUnsupported = errors.New("tls: client cert unsupported")
func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client { func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client {
return &http.Client{ return &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{

View File

@ -13,7 +13,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -63,7 +62,7 @@ readLoop:
session := newSession(conn, request, response) session := newSession(conn, request, response)
source = parseSourceAddress(session.request, c, source) source = parseSourceAddress(session.request, c.RemoteAddr(), source)
session.request.RemoteAddr = source.String() session.request.RemoteAddr = source.String()
if !trusted { if !trusted {
@ -80,42 +79,45 @@ readLoop:
break readLoop // close connection break readLoop // close connection
} }
if couldBeWithManInTheMiddleAttack(session.request.URL.Host, opt) { if strings.HasSuffix(session.request.URL.Host, ":80") {
b := make([]byte, 1) goto readLoop
if _, err = session.conn.Read(b); err != nil { }
handleError(opt, session, err)
b := make([]byte, 1)
if _, err = session.conn.Read(b); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
buff := make([]byte, session.conn.(*N.BufferedConn).Buffered())
if _, err = session.conn.Read(buff); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
mrConn := &multiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
}
// TLS handshake.
if b[0] == 0x16 {
// TODO serve by generic host name maybe better?
tlsConn := tls.Server(mrConn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client
if err = tlsConn.Handshake(); err != nil {
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
_ = writeResponse(session, false)
break readLoop // close connection break readLoop // close connection
} }
buff := make([]byte, session.conn.(*N.BufferedConn).Buffered()) c = tlsConn
_, _ = session.conn.Read(buff) } else {
c = mrConn
mrc := &multiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
}
// TLS handshake.
if b[0] == 0x16 {
// TODO serve by generic host name maybe better?
tlsConn := tls.Server(mrc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client
if err = tlsConn.Handshake(); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
c = tlsConn
goto startOver // hijack and decrypt tls connection
}
// maybe it's the others encrypted connection
in <- inbound.NewHTTPS(session.request, mrc)
} }
// maybe it's a http connection goto startOver
goto readLoop
} }
prepareRequest(c, session.request) prepareRequest(c, session.request)
@ -149,7 +151,7 @@ readLoop:
session.request.RequestURI = "" session.request.RequestURI = ""
if session.request.URL.Host == "" { if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(errors.New("invalid URL")) session.response = session.NewErrorResponse(ErrInvalidURL)
} else { } else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
@ -202,9 +204,7 @@ func writeResponse(session *Session, keepAlive bool) error {
session.response.Header.Set("Keep-Alive", "timeout=25") session.response.Header.Set("Keep-Alive", "timeout=25")
} }
// session.response.Close = !keepAlive // let handler do it return session.writeResponse()
return session.response.Write(session.conn)
} }
func handleApiRequest(session *Session, opt *Option) error { func handleApiRequest(session *Session, opt *Option) error {
@ -224,7 +224,7 @@ func handleApiRequest(session *Session, opt *Option) error {
session.response.Header.Set("Content-Type", "application/x-x509-ca-cert") session.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
session.response.ContentLength = int64(len(b)) session.response.ContentLength = int64(len(b))
return session.response.Write(session.conn) return session.writeResponse()
} }
b := `<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN"> b := `<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
@ -254,7 +254,7 @@ func handleApiRequest(session *Session, opt *Option) error {
session.response.Header.Set("Content-Type", "text/html;charset=utf-8") session.response.Header.Set("Content-Type", "text/html;charset=utf-8")
session.response.ContentLength = int64(len(b)) session.response.ContentLength = int64(len(b))
return session.response.Write(session.conn) return session.writeResponse()
} }
func handleError(opt *Option, session *Session, err error) { func handleError(opt *Option, session *Session, err error) {
@ -292,38 +292,26 @@ func prepareRequest(conn net.Conn, request *http.Request) {
H.RemoveExtraHTTPHostPort(request) H.RemoveExtraHTTPHostPort(request)
} }
func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool { func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr {
if opt.CertConfig == nil {
return false
}
if _, port, err := net.SplitHostPort(hostname); err == nil && (port == "443" || port == "8443") {
return true
}
return false
}
func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr {
if source != nil { if source != nil {
return source return source
} }
sourceAddress := req.Header.Get("Origin-Request-Source-Address") sourceAddress := req.Header.Get("Origin-Request-Source-Address")
if sourceAddress == "" { if sourceAddress == "" {
return c.RemoteAddr() return connSource
} }
req.Header.Del("Origin-Request-Source-Address") req.Header.Del("Origin-Request-Source-Address")
host, port, err := net.SplitHostPort(sourceAddress) host, port, err := net.SplitHostPort(sourceAddress)
if err != nil { if err != nil {
return c.RemoteAddr() return connSource
} }
p, err := strconv.ParseUint(port, 10, 16) p, err := strconv.ParseUint(port, 10, 16)
if err != nil { if err != nil {
return c.RemoteAddr() return connSource
} }
if ip := net.ParseIP(host); ip != nil { if ip := net.ParseIP(host); ip != nil {
@ -333,7 +321,7 @@ func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr
} }
} }
return c.RemoteAddr() return connSource
} }
func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client { func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client {

View File

@ -39,6 +39,13 @@ func (s *Session) NewErrorResponse(err error) *http.Response {
return NewErrorResponse(s.request, err) return NewErrorResponse(s.request, err)
} }
func (s *Session) writeResponse() error {
if s.response == nil {
return ErrInvalidResponse
}
return s.response.Write(s.conn)
}
func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session { func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session {
return &Session{ return &Session{
conn: conn, conn: conn,

View File

@ -3,6 +3,7 @@ package mitm
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -14,6 +15,12 @@ import (
"golang.org/x/text/transform" "golang.org/x/text/transform"
) )
var (
ErrCertUnsupported = errors.New("tls: client cert unsupported")
ErrInvalidResponse = errors.New("invalid response")
ErrInvalidURL = errors.New("invalid URL")
)
type multiReaderConn struct { type multiReaderConn struct {
net.Conn net.Conn
reader io.Reader reader io.Reader