diff --git a/cmd/user/main.go b/cmd/user/main.go new file mode 100644 index 0000000..8c2e2b6 --- /dev/null +++ b/cmd/user/main.go @@ -0,0 +1,195 @@ +// Package main provides local user administration commands. +package main + +import ( + "bufio" + "context" + "database/sql" + "errors" + "fmt" + "io" + "mal/internal" + "mal/internal/config" + "mal/internal/db" + "mal/internal/observability" + "os" + "strings" + "time" + + "github.com/google/uuid" + "github.com/joho/godotenv" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +func main() { + if err := godotenv.Load(); err != nil { + observability.Warn("env_file_load_failed", "user", "", nil, err) + } + + if err := run(os.Args[1:]); err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + os.Exit(1) + } +} + +func run(args []string) error { + if len(args) == 1 && args[0] == "run-fixes" { + return runFixes() + } + + if len(args) != 1 && len(args) != 2 { + return errors.New("usage: create-user [password]") + } + + username := strings.TrimSpace(args[0]) + password := "" + if len(args) == 2 { + password = args[1] + } + if username == "" { + return errors.New("username must not be empty") + } + + sqlDB, err := openDatabase() + if err != nil { + return err + } + defer sqlDB.Close() + + if err := internal.RunMigrationsAndFixes(sqlDB); err != nil { + return fmt.Errorf("prepare database: %w", err) + } + + return createOrUpdateUser(sqlDB, username, password) +} + +func runFixes() error { + sqlDB, err := openDatabase() + if err != nil { + return err + } + defer sqlDB.Close() + + if err := internal.RunMigrationsAndFixes(sqlDB); err != nil { + return fmt.Errorf("run migrations and fixes: %w", err) + } + fmt.Println("Database migrations and fixes complete") + return nil +} + +func openDatabase() (*sql.DB, error) { + cfg, err := config.Load() + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + sqlDB, err := db.Open(cfg.DatabaseFile) + if err != nil { + return nil, fmt.Errorf("open database: %w", err) + } + return sqlDB, nil +} + +func createOrUpdateUser(sqlDB *sql.DB, username, password string) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var userID string + err := sqlDB.QueryRowContext(ctx, `SELECT id FROM user WHERE username = ? LIMIT 1`, username).Scan(&userID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("check user: %w", err) + } + userExists := err == nil + + if !userExists { + return createUser(ctx, sqlDB, username, password) + } + + update, err := confirmPasswordUpdate(username) + if err != nil { + return err + } + if !update { + fmt.Println("No changes made") + return nil + } + + return updateUserPassword(ctx, sqlDB, userID, username, password) +} + +func createUser(ctx context.Context, sqlDB *sql.DB, username, password string) error { + password, err := resolvePassword(password) + if err != nil { + return err + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + + _, err = sqlDB.ExecContext( + ctx, + `INSERT INTO user (id, username, password_hash, avatar_url) VALUES (?, ?, ?, ?)`, + uuid.NewString(), username, string(passwordHash), internal.DefaultAvatarURL(username), + ) + if err != nil { + return fmt.Errorf("create user: %w", err) + } + fmt.Printf("Created user %q\n", username) + return nil +} + +func updateUserPassword(ctx context.Context, sqlDB *sql.DB, userID, username, password string) error { + password, err := resolvePassword(password) + if err != nil { + return err + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + + if _, err := sqlDB.ExecContext(ctx, `UPDATE user SET password_hash = ? WHERE id = ?`, string(passwordHash), userID); err != nil { + return fmt.Errorf("update password: %w", err) + } + + fmt.Printf("Updated password for user %q\n", username) + return nil +} + +func resolvePassword(password string) (string, error) { + if password != "" { + return password, nil + } + + fmt.Print("Password: ") + passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() + if err != nil { + return "", fmt.Errorf("read password: %w", err) + } + if len(passwordBytes) == 0 { + return "", errors.New("password must not be empty") + } + return string(passwordBytes), nil +} + +func confirmPasswordUpdate(username string) (bool, error) { + fmt.Printf("User %q already exists. Change password? [Y/n] ", username) + answer, err := bufio.NewReader(os.Stdin).ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return false, fmt.Errorf("read confirmation: %w", err) + } + + switch strings.ToLower(strings.TrimSpace(answer)) { + case "", "y", "yes": + return true, nil + case "n", "no": + return false, nil + default: + return false, errors.New("invalid response; enter y or n") + } +} diff --git a/create-user b/create-user new file mode 100755 index 0000000..5aadeaf --- /dev/null +++ b/create-user @@ -0,0 +1,4 @@ +#!/bin/sh +set -e + +exec go run ./cmd/user "$@" diff --git a/go.mod b/go.mod index 454809b..68605a6 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/gin-gonic/gin v1.12.0 github.com/pressly/goose/v3 v3.27.1 go.uber.org/fx v1.24.0 + golang.org/x/term v0.43.0 ) require ( @@ -56,6 +57,6 @@ require ( github.com/andybalholm/cascadia v1.3.3 // indirect github.com/klauspost/compress v1.18.5 // indirect golang.org/x/sync v0.20.0 // direct - golang.org/x/sys v0.43.0 // indirect + golang.org/x/sys v0.44.0 // indirect golang.org/x/text v0.36.0 // indirect ) diff --git a/go.sum b/go.sum index cf4f875..ec8746f 100644 --- a/go.sum +++ b/go.sum @@ -158,8 +158,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -169,6 +169,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=