diff --git a/.gitignore b/.gitignore index 1e61f38..459afc7 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -users.db \ No newline at end of file +users.db +start \ No newline at end of file diff --git a/main/main.go b/main/main.go index 8b3a7d7..a22af19 100644 --- a/main/main.go +++ b/main/main.go @@ -1,15 +1,11 @@ package main import ( - "crypto/rand" - "encoding/base64" "fmt" "os" "playback-device-server/users" ) -const DEFAULT_USERNAME = "admin" -const MIN_PASSWORD_LENGTH = 8 const USER_DATABASE_DIR = "." func main() { @@ -22,49 +18,9 @@ func main() { } defer userDatabase.Close() - exists, error := userDatabase.UsernameExists(DEFAULT_USERNAME) - if error != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - - if !exists { - password, error := generateRandomPassword(MIN_PASSWORD_LENGTH) - if error != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - - fmt.Println() - fmt.Println("Creating default admin user:") - fmt.Printf("Username: %s\n", DEFAULT_USERNAME) - fmt.Printf("Password: %s\n", password) - fmt.Println() - - user := users.User{Username: DEFAULT_USERNAME, Password: password, IsAdmin: true} - _, error = userDatabase.CreateUser(user.Username, user.Password, user.IsAdmin) - if error != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - } - -} - -func generateRandomPassword(length int) (string, error) { - numBytes := length * 3 / 4 // Base64 encoding increases length by 4/3 - - randomBytes := make([]byte, numBytes) - _, err := rand.Read(randomBytes) + userManager := users.UserManager{} + err = userManager.Initialize(&userDatabase) if err != nil { - return "", err + fmt.Println("failed to initialize user manager") } - - password := base64.URLEncoding.EncodeToString(randomBytes) - - if len(password) > length { - password = password[:length] - } - - return password, nil } diff --git a/start b/start deleted file mode 100755 index 3c658ad..0000000 Binary files a/start and /dev/null differ diff --git a/users/user_manager.go b/users/user_manager.go new file mode 100644 index 0000000..9ee7df0 --- /dev/null +++ b/users/user_manager.go @@ -0,0 +1,192 @@ +package users + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "time" +) + +const DEFAULT_USERNAME = "admin" +const MIN_PASSWORD_LENGTH = 6 + +type UserManager struct { + userDatabase *UserDatabase +} + +func (um *UserManager) Initialize(userDatabase *UserDatabase) error { + um.userDatabase = userDatabase + + exists, error := um.UsernameExists(DEFAULT_USERNAME) + if error != nil { + return error + } + + if !exists { + password, error := generateRandomPassword(MIN_PASSWORD_LENGTH) + if error != nil { + return error + } + + fmt.Println() + fmt.Println("Creating default admin user:") + fmt.Printf("Username: %s\n", DEFAULT_USERNAME) + fmt.Printf("Password: %s\n", password) + fmt.Println() + + user := User{Username: DEFAULT_USERNAME, Password: password, IsAdmin: true} + _, error = um.CreateUser(&user) + if error != nil { + return error + } + } + + return nil +} + +func (um *UserManager) CreateUser(user *User) (string, error) { + exists, error := um.UsernameExists(user.Username) + if error != nil { + return "", error + } + if exists { + return "", fmt.Errorf("User '%s' already exists", user.Username) + } + + if !isValidPassword(user.Password) { + return "", fmt.Errorf("invalid password") + } + + id, error := um.userDatabase.CreateUser(user.Username, user.Password, user.IsAdmin) + if error != nil { + return "", error + } + return id, nil +} + +func (um *UserManager) UsernameExists(username string) (bool, error) { + exists, error := um.userDatabase.UsernameExists(username) + if error != nil { + return false, error + } + return exists, nil +} + +func (um *UserManager) UserIdExists(id string) (bool, error) { + exists, error := um.userDatabase.UserIdExists(id) + if error != nil { + return false, error + } + return exists, nil +} + +func (um *UserManager) Login(username, password string) (string, error) { + exists, error := um.UsernameExists(username) + if error != nil { + return "", error + } + if !exists { + return "", fmt.Errorf("user '%s' doesn't exist", username) + } + + correct, error := um.userDatabase.CheckCredentials(username, password) + if error != nil { + return "", error + } + if !correct { + return "", fmt.Errorf("wrong password") + } + + user, error := um.userDatabase.GetUserByUsername(username) + if error != nil { + return "", error + } + + expiryDate := time.Now().AddDate(0, 0, 30) + token, error := um.userDatabase.CreateSession(user.ID, expiryDate) + if error != nil { + return "", error + } + + return token, nil +} + +func (um *UserManager) GetSession(sessionToken string) (*Session, error) { + session, error := um.userDatabase.GetSession(sessionToken) + if error != nil { + return nil, error + } + return session, nil +} + +func (um *UserManager) GetUserById(id string) (*User, error) { + user, error := um.userDatabase.GetUserById(id) + if error != nil { + return nil, error + } + return user, nil +} + +func (um *UserManager) UpdateUser(user *User) error { + error := um.userDatabase.UpdateUser(user) + return error +} + +func (um *UserManager) UpdatePassword(currentPassword string, newPassword string, user *User) error { + correct, error := um.userDatabase.CheckCredentials(user.Username, currentPassword) + if error != nil { + return error + } + if !correct { + return fmt.Errorf("wrong password") + } + + if !isValidPassword(user.Password) { + return fmt.Errorf("invalid password") + } + + error = um.userDatabase.UpdatePassword(user.ID, newPassword) + return error +} + +func (um *UserManager) DeleteSession(token string) error { + error := um.userDatabase.DeleteSessionByToken(token) + if error != nil { + return error + } + return nil +} + +func (um *UserManager) GetUsers() (*[]User, error) { + users, error := um.userDatabase.GetUsers() + + return users, error +} + +func (um *UserManager) DeleteUser(ID string) error { + error := um.userDatabase.DeleteUser(ID) + + return error +} + +func isValidPassword(password string) bool { + return len(password) >= MIN_PASSWORD_LENGTH +} + +func generateRandomPassword(length int) (string, error) { + numBytes := length * 3 / 4 // Base64 encoding increases length by 4/3 + + randomBytes := make([]byte, numBytes) + _, err := rand.Read(randomBytes) + if err != nil { + return "", err + } + + password := base64.URLEncoding.EncodeToString(randomBytes) + + if len(password) > length { + password = password[:length] + } + + return password, nil +}