feat: add proper session authentication for user and integration

This commit is contained in:
Fritz Heiden 2025-03-31 13:21:06 +02:00
parent e4384bdbfb
commit 811e348697
10 changed files with 154 additions and 74 deletions

View File

@ -110,12 +110,7 @@ func (db *DeviceDatabase) CreateIntegration(name, token string) (string, error)
return "", err
}
hashed_token, err := hashPassword(token)
if err != nil {
return "", err
}
_, err = db.Connection.Exec("INSERT INTO integrations (id, name, token) VALUES (?, ?, ?)", id, name, hashed_token)
_, err = db.Connection.Exec("INSERT INTO integrations (id, name, token) VALUES (?, ?, ?)", id, name, token)
if err != nil {
return "", err
}
@ -163,6 +158,22 @@ func (db *DeviceDatabase) GetIntegrations() ([]Integration, error) {
return integrations, nil
}
func (db *DeviceDatabase) GetIntegrationByToken(token string) (*Integration, error) {
exists, err := db.IntegrationTokenExists(token)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var integration Integration
err = db.Connection.QueryRow("SELECT id, name FROM integrations WHERE token = ?", token).Scan(&integration.ID, &integration.Name)
if err != nil {
return nil, err
}
return &integration, nil
}
func (db *DeviceDatabase) DeleteIntegration(id string) error {
_, err := db.Connection.Exec("DELETE FROM integrations WHERE id = ?", id)
return err
@ -177,22 +188,13 @@ func (db *DeviceDatabase) IntegrationNameExists(name string) (bool, error) {
return exists, nil
}
func (db *DeviceDatabase) GetSession(sessionToken string) (*DeviceSession, error) {
var session DeviceSession
row := db.Connection.QueryRow("SELECT token, device_id, expiry_date FROM sessions WHERE token = ?", sessionToken)
err := row.Scan(&session.Token, &session.DeviceID, &session.ExpiryDate)
func (db *DeviceDatabase) IntegrationTokenExists(token string) (bool, error) {
var exists bool
err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM integrations WHERE token = ?)", token).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
return false, err
}
return &session, nil
}
func (db *DeviceDatabase) DeleteSessionByToken(token string) error {
_, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token)
return err
return exists, nil
}
func (db *DeviceDatabase) SetDirectory(directory string) {

View File

@ -1,9 +0,0 @@
package data
import "time"
type DeviceSession struct {
DeviceID string
Token string
ExpiryDate time.Time
}

View File

@ -86,8 +86,15 @@ func (db *UserDatabase) GetUserByUsername(username string) (*User, error) {
}
func (db *UserDatabase) GetUserById(id string) (*User, error) {
exists, err := db.UserIdExists(id)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var user User
err := db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin)
err = db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin)
if err != nil {
return nil, err
}
@ -145,9 +152,16 @@ func (db *UserDatabase) CheckCredentials(username, password string) (bool, error
}
func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) {
exists, err := db.SessionTokenExists(sessionToken)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var session UserSession
row := db.Connection.QueryRow("SELECT token, user_id, expiry_date FROM sessions WHERE token = ?", sessionToken)
err := row.Scan(&session.Token, &session.UserID, &session.ExpiryDate)
err = row.Scan(&session.Token, &session.UserID, &session.ExpiryDate)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
@ -157,6 +171,15 @@ func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) {
return &session, nil
}
func (db *UserDatabase) SessionTokenExists(token string) (bool, error) {
var exists bool
err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM sessions WHERE token = ?)", token).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}
func (db *UserDatabase) DeleteSessionByToken(token string) error {
_, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token)
return err

View File

