From 2b1e0dbeab6227f67e2ced780f325938202d51c3 Mon Sep 17 00:00:00 2001 From: Keuin Date: Mon, 12 Sep 2022 04:42:22 +0800 Subject: Improve config `allowed_network_types`. - Rename values to "ipv4", "ipv6" and "any". - Validate them when parsing. --- bilibili/netprobe.go | 26 ++++++++++++++++++++++---- main.go | 14 +++++++++++++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/bilibili/netprobe.go b/bilibili/netprobe.go index ac11a77..419cc20 100644 --- a/bilibili/netprobe.go +++ b/bilibili/netprobe.go @@ -2,17 +2,35 @@ package bilibili import ( "context" + "fmt" "net" ) type IpNetType string var ( - IPv6Net IpNetType = "tcp6" - IPv4Net IpNetType = "tcp4" - IP64 IpNetType = "tcp" + IPv6Net IpNetType = "ipv6" + IPv4Net IpNetType = "ipv4" + IP64 IpNetType = "any" ) +// GetDialNetString returns the string accepted by net.Dialer::DialContext +func (t IpNetType) GetDialNetString() string { + switch t { + case IPv4Net: + return "tcp4" + case IPv6Net: + return "tcp6" + case IP64: + return "tcp" + } + return "" +} + +func (t IpNetType) String() string { + return fmt.Sprintf("%s(%s)", string(t), t.GetDialNetString()) +} + type netContext = func(context.Context, string, string) (net.Conn, error) type netProbe struct { @@ -36,6 +54,6 @@ func (p *netProbe) NextNetworkType(dialer net.Dialer) (netContext, IpNetType) { network := p.list[p.i] p.i++ return func(ctx context.Context, _, addr string) (net.Conn, error) { - return dialer.DialContext(ctx, string(network), addr) + return dialer.DialContext(ctx, network.GetDialNetString(), addr) }, network } diff --git a/main.go b/main.go index d6a433e..c130301 100644 --- a/main.go +++ b/main.go @@ -9,13 +9,16 @@ import ( "context" "fmt" "github.com/akamensky/argparse" + "github.com/keuin/slbr/bilibili" "github.com/keuin/slbr/common" "github.com/keuin/slbr/logging" "github.com/keuin/slbr/recording" + "github.com/mitchellh/mapstructure" "github.com/spf13/viper" "log" "os" "os/signal" + "reflect" "sync" "syscall" ) @@ -104,7 +107,16 @@ func getTasks() (tasks []recording.TaskConfig) { return } var gc GlobalConfig - err = viper.Unmarshal(&gc) + netType := reflect.TypeOf(bilibili.IP64) + err = viper.Unmarshal(&gc, func(conf *mapstructure.DecoderConfig) { + conf.DecodeHook = func(from reflect.Value, to reflect.Value) (interface{}, error) { + if to.Type() == netType && + bilibili.IpNetType(from.String()).GetDialNetString() == "" { + return nil, fmt.Errorf("invalid IpNetType: %v", from.String()) + } + return from.Interface(), nil + } + }) if err != nil { err = fmt.Errorf("cannot parse config file \"%v\": %w", configFile, err) return -- cgit v1.2.3