From ab6c0079c9276b8461fe352af64bd60c4b962984 Mon Sep 17 00:00:00 2001 From: jordi fita mas Date: Tue, 17 Jan 2023 14:46:22 +0100 Subject: [PATCH] Set search_path on each new connection, and role on each acquisition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The whole application will need the same search_path, so it is wasteful to do that in each handler. It is possible to pass the search path as a parameter to the database’s connection string, but then everyone would need to remember to do that, and update the configuration in case i add another schema. Similarly, i need to change the user’s role to match her permissions—which are not in yet—, but this time i need it each time a handler requests a connection from the pool, because each time the connection is returned to the pool i reset the role back to the initial, that hopefully will be authenticator. --- cmd/numerus/main.go | 4 +--- pkg/db.go | 39 +++++++++++++++++++++++++++++++++++++++ pkg/router.go | 10 ++-------- 3 files changed, 42 insertions(+), 11 deletions(-) create mode 100644 pkg/db.go diff --git a/cmd/numerus/main.go b/cmd/numerus/main.go index 5ee3c68..0abd24f 100644 --- a/cmd/numerus/main.go +++ b/cmd/numerus/main.go @@ -9,13 +9,11 @@ import ( "syscall" "time" - "github.com/jackc/pgx/v4/pgxpool" - numerus "dev.tandem.ws/tandem/numerus/pkg" ) func main() { - dbpool, err := pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + dbpool, err := numerus.ConnectToDatabase(context.Background(), os.Getenv("NUMERUS_DATABASE_URL")) if err != nil { log.Fatal(err) } diff --git a/pkg/db.go b/pkg/db.go new file mode 100644 index 0000000..550c02d --- /dev/null +++ b/pkg/db.go @@ -0,0 +1,39 @@ +package pkg + +import ( + "context" + "log" + + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +func ConnectToDatabase(ctx context.Context, connString string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(connString) + if err != nil { + log.Fatal(err) + } + + config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + _, err := conn.Exec(context.Background(), "SET search_path TO numerus, public") + return err + } + + config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { + if _, err := conn.Exec(ctx, "select set_config('role', $1, false)", "guest"); err != nil { + log.Printf("ERROR - Failed to set role: %v", err) + return false + } + return true + } + + config.AfterRelease = func(conn *pgx.Conn) bool { + if _, err := conn.Exec(context.Background(), "RESET ROLE"); err != nil { + log.Printf("ERROR - Failed to reset role: %v", err) + return false + } + return true + } + + return pgxpool.ConnectConfig(ctx, config) +} diff --git a/pkg/router.go b/pkg/router.go index 46b03db..239736f 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -1,7 +1,6 @@ package pkg import ( - "context" "html/template" "log" "net/http" @@ -17,14 +16,9 @@ func NewRouter(db *pgxpool.Pool) http.Handler { email := r.FormValue("email") password := r.FormValue("password") var role string - if _, err := db.Exec(context.Background(), "select set_config('search_path', 'numerus, public', false)"); err != nil { - log.Printf("ERROR - %s", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - err := db.QueryRow(context.Background(), "select login($1, $2)", email, password).Scan(&role) + err := db.QueryRow(r.Context(), "select login($1, $2)", email, password).Scan(&role) if err != nil { - log.Printf("ERROR - %s", err.Error()) + log.Printf("ERROR - %v for %q", err, email) http.Error(w, err.Error(), http.StatusInternalServerError) return }