summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bilibili/netprobe.go26
-rw-r--r--main.go14
2 files changed, 35 insertions, 5 deletions
diff --git a/bilibili/netprobe.go b/bilibili/netprobe.go
index ac11a77..419cc20 100644
--- a/bilibili/netprobe.go
+++ b/bilibili/netprobe.go
@@ -2,17 +2,35 @@ package bilibili
import (
"context"
+ "fmt"
"net"
)
type IpNetType string
var (
- IPv6Net IpNetType = "tcp6"
- IPv4Net IpNetType = "tcp4"
- IP64 IpNetType = "tcp"
+ IPv6Net IpNetType = "ipv6"
+ IPv4Net IpNetType = "ipv4"
+ IP64 IpNetType = "any"
)
+// GetDialNetString returns the string accepted by net.Dialer::DialContext
+func (t IpNetType) GetDialNetString() string {
+ switch t {
+ case IPv4Net:
+ return "tcp4"
+ case IPv6Net:
+ return "tcp6"
+ case IP64:
+ return "tcp"
+ }
+ return ""
+}
+
+func (t IpNetType) String() string {
+ return fmt.Sprintf("%s(%s)", string(t), t.GetDialNetString())
+}
+
type netContext = func(context.Context, string, string) (net.Conn, error)
type netProbe struct {
@@ -36,6 +54,6 @@ func (p *netProbe) NextNetworkType(dialer net.Dialer) (netContext, IpNetType) {
network := p.list[p.i]
p.i++
return func(ctx context.Context, _, addr string) (net.Conn, error) {
- return dialer.DialContext(ctx, string(network), addr)
+ return dialer.DialContext(ctx, network.GetDialNetString(), addr)
}, network
}
diff --git a/main.go b/main.go
index d6a433e..c130301 100644
--- a/main.go
+++ b/main.go
@@ -9,13 +9,16 @@ import (
"context"
"fmt"
"github.com/akamensky/argparse"
+ "github.com/keuin/slbr/bilibili"
"github.com/keuin/slbr/common"
"github.com/keuin/slbr/logging"
"github.com/keuin/slbr/recording"
+ "github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"log"
"os"
"os/signal"
+ "reflect"
"sync"
"syscall"
)
@@ -104,7 +107,16 @@ func getTasks() (tasks []recording.TaskConfig) {
return
}
var gc GlobalConfig
- err = viper.Unmarshal(&gc)
+ netType := reflect.TypeOf(bilibili.IP64)
+ err = viper.Unmarshal(&gc, func(conf *mapstructure.DecoderConfig) {
+ conf.DecodeHook = func(from reflect.Value, to reflect.Value) (interface{}, error) {
+ if to.Type() == netType &&
+ bilibili.IpNetType(from.String()).GetDialNetString() == "" {
+ return nil, fmt.Errorf("invalid IpNetType: %v", from.String())
+ }
+ return from.Interface(), nil
+ }
+ })
if err != nil {
err = fmt.Errorf("cannot parse config file \"%v\": %w", configFile, err)
return