From 12c856a10cb9c59ae97504ce0fcd9fdb044bdd14 Mon Sep 17 00:00:00 2001 From: Keuin Date: Thu, 15 Sep 2022 01:36:12 +0800 Subject: Use io.CopyN to utilize zero copy technique. --- bilibili/streaming.go | 29 ++++++++++++++++++++++------- main.go | 12 +++++++++--- recording/config.go | 2 +- recording/runner.go | 10 +--------- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/bilibili/streaming.go b/bilibili/streaming.go index 0bda347..21bca01 100644 --- a/bilibili/streaming.go +++ b/bilibili/streaming.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/keuin/slbr/common" + "io" "net/http" "os" "strings" @@ -16,8 +17,7 @@ func (b Bilibili) CopyLiveStream( roomId common.RoomId, stream StreamingUrlInfo, out *os.File, - buffer []byte, - readChunkSize int, + bufSize int64, ) (err error) { url := stream.URL if !strings.HasPrefix(url, "https://") && @@ -53,15 +53,30 @@ func (b Bilibili) CopyLiveStream( defer func() { _ = resp.Body.Close() }() b.logger.Info("Copying live stream...") - // blocking copy - n, err := common.CopyToFileWithBuffer(ctx, out, resp.Body, buffer, false, uint(len(buffer)/readChunkSize)) - if err != nil && !errors.Is(err, context.Canceled) { - b.logger.Error("Stream copying was interrupted unexpectedly: %v", err) + var n int64 + + // blocking copy +copyLoop: + for err == nil { + select { + case <-ctx.Done(): + // cancelled + err = ctx.Err() + break copyLoop + default: + var sz int64 + sz, err = io.CopyN(out, resp.Body, bufSize) + n += sz + } } - if err == nil { + if errors.Is(err, context.Canceled) { + b.logger.Info("Stop copying...") + } else if errors.Is(err, io.EOF) { b.logger.Info("The live is ended. (room %v)", roomId) + } else { + b.logger.Error("Stream copying was interrupted unexpectedly: %v", err) } b.logger.Info("Total downloaded: %v", common.PrettyBytes(uint64(n))) diff --git a/main.go b/main.go index c130301..84f8a5c 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,8 @@ import ( "syscall" ) +const defaultDiskBufSize = uint64(1024 * 1024) // 1MiB + var globalConfig *GlobalConfig func getTasks() (tasks []recording.TaskConfig) { @@ -67,7 +69,8 @@ func getTasks() (tasks []recording.TaskConfig) { &argparse.Options{ Required: false, Help: "Specify disk write buffer size (bytes). " + - "The real minimum buffer size is determined by OS", + "The real minimum buffer size is determined by OS. " + + "Setting this to a large value may make stopping take a long time", Default: 4194304, }, ) @@ -129,13 +132,16 @@ func getTasks() (tasks []recording.TaskConfig) { taskCount := len(*rooms) tasks = make([]recording.TaskConfig, taskCount) saveTo := common.Zeroable[string](*saveToPtr).OrElse(".") - diskBufSize := *diskBufSizePtr + diskBufSize := uint64(*diskBufSizePtr) + if *diskBufSizePtr <= 0 { + diskBufSize = defaultDiskBufSize + } for i := 0; i < taskCount; i++ { tasks[i] = recording.TaskConfig{ RoomId: common.RoomId((*rooms)[i]), Transport: recording.DefaultTransportConfig(), Download: recording.DownloadConfig{ - DiskWriteBufferBytes: diskBufSize, + DiskWriteBufferBytes: int64(diskBufSize), SaveDirectory: saveTo, }, } diff --git a/recording/config.go b/recording/config.go index ff3ae2a..7dbf189 100644 --- a/recording/config.go +++ b/recording/config.go @@ -22,7 +22,7 @@ type TransportConfig struct { type DownloadConfig struct { SaveDirectory string `mapstructure:"save_directory"` - DiskWriteBufferBytes int `mapstructure:"disk_write_buffer_bytes"` + DiskWriteBufferBytes int64 `mapstructure:"disk_write_buffer_bytes"` UseSpecialExtNameBeforeFinishing bool `mapstructure:"use_special_ext_name_when_downloading"` } diff --git a/recording/runner.go b/recording/runner.go index 7fdffac..47c6764 100644 --- a/recording/runner.go +++ b/recording/runner.go @@ -26,7 +26,6 @@ type TaskResult struct { Error error } -const kReadChunkSize = 1024 * 1024 const kSpecialExtName = "partial" var errLiveEnded = NewRecoverableTaskError("live is ended", nil) @@ -310,16 +309,9 @@ func record( defer func() { _ = file.Close() }() writeBufferSize := task.Download.DiskWriteBufferBytes - if writeBufferSize < kReadChunkSize { - writeBufferSize = kReadChunkSize - } - if mod := writeBufferSize % kReadChunkSize; mod != 0 { - writeBufferSize += kReadChunkSize - mod - } - writeBuffer := make([]byte, writeBufferSize) logger.Info("Write buffer size: %v byte", writeBufferSize) logger.Info("Recording live stream to file \"%v\"...", filePath) - err = bi.CopyLiveStream(ctx, task.RoomId, streamSource, file, writeBuffer, kReadChunkSize) + err = bi.CopyLiveStream(ctx, task.RoomId, streamSource, file, writeBufferSize) if errors.Is(err, context.Canceled) || err == nil { return err } -- cgit v1.2.3