From 6d2fca696e40ca9a99be99a2e748d7cd8596d3c8 Mon Sep 17 00:00:00 2001 From: Fritz Heiden Date: Sun, 6 Apr 2025 18:47:51 +0200 Subject: [PATCH] feat: add thread safety to web socket list --- server/websocket_server.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/server/websocket_server.go b/server/websocket_server.go index 93bbc79..f2dd4ce 100644 --- a/server/websocket_server.go +++ b/server/websocket_server.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "sync" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" @@ -14,12 +15,14 @@ type WebsocketHandler interface { } type WebsocketServer struct { - router *echo.Echo - sockets map[string]*websocket.Conn - handlers []WebsocketHandler + router *echo.Echo + sockets map[string]*websocket.Conn + handlers []WebsocketHandler + socketsMutex *sync.Mutex } func (s *WebsocketServer) Initialize(authenticator *Authenticator) { + s.socketsMutex = &sync.Mutex{} s.sockets = make(map[string]*websocket.Conn) s.router.Use(authenticator.Authenticate("/ws", []string{})) s.router.GET("/ws", s.handle) @@ -29,13 +32,17 @@ func (s *WebsocketServer) handle(context echo.Context) error { authContext := context.(AuthContext) senderId := getAuthenticatedId(authContext) ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil) + s.socketsMutex.Lock() s.sockets[senderId] = ws + s.socketsMutex.Unlock() if err != nil { return err } defer func() { ws.Close() + s.socketsMutex.Lock() delete(s.sockets, senderId) + s.socketsMutex.Unlock() }() for {