diff --git a/cmd/numerus/main.go b/cmd/numerus/main.go index 5ee3c68..0abd24f 100644 --- a/cmd/numerus/main.go +++ b/cmd/numerus/main.go @@ -9,13 +9,11 @@ import ( "syscall" "time" - "github.com/jackc/pgx/v4/pgxpool" - numerus "dev.tandem.ws/tandem/numerus/pkg" ) func main() { - dbpool, err := pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + dbpool, err := numerus.ConnectToDatabase(context.Background(), os.Getenv("NUMERUS_DATABASE_URL")) if err != nil { log.Fatal(err) } diff --git a/pkg/db.go b/pkg/db.go new file mode 100644 index 0000000..550c02d --- /dev/null +++ b/pkg/db.go @@ -0,0 +1,39 @@ +package pkg + +import ( + "context" + "log" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +func ConnectToDatabase(ctx context.Context, connString string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(connString) + if err != nil { + log.Fatal(err) + } + + config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + _, err := conn.Exec(context.Background(), "SET search_path TO numerus, public") + return err + } + + config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { + if _, err := conn.Exec(ctx, "select set_config('role', $1, false)", "guest"); err != nil { + log.Printf("ERROR - Failed to set role: %v", err) + return false + } + return true + } + + config.AfterRelease = func(conn *pgx.Conn) bool { + if _, err := conn.Exec(context.Background(), "RESET ROLE"); err != nil { + log.Printf("ERROR - Failed to reset role: %v", err) + return false + } + return true + } + + return pgxpool.ConnectConfig(ctx, config) +} diff --git a/pkg/router.go b/pkg/router.go index 46b03db..239736f 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -1,7 +1,6 @@ package pkg import ( - "context" "html/template" "log" "net/http" @@ -17,14 +16,9 @@ func NewRouter(db *pgxpool.Pool) http.Handler { email := r.FormValue("email") password := r.FormValue("password") var role string - if _, err := db.Exec(context.Background(), "select set_config('search_path', 'numerus, public', false)"); err != nil { - log.Printf("ERROR - %s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - err := db.QueryRow(context.Background(), "select login($1, $2)", email, password).Scan(&role) + err := db.QueryRow(r.Context(), "select login($1, $2)", email, password).Scan(&role) if err != nil { - log.Printf("ERROR - %s", err.Error()) + log.Printf("ERROR - %v for %q", err, email) http.Error(w, err.Error(), http.StatusInternalServerError) return }