diff options
author | Keuin <[email protected]> | 2024-03-09 20:19:35 +0800 |
---|---|---|
committer | Keuin <[email protected]> | 2024-03-09 20:20:20 +0800 |
commit | b933083d20b3db4a3d6a8134efe312eb6ff3d8e2 (patch) | |
tree | 2a86151c77b1eed1596b5f50ce824659e3c1fc22 /observatory/server.go |
Diffstat (limited to 'observatory/server.go')
-rw-r--r-- | observatory/server.go | 261 |
1 files changed, 261 insertions, 0 deletions
diff --git a/observatory/server.go b/observatory/server.go new file mode 100644 index 0000000..ea338f0 --- /dev/null +++ b/observatory/server.go @@ -0,0 +1,261 @@ +package observatory + +import ( + "cmp" + "context" + "encoding/json" + "fmt" + "github.com/go-logr/logr" + "github.com/gorilla/websocket" + "github.com/hit-mc/observatory/config" + "github.com/hit-mc/observatory/protocol" + "io" + "net" + "net/http" + "slices" + "sync" + "time" +) + +func NewCollector( + listen string, + handshakeTimeout time.Duration, + logger logr.Logger, + token string, + targets []config.Target, +) *Collector { + targets2 := make([]protocol.Target, len(targets)) + for i := range targets { + targets2[i] = protocol.Target{ + Host: targets[i].Host, + Port: uint16(targets[i].Port), + } + } + return &Collector{ + listen: listen, + handshakeTimeout: handshakeTimeout, + logger: logger, + stats: newServerStatus(), + token: token, + targets: targets2, + } +} + +type Collector struct { + listen string + handshakeTimeout time.Duration + logger logr.Logger + stats serverStatus + token string + targets []protocol.Target +} + +func (c *Collector) Run(ctx context.Context) error { + ssc, err := json.Marshal(protocol.ServerPushInfo{ + Version: protocol.CurrentServerPushInfoVersion, + Targets: c.targets, + }) + if err != nil { + return fmt.Errorf("error marshalling ServerPushInfo: %w", err) + } + + up := websocket.Upgrader{ + HandshakeTimeout: c.handshakeTimeout, + } + mux := http.NewServeMux() + mux.HandleFunc("/stats", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + w.WriteHeader(405) + return + } + stats := c.stats.GetStats() + data, err := json.Marshal(stats) + if err != nil { + w.WriteHeader(500) + _, _ = io.WriteString(w, fmt.Sprintf("error marshalling JSON: %v", err)) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(data) + }) + mux.HandleFunc("/collect", func(w http.ResponseWriter, r *http.Request) { + conn, err := up.Upgrade(w, r, nil) + if err != nil { + c.logger.Error(err, "error upgrading connection protocol to websocket") + return + } + defer func() { + _ = conn.Close() + }() + token := r.Header.Get(protocol.TokenHeader) + if token != c.token { + c.logger.Error(nil, "reject connection with invalid token", "token", token) + return + } + observerID := r.Header.Get(protocol.ObserverIDHeader) + if observerID == "" { + return + } + c.logger.Info("observer connected successfully", + "endpoint", r.RemoteAddr, "observer_id", observerID) + + const pingInterval = 10 * time.Second + const maxPingWait = 20 * time.Second + conn.SetPongHandler(func(string) error { + _ = conn.SetReadDeadline(time.Now().Add(pingInterval + maxPingWait)) + return nil + }) + + writeMu := &sync.Mutex{} + writeMessage := func(typ int, data []byte) error { + writeMu.Lock() + defer writeMu.Unlock() + return conn.WriteMessage(websocket.BinaryMessage, ssc) + } + + wg := &sync.WaitGroup{} + // send keepalive packets periodically + wg.Add(1) + go func() { + for { + err := writeMessage(websocket.PingMessage, nil) + if err != nil { + c.logger.Error(err, "error pinging client, disconnect", + "observer_id", observerID) + _ = conn.Close() + return + } + sleep(r.Context(), pingInterval) + } + }() + defer wg.Wait() + + // send server-side config + err = writeMessage(websocket.BinaryMessage, ssc) + if err != nil { + c.logger.Error(err, "error writing server-side-config") + return + } + + for { + typ, data, err := conn.ReadMessage() + if err != nil { + c.logger.Error(err, "error reading websocket message") + return + } + if typ != websocket.BinaryMessage { + continue + } + var observation protocol.Observation + err = json.Unmarshal(data, &observation) + if err != nil { + c.logger.Error(err, "error unmarshalling message from client") + return + } + c.stats.Put(&observation, observerID) + } + }) + server := http.Server{ + Addr: c.listen, + Handler: mux, + ReadTimeout: 10 * time.Second, + ReadHeaderTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 300 * time.Second, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + } + go func() { + <-ctx.Done() + _ = server.Close() + }() + err = server.ListenAndServe() + if err != nil { + return fmt.Errorf("error starting HTTP server: %w", err) + } + return nil +} + +func newServerStatus() serverStatus { + return serverStatus{ + status: make(map[protocol.Target][]protocol.SourcedObservation), + } +} + +type serverStatus struct { + mu sync.RWMutex + status map[protocol.Target][]protocol.SourcedObservation +} + +func (s *serverStatus) Put(observation *protocol.Observation, source string) { + entry := protocol.SourcedObservation{ + Observation: *observation, + Source: source, + } + + s.mu.Lock() + defer s.mu.Unlock() + // replace existing entry with the same observer + for i, ob := range s.status[observation.Target] { + if ob.Source == source { + s.status[observation.Target][i] = entry + goto done + } + } + // add new entry if not exist + s.status[observation.Target] = append(s.status[observation.Target], entry) + slices.SortFunc(s.status[observation.Target], func(a, b protocol.SourcedObservation) int { + return cmp.Compare(a.Source, b.Source) + }) +done: + s.purge() +} + +// purge removes state entries lived longer than 180s +// Note: caller should hold write lock +func (s *serverStatus) purge() { + const ttl = 180 * time.Second + t0 := time.Now().Add(-ttl) + for target := range s.status { + ss := s.status[target] + purged := false + for i := range ss { + if time.Time(ss[i].Time).Before(t0) { + ss[i].Time = protocol.TimeStamp{} + purged = true + } + } + if purged { + s.status[target] = shrink(ss) + } + } +} + +func shrink(s []protocol.SourcedObservation) []protocol.SourcedObservation { + i := 0 // slow pointer + j := 0 // fast pointer + for j < len(s) { + if !time.Time(s[j].Time).IsZero() { + i++ + } + j++ + if j < len(s) { + s[i] = s[j] + } + } + if i != j { + return s[:i] + } + return s +} + +func (s *serverStatus) GetStats() protocol.TargetStats { + s.mu.RLock() + defer s.mu.RUnlock() + ret := make(protocol.TargetStats, len(s.status)) + for t, os := range s.status { + ret[t.String()] = os + } + return ret +} |