summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeuin <[email protected]>2022-09-08 01:14:20 +0800
committerKeuin <[email protected]>2022-09-08 01:14:20 +0800
commit493ad1a723f9ade3bd049b156f0dc4d194f8fd3e (patch)
treef9d76845a7ae5382e58b13936d732de2b3e98598
parent1009e88ff752525966708c56190c2dfa32bc9537 (diff)
Completely fix timing of goroutines. Implement graceful shutdown correctly.
-rw-r--r--bilibili/streaming.go6
-rw-r--r--common/retry.go17
-rw-r--r--main.go38
-rw-r--r--recording/runner.go22
4 files changed, 45 insertions, 38 deletions
diff --git a/bilibili/streaming.go b/bilibili/streaming.go
index 5407043..f6ffac9 100644
--- a/bilibili/streaming.go
+++ b/bilibili/streaming.go
@@ -62,11 +62,7 @@ func (b Bilibili) CopyLiveStream(
// blocking copy
n, err := common.Copy(ctx, out, resp.Body)
- if errors.Is(err, context.Canceled) {
- // cancelled by context
- // this error is useless
- err = nil
- } else {
+ if !errors.Is(err, context.Canceled) {
// real error happens
b.error.Printf("Stream copying was interrupted unexpectedly: %v", err)
}
diff --git a/common/retry.go b/common/retry.go
index 6b97ff3..435d407 100644
--- a/common/retry.go
+++ b/common/retry.go
@@ -1,6 +1,7 @@
package common
import (
+ "context"
"log"
"time"
)
@@ -10,6 +11,7 @@ import (
// the last error will be returned.
// If logger is not nil, retry information will be printed to it.
func AutoRetry[T any](
+ ctx context.Context,
supplier func() (T, error),
maxRetryTimes int,
retryInterval time.Duration,
@@ -22,8 +24,16 @@ func AutoRetry[T any](
logger.Printf("Try %v/%v (sleep %vs): %v\n",
i, maxRetryTimes, retryInterval, err)
}
- time.Sleep(retryInterval)
- continue
+ timer := time.NewTimer(retryInterval)
+ select {
+ case <-timer.C:
+ // time to have the next try
+ continue
+ case <-ctx.Done():
+ // context is cancelled
+ var zero T
+ return zero, ctx.Err()
+ }
}
// success
return ret, nil
@@ -31,5 +41,6 @@ func AutoRetry[T any](
if logger != nil {
logger.Printf("Max retry times reached, but it still fails. Last error: %v", err)
}
- return *new(T), err
+ var zero T
+ return zero, err
}
diff --git a/main.go b/main.go
index 60df4d6..dc50251 100644
--- a/main.go
+++ b/main.go
@@ -120,37 +120,25 @@ func main() {
logger := log.Default()
logger.Printf("Starting tasks...")
- chResult := make(chan recording.TaskResult)
wg := sync.WaitGroup{}
+ defer func() {
+ wg.Wait()
+ logger.Println("Stopping YABR...")
+ }()
ctx, cancelTasks := context.WithCancel(context.Background())
for _, task := range tasks {
wg.Add(1)
- go recording.RunTask(
- ctx,
- &wg,
- &task,
- chResult,
- )
+ go recording.RunTask(ctx, &wg, &task)
}
+ // listen Ctrl-C
chSigInt := make(chan os.Signal)
signal.Notify(chSigInt, os.Interrupt)
-loop:
- for {
- select {
- case <-chSigInt:
- logger.Println("YABR is stopped.")
- cancelTasks()
- break loop
- case result := <-chResult:
- err := result.Error
- if err != nil {
- logger.Printf("A task stopped with an error (room %v): %v\n",
- result.Task.RoomId, result.Error)
- } else {
- logger.Printf("Task stopped (room %v): %v\n",
- result.Task.RoomId, result.Task.String())
- }
- }
- }
+ go func() {
+ <-chSigInt
+ cancelTasks()
+ }()
+
+ // block main goroutine on task goroutines
+ wg.Wait()
}
diff --git a/recording/runner.go b/recording/runner.go
index bcea9a7..7b82a5c 100644
--- a/recording/runner.go
+++ b/recording/runner.go
@@ -27,12 +27,14 @@ type TaskResult struct {
// RunTask start a monitor&download task and
// put its execution result into a channel.
-func RunTask(ctx context.Context, wg *sync.WaitGroup, task *TaskConfig, chTaskResult chan<- TaskResult) {
+func RunTask(ctx context.Context, wg *sync.WaitGroup, task *TaskConfig) {
defer wg.Done()
err := doTask(ctx, task)
- chTaskResult <- TaskResult{
- Task: task,
- Error: err,
+ logger := log.Default()
+ if err != nil && !errors.Is(err, context.Canceled) {
+ logger.Printf("A task stopped with an error (room %v): %v\n", task.RoomId, err)
+ } else {
+ logger.Printf("Task stopped (room %v): %v\n", task.RoomId, task.String())
}
}
@@ -108,6 +110,7 @@ func record(
logger.Printf("INFO: Getting room profile...\n")
profile, err := common.AutoRetry(
+ ctx,
func() (bilibili.RoomProfileResponse, error) {
return bi.GetRoomProfile(task.RoomId)
},
@@ -115,6 +118,10 @@ func record(
time.Duration(task.Transport.RetryIntervalSeconds)*time.Second,
logger,
)
+ if errors.Is(err, context.Canceled) {
+ cancelled = true
+ return
+ }
if err != nil {
// still error, abort
logger.Printf("ERROR: Cannot get room information: %v. Stopping current task.\n", err)
@@ -123,6 +130,7 @@ func record(
}
urlInfo, err := common.AutoRetry(
+ ctx,
func() (bilibili.RoomUrlInfoResponse, error) {
return bi.GetStreamingInfo(task.RoomId)
},
@@ -130,6 +138,10 @@ func record(
time.Duration(task.Transport.RetryIntervalSeconds)*time.Second,
logger,
)
+ if errors.Is(err, context.Canceled) {
+ cancelled = true
+ return
+ }
if err != nil {
logger.Printf("ERROR: Cannot get streaming info: %v", err)
cancelled = true
@@ -162,7 +174,7 @@ func record(
logger.Printf("Recording live stream to file \"%v\"...", filePath)
err = bi.CopyLiveStream(ctx, task.RoomId, streamSource, file)
- cancelled = err == nil
+ cancelled = err == nil || errors.Is(err, context.Canceled)
if !cancelled {
// real error happens
logger.Printf("Error when copying live stream: %v\n", err)