feat: add proper session authentication for user and integration

This commit is contained in:
Fritz Heiden 2025-03-31 13:21:06 +02:00
parent e4384bdbfb
commit 8254e12da0
8 changed files with 148 additions and 71 deletions

View File

@ -110,12 +110,7 @@ func (db *DeviceDatabase) CreateIntegration(name, token string) (string, error)
return "", err return "", err
} }
hashed_token, err := hashPassword(token) _, err = db.Connection.Exec("INSERT INTO integrations (id, name, token) VALUES (?, ?, ?)", id, name, token)
if err != nil {
return "", err
}
_, err = db.Connection.Exec("INSERT INTO integrations (id, name, token) VALUES (?, ?, ?)", id, name, hashed_token)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -163,6 +158,22 @@ func (db *DeviceDatabase) GetIntegrations() ([]Integration, error) {
return integrations, nil return integrations, nil
} }
func (db *DeviceDatabase) GetIntegrationByToken(token string) (*Integration, error) {
exists, err := db.IntegrationTokenExists(token)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var integration Integration
err = db.Connection.QueryRow("SELECT id, name FROM integrations WHERE token = ?", token).Scan(&integration.ID, &integration.Name)
if err != nil {
return nil, err
}
return &integration, nil
}
func (db *DeviceDatabase) DeleteIntegration(id string) error { func (db *DeviceDatabase) DeleteIntegration(id string) error {
_, err := db.Connection.Exec("DELETE FROM integrations WHERE id = ?", id) _, err := db.Connection.Exec("DELETE FROM integrations WHERE id = ?", id)
return err return err
@ -177,22 +188,13 @@ func (db *DeviceDatabase) IntegrationNameExists(name string) (bool, error) {
return exists, nil return exists, nil
} }
func (db *DeviceDatabase) GetSession(sessionToken string) (*DeviceSession, error) { func (db *DeviceDatabase) IntegrationTokenExists(token string) (bool, error) {
var session DeviceSession var exists bool
row := db.Connection.QueryRow("SELECT token, device_id, expiry_date FROM sessions WHERE token = ?", sessionToken) err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM integrations WHERE token = ?)", token).Scan(&exists)
err := row.Scan(&session.Token, &session.DeviceID, &session.ExpiryDate)
if err != nil { if err != nil {
if err == sql.ErrNoRows { return false, err
return nil, nil
} }
return nil, err return exists, nil
}
return &session, nil
}
func (db *DeviceDatabase) DeleteSessionByToken(token string) error {
_, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token)
return err
} }
func (db *DeviceDatabase) SetDirectory(directory string) { func (db *DeviceDatabase) SetDirectory(directory string) {

View File

@ -1,9 +0,0 @@
package data
import "time"
type DeviceSession struct {
DeviceID string
Token string
ExpiryDate time.Time
}

View File

@ -86,8 +86,15 @@ func (db *UserDatabase) GetUserByUsername(username string) (*User, error) {
} }
func (db *UserDatabase) GetUserById(id string) (*User, error) { func (db *UserDatabase) GetUserById(id string) (*User, error) {
exists, err := db.UserIdExists(id)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var user User var user User
err := db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin) err = db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,9 +152,16 @@ func (db *UserDatabase) CheckCredentials(username, password string) (bool, error
} }
func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) { func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) {
exists, err := db.SessionTokenExists(sessionToken)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var session UserSession var session UserSession
row := db.Connection.QueryRow("SELECT token, user_id, expiry_date FROM sessions WHERE token = ?", sessionToken) row := db.Connection.QueryRow("SELECT token, user_id, expiry_date FROM sessions WHERE token = ?", sessionToken)
err := row.Scan(&session.Token, &session.UserID, &session.ExpiryDate) err = row.Scan(&session.Token, &session.UserID, &session.ExpiryDate)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
@ -157,6 +171,15 @@ func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) {
return &session, nil return &session, nil
} }
func (db *UserDatabase) SessionTokenExists(token string) (bool, error) {
var exists bool
err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM sessions WHERE token = ?)", token).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}
func (db *UserDatabase) DeleteSessionByToken(token string) error { func (db *UserDatabase) DeleteSessionByToken(token string) error {
_, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token) _, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token)
return err return err

View File

