diff options
author | Keuin <[email protected]> | 2024-03-08 13:39:53 +0800 |
---|---|---|
committer | Keuin <[email protected]> | 2024-03-08 13:42:11 +0800 |
commit | 6ee1bbbd1c491f1a6972fd62cf8ab652d4e8a942 (patch) | |
tree | 14aec57c64271db5bba33b8bc85f399e499edda8 /protocol |
Diffstat (limited to 'protocol')
-rw-r--r-- | protocol/client_mode.go | 103 | ||||
-rw-r--r-- | protocol/client_protocol.go | 186 | ||||
-rw-r--r-- | protocol/client_recv.go | 66 |
3 files changed, 355 insertions, 0 deletions
diff --git a/protocol/client_mode.go b/protocol/client_mode.go new file mode 100644 index 0000000..b9ca057 --- /dev/null +++ b/protocol/client_mode.go @@ -0,0 +1,103 @@ +package protocol + +import ( + "bufio" + "errors" + "fmt" + "io" +) + +type ModeType uint16 + +const ( + ModePublish ModeType = 1 + ModeSubscribe ModeType = 2 +) + +type ClientMode interface { + Type() ModeType + // Consume a message from remote peer. This method should panic if not supported. + Consume(reader io.Reader, length int64) + sendSelectionRequest(c *Client) error +} + +func Publish(topicID string) ClientMode { + return modePublish{ + topicID: topicID, + } +} + +type modePublish struct { + topicID string +} + +func (m modePublish) Consume(io.Reader, int64) { + // it's the user's responsibility to filter out unsuitable messages + panic(errors.New("remote peer should not send data message to publisher")) +} + +func (m modePublish) Type() ModeType { + return ModePublish +} + +func (m modePublish) sendSelectionRequest(c *Client) error { + if m.topicID == "" { + panic(fmt.Errorf("empty subscription topic ID")) + } + return c.writeFlush("PUB" + m.topicID + "\x00") +} + +func Subscribe(topicIDPattern string, messageConsumer func(reader io.Reader, length int64)) ClientMode { + return modeSubscribe{ + topicIDPattern: topicIDPattern, + messageConsumer: messageConsumer, + } +} + +type modeSubscribe struct { + topicIDPattern string + messageConsumer func(reader io.Reader, length int64) +} + +func (m modeSubscribe) Consume(reader io.Reader, length int64) { + m.messageConsumer(reader, length) +} + +func (m modeSubscribe) Type() ModeType { + return ModeSubscribe +} + +func (m modeSubscribe) sendSelectionRequest(c *Client) error { + return c.writeFlush("SUB\x00\x00\x00\x00" + m.topicIDPattern + "\x00") +} + +func readModeSelectionResponse(r *bufio.Reader) error { + msg, err := r.ReadString('\x00') + if err != nil { + return err + } + msg = msg[:len(msg)-1] + if msg == "OK" { + return nil + } + // failed or other protocol errors + if msg == "FAILED" { + msg2, err := r.ReadString('\x00') + if err != nil { + return err + } + msg += ": " + msg2[:len(msg2)-1] + } + return protocolError{ + phase: PhaseModeSelection, + message: msg, + } +} + +func (c *Client) SelectMode(mode ClientMode) error { + err := mode.sendSelectionRequest(c) + if err != nil { + return err + } + return readModeSelectionResponse(c.rx) +} diff --git a/protocol/client_protocol.go b/protocol/client_protocol.go new file mode 100644 index 0000000..3128527 --- /dev/null +++ b/protocol/client_protocol.go @@ -0,0 +1,186 @@ +package protocol + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "reflect" +) + +type Command string + +const ( + CmdMsg Command = "MSG" + CmdNop Command = "NOP" + CmdBye Command = "BYE" + CmdNil Command = "NIL" +) + +type ConnectionPhase uint16 + +func (c ConnectionPhase) String() string { + s, ok := phaseStrings[c] + if ok { + return s + } + return fmt.Sprintf("<Unknown ConnectionPhase %v>", int(c)) +} + +const ( + PhaseHandshake ConnectionPhase = 1 + PhaseModeSelection ConnectionPhase = 2 +) + +var phaseStrings = map[ConnectionPhase]string{ + PhaseHandshake: "Handshake", + PhaseModeSelection: "ModeSelection", +} + +type protocolError struct { + phase ConnectionPhase + message string +} + +func (p protocolError) Error() string { + return fmt.Sprintf("protocol error, phase=%v, server message=%v", p.phase, p.message) +} + +func NewClient(tx *bufio.Writer, rx *bufio.Reader) Client { + return Client{ + tx: tx, + rx: rx, + } +} + +type Client struct { + tx *bufio.Writer + rx *bufio.Reader +} + +func (c *Client) read(v interface{}) error { + return binary.Read(c.rx, binary.BigEndian, v) +} + +// write given data in write format. +// Note: this method DOES NOT flush the write buffer! +func (c *Client) write(v interface{}) error { + if v == nil { + panic(errors.New("writing nil")) + } + vv := reflect.ValueOf(v) + if vv.Kind() == reflect.String { + _, err := c.tx.Write([]byte(vv.String())) + return err + } + if vv.Kind() >= reflect.Int && vv.Kind() <= reflect.Uint64 { + return binary.Write(c.tx, binary.BigEndian, v) + } + panic(fmt.Errorf("unsupported type to write: %v, kind: %v", vv.Type(), vv.Kind())) +} + +func (c *Client) writeFlush(v interface{}) error { + err := c.write(v) + if err != nil { + return err + } + return c.flush() +} + +func (c *Client) flush() error { + return c.tx.Flush() +} + +const handshakeSequence = "PSMB" + +func (c *Client) Handshake() error { + err := c.write(handshakeSequence) + if err != nil { + return err + } + const ( + protocolVersion uint32 = 2 + protocolOptions uint32 = 0 + ) + err = c.write(protocolVersion) + if err != nil { + return err + } + err = c.write(protocolOptions) + if err != nil { + return err + } + err = c.flush() + if err != nil { + return err + } + msg, err := c.rx.ReadString('\x00') + if err != nil { + return err + } + msg = msg[:len(msg)-1] + if msg != "OK" { + return protocolError{ + phase: PhaseHandshake, + message: msg, + } + } + var serverOptions uint32 + err = c.read(&serverOptions) + if err != nil { + return err + } + if serverOptions != 0 { + return fmt.Errorf("invalid server options: %v", serverOptions) + } + return nil +} + +func (c *Client) Publish(msg io.Reader, n int64) error { + err := c.write(CmdMsg) + if err != nil { + return err + } + err = c.write(uint64(n)) + if err != nil { + return err + } + _, err = io.CopyN(c.tx, msg, n) + if err != nil { + return err + } + err = c.flush() + if err != nil { + return err + } + return err +} + +func (c *Client) PublishBytes(msg []byte) error { + err := c.write(CmdMsg) + if err != nil { + return err + } + err = c.write(uint64(len(msg))) + if err != nil { + return err + } + err = c.write(msg) + if err != nil { + return err + } + return c.flush() +} + +func (c *Client) Nop() error { + return c.writeFlush(CmdNop) +} + +func (c *Client) Bye() error { + return c.writeFlush(CmdBye) +} + +func (c *Client) Nil() error { + return c.writeFlush(CmdNil) +} diff --git a/protocol/client_recv.go b/protocol/client_recv.go new file mode 100644 index 0000000..eb406c2 --- /dev/null +++ b/protocol/client_recv.go @@ -0,0 +1,66 @@ +package protocol + +import ( + "bytes" + "fmt" + "io" +) + +func (c *Client) Receive() (Message, error) { + var cmd [3]byte + _, err := io.ReadFull(c.rx, cmd[:]) + if err != nil { + return nil, err + } + typ := Command(cmd[:]) + switch typ { + case CmdMsg: + var length uint64 + err = c.read(&length) + if err != nil { + return nil, err + } + data := make([]byte, length) + _, err = io.ReadFull(c.rx, data) + if err != nil { + return nil, err + } + return commandMsg{ + data: data, + }, nil + case CmdNop: + fallthrough + case CmdBye: + fallthrough + case CmdNil: + return trivialCommand(typ), nil + default: + return nil, fmt.Errorf("unknown command from server: %v", string(cmd[:])) + } +} + +type Message interface { + Command() Command + // Consume calls messageConsumer is this message contains a data message to the upper level. + Consume(messageConsumer func(reader io.Reader, length int64)) +} + +type trivialCommand Command + +func (t trivialCommand) Consume(func(reader io.Reader, length int64)) {} + +func (t trivialCommand) Command() Command { + return Command(t) +} + +type commandMsg struct { + data []byte +} + +func (c commandMsg) Consume(messageConsumer func(reader io.Reader, length int64)) { + messageConsumer(bytes.NewReader(c.data), int64(len(c.data))) +} + +func (c commandMsg) Command() Command { + return CmdMsg +} |