From 9fccd5f81db046a83e0e3f450c10457d54522890 Mon Sep 17 00:00:00 2001 From: jordi fita mas Date: Wed, 26 Jul 2023 00:48:58 +0200 Subject: [PATCH] Acquire the database connection in App.ServeHTTP with cookie set Almost all request will need the database connection, either because they will perform queries or because they need to check if the user is logged in, for instance, so it should be acquired as far above as possible to avoid acquiring multiple connections. This is specially true since i have to pass the cookie to the database to switch role and set request.user.email and request.user.cookie config variables. I do not want to do that many times per request. --- pkg/app/app.go | 42 +++++++++++++++++++++++++----------------- pkg/app/login.go | 11 +++++++++-- pkg/database/db.go | 5 ++++- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/pkg/app/app.go b/pkg/app/app.go index c80b0b2..f5915f7 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -57,24 +57,32 @@ func New(db *database.DB) http.Handler { func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { var head string head, r.URL.Path = shiftPath(r.URL.Path) - switch head { - case "static": + if head == "static" { h.fileHandler.ServeHTTP(w, r) - case "login": - switch r.Method { - case http.MethodPost: - h.handleLogin(w, r) - default: - methodNotAllowed(w, r, http.MethodPost) + } else { + cookie := getSessionCookie(r) + conn := h.db.MustAcquire(r.Context(), cookie) + defer conn.Release() + + if head == "login" { + switch r.Method { + case http.MethodPost: + h.handleLogin(w, r, conn) + default: + methodNotAllowed(w, r, http.MethodPost) + } + } else { + switch head { + case "": + switch r.Method { + case http.MethodGet: + h.handleGet(w, r) + default: + methodNotAllowed(w, r, http.MethodGet) + } + default: + http.NotFound(w, r) + } } - case "": - switch r.Method { - case http.MethodGet: - h.handleGet(w, r) - default: - methodNotAllowed(w, r, http.MethodGet) - } - default: - http.NotFound(w, r) } } diff --git a/pkg/app/login.go b/pkg/app/login.go index 3f24ea5..9646ceb 100644 --- a/pkg/app/login.go +++ b/pkg/app/login.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "dev.tandem.ws/tandem/camper/pkg/database" "dev.tandem.ws/tandem/camper/pkg/form" httplib "dev.tandem.ws/tandem/camper/pkg/http" "dev.tandem.ws/tandem/camper/pkg/locale" @@ -65,7 +66,7 @@ func (h *App) matchLocale(r *http.Request) *locale.Locale { return locale.Match(r, h.locales, h.defaultLocale, h.languageMatcher) } -func (h *App) handleLogin(w http.ResponseWriter, r *http.Request) { +func (h *App) handleLogin(w http.ResponseWriter, r *http.Request, conn *database.Conn) { login := newLoginForm() if err := login.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -73,7 +74,6 @@ func (h *App) handleLogin(w http.ResponseWriter, r *http.Request) { } l := h.matchLocale(r) if login.Valid(l) { - conn := h.db.MustAcquire(r.Context()) cookie := conn.MustGetText(r.Context(), "select login($1, $2, $3)", login.Email, login.Password, httplib.RemoteAddr(r)) if cookie != "" { setSessionCookie(w, cookie) @@ -92,6 +92,13 @@ func setSessionCookie(w http.ResponseWriter, cookie string) { http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour)) } +func getSessionCookie(r *http.Request) string { + if cookie, err := r.Cookie(sessionCookie); err == nil { + return cookie.Value + } + return "" +} + func createSessionCookie(value string, duration time.Duration) *http.Cookie { return &http.Cookie{ Name: sessionCookie, diff --git a/pkg/database/db.go b/pkg/database/db.go index 3f8fe25..14e8d8a 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -53,11 +53,14 @@ func (db *DB) Acquire(ctx context.Context) (*Conn, error) { return &Conn{conn}, nil } -func (db *DB) MustAcquire(ctx context.Context) *Conn { +func (db *DB) MustAcquire(ctx context.Context, cookie string) *Conn { conn, err := db.Acquire(ctx) if err != nil { panic(err) } + if _, err = conn.Exec(ctx, "select set_cookie($1)", cookie); err != nil { + panic(false) + } return conn }