summaryrefslogtreecommitdiff
path: root/protocol
diff options
context:
space:
mode:
authorKeuin <[email protected]>2024-03-08 13:39:53 +0800
committerKeuin <[email protected]>2024-03-08 13:42:11 +0800
commit6ee1bbbd1c491f1a6972fd62cf8ab652d4e8a942 (patch)
tree14aec57c64271db5bba33b8bc85f399e499edda8 /protocol
first open-source versionHEADmaster
Diffstat (limited to 'protocol')
-rw-r--r--protocol/client_mode.go103
-rw-r--r--protocol/client_protocol.go186
-rw-r--r--protocol/client_recv.go66
3 files changed, 355 insertions, 0 deletions
diff --git a/protocol/client_mode.go b/protocol/client_mode.go
new file mode 100644
index 0000000..b9ca057
--- /dev/null
+++ b/protocol/client_mode.go
@@ -0,0 +1,103 @@
+package protocol
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+)
+
+type ModeType uint16
+
+const (
+ ModePublish ModeType = 1
+ ModeSubscribe ModeType = 2
+)
+
+type ClientMode interface {
+ Type() ModeType
+ // Consume a message from remote peer. This method should panic if not supported.
+ Consume(reader io.Reader, length int64)
+ sendSelectionRequest(c *Client) error
+}
+
+func Publish(topicID string) ClientMode {
+ return modePublish{
+ topicID: topicID,
+ }
+}
+
+type modePublish struct {
+ topicID string
+}
+
+func (m modePublish) Consume(io.Reader, int64) {
+ // it's the user's responsibility to filter out unsuitable messages
+ panic(errors.New("remote peer should not send data message to publisher"))
+}
+
+func (m modePublish) Type() ModeType {
+ return ModePublish
+}
+
+func (m modePublish) sendSelectionRequest(c *Client) error {
+ if m.topicID == "" {
+ panic(fmt.Errorf("empty subscription topic ID"))
+ }
+ return c.writeFlush("PUB" + m.topicID + "\x00")
+}
+
+func Subscribe(topicIDPattern string, messageConsumer func(reader io.Reader, length int64)) ClientMode {
+ return modeSubscribe{
+ topicIDPattern: topicIDPattern,
+ messageConsumer: messageConsumer,
+ }
+}
+
+type modeSubscribe struct {
+ topicIDPattern string
+ messageConsumer func(reader io.Reader, length int64)
+}
+
+func (m modeSubscribe) Consume(reader io.Reader, length int64) {
+ m.messageConsumer(reader, length)
+}
+
+func (m modeSubscribe) Type() ModeType {
+ return ModeSubscribe
+}
+
+func (m modeSubscribe) sendSelectionRequest(c *Client) error {
+ return c.writeFlush("SUB\x00\x00\x00\x00" + m.topicIDPattern + "\x00")
+}
+
+func readModeSelectionResponse(r *bufio.Reader) error {
+ msg, err := r.ReadString('\x00')
+ if err != nil {
+ return err
+ }
+ msg = msg[:len(msg)-1]
+ if msg == "OK" {
+ return nil
+ }
+ // failed or other protocol errors
+ if msg == "FAILED" {
+ msg2, err := r.ReadString('\x00')
+ if err != nil {
+ return err
+ }
+ msg += ": " + msg2[:len(msg2)-1]
+ }
+ return protocolError{
+ phase: PhaseModeSelection,
+ message: msg,
+ }
+}
+
+func (c *Client) SelectMode(mode ClientMode) error {
+ err := mode.sendSelectionRequest(c)
+ if err != nil {
+ return err
+ }
+ return readModeSelectionResponse(c.rx)
+}
diff --git a/protocol/client_protocol.go b/protocol/client_protocol.go
new file mode 100644
index 0000000..3128527
--- /dev/null
+++ b/protocol/client_protocol.go
@@ -0,0 +1,186 @@
+package protocol
+
+import (
+ "bufio"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+)
+
+type Command string
+
+const (
+ CmdMsg Command = "MSG"
+ CmdNop Command = "NOP"
+ CmdBye Command = "BYE"
+ CmdNil Command = "NIL"
+)
+
+type ConnectionPhase uint16
+
+func (c ConnectionPhase) String() string {
+ s, ok := phaseStrings[c]
+ if ok {
+ return s
+ }
+ return fmt.Sprintf("<Unknown ConnectionPhase %v>", int(c))
+}
+
+const (
+ PhaseHandshake ConnectionPhase = 1
+ PhaseModeSelection ConnectionPhase = 2
+)
+
+var phaseStrings = map[ConnectionPhase]string{
+ PhaseHandshake: "Handshake",
+ PhaseModeSelection: "ModeSelection",
+}
+
+type protocolError struct {
+ phase ConnectionPhase
+ message string
+}
+
+func (p protocolError) Error() string {
+ return fmt.Sprintf("protocol error, phase=%v, server message=%v", p.phase, p.message)
+}
+
+func NewClient(tx *bufio.Writer, rx *bufio.Reader) Client {
+ return Client{
+ tx: tx,
+ rx: rx,
+ }
+}
+
+type Client struct {
+ tx *bufio.Writer
+ rx *bufio.Reader
+}
+
+func (c *Client) read(v interface{}) error {
+ return binary.Read(c.rx, binary.BigEndian, v)
+}
+
+// write given data in write format.
+// Note: this method DOES NOT flush the write buffer!
+func (c *Client) write(v interface{}) error {
+ if v == nil {
+ panic(errors.New("writing nil"))
+ }
+ vv := reflect.ValueOf(v)
+ if vv.Kind() == reflect.String {
+ _, err := c.tx.Write([]byte(vv.String()))
+ return err
+ }
+ if vv.Kind() >= reflect.Int && vv.Kind() <= reflect.Uint64 {
+ return binary.Write(c.tx, binary.BigEndian, v)
+ }
+ panic(fmt.Errorf("unsupported type to write: %v, kind: %v", vv.Type(), vv.Kind()))
+}
+
+func (c *Client) writeFlush(v interface{}) error {
+ err := c.write(v)
+ if err != nil {
+ return err
+ }
+ return c.flush()
+}
+
+func (c *Client) flush() error {
+ return c.tx.Flush()
+}
+
+const handshakeSequence = "PSMB"
+
+func (c *Client) Handshake() error {
+ err := c.write(handshakeSequence)
+ if err != nil {
+ return err
+ }
+ const (
+ protocolVersion uint32 = 2
+ protocolOptions uint32 = 0
+ )
+ err = c.write(protocolVersion)
+ if err != nil {
+ return err
+ }
+ err = c.write(protocolOptions)
+ if err != nil {
+ return err
+ }
+ err = c.flush()
+ if err != nil {
+ return err
+ }
+ msg, err := c.rx.ReadString('\x00')
+ if err != nil {
+ return err
+ }
+ msg = msg[:len(msg)-1]
+ if msg != "OK" {
+ return protocolError{
+ phase: PhaseHandshake,
+ message: msg,
+ }
+ }
+ var serverOptions uint32
+ err = c.read(&serverOptions)
+ if err != nil {
+ return err
+ }
+ if serverOptions != 0 {
+ return fmt.Errorf("invalid server options: %v", serverOptions)
+ }
+ return nil
+}
+
+func (c *Client) Publish(msg io.Reader, n int64) error {
+ err := c.write(CmdMsg)
+ if err != nil {
+ return err
+ }
+ err = c.write(uint64(n))
+ if err != nil {
+ return err
+ }
+ _, err = io.CopyN(c.tx, msg, n)
+ if err != nil {
+ return err
+ }
+ err = c.flush()
+ if err != nil {
+ return err
+ }
+ return err
+}
+
+func (c *Client) PublishBytes(msg []byte) error {
+ err := c.write(CmdMsg)
+ if err != nil {
+ return err
+ }
+ err = c.write(uint64(len(msg)))
+ if err != nil {
+ return err
+ }
+ err = c.write(msg)
+ if err != nil {
+ return err
+ }
+ return c.flush()
+}
+
+func (c *Client) Nop() error {
+ return c.writeFlush(CmdNop)
+}
+
+func (c *Client) Bye() error {
+ return c.writeFlush(CmdBye)
+}
+
+func (c *Client) Nil() error {
+ return c.writeFlush(CmdNil)
+}
diff --git a/protocol/client_recv.go b/protocol/client_recv.go
new file mode 100644
index 0000000..eb406c2
--- /dev/null
+++ b/protocol/client_recv.go
@@ -0,0 +1,66 @@
+package protocol
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+)
+
+func (c *Client) Receive() (Message, error) {
+ var cmd [3]byte
+ _, err := io.ReadFull(c.rx, cmd[:])
+ if err != nil {
+ return nil, err
+ }
+ typ := Command(cmd[:])
+ switch typ {
+ case CmdMsg:
+ var length uint64
+ err = c.read(&length)
+ if err != nil {
+ return nil, err
+ }
+ data := make([]byte, length)
+ _, err = io.ReadFull(c.rx, data)
+ if err != nil {
+ return nil, err
+ }
+ return commandMsg{
+ data: data,
+ }, nil
+ case CmdNop:
+ fallthrough
+ case CmdBye:
+ fallthrough
+ case CmdNil:
+ return trivialCommand(typ), nil
+ default:
+ return nil, fmt.Errorf("unknown command from server: %v", string(cmd[:]))
+ }
+}
+
+type Message interface {
+ Command() Command
+ // Consume calls messageConsumer is this message contains a data message to the upper level.
+ Consume(messageConsumer func(reader io.Reader, length int64))
+}
+
+type trivialCommand Command
+
+func (t trivialCommand) Consume(func(reader io.Reader, length int64)) {}
+
+func (t trivialCommand) Command() Command {
+ return Command(t)
+}
+
+type commandMsg struct {
+ data []byte
+}
+
+func (c commandMsg) Consume(messageConsumer func(reader io.Reader, length int64)) {
+ messageConsumer(bytes.NewReader(c.data), int64(len(c.data)))
+}
+
+func (c commandMsg) Command() Command {
+ return CmdMsg
+}