From 29f3ff3ea13bc685403506f9c94a0946932ba836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?T=C3=B3th=204?= <toth4@sch.bme.hu> Date: Sun, 9 Feb 2025 00:46:46 +0100 Subject: [PATCH] engineio server v1 --- src/engine.go | 454 ------------------------------- src/engineio/client.go | 407 +++++++++++++++++++++++++++ src/engineio/engineio.go | 168 ++++++++++++ src/engineio/server.go | 558 ++++++++++++++++++++++++++++++++++++++ src/go.mod | 7 +- src/go.sum | 11 + src/main.go | 33 ++- src/test/engineio_test.go | 231 ++++++++++++++++ 8 files changed, 1400 insertions(+), 469 deletions(-) delete mode 100644 src/engine.go create mode 100644 src/engineio/client.go create mode 100644 src/engineio/engineio.go create mode 100644 src/engineio/server.go create mode 100644 src/test/engineio_test.go diff --git a/src/engine.go b/src/engine.go deleted file mode 100644 index d1c35f2..0000000 --- a/src/engine.go +++ /dev/null @@ -1,454 +0,0 @@ -package main - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "errors" - "git.sch.bme.hu/yass/utils" - "github.com/gorilla/websocket" - "log/slog" - "net/http" - "strings" - "sync" - "time" -) - -type IConnection interface { - GetEIO() int // Get the version of the connection - GetSid() string // Get the session ID of the connection - GetMessageChannels() (inbound chan IEngineData, outbound chan IEngineData) // Get the message channel of the connection - Close(ctx context.Context) error // Closes the connection - Done() <-chan struct{} // Returns a channel that closes when the connection is closed -} - -type OpenResponse struct { - Sid string `json:"sid"` - Upgrades []string `json:"upgrades"` - PingInterval int `json:"pingInterval"` - PingTimeout int `json:"pingTimeout"` - MaxPayload int `json:"maxPayload"` -} - -// EngineIOFrame types - -type EngineIOFrameType int - -const ( - OPEN EngineIOFrameType = iota - CLOSE - PING - PONG - MESSAGE - UPGRADE - NOOP -) - -// This could be simplified by having a UTF8 and a Bindary that is the child of it -type IEngineData interface { - GetBindata() ([]byte, error) - GetUTF8() string - IsEmpty() bool -} - -type EngineUTF8Data struct { - Data string -} - -func (e *EngineUTF8Data) GetUTF8() string { - return e.Data -} - -func (e *EngineUTF8Data) GetBindata() ([]byte, error) { - return []byte(e.Data), errors.New("Data is not binary") -} - -func (e *EngineUTF8Data) IsEmpty() bool { - return len(e.Data) == 0 -} - -type EngineBinaryData struct { - EngineUTF8Data - Data []byte -} - -func (e *EngineBinaryData) GetBindata() ([]byte, error) { - return e.Data, nil -} - -func (e *EngineBinaryData) IsEmpty() bool { - return len(e.Data) == 0 -} - -type EngineIOFrame struct { - Type EngineIOFrameType - Data IEngineData -} - -func parsePollingFrame(data []byte) []EngineIOFrame { - frames := bytes.Split(data, []byte{0x1e}) - var result []EngineIOFrame - for _, frame := range frames { - if len(frame) == 1 { - result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), Data: &EngineUTF8Data{}}) - continue - } - if frame[1] == 'b' { - data, err := base64.StdEncoding.DecodeString(string(frame[2:])) - if err == nil { - result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), Data: &EngineBinaryData{ - Data: data, - EngineUTF8Data: EngineUTF8Data{string(frame[1:])}, - }}) - continue - } - } else { - result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), Data: &EngineUTF8Data{ - Data: string(frame[1:]), - }}) - } - } - return result -} - -func parseWebSocketFrame(data []byte) EngineIOFrame { - decoded, err := base64.StdEncoding.DecodeString(string(data[1:])) - if err == nil { - return EngineIOFrame{Type: EngineIOFrameType(data[0] - '0'), Data: &EngineBinaryData{ - Data: decoded, - EngineUTF8Data: EngineUTF8Data{string(data[1:])}, - }} - } - return EngineIOFrame{Type: EngineIOFrameType(data[0] - '0'), Data: &EngineUTF8Data{ - Data: string(data[1:]), - }} -} - -func marshalPollingFrame(frame EngineIOFrame) []byte { - if frame.Data.IsEmpty() { - return []byte{byte(frame.Type + '0')} - } - data, err := frame.Data.GetBindata() - if err == nil { - return append([]byte{byte(frame.Type + '0'), 'b'}, data...) - } - return append([]byte{byte(frame.Type + '0')}, []byte(frame.Data.GetUTF8())...) -} - -type SocketIOMessage struct { - Type int - Namespace string - Payload interface{} - Acknowledgement int -} - -// This is only a client connection for now -// The official documentation gives itself to a timer based solution instead of a global clock -type Connection struct { - //PingAt time.Time // Time at which the last ping was received - //DeadAt time.Time // Time after which the connection is considered dead - Url string // URL of the server - Sid string // Session ID - Inbound chan IEngineData // Messages received - Outbound chan IEngineData // Messages to be sent - PingInterval int // Interval at which to send ping messages milliseconds - PingTimeout int // Timeout for ping messages milliseconds - MaxPayload int // Maximum payload size, may not be used in websockets, idk - Context context.Context // Context of the connection - Cancellable context.CancelFunc // Cancel function of the context - timer time.Timer // Timer for the ping messages - send func(ctx context.Context, frame EngineIOFrame) error - receive func(ctx context.Context) (EngineIOFrame, error) -} - -func (c *Connection) Done() <-chan struct{} { - return c.Context.Done() -} - -func (c *Connection) GetMessageChannels() (inbound chan IEngineData, outbound chan IEngineData) { - return (c.Inbound), (c.Outbound) -} - -func (c *Connection) GetEIO() int { - return 4 -} -func (c *Connection) GetSid() string { - return c.Sid -} - -//func (c *Connection) send(ctx context.Context, frame EngineIOFrame) error { -// return errors.New("Not implemented") -//} -// -//func (c *Connection) receive(ctx context.Context) (EngineIOFrame, error) { -// //return EngineIOFrame{}, errors.New("Not implemented") -// return EngineIOFrame{}, nil -//} - -func (c *Connection) Close(ctx context.Context) error { - c.Cancellable() - return c.send(utils.AddAttr(ctx, slog.Attr(slog.String("SID", c.Sid))), EngineIOFrame{Type: CLOSE, Data: &EngineUTF8Data{}}) -} - -func (c *Connection) receiveLoop() { -out: - for { - select { - case <-c.Context.Done(): - break out - default: - frame, err := c.receive(c.Context) - if err != nil { - utils.WithContext(c.Context).Error("Receive failed: " + err.Error()) - continue - } - switch frame.Type { - case PING: - c.timer.Reset(time.Duration(c.PingTimeout+c.PingInterval) * time.Millisecond) - c.send(c.Context, EngineIOFrame{Type: PONG, Data: &EngineUTF8Data{}}) - case MESSAGE: - c.Inbound <- frame.Data - case NOOP: - case CLOSE: - utils.WithContext(c.Context).Warn("Connection closed by server") - c.Cancellable() - return - default: - } - } - } - close(c.Inbound) -} - -func (c *Connection) sendLoop() { -out: - for { - select { - case <-c.Context.Done(): - break out - case data := <-c.Outbound: - c.send(c.Context, EngineIOFrame{Type: MESSAGE, Data: data}) - } - } - close(c.Outbound) -} - -type PollingConnection struct { - Connection - frameBuffer []EngineIOFrame - sendMutex sync.Mutex -} - -func (p *PollingConnection) sender(ctx context.Context, frame EngineIOFrame) error { - p.sendMutex.Lock() - defer p.sendMutex.Unlock() - client := http.Client{ - Timeout: time.Duration(p.PingTimeout+p.PingInterval) * time.Millisecond, - } - req, err := http.NewRequest("POST", p.Url+"/socket.io/?EIO=4&transport=polling&sid="+p.Sid, bytes.NewReader(marshalPollingFrame(frame))) - if err != nil { - return err - } - req.Header.Set("Content-Type", "text/plain") - _, err = client.Do(req.WithContext(ctx)) - utils.WithContext(ctx).With(slog.Int("TYPE", int(frame.Type))).Info("Sending Frame: " + string(frame.Data.GetUTF8())) - _, err = http.Post(p.Url+"/socket.io/?EIO=4&transport=polling&sid="+p.Sid, "text/plain", bytes.NewReader(marshalPollingFrame(frame))) - return err -} - -func (p *PollingConnection) receiver(ctx context.Context) (EngineIOFrame, error) { - if len(p.frameBuffer) > 0 { - frame := p.frameBuffer[0] - p.frameBuffer = p.frameBuffer[1:] - return frame, nil - } - client := http.Client{ - Timeout: time.Duration(p.PingTimeout+p.PingInterval) * time.Millisecond, - } - req, err := http.NewRequest("GET", p.Url+"/socket.io/?EIO=4&transport=polling&sid="+p.Sid, nil) - if err != nil { - return EngineIOFrame{}, err - } - resp, err := client.Do(req.WithContext(ctx)) - if err != nil { - return EngineIOFrame{}, err - } - if resp.StatusCode != 200 { - utils.WithContext(ctx).Warn("Received status code: " + resp.Status) - } - defer resp.Body.Close() - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - data := buf.Bytes() - //println(string(data)) - frames := parsePollingFrame(data) - for _, frame := range frames { - utils.WithContext(ctx).With(slog.Int("TYPE", int(frame.Type))).Info("Received Frame: " + string(frame.Data.GetUTF8())) - } - p.frameBuffer = frames[1:] - return frames[0], nil -} - -func NewPollingConnection(ctx context.Context, url string) IConnection { - resp, err := http.Get(url + "/socket.io/?EIO=4&transport=polling") - if err != nil { - return nil - } - defer resp.Body.Close() - // Read all the data from the response - buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) - data := buf.Bytes() - // Parse the response - frames := parsePollingFrame(data) - print(frames) - message := frames[0].Data.GetUTF8() - var openResponse OpenResponse - json.Unmarshal([]byte(message), &openResponse) - - inbound := make(chan IEngineData) - outbound := make(chan IEngineData) - - connection := PollingConnection{Connection{ - Sid: openResponse.Sid, - Inbound: inbound, - Outbound: outbound, - PingInterval: openResponse.PingInterval, - PingTimeout: openResponse.PingTimeout, - MaxPayload: openResponse.MaxPayload, - Url: url, - }, - []EngineIOFrame{}, - sync.Mutex{}, - } - - connection.send = connection.sender - connection.receive = connection.receiver - - utils.WithContext(ctx).Info(`Creating connection with SID: ` + openResponse.Sid) - - ct, canceller := context.WithCancel(utils.AddAttr(ctx, slog.String("SID", openResponse.Sid))) - - timer := time.AfterFunc(time.Duration(connection.PingTimeout+connection.PingInterval)*time.Millisecond, func() { - select { - case <-ct.Done(): - return - default: - utils.WithContext(ct).Error("Connection timed out") - canceller() - } - }) - - connection.Context = ct - connection.Cancellable = canceller - connection.timer = *timer - - go connection.receiveLoop() - - go connection.sendLoop() - - // Set the session ID - return &connection -} - -type WebSocketConnection struct { - Connection - conn *websocket.Conn - mutex sync.Mutex -} - -// Visitor pattern be like -func marshalPollingFunction(frame EngineIOFrame) []byte { - if frame.Data.IsEmpty() { - return []byte{byte(frame.Type + '0')} - } - data, err := frame.Data.GetBindata() - if err == nil { - return append([]byte{byte(frame.Type + '0')}, data...) - } - return append([]byte{byte(frame.Type + '0')}, []byte(frame.Data.GetUTF8())...) -} - -func (w *WebSocketConnection) sender(ctx context.Context, frame EngineIOFrame) error { - w.mutex.Lock() - defer w.mutex.Unlock() - utils.WithContext(ctx).Info("Sending frame: " + frame.Data.GetUTF8()) - return w.conn.WriteMessage(websocket.TextMessage, marshalPollingFrame(frame)) -} - -func (w *WebSocketConnection) receiver(ctx context.Context) (EngineIOFrame, error) { - _, data, err := w.conn.ReadMessage() - if err == nil { - utils.WithContext(ctx).Info("Received frame: " + string(data)) - } - return parseWebSocketFrame(data), err -} - -// This is a background connection -func UpgradeConnection(ctx context.Context, connection *PollingConnection) (IConnection, error) { - connection.Cancellable() - - utils.WithContext(ctx).Info(`Upgrading connection with SID: ` + connection.Sid) - - ct, canceller := context.WithCancel(utils.AddAttr(ctx, slog.String("SID", connection.Sid))) - - timer := time.AfterFunc(time.Duration(connection.PingTimeout+connection.PingInterval)*time.Millisecond, func() { - select { - case <-ct.Done(): - return - default: - utils.WithContext(ct).Error("Connection timed out") - canceller() - } - }) - - url := strings.Replace(connection.Url, "http", "ws", 1) - - conn, piss, err := websocket.DefaultDialer.DialContext(ctx, url+"/socket.io/?EIO=4&transport=websocket&sid="+connection.Sid, nil) - if err != nil { - utils.WithContext(ctx).Error("Failed to upgrade connection: " + err.Error() + " " + piss.Status) - return nil, err - } - - sock := WebSocketConnection{ - Connection: Connection{ - Url: url, - Sid: connection.Sid, - Inbound: make(chan IEngineData), - Outbound: make(chan IEngineData), - PingInterval: connection.PingInterval, - PingTimeout: connection.PingTimeout, - MaxPayload: connection.MaxPayload, - Context: ct, - Cancellable: canceller, - timer: *timer, - }, - conn: conn, - mutex: sync.Mutex{}, - } - - sock.send = sock.sender - sock.receive = sock.receiver - - err = sock.send(ct, EngineIOFrame{Type: PING, Data: &EngineUTF8Data{"probe"}}) - if err != nil { - utils.WithContext(ct).Error("Failed to upgrade connection: " + err.Error()) - return nil, err - } - resp, err := sock.receive(ct) - if err != nil || resp.Data.GetUTF8() != "probe" { - utils.WithContext(ct).Error("Failed to upgrade connection: " + err.Error()) - return nil, err - } - err = sock.send(ct, EngineIOFrame{Type: UPGRADE, Data: &EngineUTF8Data{}}) - if err != nil { - utils.WithContext(ct).Error("Failed to upgrade connection: " + err.Error()) - return nil, err - } - - go sock.receiveLoop() - go sock.sendLoop() - - return &sock, nil -} diff --git a/src/engineio/client.go b/src/engineio/client.go new file mode 100644 index 0000000..1768719 --- /dev/null +++ b/src/engineio/client.go @@ -0,0 +1,407 @@ +package engineio + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "git.sch.bme.hu/yass/utils" + "github.com/gorilla/websocket" + "log/slog" + "net/http" + "strings" + "sync" + "time" +) + +// This is mostly used for testing +// It's not as thread safe as it should be +// but fixing it is not a priority for now + +// This is only a client connection for now +// The official documentation gives itself to a timer based solution instead of a global clock +type clientConnection struct { + url string // URL of the server + sid string // Session ID + inbound chan IEngineData // Messages received + outbound chan IEngineData // Messages to be sent + pingInterval int // Interval at which to send ping messages milliseconds + pingTimeout int // Timeout for ping messages milliseconds + maxPayload int // Maximum payload size, may not be used in websockets, idk + context context.Context /* context of the connection, used for cancellation for both the + user facing channels, and the goroutines handling the send, receive and ping timer loops */ + cancellable context.CancelFunc // Cancel function of the context + timer *time.Timer // Timer for the ping messages + send func(ctx context.Context, frame EngineIOFrame) error // function pointer fuckery because go doesn't have virtual functions + receive func(ctx context.Context) (EngineIOFrame, error) // function pointer fuckery because go doesn't have virtual functions + close func(ctx context.Context) error // function pointer fuckery because go doesn't have virtual functions +} + +func (c *clientConnection) Done() <-chan struct{} { + return c.context.Done() +} + +func (c *clientConnection) Inbound() <-chan IEngineData { + return c.inbound +} + +func (c *clientConnection) Outbound() chan<- IEngineData { + return c.outbound +} + +func (c *clientConnection) GetEIO() int { + return 4 +} +func (c *clientConnection) GetSid() string { + return c.sid +} + +func (c *clientConnection) Send(data IEngineData) error { + select { + case <-c.context.Done(): + return errors.New("clientConnection closed") + default: + c.outbound <- data + return nil + } +} + +// This consumes the message, and cannot be used with the getchannels method concurrently, either use receives, or channels +func (c *clientConnection) Receive() (IEngineData, error) { + select { + case <-c.context.Done(): + return nil, errors.New("clientConnection closed") + case data := <-c.inbound: + return data, nil + } +} + +func (c *clientConnection) Close(ctx context.Context) error { + c.cancellable() + return c.send(utils.AddAttr(ctx, slog.Attr(slog.String("SID", c.sid))), EngineIOFrame{Type: CLOSE, IEngineData: &EngineUTF8Data{}}) +} + +// Receive loop shared by both polling and websocket connections +func (c *clientConnection) receiveLoop() { +out: + for { + select { + case <-c.context.Done(): + break out + default: + frame, err := c.receive(c.context) + if err != nil { + utils.WithContext(c.context).Error("Receive failed: " + err.Error()) + // Backoff time + time.Sleep(1 * time.Second) + continue + } + switch frame.Type { + case PING: + //c.timer.Stop() + if !c.timer.Stop() { + <-c.timer.C + } + c.timer.Reset(time.Duration(c.pingTimeout+c.pingInterval) * time.Millisecond) + c.send(c.context, EngineIOFrame{Type: PONG, IEngineData: &EngineUTF8Data{}}) + case MESSAGE: + c.inbound <- frame.IEngineData + case NOOP: + case CLOSE: + utils.WithContext(c.context).Warn("clientConnection closed by server") + c.cancellable() + return + default: + } + } + } + c.close(c.context) +} + +// Send loop shared by both polling and websocket connections +func (c *clientConnection) sendLoop() { +out: + for { + select { + case <-c.context.Done(): + break out + case data := <-c.outbound: + c.send(c.context, EngineIOFrame{Type: MESSAGE, IEngineData: data}) + } + } + c.close(c.context) +} + +type pollingClientConnection struct { + clientConnection + frameBuffer []EngineIOFrame /* Because of generalization, the read function only reads one frame at a time + But the server can send multiple frames at once, so a buffer is needed */ + sendMutex sync.Mutex // Mutex for sending polling messages because only one POST request can be sent at a time +} + +// The send function for polling clients +func (p *pollingClientConnection) sender(ctx context.Context, frame EngineIOFrame) error { + p.sendMutex.Lock() + defer p.sendMutex.Unlock() + client := http.Client{ + Timeout: time.Duration(p.pingTimeout+p.pingInterval) * time.Millisecond, + } + req, err := http.NewRequest("POST", p.url+"?EIO=4&transport=polling&sid="+p.sid, bytes.NewReader(marshalPollingFrame(frame))) + if err != nil { + return err + } + req.Header.Set("Content-Type", "text/plain") + _, err = client.Do(req.WithContext(ctx)) + utils.WithContext(ctx).With(slog.String("TYPE", frame.Type.String())).Info("Sending Frame: " + string(marshalPollingFrame(frame))) + _, err = http.Post(p.url+"?EIO=4&transport=polling&sid="+p.sid, "text/plain", bytes.NewReader(marshalPollingFrame(frame))) + return err +} + +// The receive function for polling clients +func (p *pollingClientConnection) receiver(ctx context.Context) (EngineIOFrame, error) { + if len(p.frameBuffer) > 0 { + frame := p.frameBuffer[0] + p.frameBuffer = p.frameBuffer[1:] + return frame, nil + } + client := http.Client{ + Timeout: time.Duration(p.pingTimeout+p.pingInterval) * time.Millisecond, + } + req, err := http.NewRequest("GET", p.url+"?EIO=4&transport=polling&sid="+p.sid, nil) + if err != nil { + return EngineIOFrame{}, err + } + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return EngineIOFrame{}, err + } + if resp.StatusCode != 200 { + utils.WithContext(ctx).Warn("Received status code: " + resp.Status) + } + defer resp.Body.Close() + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + data := buf.Bytes() + //println(string(data)) + frames := parsePollingFrame(data) + for _, frame := range frames { + utils.WithContext(ctx).With(slog.String("TYPE", frame.Type.String())).Info("Received Frame: " + string(frame.GetUTF8())) + } + p.frameBuffer = frames[1:] + return frames[0], nil +} + +func NewPollingConnection(ctx context.Context, url string) (IConnection, error) { + resp, err := http.Get(url + "?EIO=4&transport=polling") + if err != nil { + return nil, err + } + defer resp.Body.Close() + // Read all the data from the response + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + data := buf.Bytes() + + // Parse the response + frames := parsePollingFrame(data) + print(frames) + message := frames[0].GetUTF8() + var openResponse OpenResponse + json.Unmarshal([]byte(message), &openResponse) + + inbound := make(chan IEngineData) + outbound := make(chan IEngineData) + + connection := pollingClientConnection{clientConnection{ + sid: openResponse.Sid, + inbound: inbound, + outbound: outbound, + pingInterval: openResponse.PingInterval, + pingTimeout: openResponse.PingTimeout, + maxPayload: openResponse.MaxPayload, + url: url, + }, + []EngineIOFrame{}, + sync.Mutex{}, + } + + // send, receive and close are dynamically set, because golang doesn't have virtual functions + connection.send = connection.sender + connection.receive = connection.receiver + connection.close = func(c context.Context) error { + connection.sendMutex.Lock() + select { + case <-connection.context.Done(): + default: + connection.cancellable() + if !connection.timer.Stop() { + <-connection.timer.C + } + close(connection.inbound) + close(connection.outbound) + } + connection.sendMutex.Unlock() + return nil + } + + utils.WithContext(ctx).Info(`Creating connection with SID: ` + openResponse.Sid) + + ct, canceller := context.WithCancel(utils.AddAttr(ctx, slog.String("SID", openResponse.Sid))) + + timer := time.AfterFunc(time.Duration(connection.pingTimeout+connection.pingInterval)*time.Millisecond, func() { + select { + case <-ct.Done(): + return + default: + utils.WithContext(ct).Error("clientConnection timed out") + canceller() + } + }) + + //timer.Reset(time.Duration(connection.pingTimeout+connection.pingInterval) * time.Millisecond) + + connection.context = ct + connection.cancellable = canceller + connection.timer = timer + + go connection.receiveLoop() + + go connection.sendLoop() + + // Set the session ID + return &connection, nil +} + +type webSocketClientConnection struct { + clientConnection + conn *websocket.Conn + mutex sync.Mutex // Mutex for sending websocket messages because they are not thread safe +} + +func marshalWebsocketFrame(frame EngineIOFrame) []byte { + if frame.IEngineData == nil || frame.IsEmpty() { + return []byte{byte(frame.Type + '0')} + } + data, err := frame.GetBindata() + if err == nil { + return append([]byte{byte(frame.Type + '0')}, data...) + } + return append([]byte{byte(frame.Type + '0')}, []byte(frame.GetUTF8())...) +} + +func (w *webSocketClientConnection) sender(ctx context.Context, frame EngineIOFrame) error { + w.mutex.Lock() + defer w.mutex.Unlock() + utils.WithContext(ctx).Info("Sending frame: " + string(marshalWebsocketFrame(frame))) + return w.conn.WriteMessage(websocket.TextMessage, marshalWebsocketFrame(frame)) +} + +func (w *webSocketClientConnection) receiver(ctx context.Context) (EngineIOFrame, error) { + _, data, err := w.conn.ReadMessage() + if err == nil { + utils.WithContext(ctx).Info("Received frame: " + string(data)) + } + return parseWebSocketFrame(data), err +} + +// This is a background connection + +func NewWebsocketConnection(ctx context.Context, url string) (IConnection, error) { + conn, err := NewPollingConnection(ctx, url) + if conn == nil { + return nil, err + } + return UpgradeConnection(ctx, conn.(*pollingClientConnection)) +} + +// This function deletes the connection passed to it, and returns a new connection with websocket transport +// Its basically only used to test late upgrades for the server +func UpgradeConnection(ctx context.Context, c IConnection) (IConnection, error) { + connection := c.(*pollingClientConnection) + if connection == nil { + return nil, errors.New("Invalid connection") + } + connection.cancellable() + + utils.WithContext(ctx).Info(`Upgrading connection with SID: ` + connection.sid) + + ct, canceller := context.WithCancel(utils.AddAttr(ctx, slog.String("SID", connection.sid))) + + // This supports both http and https, goodie + url := strings.Replace(connection.url, "http", "ws", 1) + + conn, piss, err := websocket.DefaultDialer.DialContext(ctx, url+"?EIO=4&transport=websocket&sid="+connection.sid, nil) + if err != nil { + utils.WithContext(ctx).Error("Failed to upgrade connection: " + err.Error() + " " + piss.Status) + return nil, err + } + + timer := time.AfterFunc(time.Duration(connection.pingTimeout+connection.pingInterval)*time.Millisecond, func() { + select { + case <-ct.Done(): + return + default: + utils.WithContext(ct).Error("clientConnection timed out") + canceller() + } + }) + + sock := webSocketClientConnection{ + clientConnection: clientConnection{ + url: url, + sid: connection.sid, + inbound: make(chan IEngineData, 3), + outbound: make(chan IEngineData, 3), + pingInterval: connection.pingInterval, + pingTimeout: connection.pingTimeout, + maxPayload: connection.maxPayload, + context: ct, + cancellable: canceller, + timer: timer, + }, + conn: conn, + mutex: sync.Mutex{}, + } + + // send, receive and close are dynamically set, because golang doesn't have virtual functions + sock.send = sock.sender + sock.receive = sock.receiver + sock.close = func(c context.Context) error { + sock.mutex.Lock() + select { + case <-sock.context.Done(): + default: + sock.cancellable() + if !sock.timer.Stop() { + <-sock.timer.C + } + close(sock.inbound) + close(sock.outbound) + sock.conn.Close() + } + sock.mutex.Unlock() + return nil + } + + // Upgrade sequence + err = sock.send(ct, EngineIOFrame{Type: PING, IEngineData: &EngineUTF8Data{"probe"}}) + if err != nil { + utils.WithContext(ct).Error("Failed to upgrade connection: " + err.Error()) + return nil, err + } + resp, err := sock.receive(ct) + if err != nil || resp.GetUTF8() != "probe" { + utils.WithContext(ct).Error("Failed to upgrade connection: " + err.Error()) + return nil, err + } + err = sock.send(ct, EngineIOFrame{Type: UPGRADE, IEngineData: &EngineUTF8Data{}}) + if err != nil { + utils.WithContext(ct).Error("Failed to upgrade connection: " + err.Error()) + return nil, err + } + + go sock.receiveLoop() + go sock.sendLoop() + + return &sock, nil +} diff --git a/src/engineio/engineio.go b/src/engineio/engineio.go new file mode 100644 index 0000000..4faba90 --- /dev/null +++ b/src/engineio/engineio.go @@ -0,0 +1,168 @@ +package engineio + +import ( + "bytes" + "context" + "encoding/base64" + "errors" +) + +// The main interface for engine.io connections +// Represents a single connection to a server +// Can be used either with the send and receive methods or with the Inbound and Outbound channels +// When used with the channels, the Done() channel can be used to detect when the connection is closed +type IConnection interface { + GetEIO() int // Get the version of the connection + GetSid() string // Get the session ID of the connection + Inbound() <-chan IEngineData // Returns a channel for incoming messages + Outbound() chan<- IEngineData // Returns a channel for outgoing messages + Close(ctx context.Context) error // Closes the connection + Done() <-chan struct{} // Returns a channel that closes when the connection is closed context.Context style + Send(data IEngineData) error // Sends a message + Receive() (IEngineData, error) // Receives a message +} + +// The json response format of an open response from the server +type OpenResponse struct { + Sid string `json:"sid"` + Upgrades []string `json:"upgrades"` + PingInterval int `json:"pingInterval"` + PingTimeout int `json:"pingTimeout"` + MaxPayload int `json:"maxPayload"` +} + +// The defined frame types for engine.io +type EngineIOFrameType int + +// generate string from EngineIOFrameType +func (e EngineIOFrameType) String() string { + switch e { + case OPEN: + return "OPEN" + case CLOSE: + return "CLOSE" + case PING: + return "PING" + case PONG: + return "PONG" + case MESSAGE: + return "MESSAGE" + case UPGRADE: + return "UPGRADE" + case NOOP: + return "NOOP" + default: + return "UNKNOWN" + } +} + +const ( + OPEN EngineIOFrameType = iota + CLOSE + PING + PONG + MESSAGE + UPGRADE + NOOP +) + +// This could be simplified by having a UTF8 and a Bindary that is the child of it +type IEngineData interface { + GetBindata() ([]byte, error) // Binary data is represented differently in polling frames, and may actually be utf8 data, because the protocol specification is not that good + GetUTF8() string // Can always be used to either get the underlying utf8 data of the frame + IsEmpty() bool +} + +type EngineUTF8Data struct { + Data string +} + +func (e *EngineUTF8Data) GetUTF8() string { + return e.Data +} + +func (e *EngineUTF8Data) GetBindata() ([]byte, error) { + return []byte(e.Data), errors.New("Data is not binary") +} + +func (e *EngineUTF8Data) IsEmpty() bool { + return len(e.Data) == 0 +} + +type EngineBinaryData struct { + EngineUTF8Data + Data []byte +} + +func (e *EngineBinaryData) GetBindata() ([]byte, error) { + return e.Data, nil +} + +func (e *EngineBinaryData) IsEmpty() bool { + return len(e.Data) == 0 +} + +type EngineIOFrame struct { + Type EngineIOFrameType + IEngineData +} + +// Multiple frames can be sent in a single polling response +func parsePollingFrame(data []byte) []EngineIOFrame { + frames := bytes.Split(data, []byte{0x1e}) + var result []EngineIOFrame + for _, frame := range frames { + if len(frame) == 1 { + result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), IEngineData: &EngineUTF8Data{}}) + continue + } + if frame[1] == 'b' { + data, err := base64.StdEncoding.DecodeString(string(frame[2:])) + if err == nil { + result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), IEngineData: &EngineBinaryData{ + Data: data, + EngineUTF8Data: EngineUTF8Data{string(frame[1:])}, + }}) + continue + } + } else { + result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), IEngineData: &EngineUTF8Data{ + Data: string(frame[1:]), + }}) + } + } + return result +} + +func parseWebSocketFrame(data []byte) EngineIOFrame { + decoded, err := base64.StdEncoding.DecodeString(string(data[1:])) + if err == nil { + return EngineIOFrame{Type: EngineIOFrameType(data[0] - '0'), IEngineData: &EngineBinaryData{ + Data: decoded, + EngineUTF8Data: EngineUTF8Data{string(data[1:])}, + }} + } + return EngineIOFrame{Type: EngineIOFrameType(data[0] - '0'), IEngineData: &EngineUTF8Data{ + Data: string(data[1:]), + }} +} + +func marshalPollingFrame(frame EngineIOFrame) []byte { + if frame.IEngineData == nil || frame.IsEmpty() { + return []byte{byte(frame.Type + '0')} + } + data, err := frame.GetBindata() + if err == nil { + return append([]byte{byte(frame.Type + '0'), 'b'}, data...) + } + return append([]byte{byte(frame.Type + '0')}, []byte(frame.GetUTF8())...) +} + +func marshalPollingFrames(frames []EngineIOFrame) []byte { + var result []byte + for _, frame := range frames { + result = append(result, marshalPollingFrame(frame)...) + result = append(result, 0x1e) + } + return result[:len(result)-1] +} diff --git a/src/engineio/server.go b/src/engineio/server.go new file mode 100644 index 0000000..40f4792 --- /dev/null +++ b/src/engineio/server.go @@ -0,0 +1,558 @@ +package engineio + +import ( + "context" + "crypto/md5" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + utils "git.sch.bme.hu/yass/utils" + "github.com/gorilla/websocket" + "io" + "log/slog" + "math/rand" + "net/http" + "sync" + "sync/atomic" + "time" +) + +type EngineIOServerOptions struct { + Websockets bool + // PingTimeout ping timeout + PingTimeout int + // PingInterval ping interval + PingInterval int + // MaxHttpBufferSize max http buffer size + MaxHttpBufferSize int + // AllowRequest allow request +} + +const ( + DEFAULT_PING_TIMEOUT = 20000 + DEFAULT_PING_INTERVAL = 25000 + DEFAULT_MAX_BUFFER = 1000000 +) + +var DEFAULT_OPTIONS = EngineIOServerOptions{ + Websockets: true, + PingTimeout: DEFAULT_PING_TIMEOUT, + PingInterval: DEFAULT_PING_INTERVAL, + MaxHttpBufferSize: DEFAULT_MAX_BUFFER, +} + +type EngineIOServer struct { + EngineIOServerOptions + connections map[string]*ServerConnection + logger slog.Logger + connectionsMutex sync.Mutex // Mutex for the connections map + stableMutex sync.Mutex // Mutex for outside reads of the connections map + /* + Ideas: + - broadcast send and receive + - per connection lambda function + */ +} + +// The serverside has a single connection class for both polling and websocket, although it may be wise to split them up +// Upgrading from polling to websocket creates a new ServerConnection object +// But the IConnection interface remains the same i.e. returns the same channels, and the same data +// So the client can upgrade the connection without the socketio server knowing +type ServerConnection struct { + Sid string + outbound chan IEngineData + inbound chan IEngineData + // Polling specific + sendChannel chan EngineIOFrame // channel for buffering messages to be send in polling mode + + channelGuard sync.WaitGroup /* Guard for the channels to ensure no writes on closed channels happen + TODO:: reads on closed channels are not yet safe */ + + ponged atomic.Bool // sets whether a pong has been received since the last check + pinged atomic.Bool // sets whether a ping timer has already been started + + Context context.Context // Context for the connection, used only for inside goroutines + Cancel context.CancelFunc // Cancel function for the context + // Polling can't have concurrent GET and POST requests for a single connection + ReceiveMutex sync.Mutex // Mutex for POST requests + SendMutex sync.Mutex // Mutex for GET requests + // Timer for the ping interval + Timer *time.Ticker // Timer for the ping interval + // Done channel for the connection + done chan struct{} // This channel signals to the outside that the connection with this SID is closed +} + +func NewEngineIOServer(options EngineIOServerOptions) *EngineIOServer { + return &EngineIOServer{ + EngineIOServerOptions: options, + connections: make(map[string]*ServerConnection), + logger: *utils.WithContext(context.Background()), + connectionsMutex: sync.Mutex{}, + stableMutex: sync.Mutex{}, + } +} + +func (s *EngineIOServer) SetLogger(logger slog.Logger) { + s.logger = logger +} + +func (s *EngineIOServer) GetConnection(sid string) IConnection { + s.stableMutex.Lock() + defer s.stableMutex.Unlock() + s.connectionsMutex.Lock() + defer s.connectionsMutex.Unlock() + return s.connections[sid] +} + +func (s *EngineIOServer) getConnection(sid string) *ServerConnection { + s.connectionsMutex.Lock() + defer s.connectionsMutex.Unlock() + return s.connections[sid] +} + +// Removes a connection from the servers connection list +// if killInterfaceChannels is false, then it will not close the Inbound and Outbound channels +// So they can be reused i.e. during an upgrade to websocket connection fully opaquely to the user +func (s *EngineIOServer) removeConnection(sid string, killInterfaceChannels bool) error { + //TODO:: Maybe return with the messages left in the SendChannel to be sent later + s.connectionsMutex.Lock() + defer s.connectionsMutex.Unlock() + connection := s.connections[sid] + if connection != nil { + select { + case <-connection.Context.Done(): + return errors.New("clientConnection already closed") + default: + } + // We cancel the context, this signals to the goroutines to stop + connection.Cancel() + //connection.ReceiveMutex.Lock() + //defer connection.ReceiveMutex.Unlock() + //connection.SendMutex.Lock() + //defer connection.SendMutex.Unlock() + connection.Timer.Stop() + select { + case <-connection.Timer.C: + default: + } + + if killInterfaceChannels { + // Signals to the outside that this connection is closed + close(connection.done) + } + + // We wait for all started channel operations to finish + connection.channelGuard.Wait() + + if killInterfaceChannels { + close(connection.inbound) + close(connection.outbound) + } + close(connection.sendChannel) + //close(connection.PongChannel) + delete(s.connections, sid) + return nil + } + return nil +} + +func (s *EngineIOServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + //s.logger.Info("Handling request") + if isWebSocketRequest(r) { + s.serveWebSocket(w, r) + } else { + s.servePolling(w, r) + } +} + +func (s *EngineIOServer) servePolling(w http.ResponseWriter, r *http.Request) { + values := r.URL.Query() + transport := values.Get("transport") + eio := values.Get("EIO") + sid := values.Get("sid") + connection := s.getConnection(sid) + + if sid != "" && (connection == nil) { + w.Write([]byte("{\"code\":1,\"message\":\"Session ID unknown\"}")) + w.WriteHeader(400) + return + } + if transport != "polling" { + w.Write([]byte("{\"code\":3,\"message\":\"Bad request\"}")) + w.WriteHeader(400) + return + } + if eio != "4" { + w.Write([]byte("{\"code\":5,\"message\":\"Unsupported protocol version\"}")) + w.WriteHeader(400) + return + } + + if sid == "" { + // IF SID is empty, create a new one + var upgrades []string = []string{} + if s.Websockets { + upgrades = append(upgrades, "websocket") + } + newSid := s.CreateNewConnection() + resp := OpenResponse{ + Sid: newSid, + Upgrades: upgrades, + PingInterval: s.PingInterval, + PingTimeout: s.PingTimeout, + MaxPayload: s.MaxHttpBufferSize, + } + + s.logger.Info("New polling connection created with SID: " + newSid) + + bytes, _ := json.Marshal(resp) + w.Write(marshalPollingFrame(EngineIOFrame{ + Type: OPEN, + IEngineData: &EngineUTF8Data{ + Data: string(bytes), + }, + })) + return + } else { + connection.channelGuard.Add(1) + defer connection.channelGuard.Done() + // if POST request + if r.Method == "POST" { + if connection.ReceiveMutex.TryLock() { + // Handle the request + bytes, err := io.ReadAll(r.Body) + if err != nil { + w.Write([]byte("{\"code\":4,\"message\":\"Bad request\"}")) + w.WriteHeader(400) + return + } + frames := parsePollingFrame(bytes) + for _, frame := range frames { + s.logger.Info("Received frame: " + frame.Type.String() + " from " + sid) + switch frame.Type { + case PONG: + connection.ponged.Store(true) + case MESSAGE: + connection.inbound <- frame.IEngineData + case CLOSE: + connection.ReceiveMutex.Unlock() + s.logger.Info("Closing connection with SID: " + sid) + connection.channelGuard.Done() + s.removeConnection(sid, true) + connection.channelGuard.Add(1) // Defer is LIFO so this is the easiest way + w.Write([]byte("ok")) + return + default: + } + } + w.Write([]byte("ok")) + connection.ReceiveMutex.Unlock() + } else { + w.Write([]byte("{\"code\":6,\"message\":\"Server overload\"}")) + w.WriteHeader(400) + connection.channelGuard.Done() + s.removeConnection(sid, true) + connection.channelGuard.Add(1) // Defer is LIFO so this is the easiest way + return + } + } else { + // if GET request + s.logger.Info("Polling connection with SID: " + sid) + sendlist := []EngineIOFrame{} + if connection.SendMutex.TryLock() { + select { + case data := <-connection.sendChannel: + // Multiple frames can be sent in a single polling response + // So we collect the buffered frames + sendlist = append(sendlist, data) + forLoop: + for { + select { + case buf := <-connection.sendChannel: + // if there is more frame immediately available, we add them to our slice + sendlist = append(sendlist, buf) + default: + // else we move on without blocking + break forLoop + } + } + case <-connection.Context.Done(): + sendlist = append(sendlist, EngineIOFrame{ + Type: NOOP, + }) + } + w.Write(marshalPollingFrames(sendlist)) + connection.SendMutex.Unlock() + return + } else { + w.Write([]byte("{\"code\":6,\"message\":\"Server overload\"}")) + w.WriteHeader(400) + connection.channelGuard.Done() + s.removeConnection(sid, true) + connection.channelGuard.Add(1) // Defer is LIFO so this is the easiest way + return + } + } + } + +} + +func (s *ServerConnection) GetEIO() int { + return 4 +} // Get the version of the connection +func (s *ServerConnection) GetSid() string { + return s.Sid +} // Get the session ID of the connection +func (s *ServerConnection) Inbound() <-chan IEngineData { + return s.inbound +} +func (s *ServerConnection) Outbound() chan<- IEngineData { + return s.outbound +} +func (s *ServerConnection) Close(ctx context.Context) error { + return errors.New("Not implemented") +} // Closes the connection, serverconnections can only be closed by the server object +func (s *ServerConnection) Done() <-chan struct{} { + return s.done +} + +func (s *ServerConnection) Send(data IEngineData) error { + select { + case <-s.done: + return errors.New("clientConnection is closed") + default: + s.outbound <- data + return nil + } +} +func (s *ServerConnection) Receive() (IEngineData, error) { + select { + case <-s.done: + return nil, errors.New("clientConnection is closed") + case data := <-s.inbound: + return data, nil + } +} + +func (s *EngineIOServer) addconnection(sid string) *ServerConnection { + s.connectionsMutex.Lock() + defer s.connectionsMutex.Unlock() + if s.connections[sid] != nil { + errors.New("clientConnection already exists") + } + ctx, cancel := context.WithCancel(context.Background()) + timer := time.NewTicker(time.Duration(s.PingInterval) * time.Millisecond) + newConnection := &ServerConnection{ + Sid: sid, + outbound: make(chan IEngineData, 3), + inbound: make(chan IEngineData, 3), + sendChannel: make(chan EngineIOFrame, 5), + pinged: atomic.Bool{}, + ponged: atomic.Bool{}, + done: make(chan struct{}), + Context: ctx, + Cancel: cancel, + Timer: timer, + } + s.connections[sid] = newConnection + return s.connections[sid] +} + +func (s *EngineIOServer) CreateNewConnection() string { + // Generate a random SID + hash := md5.New() + hash.Write([]byte("salt")) + // turn a random float into a byte slice + // and hash it + var buf = make([]byte, 8) + binary.LittleEndian.PutUint64(buf[:], rand.Uint64()) + sid := hash.Sum(buf) + sidString := base64.RawURLEncoding.EncodeToString(sid) + //ctx, cancel := context.WithCancel(context.Background()) + // + //timer := time.NewTicker(time.Duration(s.pingInterval) * time.Millisecond) + //timeoutTimer := time.NewTimer(time.Duration(s.pingTimeout) * time.Millisecond).Stop() + + newConnection := s.addconnection(sidString) + + go func() { + for { + newConnection.channelGuard.Add(1) + select { + case <-newConnection.Context.Done(): + newConnection.channelGuard.Done() + return + case data := <-newConnection.outbound: + newConnection.sendChannel <- EngineIOFrame{ + Type: MESSAGE, + IEngineData: data, + } + newConnection.channelGuard.Done() + case <-newConnection.Timer.C: + s.logger.Info("Pinging " + sidString) + newConnection.sendChannel <- EngineIOFrame{ + Type: PING, + } + newConnection.channelGuard.Done() + if !newConnection.pinged.Swap(true) { + time.AfterFunc(time.Duration(s.PingTimeout)*time.Millisecond, func() { + select { + case <-newConnection.Context.Done(): + return + default: + if !newConnection.ponged.Swap(false) { + s.logger.Warn("clientConnection timed out SID: " + sidString) + s.removeConnection(sidString, true) + } + } + }) + } + } + } + }() + + return sidString +} + +func (s *EngineIOServer) serveWebSocket(w http.ResponseWriter, r *http.Request) { + values := r.URL.Query() + transport := values.Get("transport") + eio := values.Get("EIO") + sid := values.Get("sid") + connection := s.getConnection(sid) + if sid != "" && (connection == nil) { + w.Write([]byte("{\"code\":1,\"message\":\"Session ID unknown\"}")) + w.WriteHeader(400) + return + } + if transport != "websocket" { + w.Write([]byte("{\"code\":3,\"message\":\"Bad request\"}")) + w.WriteHeader(400) + return + } + if eio != "4" { + w.Write([]byte("{\"code\":5,\"message\":\"Unsupported protocol version\"}")) + w.WriteHeader(400) + return + } + if sid == "" { + // TODO:: Implement websockets without upgrade + w.Write([]byte("{\"code\":2,\"message\":\"Websocket sessions currently can only be created by upgrading polling\"}")) + w.WriteHeader(400) + return + } + // We lock the stable mutex to ensure that the connection is not Get()ed while it is nil + s.stableMutex.Lock() + defer s.stableMutex.Unlock() + s.removeConnection(sid, false) + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + logger := s.logger.With(slog.String("sid", sid)) + if err != nil { + logger.Error("Failed to upgrade connection") + return + } + _, probe, err := conn.ReadMessage() + if err != nil { + logger.Error("Failed to read probe message from websocket") + return + } + if string(probe) != "2probe" { + logger.Error("Invalid probe message") + return + } + conn.WriteMessage(websocket.TextMessage, []byte("3probe")) + _, probe, err = conn.ReadMessage() + if err != nil { + logger.Error("Failed to read upgrade message from websocket") + return + } + if string(probe) != "5" { + logger.Error("Invalid upgrade message") + return + } + + s.logger.Info("Upgraded connection with SID: " + sid) + + newconnection := s.addconnection(sid) + + // We copy the channels from the old connection to the new one + // So that the user can continue to use the same channels + newconnection.inbound = connection.inbound + newconnection.outbound = connection.outbound + newconnection.done = connection.done + + go func() { + for { + _, message, err := conn.ReadMessage() + if err != nil { + newconnection.channelGuard.Done() + s.logger.Error("Failed to read message from websocket") + return + } + newconnection.channelGuard.Add(1) + select { + case <-newconnection.Context.Done(): + return + default: + } + frame := parseWebSocketFrame(message) + s.logger.With(slog.String("SID", sid)).Info("Received frame: " + string(message)) + switch frame.Type { + case PONG: + newconnection.ponged.Store(true) + case MESSAGE: + newconnection.inbound <- frame.IEngineData + case CLOSE: + newconnection.channelGuard.Done() + s.removeConnection(sid, true) + return + default: + } + newconnection.channelGuard.Done() + } + conn.Close() + }() + + go func() { + for { + newconnection.channelGuard.Add(1) + select { + case <-newconnection.Context.Done(): + newconnection.channelGuard.Done() + return + case data := <-newconnection.outbound: + newconnection.channelGuard.Done() + s.logger.Info("Sending message to " + sid) + conn.WriteMessage(websocket.TextMessage, marshalWebsocketFrame(EngineIOFrame{ + Type: MESSAGE, + IEngineData: data, + })) + case <-newconnection.Timer.C: + newconnection.sendChannel <- EngineIOFrame{ + Type: PING, + } + s.logger.Info("Pinging " + sid) + newconnection.channelGuard.Done() + if !newconnection.pinged.Swap(true) { + time.AfterFunc(time.Duration(s.PingTimeout)*time.Millisecond, func() { + select { + case <-newconnection.Context.Done(): + return + default: + if !newconnection.ponged.Swap(false) { + s.logger.Warn("clientConnection timed out SID: " + sid) + s.removeConnection(sid, true) + } + } + }) + } + } + } + }() +} + +func isWebSocketRequest(r *http.Request) bool { + return r.Header.Get("Upgrade") == "websocket" && r.Header.Get("Connection") == "Upgrade" +} // ... diff --git a/src/go.mod b/src/go.mod index a8e31f2..7657591 100644 --- a/src/go.mod +++ b/src/go.mod @@ -2,4 +2,9 @@ module git.sch.bme.hu/yass go 1.23.5 -require github.com/gorilla/websocket v1.5.3 // indirect +require ( + github.com/alvaroloes/enumer v1.1.2 // indirect + github.com/gorilla/websocket v1.5.3 // indirect + github.com/pascaldekloe/name v0.0.0-20180628100202-0fd16699aae1 // indirect + golang.org/x/tools v0.0.0-20190524210228-3d17549cdc6b // indirect +) diff --git a/src/go.sum b/src/go.sum index 25a9fc4..7f45ce7 100644 --- a/src/go.sum +++ b/src/go.sum @@ -1,2 +1,13 @@ +github.com/alvaroloes/enumer v1.1.2 h1:5khqHB33TZy1GWCO/lZwcroBFh7u+0j40T83VUbfAMY= +github.com/alvaroloes/enumer v1.1.2/go.mod h1:FxrjvuXoDAx9isTJrv4c+T410zFi0DtXIT0m65DJ+Wo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pascaldekloe/name v0.0.0-20180628100202-0fd16699aae1 h1:/I3lTljEEDNYLho3/FUB7iD/oc2cEFgVmbHzV+O0PtU= +github.com/pascaldekloe/name v0.0.0-20180628100202-0fd16699aae1/go.mod h1:eD5JxqMiuNYyFNmyY9rkJ/slN8y59oEu4Ei7F8OoKWQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190524210228-3d17549cdc6b h1:iEAPfYPbYbxG/2lNN4cMOHkmgKNsCuUwkxlDCK46UlU= +golang.org/x/tools v0.0.0-20190524210228-3d17549cdc6b/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/src/main.go b/src/main.go index 5464011..7970c6f 100644 --- a/src/main.go +++ b/src/main.go @@ -1,18 +1,23 @@ package main -import ( - "context" -) - func main() { - connection := NewPollingConnection(context.Background(), "http://localhost:3000") - connect := connection.(*PollingConnection) - sock, err := UpgradeConnection(context.Background(), connect) - if err != nil { - panic(err) - } - in, out := sock.GetMessageChannels() - data := EngineUTF8Data{Data: "0"} - out <- &data - <-in + //connection := engineio.NewPollingConnection(context.Background(), "http://localhost:3000/socket.io/") + //connect := connection.(*engineio.PollingConnection) + //sock, err := engineio.UpgradeConnection(context.Background(), connect) + //if err != nil { + // panic(err) + //} + //in, out := sock.GetMessageChannels() + //data := engineio.EngineUTF8Data{Data: "0"} + //out <- &data + //<-in + //out <- &engineio.EngineUTF8Data{Data: "2[\"chat message\", \"Hello\"]"} + //<-in + //time.Sleep(60 * time.Second) + + channel := make(chan int) + channel <- 1 + <-channel + close(channel) + <-channel } diff --git a/src/test/engineio_test.go b/src/test/engineio_test.go new file mode 100644 index 0000000..0e53926 --- /dev/null +++ b/src/test/engineio_test.go @@ -0,0 +1,231 @@ +package test + +import ( + "context" + "git.sch.bme.hu/yass/utils" + "log/slog" + "net/http/httptest" + "testing" + "time" +) +import "git.sch.bme.hu/yass/engineio" + +func TestCreateServerConnection(t *testing.T) { + // Create a new server connection + server := engineio.NewEngineIOServer(engineio.DEFAULT_OPTIONS) + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testserver := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testserver.URL) + if connection == nil { + t.Error("clientConnection is nil") + } +} + +func TestPollingPing(t *testing.T) { + options := engineio.EngineIOServerOptions{ + Websockets: true, + PingTimeout: 1000, + PingInterval: 3000, + MaxHttpBufferSize: 1000000, + } + server := engineio.NewEngineIOServer(options) + + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testServer := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testServer.URL) + + time.Sleep(10 * time.Second) + select { + case <-connection.Done(): + t.Error("clientConnection is timed out") + default: + break + } + +} + +func TestPollingConnection(t *testing.T) { + options := engineio.EngineIOServerOptions{ + Websockets: true, + PingTimeout: 25000, + PingInterval: 3000, + MaxHttpBufferSize: 1000000, + } + server := engineio.NewEngineIOServer(options) + + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testServer := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testServer.URL) + + //// Upgrade the connection + //_, err := engineio.UpgradeConnection(context.Background(), connection.(*engineio.pollingClientConnection)) + //if err != nil { + // t.Error("Failed to upgrade connection") + //} + + connection.Send(&engineio.EngineUTF8Data{Data: "2[\"chat message\", \"Hello\"]"}) + + message := <-server.GetConnection(connection.GetSid()).Inbound() + if message.GetUTF8() != "2[\"chat message\", \"Hello\"]" { + t.Error("Message is not correct") + } + + server.GetConnection(connection.GetSid()).Outbound() <- &engineio.EngineUTF8Data{Data: "2[\"chat message\", \"World\"]"} + message = <-connection.Inbound() + if message.GetUTF8() != "2[\"chat message\", \"World\"]" { + t.Error("Message is not correct") + } +} + +func TestConnectionClose(t *testing.T) { + options := engineio.EngineIOServerOptions{ + Websockets: true, + PingTimeout: 1000, + PingInterval: 1000, + MaxHttpBufferSize: 1000000, + } + server := engineio.NewEngineIOServer(options) + + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testServer := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testServer.URL) + + // Close the connection + connection.Close(context.Background()) + + time.Sleep(3 * time.Second) + + connection, _ = engineio.NewPollingConnection(context.Background(), testServer.URL) + + testServer.Close() + + time.Sleep(3 * time.Second) +} + +func TestPollingConnectionUpgrade(t *testing.T) { + options := engineio.EngineIOServerOptions{ + Websockets: true, + PingTimeout: 25000, + PingInterval: 20000, + MaxHttpBufferSize: 1000000, + } + server := engineio.NewEngineIOServer(options) + + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testServer := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testServer.URL) + + // Upgrade the connection + newconnection, err := engineio.UpgradeConnection(context.Background(), connection) + if err != nil { + t.Error("Failed to upgrade connection") + } + + newconnection.Send(&engineio.EngineUTF8Data{Data: "2[\"chat message\", \"Hello\"]"}) + + serverConn := server.GetConnection(newconnection.GetSid()) + message := <-serverConn.Inbound() + if message.GetUTF8() != "2[\"chat message\", \"Hello\"]" { + t.Error("Message is not correct") + } + + serverConn.Outbound() <- &engineio.EngineUTF8Data{Data: "2[\"chat message\", \"World\"]"} + message = <-newconnection.Inbound() + if message.GetUTF8() != "2[\"chat message\", \"World\"]" { + t.Error("Message is not correct") + } +} + +func TestWebsocketClose(t *testing.T) { + options := engineio.EngineIOServerOptions{ + Websockets: true, + PingTimeout: 1000, + PingInterval: 1000, + MaxHttpBufferSize: 1000000, + } + server := engineio.NewEngineIOServer(options) + + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testServer := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testServer.URL) + + conn, err := engineio.UpgradeConnection(context.Background(), connection) + + conn.Close(context.Background()) + + connection, _ = engineio.NewPollingConnection(context.Background(), testServer.URL) + + // Upgrade the connection + _, err = engineio.UpgradeConnection(context.Background(), connection) + if err != nil { + t.Error("Failed to upgrade connection") + } + + testServer.Close() + + time.Sleep(3 * time.Second) +} + +func TestInvisibleUpgrade(t *testing.T) { + options := engineio.EngineIOServerOptions{ + Websockets: true, + PingTimeout: 25000, + PingInterval: 20000, + MaxHttpBufferSize: 1000000, + } + server := engineio.NewEngineIOServer(options) + + server.SetLogger(*(utils.WithContext(context.Background()).With(slog.String("context", "server")))) + + testServer := httptest.NewServer(server) + + connection, _ := engineio.NewPollingConnection(context.Background(), testServer.URL) + + conn := server.GetConnection(connection.GetSid()) + + go func() { + for { + select { + case <-conn.Done(): + return + case data := <-conn.Inbound(): + if data.GetUTF8() == "hello" { + conn.Outbound() <- &engineio.EngineUTF8Data{Data: "world"} + } else { + t.Error("Message is not correct") + } + case <-time.After(3 * time.Second): + t.Error("Timeout") + return + } + } + }() + + // Upgrade the connection + ws, err := engineio.UpgradeConnection(context.Background(), connection) + if err != nil { + t.Error("Failed to upgrade connection") + } + + ws.Send(&engineio.EngineUTF8Data{Data: "hello"}) + data, _ := ws.Receive() + + if data.GetUTF8() != "world" { + t.Error("Message is not correct") + } + + ws.Close(context.Background()) + time.Sleep(5 * time.Second) +} -- GitLab