@ -59,6 +59,7 @@ func main() {
authenticator := server.Authenticator{} authenticator := server.Authenticator{}
authenticator.SetUserManager(&userManager) authenticator.SetUserManager(&userManager)
authenticator.SetDeviceManager(&deviceManager)
userApiHandler := server.UsersApiHandler{} userApiHandler := server.UsersApiHandler{}
userApiHandler.SetUserManager(&userManager) userApiHandler.SetUserManager(&userManager)

View File

@ -49,14 +49,6 @@ func (dm *DeviceManager) DeviceIdExists(id string) (bool, error) {
// return token, nil // return token, nil
//} //}
func (dm *DeviceManager) GetSession(sessionToken string) (*d.DeviceSession, error) {
session, error := dm.deviceDatabase.GetSession(sessionToken)
if error != nil {
return nil, error
}
return session, nil
}
func (dm *DeviceManager) GetDeviceById(id string) (*d.PlaybackDevice, error) { func (dm *DeviceManager) GetDeviceById(id string) (*d.PlaybackDevice, error) {
device, error := dm.deviceDatabase.GetDeviceById(id) device, error := dm.deviceDatabase.GetDeviceById(id)
if error != nil { if error != nil {
@ -70,14 +62,6 @@ func (dm *DeviceManager) UpdateDevice(device *d.PlaybackDevice) error {
return error return error
} }
func (dm *DeviceManager) DeleteSession(token string) error {
error := dm.deviceDatabase.DeleteSessionByToken(token)
if error != nil {
return error
}
return nil
}
func (dm *DeviceManager) GetDevices() (*[]d.PlaybackDevice, error) { func (dm *DeviceManager) GetDevices() (*[]d.PlaybackDevice, error) {
users, error := dm.deviceDatabase.GetDevices() users, error := dm.deviceDatabase.GetDevices()
return users, error return users, error
@ -141,6 +125,14 @@ func (dm *DeviceManager) GetIntegrations() ([]d.Integration, error) {
return integrations, nil return integrations, nil
} }
func (dm *DeviceManager) GetIntegrationByToken(token string) (*d.Integration, error) {
integration, err := dm.deviceDatabase.GetIntegrationByToken(token)
if err != nil {
return nil, err
}
return integration, nil
}
func (dm *DeviceManager) DeleteIntegration(id string) error { func (dm *DeviceManager) DeleteIntegration(id string) error {
error := dm.deviceDatabase.DeleteIntegration(id) error := dm.deviceDatabase.DeleteIntegration(id)
return error return error

View File

@ -14,16 +14,22 @@ type AuthContext struct {
echo.Context echo.Context
User *d.User User *d.User
Session *d.UserSession Session *d.UserSession
Integration *d.Integration
} }
type Authenticator struct { type Authenticator struct {
userManager *m.UserManager userManager *m.UserManager
deviceManager *m.DeviceManager
} }
func (r *Authenticator) SetUserManager(userManager *m.UserManager) { func (r *Authenticator) SetUserManager(userManager *m.UserManager) {
r.userManager = userManager r.userManager = userManager
} }
func (r *Authenticator) SetDeviceManager(deviceManager *m.DeviceManager) {
r.deviceManager = deviceManager
}
func (r *Authenticator) Authenticate(path string, exceptions []string) func(next echo.HandlerFunc) echo.HandlerFunc { func (r *Authenticator) Authenticate(path string, exceptions []string) func(next echo.HandlerFunc) echo.HandlerFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(context echo.Context) error { return func(context echo.Context) error {
@ -37,30 +43,61 @@ func (r *Authenticator) Authenticate(path string, exceptions []string) func(next
} }
} }
cookie, err := context.Cookie("token") cookie, err := context.Cookie("token")
fmt.Println(context.Cookies())
if err != nil { if err != nil {
SendError(401, context, "no session token found") SendError(401, context, "no cookie for session token found")
return err return err
} }
session, error := r.userManager.GetSession(cookie.Value) token := cookie.Value
if error != nil || session == nil { user, session, error := r.getUserAndSession(token)
SendError(401, context, fmt.Sprintf("session not found: %s", cookie.Value))
return fmt.Errorf("session not found: %s", cookie.Value)
}
user, error := r.userManager.GetUserById(session.UserID)
if error != nil { if error != nil {
log.Error().Err(error).Msg("error getting user by id") log.Error().Err(error).Msg("error authenticating user")
SendError(401, context, "no user found for given session") SendError(500, context, fmt.Sprintf("error authenticating user: %s", error))
return error return error
} }
if user == nil {
SendError(401, context, "no user found for given session") integration, error := r.getIntegration(token)
return fmt.Errorf("no user found for session '%s'", cookie.Value) if error != nil {
log.Error().Err(error).Msg("error getting integration")
SendError(500, context, fmt.Sprintf("error getting integration: %s", error))
return error
} }
authContext := AuthContext{Context: context, User: user, Session: session} if integration == nil && user == nil {
log.Error().Msg("no integration or user found for given token")
SendError(401, context, "no integration or user found for given token")
return fmt.Errorf("no integration or user found for given token")
}
fmt.Println("user:", user, "session:", session, "integration:", integration)
authContext := AuthContext{Context: context, User: user, Session: session, Integration: integration}
return next(authContext) return next(authContext)
} }
} }
} }
func (r *Authenticator) getUserAndSession(token string) (*d.User, *d.UserSession, error) {
session, error := r.userManager.GetSession(token)
if error != nil {
return nil, nil, error
}
if session == nil {
return nil, nil, nil
}
user, error := r.userManager.GetUserById(session.UserID)
if error != nil {
return nil, nil, error
}
return user, session, nil
}
func (r *Authenticator) getIntegration(token string) (*d.Integration, error) {
integration, error := r.deviceManager.GetIntegrationByToken(token)
if error != nil {
return nil, error
}
return integration, nil
}

View File

@ -2,6 +2,7 @@ package server
import ( import (
"fmt" "fmt"
"net/http"
d "playback-device-server/data" d "playback-device-server/data"
m "playback-device-server/management" m "playback-device-server/management"
@ -23,15 +24,17 @@ func (r *DeviceApiHandler) Initialize(authenticator *Authenticator) {
devicesApi.GET("", r.handleGetDevices) devicesApi.GET("", r.handleGetDevices)
devicesApi.POST("", r.handleCreateDevice) devicesApi.POST("", r.handleCreateDevice)
devicesApi.DELETE("/:id", r.handleDeleteDevice) devicesApi.DELETE("/:id", r.handleDeleteDevice)
r.router.Use(authenticator.Authenticate("/api/integrations", []string{"/api/integrations/register"}))
integrationsApi := r.router.Group("/api/integrations") integrationsApi := r.router.Group("/api/integrations")
integrationsApi.GET("/register", r.handleIntegrationRegistration) integrationsApi.POST("/register", r.handleIntegrationRegistration)
integrationsApi.POST("", r.handleCreateIntegration) integrationsApi.POST("", r.handleCreateIntegration)
integrationsApi.GET("", r.handleGetIntegrations) integrationsApi.GET("", r.handleGetIntegrations)
integrationsApi.GET("/:id", r.handleGetIntegration) integrationsApi.GET("/:id", r.handleGetIntegration)
integrationsApi.DELETE("/:id", r.handleDeleteIntegration) integrationsApi.DELETE("/:id", r.handleDeleteIntegration)
} }
func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) error { func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error {
code, error := r.deviceManager.GetRegistrationCode() code, error := r.deviceManager.GetRegistrationCode()
if error != nil { if error != nil {
@ -48,7 +51,7 @@ func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) e
return context.JSON(200, response) return context.JSON(200, response)
} }
func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error { func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) error {
var data struct { var data struct {
Name string `json:"name"` Name string `json:"name"`
Code string `json:"code"` Code string `json:"code"`
@ -77,6 +80,16 @@ func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error {
Token: integration.Token, Token: integration.Token,
} }
thirdyDays := 30 * 24 * 60 * 60
cookie := &http.Cookie{
Name: "token",
Value: integration.Token,
Path: "/",
HttpOnly: true,
MaxAge: thirdyDays,
}
context.SetCookie(cookie)
return context.JSON(200, integrationData) return context.JSON(200, integrationData)
} }

View File

@ -12,20 +12,27 @@ var upgrader = websocket.Upgrader{}
type WebsocketServer struct { type WebsocketServer struct {
router *echo.Echo router *echo.Echo
sockets map[string]*websocket.Conn
} }
func (s *WebsocketServer) Initialize(authenticator *Authenticator) { func (s *WebsocketServer) Initialize(authenticator *Authenticator) {
//s.router.Use(authenticator.Authenticate("/ws/terminal", []string{})) s.sockets = make(map[string]*websocket.Conn)
s.router.Use(authenticator.Authenticate("/ws", []string{}))
s.router.GET("/ws", s.handle) s.router.GET("/ws", s.handle)
} }
func (s *WebsocketServer) handle(context echo.Context) error { func (s *WebsocketServer) handle(context echo.Context) error {
//authContext := context.(AuthContext) authContext := context.(AuthContext)
senderId := getAuthenticatedId(authContext)
ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil) ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil)
s.sockets[senderId] = ws
if err != nil { if err != nil {
return err return err
} }
defer ws.Close() defer func() {
ws.Close()
delete(s.sockets, senderId)
}()
for { for {
messageType, messageBytes, err := ws.ReadMessage() messageType, messageBytes, err := ws.ReadMessage()
@ -35,11 +42,22 @@ func (s *WebsocketServer) handle(context echo.Context) error {
if messageType == websocket.TextMessage { if messageType == websocket.TextMessage {
var messageObject map[string]any var messageObject map[string]any
json.Unmarshal(messageBytes, &messageObject) json.Unmarshal(messageBytes, &messageObject)
fmt.Println("Received message from authenticated user", senderId)
fmt.Println(messageObject) fmt.Println(messageObject)
} }
} }
} }
func getAuthenticatedId(authContext AuthContext) string {
if authContext.User != nil {
return authContext.User.ID
}
if authContext.Integration != nil {
return authContext.Integration.ID
}
return ""
}
func (s *WebsocketServer) SetRouter(router *echo.Echo) { func (s *WebsocketServer) SetRouter(router *echo.Echo) {
s.router = router s.router = router
} }