diff --git a/common/contextutils/afterfunc_compact.go b/common/contextutils/afterfunc_compact.go new file mode 100644 index 00000000..2400ddb3 --- /dev/null +++ b/common/contextutils/afterfunc_compact.go @@ -0,0 +1,31 @@ +package contextutils + +import ( + "context" + "sync" +) + +func afterFunc(ctx context.Context, f func()) (stop func() bool) { + stopc := make(chan struct{}) + once := sync.Once{} // either starts running f or stops f from running + if ctx.Done() != nil { + go func() { + select { + case <-ctx.Done(): + once.Do(func() { + go f() + }) + case <-stopc: + } + }() + } + + return func() bool { + stopped := false + once.Do(func() { + stopped = true + close(stopc) + }) + return stopped + } +} diff --git a/common/contextutils/afterfunc_go120.go b/common/contextutils/afterfunc_go120.go new file mode 100644 index 00000000..6ff22bda --- /dev/null +++ b/common/contextutils/afterfunc_go120.go @@ -0,0 +1,11 @@ +//go:build !go1.21 + +package contextutils + +import ( + "context" +) + +func AfterFunc(ctx context.Context, f func()) (stop func() bool) { + return afterFunc(ctx, f) +} diff --git a/common/contextutils/afterfunc_go121.go b/common/contextutils/afterfunc_go121.go new file mode 100644 index 00000000..b9d4c1fa --- /dev/null +++ b/common/contextutils/afterfunc_go121.go @@ -0,0 +1,9 @@ +//go:build go1.21 + +package contextutils + +import "context" + +func AfterFunc(ctx context.Context, f func()) (stop func() bool) { + return context.AfterFunc(ctx, f) +} diff --git a/common/contextutils/afterfunc_test.go b/common/contextutils/afterfunc_test.go new file mode 100644 index 00000000..05f6a055 --- /dev/null +++ b/common/contextutils/afterfunc_test.go @@ -0,0 +1,100 @@ +package contextutils + +import ( + "context" + "testing" + "time" +) + +const ( + shortDuration = 1 * time.Millisecond // a reasonable duration to block in a test + veryLongDuration = 1000 * time.Hour // an arbitrary upper bound on the test's running time +) + +func TestAfterFuncCalledAfterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + donec := make(chan struct{}) + stop := afterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + t.Fatalf("AfterFunc called before context is done") + case <-time.After(shortDuration): + } + cancel() + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } + if stop() { + t.Fatalf("stop() = true, want false") + } +} + +func TestAfterFuncCalledAfterTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), shortDuration) + defer cancel() + donec := make(chan struct{}) + afterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } +} + +func TestAfterFuncCalledImmediately(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + donec := make(chan struct{}) + afterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called for already-canceled context") + } +} + +func TestAfterFuncNotCalledAfterStop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + donec := make(chan struct{}) + stop := afterFunc(ctx, func() { + close(donec) + }) + if !stop() { + t.Fatalf("stop() = false, want true") + } + cancel() + select { + case <-donec: + t.Fatalf("AfterFunc called for already-canceled context") + case <-time.After(shortDuration): + } + if stop() { + t.Fatalf("stop() = true, want false") + } +} + +// This test verifies that canceling a context does not block waiting for AfterFuncs to finish. +func TestAfterFuncCalledAsynchronously(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + donec := make(chan struct{}) + stop := afterFunc(ctx, func() { + // The channel send blocks until donec is read from. + donec <- struct{}{} + }) + defer stop() + cancel() + // After cancel returns, read from donec and unblock the AfterFunc. + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } +} diff --git a/common/net/context.go b/common/net/context.go index 917028d1..ef0e9faf 100644 --- a/common/net/context.go +++ b/common/net/context.go @@ -3,29 +3,26 @@ package net import ( "context" "net" + + "github.com/metacubex/mihomo/common/contextutils" ) // SetupContextForConn is a helper function that starts connection I/O interrupter goroutine. func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) { - var ( - quit = make(chan struct{}) - interrupt = make(chan error, 1) - ) - go func() { - select { - case <-quit: - interrupt <- nil - case <-ctx.Done(): - // Close the connection, discarding the error - _ = conn.Close() - interrupt <- ctx.Err() - } - }() + stopc := make(chan struct{}) + stop := contextutils.AfterFunc(ctx, func() { + // Close the connection, discarding the error + _ = conn.Close() + close(stopc) + }) return func(inputErr *error) { - close(quit) - if ctxErr := <-interrupt; ctxErr != nil && inputErr != nil { - // Return context error to user. - inputErr = &ctxErr + if !stop() { + // The AfterFunc was started, wait for it to complete. + <-stopc + if ctxErr := ctx.Err(); ctxErr != nil && inputErr != nil { + // Return context error to user. + inputErr = &ctxErr + } } } }