Skip to content
Snippets Groups Projects
Select Git revision
  • 11b47f5ad39dc0148d4d39f6be201b339654982e
  • main default protected
2 results

engine.go

Blame
  • engine.go 10.96 KiB
    package main
    
    import (
    	"bytes"
    	"context"
    	"encoding/base64"
    	"encoding/json"
    	"errors"
    	"git.sch.bme.hu/yass/utils"
    	"github.com/gorilla/websocket"
    	"log/slog"
    	"net/http"
    	"strings"
    	"sync"
    	"time"
    )
    
    type IConnection interface {
    	GetEIO() int                                                               // Get the version of the connection
    	GetSid() string                                                            // Get the session ID of the connection
    	GetMessageChannels() (inbound chan IEngineData, outbound chan IEngineData) // Get the message channel of the connection
    	Close(ctx context.Context) error                                           // Closes the connection
    	Done() <-chan struct{}                                                     // Returns a channel that closes when the connection is closed
    }
    
    type OpenResponse struct {
    	Sid          string   `json:"sid"`
    	Upgrades     []string `json:"upgrades"`
    	PingInterval int      `json:"pingInterval"`
    	PingTimeout  int      `json:"pingTimeout"`
    	MaxPayload   int      `json:"maxPayload"`
    }
    
    // EngineIOFrame types
    
    type EngineIOFrameType int
    
    const (
    	OPEN EngineIOFrameType = iota
    	CLOSE
    	PING
    	PONG
    	MESSAGE
    	UPGRADE
    	NOOP
    )
    
    // This could be simplified by having a UTF8 and a Bindary that is the child of it
    type IEngineData interface {
    	GetBindata() ([]byte, error)
    	GetUTF8() string
    	IsEmpty() bool
    }
    
    type EngineUTF8Data struct {
    	Data string
    }
    
    func (e *EngineUTF8Data) GetUTF8() string {
    	return e.Data
    }
    
    func (e *EngineUTF8Data) GetBindata() ([]byte, error) {
    	return []byte(e.Data), nil
    }
    
    func (e *EngineUTF8Data) IsEmpty() bool {
    	return len(e.Data) == 0
    }
    
    type EngineBinaryData struct {
    	EngineUTF8Data
    	Data []byte
    }
    
    func (e *EngineBinaryData) GetBindata() ([]byte, error) {
    	return e.Data, nil
    }
    
    func (e *EngineBinaryData) IsEmpty() bool {
    	return len(e.Data) == 0
    }
    
    type EngineIOFrame struct {
    	Type EngineIOFrameType
    	Data IEngineData
    }
    
    func parsePollingFrame(data []byte) []EngineIOFrame {
    	frames := bytes.Split(data, []byte{0x1e})
    	var result []EngineIOFrame
    	for _, frame := range frames {
    		if len(frame) == 1 {
    			result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), Data: &EngineUTF8Data{}})
    			continue
    		}
    		if frame[1] == 'b' {
    			data, err := base64.StdEncoding.DecodeString(string(frame[2:]))
    			if err == nil {
    				result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), Data: &EngineBinaryData{
    					Data:           data,
    					EngineUTF8Data: EngineUTF8Data{string(frame[1:])},
    				}})
    				continue
    			}
    		} else {
    			result = append(result, EngineIOFrame{Type: EngineIOFrameType(frame[0] - '0'), Data: &EngineUTF8Data{
    				Data: string(frame[1:]),
    			}})
    		}
    	}
    	return result
    }
    
    func parseWebSocketFrame(data []byte) EngineIOFrame {
    	decoded, err := base64.StdEncoding.DecodeString(string(data[1:]))
    	if err == nil {
    		return EngineIOFrame{Type: EngineIOFrameType(data[0] - '0'), Data: &EngineBinaryData{
    			Data:           decoded,
    			EngineUTF8Data: EngineUTF8Data{string(data[1:])},
    		}}
    	}
    	return EngineIOFrame{Type: EngineIOFrameType(data[0] - '0'), Data: &EngineUTF8Data{
    		Data: string(data[1:]),
    	}}
    }
    
    func marshalPollingFrame(frame EngineIOFrame) []byte {
    	if frame.Data.IsEmpty() {
    		return []byte{byte(frame.Type + '0')}
    	}
    	data, err := frame.Data.GetBindata()
    	if err == nil {
    		return append([]byte{byte(frame.Type + '0'), 'b'}, data...)
    	}
    	return append([]byte{byte(frame.Type + '0')}, []byte(frame.Data.GetUTF8())...)
    }
    
    type SocketIOMessage struct {
    	Type            int
    	Namespace       string
    	Payload         interface{}
    	Acknowledgement int
    }
    
    // This is only a client connection for now
    // The official documentation gives itself to a timer based solution instead of a global clock
    type Connection struct {
    	//PingAt         time.Time        // Time at which the last ping was received
    	//DeadAt       time.Time        // Time after which the connection is considered dead
    	Url          string             // URL of the server
    	Sid          string             // Session ID
    	Inbound      chan IEngineData   // Messages received
    	Outbound     chan IEngineData   // Messages to be sent
    	PingInterval int                // Interval at which to send ping messages milliseconds
    	PingTimeout  int                // Timeout for ping messages milliseconds
    	MaxPayload   int                // Maximum payload size, may not be used in websockets, idk
    	Context      context.Context    // Context of the connection
    	Cancellable  context.CancelFunc // Cancel function of the context
    }
    
    func (c *Connection) Done() <-chan struct{} {
    	return c.Context.Done()
    }
    
    func (c *Connection) GetMessageChannels() (inbound chan IEngineData, outbound chan IEngineData) {
    	return (c.Inbound), (c.Outbound)
    }
    
    func (c *Connection) GetEIO() int {
    	return 4
    }
    func (c *Connection) GetSid() string {
    	return c.Sid
    }
    
    func (c *Connection) send(ctx context.Context, frame EngineIOFrame) error {
    	return errors.New("Not implemented")
    }
    
    func (c *Connection) receive(ctx context.Context) (EngineIOFrame, error) {
    	return EngineIOFrame{}, errors.New("Not implemented")
    }
    
    func (c *Connection) Close(ctx context.Context) error {
    	c.Cancellable()
    	return c.send(utils.AddAttr(ctx, slog.Attr(slog.String("SID", c.Sid))), EngineIOFrame{Type: CLOSE, Data: &EngineUTF8Data{}})
    }
    
    type PollingConnection struct {
    	Connection
    	frameBuffer []EngineIOFrame
    	sendMutex   sync.Mutex
    }
    
    func (p *PollingConnection) send(ctx context.Context, frame EngineIOFrame) error {
    	p.sendMutex.Lock()
    	defer p.sendMutex.Unlock()
    	client := http.Client{
    		Timeout: time.Duration(p.PingTimeout+p.PingInterval) * time.Millisecond,
    	}
    	req, err := http.NewRequest("POST", p.Url+"/socket.io/?EIO=4&transport=polling&sid="+p.Sid, bytes.NewReader(marshalPollingFrame(frame)))
    	if err != nil {
    		return err
    	}
    	req.Header.Set("Content-Type", "text/plain")
    	_, err = client.Do(req.WithContext(ctx))
    	utils.WithContext(ctx).With(slog.Int("TYPE", int(frame.Type))).Info("Sending message: " + string(frame.Data.GetUTF8()))
    	_, err = http.Post(p.Url+"/socket.io/?EIO=4&transport=polling&sid="+p.Sid, "text/plain", bytes.NewReader(marshalPollingFrame(frame)))
    	return err
    }
    
    func (p *PollingConnection) receive(ctx context.Context) (EngineIOFrame, error) {
    	if len(p.frameBuffer) > 0 {
    		frame := p.frameBuffer[0]
    		p.frameBuffer = p.frameBuffer[1:]
    		return frame, nil
    	}
    	client := http.Client{
    		Timeout: time.Duration(p.PingTimeout+p.PingInterval) * time.Millisecond,
    	}
    	req, err := http.NewRequest("GET", p.Url+"/socket.io/?EIO=4&transport=polling&sid="+p.Sid, nil)
    	if err != nil {
    		return EngineIOFrame{}, err
    	}
    	resp, err := client.Do(req.WithContext(ctx))
    	if err != nil {
    		return EngineIOFrame{}, err
    	}
    	if resp.StatusCode != 200 {
    		utils.WithContext(ctx).Warn("Received status code: " + resp.Status)
    	}
    	defer resp.Body.Close()
    	buf := new(bytes.Buffer)
    	buf.ReadFrom(resp.Body)
    	data := buf.Bytes()
    	//println(string(data))
    	frames := parsePollingFrame(data)
    	for _, frame := range frames {
    		utils.WithContext(ctx).With(slog.Int("TYPE", int(frame.Type))).Info("Message received: " + string(frame.Data.GetUTF8()))
    	}
    	p.frameBuffer = frames[1:]
    	return frames[0], nil
    }
    
    func NewConnection(ctx context.Context, url string) IConnection {
    	resp, err := http.Get(url + "/socket.io/?EIO=4&transport=polling")
    	if err != nil {
    		return nil
    	}
    	defer resp.Body.Close()
    	// Read all the data from the response
    	buf := new(bytes.Buffer)
    	buf.ReadFrom(resp.Body)
    	data := buf.Bytes()
    	// Parse the response
    	frames := parsePollingFrame(data)
    	print(frames)
    	message := frames[0].Data.GetUTF8()
    	var openResponse OpenResponse
    	json.Unmarshal([]byte(message), &openResponse)
    
    	inbound := make(chan IEngineData)
    	outbound := make(chan IEngineData)
    
    	connection := PollingConnection{Connection{
    		Sid:          openResponse.Sid,
    		Inbound:      inbound,
    		Outbound:     outbound,
    		PingInterval: openResponse.PingInterval,
    		PingTimeout:  openResponse.PingTimeout,
    		MaxPayload:   openResponse.MaxPayload,
    		Url:          url,
    	},
    		[]EngineIOFrame{},
    		sync.Mutex{},
    	}
    
    	utils.WithContext(ctx).Info(`Creating connection with SID: ` + openResponse.Sid)
    
    	ct, canceller := context.WithCancel(utils.AddAttr(ctx, slog.String("SID", openResponse.Sid)))
    
    	timer := time.AfterFunc(time.Duration(connection.PingTimeout+connection.PingInterval)*time.Millisecond, func() {
    		select {
    		case <-ct.Done():
    			return
    		default:
    			utils.WithContext(ct).Error("Connection timed out")
    			canceller()
    		}
    	})
    
    	connection.Context = ct
    	connection.Cancellable = canceller
    
    	go func() {
    	out:
    		for {
    			frame, err := connection.receive(ct)
    			if err != nil {
    				utils.WithContext(ct).Error("Receive failed: " + err.Error())
    			} else {
    				switch frame.Type {
    				case PING:
    					timer.Reset(time.Duration(connection.PingTimeout+connection.PingInterval) * time.Millisecond)
    					connection.send(ct, EngineIOFrame{Type: PONG, Data: &EngineUTF8Data{}})
    				case MESSAGE:
    					connection.Inbound <- frame.Data
    				case NOOP:
    				case CLOSE:
    					utils.WithContext(ct).Warn("Connection closed by server")
    					canceller()
    					break out
    				case 75:
    					utils.WithContext(ct).Error("Connection Died")
    					break out
    				default:
    				}
    			}
    			select {
    			case <-ct.Done():
    				break out
    			default:
    				continue
    			}
    		}
    		close(connection.Inbound)
    	}()
    
    	go func() {
    	out:
    		for {
    			select {
    			case <-ct.Done():
    				break out
    			case data := <-connection.Outbound:
    				connection.send(ct, EngineIOFrame{Type: MESSAGE, Data: data})
    			}
    		}
    		close(connection.Outbound)
    	}()
    
    	// Set the session ID
    	return &connection
    }
    
    type WebSocketConnection struct {
    	Connection
    	conn  *websocket.Conn
    	mutex sync.Mutex
    }
    
    // Visitor pattern be like
    func marshalPollingFunction(frame EngineIOFrame) []byte {
    	if frame.Data.IsEmpty() {
    		return []byte{byte(frame.Type + '0')}
    	}
    	data, err := frame.Data.GetBindata()
    	if err == nil {
    		return append([]byte{byte(frame.Type + '0')}, data...)
    	}
    	return append([]byte{byte(frame.Type + '0')}, []byte(frame.Data.GetUTF8())...)
    }
    
    func (w *WebSocketConnection) send(ctx context.Context, frame EngineIOFrame) error {
    	w.mutex.Lock()
    	defer w.mutex.Unlock()
    	utils.WithContext(ctx).Info("Sending frame: " + frame.Data.GetUTF8())
    	return w.conn.WriteMessage(websocket.TextMessage, marshalPollingFrame(frame))
    }
    
    func UpgradeConnection(ctx context.Context, connection *PollingConnection) (*IConnection, error) {
    	connection.Cancellable()
    	url := strings.Replace(connection.Url, "http", "ws", 1)
    	conn, _, err := websocket.DefaultDialer.DialContext(ctx, url+"/socket.io/?EIO=4&transport=websocket&sid="+connection.Sid, nil)
    	if err != nil {
    		utils.WithContext(ctx).Error("Failed to upgrade connection: " + err.Error())
    		return nil, err
    	}
    	conn.WriteMessage(websocket.TextMessage, []byte("2probe"))
    	_, data, err := conn.ReadMessage()
    	if err != nil {
    		utils.WithContext(ctx).Error("Failed to upgrade connection: " + err.Error())
    		return nil, err
    	}
    	if string(data) != "3probe" {
    		utils.WithContext(ctx).Error("Failed to upgrade connection: Invalid response")
    		return nil, errors.New("Invalid response")
    	}
    	conn.WriteMessage(websocket.TextMessage, []byte("5"))
    	return nil, nil
    }