From 8254e12da0622954b234bdf8d4b51255563cc9cc Mon Sep 17 00:00:00 2001 From: Fritz Heiden Date: Mon, 31 Mar 2025 13:21:06 +0200 Subject: [PATCH] feat: add proper session authentication for user and integration --- data/device_database.go | 42 +++++++++++---------- data/device_session.go | 9 ----- data/user_database.go | 27 +++++++++++++- main/main.go | 1 + management/device_manager.go | 24 ++++-------- server/authenticator.go | 71 +++++++++++++++++++++++++++--------- server/device_api_handler.go | 19 ++++++++-- server/websocket_server.go | 26 +++++++++++-- 8 files changed, 148 insertions(+), 71 deletions(-) delete mode 100644 data/device_session.go diff --git a/data/device_database.go b/data/device_database.go index 09982a9..224942a 100644 --- a/data/device_database.go +++ b/data/device_database.go @@ -110,12 +110,7 @@ func (db *DeviceDatabase) CreateIntegration(name, token string) (string, error) return "", err } - hashed_token, err := hashPassword(token) - if err != nil { - return "", err - } - - _, err = db.Connection.Exec("INSERT INTO integrations (id, name, token) VALUES (?, ?, ?)", id, name, hashed_token) + _, err = db.Connection.Exec("INSERT INTO integrations (id, name, token) VALUES (?, ?, ?)", id, name, token) if err != nil { return "", err } @@ -163,6 +158,22 @@ func (db *DeviceDatabase) GetIntegrations() ([]Integration, error) { return integrations, nil } +func (db *DeviceDatabase) GetIntegrationByToken(token string) (*Integration, error) { + exists, err := db.IntegrationTokenExists(token) + if err != nil { + return nil, err + } + if !exists { + return nil, nil + } + var integration Integration + err = db.Connection.QueryRow("SELECT id, name FROM integrations WHERE token = ?", token).Scan(&integration.ID, &integration.Name) + if err != nil { + return nil, err + } + return &integration, nil +} + func (db *DeviceDatabase) DeleteIntegration(id string) error { _, err := db.Connection.Exec("DELETE FROM integrations WHERE id = ?", id) return err @@ -177,22 +188,13 @@ func (db *DeviceDatabase) IntegrationNameExists(name string) (bool, error) { return exists, nil } -func (db *DeviceDatabase) GetSession(sessionToken string) (*DeviceSession, error) { - var session DeviceSession - row := db.Connection.QueryRow("SELECT token, device_id, expiry_date FROM sessions WHERE token = ?", sessionToken) - err := row.Scan(&session.Token, &session.DeviceID, &session.ExpiryDate) +func (db *DeviceDatabase) IntegrationTokenExists(token string) (bool, error) { + var exists bool + err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM integrations WHERE token = ?)", token).Scan(&exists) if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err + return false, err } - return &session, nil -} - -func (db *DeviceDatabase) DeleteSessionByToken(token string) error { - _, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token) - return err + return exists, nil } func (db *DeviceDatabase) SetDirectory(directory string) { diff --git a/data/device_session.go b/data/device_session.go deleted file mode 100644 index 2f55a33..0000000 --- a/data/device_session.go +++ /dev/null @@ -1,9 +0,0 @@ -package data - -import "time" - -type DeviceSession struct { - DeviceID string - Token string - ExpiryDate time.Time -} diff --git a/data/user_database.go b/data/user_database.go index 6348bf5..e5472b9 100644 --- a/data/user_database.go +++ b/data/user_database.go @@ -86,8 +86,15 @@ func (db *UserDatabase) GetUserByUsername(username string) (*User, error) { } func (db *UserDatabase) GetUserById(id string) (*User, error) { + exists, err := db.UserIdExists(id) + if err != nil { + return nil, err + } + if !exists { + return nil, nil + } var user User - err := db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin) + err = db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin) if err != nil { return nil, err } @@ -145,9 +152,16 @@ func (db *UserDatabase) CheckCredentials(username, password string) (bool, error } func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) { + exists, err := db.SessionTokenExists(sessionToken) + if err != nil { + return nil, err + } + if !exists { + return nil, nil + } var session UserSession row := db.Connection.QueryRow("SELECT token, user_id, expiry_date FROM sessions WHERE token = ?", sessionToken) - err := row.Scan(&session.Token, &session.UserID, &session.ExpiryDate) + err = row.Scan(&session.Token, &session.UserID, &session.ExpiryDate) if err != nil { if err == sql.ErrNoRows { return nil, nil @@ -157,6 +171,15 @@ func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) { return &session, nil } +func (db *UserDatabase) SessionTokenExists(token string) (bool, error) { + var exists bool + err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM sessions WHERE token = ?)", token).Scan(&exists) + if err != nil { + return false, err + } + return exists, nil +} + func (db *UserDatabase) DeleteSessionByToken(token string) error { _, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token) return err diff --git a/main/main.go b/main/main.go index 1da68d7..7e56730 100644 --- a/main/main.go +++ b/main/main.go @@ -59,6 +59,7 @@ func main() { authenticator := server.Authenticator{} authenticator.SetUserManager(&userManager) + authenticator.SetDeviceManager(&deviceManager) userApiHandler := server.UsersApiHandler{} userApiHandler.SetUserManager(&userManager) diff --git a/management/device_manager.go b/management/device_manager.go index dc4ec98..5ca16fc 100644 --- a/management/device_manager.go +++ b/management/device_manager.go @@ -49,14 +49,6 @@ func (dm *DeviceManager) DeviceIdExists(id string) (bool, error) { // return token, nil //} -func (dm *DeviceManager) GetSession(sessionToken string) (*d.DeviceSession, error) { - session, error := dm.deviceDatabase.GetSession(sessionToken) - if error != nil { - return nil, error - } - return session, nil -} - func (dm *DeviceManager) GetDeviceById(id string) (*d.PlaybackDevice, error) { device, error := dm.deviceDatabase.GetDeviceById(id) if error != nil { @@ -70,14 +62,6 @@ func (dm *DeviceManager) UpdateDevice(device *d.PlaybackDevice) error { return error } -func (dm *DeviceManager) DeleteSession(token string) error { - error := dm.deviceDatabase.DeleteSessionByToken(token) - if error != nil { - return error - } - return nil -} - func (dm *DeviceManager) GetDevices() (*[]d.PlaybackDevice, error) { users, error := dm.deviceDatabase.GetDevices() return users, error @@ -141,6 +125,14 @@ func (dm *DeviceManager) GetIntegrations() ([]d.Integration, error) { return integrations, nil } +func (dm *DeviceManager) GetIntegrationByToken(token string) (*d.Integration, error) { + integration, err := dm.deviceDatabase.GetIntegrationByToken(token) + if err != nil { + return nil, err + } + return integration, nil +} + func (dm *DeviceManager) DeleteIntegration(id string) error { error := dm.deviceDatabase.DeleteIntegration(id) return error diff --git a/server/authenticator.go b/server/authenticator.go index a1bec72..551ac1f 100644 --- a/server/authenticator.go +++ b/server/authenticator.go @@ -12,18 +12,24 @@ import ( type AuthContext struct { echo.Context - User *d.User - Session *d.UserSession + User *d.User + Session *d.UserSession + Integration *d.Integration } type Authenticator struct { - userManager *m.UserManager + userManager *m.UserManager + deviceManager *m.DeviceManager } func (r *Authenticator) SetUserManager(userManager *m.UserManager) { r.userManager = userManager } +func (r *Authenticator) SetDeviceManager(deviceManager *m.DeviceManager) { + r.deviceManager = deviceManager +} + func (r *Authenticator) Authenticate(path string, exceptions []string) func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(context echo.Context) error { @@ -37,30 +43,61 @@ func (r *Authenticator) Authenticate(path string, exceptions []string) func(next } } cookie, err := context.Cookie("token") + fmt.Println(context.Cookies()) if err != nil { - SendError(401, context, "no session token found") + SendError(401, context, "no cookie for session token found") return err } - session, error := r.userManager.GetSession(cookie.Value) - if error != nil || session == nil { - SendError(401, context, fmt.Sprintf("session not found: %s", cookie.Value)) - return fmt.Errorf("session not found: %s", cookie.Value) - } - - user, error := r.userManager.GetUserById(session.UserID) + token := cookie.Value + user, session, error := r.getUserAndSession(token) if error != nil { - log.Error().Err(error).Msg("error getting user by id") - SendError(401, context, "no user found for given session") + log.Error().Err(error).Msg("error authenticating user") + SendError(500, context, fmt.Sprintf("error authenticating user: %s", error)) return error } - if user == nil { - SendError(401, context, "no user found for given session") - return fmt.Errorf("no user found for session '%s'", cookie.Value) + + integration, error := r.getIntegration(token) + if error != nil { + log.Error().Err(error).Msg("error getting integration") + SendError(500, context, fmt.Sprintf("error getting integration: %s", error)) + return error } - authContext := AuthContext{Context: context, User: user, Session: session} + if integration == nil && user == nil { + log.Error().Msg("no integration or user found for given token") + SendError(401, context, "no integration or user found for given token") + return fmt.Errorf("no integration or user found for given token") + } + + fmt.Println("user:", user, "session:", session, "integration:", integration) + + authContext := AuthContext{Context: context, User: user, Session: session, Integration: integration} return next(authContext) } } } + +func (r *Authenticator) getUserAndSession(token string) (*d.User, *d.UserSession, error) { + session, error := r.userManager.GetSession(token) + if error != nil { + return nil, nil, error + } + if session == nil { + return nil, nil, nil + } + + user, error := r.userManager.GetUserById(session.UserID) + if error != nil { + return nil, nil, error + } + return user, session, nil +} + +func (r *Authenticator) getIntegration(token string) (*d.Integration, error) { + integration, error := r.deviceManager.GetIntegrationByToken(token) + if error != nil { + return nil, error + } + return integration, nil +} diff --git a/server/device_api_handler.go b/server/device_api_handler.go index e0bbc38..4ce510e 100644 --- a/server/device_api_handler.go +++ b/server/device_api_handler.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "net/http" d "playback-device-server/data" m "playback-device-server/management" @@ -23,15 +24,17 @@ func (r *DeviceApiHandler) Initialize(authenticator *Authenticator) { devicesApi.GET("", r.handleGetDevices) devicesApi.POST("", r.handleCreateDevice) devicesApi.DELETE("/:id", r.handleDeleteDevice) + + r.router.Use(authenticator.Authenticate("/api/integrations", []string{"/api/integrations/register"})) integrationsApi := r.router.Group("/api/integrations") - integrationsApi.GET("/register", r.handleIntegrationRegistration) + integrationsApi.POST("/register", r.handleIntegrationRegistration) integrationsApi.POST("", r.handleCreateIntegration) integrationsApi.GET("", r.handleGetIntegrations) integrationsApi.GET("/:id", r.handleGetIntegration) integrationsApi.DELETE("/:id", r.handleDeleteIntegration) } -func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) error { +func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error { code, error := r.deviceManager.GetRegistrationCode() if error != nil { @@ -48,7 +51,7 @@ func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) e return context.JSON(200, response) } -func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error { +func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) error { var data struct { Name string `json:"name"` Code string `json:"code"` @@ -77,6 +80,16 @@ func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error { Token: integration.Token, } + thirdyDays := 30 * 24 * 60 * 60 + cookie := &http.Cookie{ + Name: "token", + Value: integration.Token, + Path: "/", + HttpOnly: true, + MaxAge: thirdyDays, + } + context.SetCookie(cookie) + return context.JSON(200, integrationData) } diff --git a/server/websocket_server.go b/server/websocket_server.go index 869bc01..023263b 100644 --- a/server/websocket_server.go +++ b/server/websocket_server.go @@ -11,21 +11,28 @@ import ( var upgrader = websocket.Upgrader{} type WebsocketServer struct { - router *echo.Echo + router *echo.Echo + sockets map[string]*websocket.Conn } func (s *WebsocketServer) Initialize(authenticator *Authenticator) { - //s.router.Use(authenticator.Authenticate("/ws/terminal", []string{})) + s.sockets = make(map[string]*websocket.Conn) + s.router.Use(authenticator.Authenticate("/ws", []string{})) s.router.GET("/ws", s.handle) } func (s *WebsocketServer) handle(context echo.Context) error { - //authContext := context.(AuthContext) + authContext := context.(AuthContext) + senderId := getAuthenticatedId(authContext) ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil) + s.sockets[senderId] = ws if err != nil { return err } - defer ws.Close() + defer func() { + ws.Close() + delete(s.sockets, senderId) + }() for { messageType, messageBytes, err := ws.ReadMessage() @@ -35,11 +42,22 @@ func (s *WebsocketServer) handle(context echo.Context) error { if messageType == websocket.TextMessage { var messageObject map[string]any json.Unmarshal(messageBytes, &messageObject) + fmt.Println("Received message from authenticated user", senderId) fmt.Println(messageObject) } } } +func getAuthenticatedId(authContext AuthContext) string { + if authContext.User != nil { + return authContext.User.ID + } + if authContext.Integration != nil { + return authContext.Integration.ID + } + return "" +} + func (s *WebsocketServer) SetRouter(router *echo.Echo) { s.router = router }