feat: add thread safety to web socket list

This commit is contained in:
Fritz Heiden 2025-04-06 18:47:51 +02:00
parent 4e639fb387
commit 6d2fca696e

View File

@ -2,6 +2,7 @@ package server
import ( import (
"encoding/json" "encoding/json"
"sync"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -14,12 +15,14 @@ type WebsocketHandler interface {
} }
type WebsocketServer struct { type WebsocketServer struct {
router *echo.Echo router *echo.Echo
sockets map[string]*websocket.Conn sockets map[string]*websocket.Conn
handlers []WebsocketHandler handlers []WebsocketHandler
socketsMutex *sync.Mutex
} }
func (s *WebsocketServer) Initialize(authenticator *Authenticator) { func (s *WebsocketServer) Initialize(authenticator *Authenticator) {
s.socketsMutex = &sync.Mutex{}
s.sockets = make(map[string]*websocket.Conn) s.sockets = make(map[string]*websocket.Conn)
s.router.Use(authenticator.Authenticate("/ws", []string{})) s.router.Use(authenticator.Authenticate("/ws", []string{}))
s.router.GET("/ws", s.handle) s.router.GET("/ws", s.handle)
@ -29,13 +32,17 @@ func (s *WebsocketServer) handle(context echo.Context) error {
authContext := context.(AuthContext) authContext := context.(AuthContext)
senderId := getAuthenticatedId(authContext) senderId := getAuthenticatedId(authContext)
ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil) ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil)
s.socketsMutex.Lock()
s.sockets[senderId] = ws s.sockets[senderId] = ws
s.socketsMutex.Unlock()
if err != nil { if err != nil {
return err return err
} }
defer func() { defer func() {
ws.Close() ws.Close()
s.socketsMutex.Lock()
delete(s.sockets, senderId) delete(s.sockets, senderId)
s.socketsMutex.Unlock()
}() }()
for { for {