From 493ad1a723f9ade3bd049b156f0dc4d194f8fd3e Mon Sep 17 00:00:00 2001 From: Keuin Date: Thu, 8 Sep 2022 01:14:20 +0800 Subject: Completely fix timing of goroutines. Implement graceful shutdown correctly. --- bilibili/streaming.go | 6 +----- common/retry.go | 17 ++++++++++++++--- main.go | 38 +++++++++++++------------------------- recording/runner.go | 22 +++++++++++++++++----- 4 files changed, 45 insertions(+), 38 deletions(-) diff --git a/bilibili/streaming.go b/bilibili/streaming.go index 5407043..f6ffac9 100644 --- a/bilibili/streaming.go +++ b/bilibili/streaming.go @@ -62,11 +62,7 @@ func (b Bilibili) CopyLiveStream( // blocking copy n, err := common.Copy(ctx, out, resp.Body) - if errors.Is(err, context.Canceled) { - // cancelled by context - // this error is useless - err = nil - } else { + if !errors.Is(err, context.Canceled) { // real error happens b.error.Printf("Stream copying was interrupted unexpectedly: %v", err) } diff --git a/common/retry.go b/common/retry.go index 6b97ff3..435d407 100644 --- a/common/retry.go +++ b/common/retry.go @@ -1,6 +1,7 @@ package common import ( + "context" "log" "time" ) @@ -10,6 +11,7 @@ import ( // the last error will be returned. // If logger is not nil, retry information will be printed to it. func AutoRetry[T any]( + ctx context.Context, supplier func() (T, error), maxRetryTimes int, retryInterval time.Duration, @@ -22,8 +24,16 @@ func AutoRetry[T any]( logger.Printf("Try %v/%v (sleep %vs): %v\n", i, maxRetryTimes, retryInterval, err) } - time.Sleep(retryInterval) - continue + timer := time.NewTimer(retryInterval) + select { + case <-timer.C: + // time to have the next try + continue + case <-ctx.Done(): + // context is cancelled + var zero T + return zero, ctx.Err() + } } // success return ret, nil @@ -31,5 +41,6 @@ func AutoRetry[T any]( if logger != nil { logger.Printf("Max retry times reached, but it still fails. Last error: %v", err) } - return *new(T), err + var zero T + return zero, err } diff --git a/main.go b/main.go index 60df4d6..dc50251 100644 --- a/main.go +++ b/main.go @@ -120,37 +120,25 @@ func main() { logger := log.Default() logger.Printf("Starting tasks...") - chResult := make(chan recording.TaskResult) wg := sync.WaitGroup{} + defer func() { + wg.Wait() + logger.Println("Stopping YABR...") + }() ctx, cancelTasks := context.WithCancel(context.Background()) for _, task := range tasks { wg.Add(1) - go recording.RunTask( - ctx, - &wg, - &task, - chResult, - ) + go recording.RunTask(ctx, &wg, &task) } + // listen Ctrl-C chSigInt := make(chan os.Signal) signal.Notify(chSigInt, os.Interrupt) -loop: - for { - select { - case <-chSigInt: - logger.Println("YABR is stopped.") - cancelTasks() - break loop - case result := <-chResult: - err := result.Error - if err != nil { - logger.Printf("A task stopped with an error (room %v): %v\n", - result.Task.RoomId, result.Error) - } else { - logger.Printf("Task stopped (room %v): %v\n", - result.Task.RoomId, result.Task.String()) - } - } - } + go func() { + <-chSigInt + cancelTasks() + }() + + // block main goroutine on task goroutines + wg.Wait() } diff --git a/recording/runner.go b/recording/runner.go index bcea9a7..7b82a5c 100644 --- a/recording/runner.go +++ b/recording/runner.go @@ -27,12 +27,14 @@ type TaskResult struct { // RunTask start a monitor&download task and // put its execution result into a channel. -func RunTask(ctx context.Context, wg *sync.WaitGroup, task *TaskConfig, chTaskResult chan<- TaskResult) { +func RunTask(ctx context.Context, wg *sync.WaitGroup, task *TaskConfig) { defer wg.Done() err := doTask(ctx, task) - chTaskResult <- TaskResult{ - Task: task, - Error: err, + logger := log.Default() + if err != nil && !errors.Is(err, context.Canceled) { + logger.Printf("A task stopped with an error (room %v): %v\n", task.RoomId, err) + } else { + logger.Printf("Task stopped (room %v): %v\n", task.RoomId, task.String()) } } @@ -108,6 +110,7 @@ func record( logger.Printf("INFO: Getting room profile...\n") profile, err := common.AutoRetry( + ctx, func() (bilibili.RoomProfileResponse, error) { return bi.GetRoomProfile(task.RoomId) }, @@ -115,6 +118,10 @@ func record( time.Duration(task.Transport.RetryIntervalSeconds)*time.Second, logger, ) + if errors.Is(err, context.Canceled) { + cancelled = true + return + } if err != nil { // still error, abort logger.Printf("ERROR: Cannot get room information: %v. Stopping current task.\n", err) @@ -123,6 +130,7 @@ func record( } urlInfo, err := common.AutoRetry( + ctx, func() (bilibili.RoomUrlInfoResponse, error) { return bi.GetStreamingInfo(task.RoomId) }, @@ -130,6 +138,10 @@ func record( time.Duration(task.Transport.RetryIntervalSeconds)*time.Second, logger, ) + if errors.Is(err, context.Canceled) { + cancelled = true + return + } if err != nil { logger.Printf("ERROR: Cannot get streaming info: %v", err) cancelled = true @@ -162,7 +174,7 @@ func record( logger.Printf("Recording live stream to file \"%v\"...", filePath) err = bi.CopyLiveStream(ctx, task.RoomId, streamSource, file) - cancelled = err == nil + cancelled = err == nil || errors.Is(err, context.Canceled) if !cancelled { // real error happens logger.Printf("Error when copying live stream: %v\n", err) -- cgit v1.2.3