Acquire the database connection in App.ServeHTTP with cookie set

Almost all request will need the database connection, either because
they will perform queries or because they need to check if the user is
logged in, for instance, so it should be acquired as far above as
possible to avoid acquiring multiple connections.

This is specially true since i have to pass the cookie to the database
to switch role and set request.user.email and request.user.cookie
config variables.  I do not want to do that many times per request.
This commit is contained in:
jordi fita mas 2023-07-26 00:48:58 +02:00
parent 01526bff1a
commit 9fccd5f81d
3 changed files with 38 additions and 20 deletions

View File

@ -57,16 +57,22 @@ func New(db *database.DB) http.Handler {
func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var head string var head string
head, r.URL.Path = shiftPath(r.URL.Path) head, r.URL.Path = shiftPath(r.URL.Path)
switch head { if head == "static" {
case "static":
h.fileHandler.ServeHTTP(w, r) h.fileHandler.ServeHTTP(w, r)
case "login": } else {
cookie := getSessionCookie(r)
conn := h.db.MustAcquire(r.Context(), cookie)
defer conn.Release()
if head == "login" {
switch r.Method { switch r.Method {
case http.MethodPost: case http.MethodPost:
h.handleLogin(w, r) h.handleLogin(w, r, conn)
default: default:
methodNotAllowed(w, r, http.MethodPost) methodNotAllowed(w, r, http.MethodPost)
} }
} else {
switch head {
case "": case "":
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
@ -77,4 +83,6 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
default: default:
http.NotFound(w, r) http.NotFound(w, r)
} }
}
}
} }

View File

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"time" "time"
"dev.tandem.ws/tandem/camper/pkg/database"
"dev.tandem.ws/tandem/camper/pkg/form" "dev.tandem.ws/tandem/camper/pkg/form"
httplib "dev.tandem.ws/tandem/camper/pkg/http" httplib "dev.tandem.ws/tandem/camper/pkg/http"
"dev.tandem.ws/tandem/camper/pkg/locale" "dev.tandem.ws/tandem/camper/pkg/locale"
@ -65,7 +66,7 @@ func (h *App) matchLocale(r *http.Request) *locale.Locale {
return locale.Match(r, h.locales, h.defaultLocale, h.languageMatcher) return locale.Match(r, h.locales, h.defaultLocale, h.languageMatcher)
} }
func (h *App) handleLogin(w http.ResponseWriter, r *http.Request) { func (h *App) handleLogin(w http.ResponseWriter, r *http.Request, conn *database.Conn) {
login := newLoginForm() login := newLoginForm()
if err := login.Parse(r); err != nil { if err := login.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
@ -73,7 +74,6 @@ func (h *App) handleLogin(w http.ResponseWriter, r *http.Request) {
} }
l := h.matchLocale(r) l := h.matchLocale(r)
if login.Valid(l) { if login.Valid(l) {
conn := h.db.MustAcquire(r.Context())
cookie := conn.MustGetText(r.Context(), "select login($1, $2, $3)", login.Email, login.Password, httplib.RemoteAddr(r)) cookie := conn.MustGetText(r.Context(), "select login($1, $2, $3)", login.Email, login.Password, httplib.RemoteAddr(r))
if cookie != "" { if cookie != "" {
setSessionCookie(w, cookie) setSessionCookie(w, cookie)
@ -92,6 +92,13 @@ func setSessionCookie(w http.ResponseWriter, cookie string) {
http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour)) http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour))
} }
func getSessionCookie(r *http.Request) string {
if cookie, err := r.Cookie(sessionCookie); err == nil {
return cookie.Value
}
return ""
}
func createSessionCookie(value string, duration time.Duration) *http.Cookie { func createSessionCookie(value string, duration time.Duration) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: sessionCookie, Name: sessionCookie,

View File

@ -53,11 +53,14 @@ func (db *DB) Acquire(ctx context.Context) (*Conn, error) {
return &Conn{conn}, nil return &Conn{conn}, nil
} }
func (db *DB) MustAcquire(ctx context.Context) *Conn { func (db *DB) MustAcquire(ctx context.Context, cookie string) *Conn {
conn, err := db.Acquire(ctx) conn, err := db.Acquire(ctx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
if _, err = conn.Exec(ctx, "select set_cookie($1)", cookie); err != nil {
panic(false)
}
return conn return conn
} }