@ -59,6 +59,7 @@ func main() {
authenticator := server.Authenticator{}
authenticator.SetUserManager(&userManager)
authenticator.SetDeviceManager(&deviceManager)
userApiHandler := server.UsersApiHandler{}
userApiHandler.SetUserManager(&userManager)

View File

@ -49,14 +49,6 @@ func (dm *DeviceManager) DeviceIdExists(id string) (bool, error) {
// return token, nil
//}
func (dm *DeviceManager) GetSession(sessionToken string) (*d.DeviceSession, error) {
session, error := dm.deviceDatabase.GetSession(sessionToken)
if error != nil {
return nil, error
}
return session, nil
}
func (dm *DeviceManager) GetDeviceById(id string) (*d.PlaybackDevice, error) {
device, error := dm.deviceDatabase.GetDeviceById(id)
if error != nil {
@ -70,14 +62,6 @@ func (dm *DeviceManager) UpdateDevice(device *d.PlaybackDevice) error {
return error
}
func (dm *DeviceManager) DeleteSession(token string) error {
error := dm.deviceDatabase.DeleteSessionByToken(token)
if error != nil {
return error
}
return nil
}
func (dm *DeviceManager) GetDevices() (*[]d.PlaybackDevice, error) {
users, error := dm.deviceDatabase.GetDevices()
return users, error
@ -141,6 +125,14 @@ func (dm *DeviceManager) GetIntegrations() ([]d.Integration, error) {
return integrations, nil
}
func (dm *DeviceManager) GetIntegrationByToken(token string) (*d.Integration, error) {
integration, err := dm.deviceDatabase.GetIntegrationByToken(token)
if err != nil {
return nil, err
}
return integration, nil
}
func (dm *DeviceManager) DeleteIntegration(id string) error {
error := dm.deviceDatabase.DeleteIntegration(id)
return error

View File

@ -12,18 +12,24 @@ import (
type AuthContext struct {
echo.Context
User *d.User
Session *d.UserSession
User *d.User
Session *d.UserSession
Integration *d.Integration
}
type Authenticator struct {
userManager *m.UserManager
userManager *m.UserManager
deviceManager *m.DeviceManager
}
func (r *Authenticator) SetUserManager(userManager *m.UserManager) {
r.userManager = userManager
}
func (r *Authenticator) SetDeviceManager(deviceManager *m.DeviceManager) {
r.deviceManager = deviceManager
}
func (r *Authenticator) Authenticate(path string, exceptions []string) func(next echo.HandlerFunc) echo.HandlerFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(context echo.Context) error {
@ -37,30 +43,61 @@ func (r *Authenticator) Authenticate(path string, exceptions []string) func(next
}
}
cookie, err := context.Cookie("token")
fmt.Println(context.Cookies())
if err != nil {
SendError(401, context, "no session token found")
SendError(401, context, "no cookie for session token found")
return err
}
session, error := r.userManager.GetSession(cookie.Value)
if error != nil || session == nil {
SendError(401, context, fmt.Sprintf("session not found: %s", cookie.Value))
return fmt.Errorf("session not found: %s", cookie.Value)
}
user, error := r.userManager.GetUserById(session.UserID)
token := cookie.Value
user, session, error := r.getUserAndSession(token)
if error != nil {
log.Error().Err(error).Msg("error getting user by id")
SendError(401, context, "no user found for given session")
log.Error().Err(error).Msg("error authenticating user")
SendError(500, context, fmt.Sprintf("error authenticating user: %s", error))
return error
}
if user == nil {
SendError(401, context, "no user found for given session")
return fmt.Errorf("no user found for session '%s'", cookie.Value)
integration, error := r.getIntegration(token)
if error != nil {
log.Error().Err(error).Msg("error getting integration")
SendError(500, context, fmt.Sprintf("error getting integration: %s", error))
return error
}
authContext := AuthContext{Context: context, User: user, Session: session}
if integration == nil && user == nil {
log.Error().Msg("no integration or user found for given token")
SendError(401, context, "no integration or user found for given token")
return fmt.Errorf("no integration or user found for given token")
}
fmt.Println("user:", user, "session:", session, "integration:", integration)
authContext := AuthContext{Context: context, User: user, Session: session, Integration: integration}
return next(authContext)
}
}
}
func (r *Authenticator) getUserAndSession(token string) (*d.User, *d.UserSession, error) {
session, error := r.userManager.GetSession(token)
if error != nil {
return nil, nil, error
}
if session == nil {
return nil, nil, nil
}
user, error := r.userManager.GetUserById(session.UserID)
if error != nil {
return nil, nil, error
}
return user, session, nil
}
func (r *Authenticator) getIntegration(token string) (*d.Integration, error) {
integration, error := r.deviceManager.GetIntegrationByToken(token)
if error != nil {
return nil, error
}
return integration, nil
}

View File

@ -2,6 +2,7 @@ package server
import (
"fmt"
"net/http"
d "playback-device-server/data"
m "playback-device-server/management"
@ -23,15 +24,17 @@ func (r *DeviceApiHandler) Initialize(authenticator *Authenticator) {
devicesApi.GET("", r.handleGetDevices)
devicesApi.POST("", r.handleCreateDevice)
devicesApi.DELETE("/:id", r.handleDeleteDevice)
r.router.Use(authenticator.Authenticate("/api/integrations", []string{"/api/integrations/register"}))
integrationsApi := r.router.Group("/api/integrations")
integrationsApi.GET("/register", r.handleIntegrationRegistration)
integrationsApi.POST("/register", r.handleIntegrationRegistration)
integrationsApi.POST("", r.handleCreateIntegration)
integrationsApi.GET("", r.handleGetIntegrations)
integrationsApi.GET("/:id", r.handleGetIntegration)
integrationsApi.DELETE("/:id", r.handleDeleteIntegration)
}
func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) error {
func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error {
code, error := r.deviceManager.GetRegistrationCode()
if error != nil {
@ -48,7 +51,7 @@ func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) e
return context.JSON(200, response)
}
func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error {
func (r *DeviceApiHandler) handleIntegrationRegistration(context echo.Context) error {
var data struct {
Name string `json:"name"`
Code string `json:"code"`
@ -77,6 +80,16 @@ func (r *DeviceApiHandler) handleCreateIntegration(context echo.Context) error {
Token: integration.Token,
}
thirdyDays := 30 * 24 * 60 * 60
cookie := &http.Cookie{
Name: "token",
Value: integration.Token,
Path: "/",
HttpOnly: true,
MaxAge: thirdyDays,
}
context.SetCookie(cookie)
return context.JSON(200, integrationData)
}

View File

@ -11,21 +11,28 @@ import (
var upgrader = websocket.Upgrader{}
type WebsocketServer struct {
router *echo.Echo
router *echo.Echo
sockets map[string]*websocket.Conn
}
func (s *WebsocketServer) Initialize(authenticator *Authenticator) {
//s.router.Use(authenticator.Authenticate("/ws/terminal", []string{}))
s.sockets = make(map[string]*websocket.Conn)
s.router.Use(authenticator.Authenticate("/ws", []string{}))
s.router.GET("/ws", s.handle)
}
func (s *WebsocketServer) handle(context echo.Context) error {
//authContext := context.(AuthContext)
authContext := context.(AuthContext)
senderId := getAuthenticatedId(authContext)
ws, err := upgrader.Upgrade(context.Response(), context.Request(), nil)
s.sockets[senderId] = ws
if err != nil {
return err
}
defer ws.Close()
defer func() {
ws.Close()
delete(s.sockets, senderId)
}()
for {
messageType, messageBytes, err := ws.ReadMessage()
@ -35,11 +42,22 @@ func (s *WebsocketServer) handle(context echo.Context) error {
if messageType == websocket.TextMessage {
var messageObject map[string]any
json.Unmarshal(messageBytes, &messageObject)
fmt.Println("Received message from authenticated user", senderId)
fmt.Println(messageObject)
}
}
}
func getAuthenticatedId(authContext AuthContext) string {
if authContext.User != nil {
return authContext.User.ID
}
if authContext.Integration != nil {
return authContext.Integration.ID
}
return ""
}
func (s *WebsocketServer) SetRouter(router *echo.Echo) {
s.router = router
}

View File

@ -66,8 +66,8 @@ function DeviceService() {
async function getRegistrationCode() {
let response = await Net.sendRequest({
method: "GET",
url: "/api/integrations/register",
method: "POST",
url: "/api/integrations",
});
if (response.status !== 200) {

View File

@ -157,7 +157,10 @@ function DevicesView(props) {
<>
<button
class="btn btn-outline-secondary me-2"
onClick={() => handleDeleteIntegration(integration)}
onClick={(event) => {
event.stopPropagation();
handleDeleteIntegration(integration);
}}
>
<i class="bi bi-trash-fill"></i>
</button>