playback-device-server/data/user_database.go

227 lines
5.8 KiB
Go

package data
import (
"database/sql"
"fmt"
"path/filepath"
"time"
gonanoid "github.com/matoous/go-nanoid"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
)
type UserDatabase struct {
Connection *sql.DB
databaseDirectory string
}
func (db *UserDatabase) Initialize() error {
connection, error := sql.Open("sqlite3", filepath.Join(db.databaseDirectory, "users.db"))
if error != nil {
return error
}
db.Connection = connection
_, error = db.Connection.Exec(`CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT UNIQUE,
password TEXT,
is_admin INTEGER
)`)
if error != nil {
return fmt.Errorf("error creating users table: %s", error)
}
_, error = db.Connection.Exec(`CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
user_id INTEGER,
expiry_date TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id)
)`)
if error != nil {
return fmt.Errorf("error creating sessions table: %s", error)
}
return nil
}
func (db *UserDatabase) Close() error {
return db.Connection.Close()
}
func (db *UserDatabase) CreateUser(username, password string, isAdmin bool) (string, error) {
userID, err := gonanoid.Nanoid(8)
if err != nil {
return "", err
}
hashedPassword, err := hashPassword(password)
if err != nil {
return "", err
}
_, err = db.Connection.Exec("INSERT INTO users (id, username, password, is_admin) VALUES (?, ?, ?, ?)", userID, username, hashedPassword, isAdmin)
return userID, err
}
func (db *UserDatabase) CreateSession(userID string, expiryDate time.Time) (string, error) {
sessionToken, err := gonanoid.Nanoid(16)
if err != nil {
return "", err
}
_, err = db.Connection.Exec("INSERT INTO sessions (user_id, token, expiry_date) VALUES (?, ?, ?)", userID, sessionToken, expiryDate)
if err != nil {
return "", err
}
return sessionToken, nil
}
func (db *UserDatabase) GetUserByUsername(username string) (*User, error) {
var user User
err := db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE username = ?", username).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *UserDatabase) GetUserById(id string) (*User, error) {
exists, err := db.UserIdExists(id)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var user User
err = db.Connection.QueryRow("SELECT id, username, password, is_admin FROM users WHERE id = ?", id).Scan(&user.ID, &user.Username, &user.Password, &user.IsAdmin)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *UserDatabase) UsernameExists(username string) (bool, error) {
var exists bool
err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE username = ?)", username).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}
func (db *UserDatabase) UserIdExists(id string) (bool, error) {
var exists bool
err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE id = ?)", id).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}
func (db *UserDatabase) UpdateUser(user *User) error {
_, err := db.Connection.Exec("UPDATE users SET username = ?, is_admin = ? WHERE id = ?", user.Username, user.IsAdmin, user.ID)
return err
}
func (db *UserDatabase) UpdatePassword(userId string, newPassword string) error {
hashedPassword, err := hashPassword(newPassword)
if err != nil {
return err
}
_, err = db.Connection.Exec("UPDATE users SET password = ? WHERE id = ?", hashedPassword, userId)
return err
}
func (db *UserDatabase) CheckCredentials(username, password string) (bool, error) {
var hashedPassword string
err := db.Connection.QueryRow("SELECT password FROM users WHERE username = ?", username).Scan(&hashedPassword)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
if err != nil {
return false, nil
}
return true, nil
}
func (db *UserDatabase) GetSession(sessionToken string) (*UserSession, error) {
exists, err := db.SessionTokenExists(sessionToken)
if err != nil {
return nil, err
}
if !exists {
return nil, nil
}
var session UserSession
row := db.Connection.QueryRow("SELECT token, user_id, expiry_date FROM sessions WHERE token = ?", sessionToken)
err = row.Scan(&session.Token, &session.UserID, &session.ExpiryDate)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &session, nil
}
func (db *UserDatabase) SessionTokenExists(token string) (bool, error) {
var exists bool
err := db.Connection.QueryRow("SELECT EXISTS(SELECT 1 FROM sessions WHERE token = ?)", token).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}
func (db *UserDatabase) DeleteSessionByToken(token string) error {
_, err := db.Connection.Exec("DELETE FROM sessions WHERE token = ?", token)
return err
}
func (db *UserDatabase) GetUsers() (*[]User, error) {
var users []User
rows, err := db.Connection.Query("SELECT id, username, is_admin FROM users")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var user User
err := rows.Scan(&user.ID, &user.Username, &user.IsAdmin)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
users = append(users, user)
}
return &users, nil
}
func (db *UserDatabase) DeleteUser(ID string) error {
_, err := db.Connection.Exec("DELETE FROM users WHERE id = ?", ID)
return err
}
func hashPassword(password string) (string, error) {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedPassword), nil
}
func (db *UserDatabase) SetDirectory(directory string) {
db.databaseDirectory = directory
}