Prefix with “Must” all functions that panic

Just following what the standard library does.
This commit is contained in:
jordi fita mas 2023-01-22 20:37:43 +01:00
parent 7e5e6121ac
commit fa6ddc70b3
6 changed files with 25 additions and 20 deletions

View File

@ -29,9 +29,9 @@ func NewDatabase(ctx context.Context, connString string) (*Db, error) {
cookie = value cookie = value
} }
if _, err := conn.Exec(ctx, "select set_cookie($1)", cookie); err != nil { if _, err := conn.Exec(ctx, "select set_cookie($1)", cookie); err != nil {
log.Printf("ERROR - Failed to set role: %v", err) log.Printf("ERROR - Failed to set role: %v", err)
return false return false
} }
return true return true
} }
@ -58,11 +58,19 @@ func (db *Db) Acquire(ctx context.Context) (*Conn, error) {
return &Conn{conn}, nil return &Conn{conn}, nil
} }
func (db *Db) MustAcquire(ctx context.Context) *Conn {
conn, err := db.Acquire(ctx)
if err != nil {
panic(err)
}
return conn
}
type Conn struct { type Conn struct {
*pgxpool.Conn *pgxpool.Conn
} }
func (c *Conn) Text(ctx context.Context, def string, sql string, args ...interface{}) string { func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string {
var result string var result string
if err := c.Conn.QueryRow(ctx, sql, args...).Scan(&result); err != nil { if err := c.Conn.QueryRow(ctx, sql, args...).Scan(&result); err != nil {
if err == pgx.ErrNoRows { if err == pgx.ErrNoRows {
@ -74,7 +82,7 @@ func (c *Conn) Text(ctx context.Context, def string, sql string, args ...interfa
return result return result
} }
func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) { func (c *Conn) MustExec(ctx context.Context, sql string, args ...interface{}) {
if _, err := c.Conn.Exec(ctx, sql, args...); err != nil { if _, err := c.Conn.Exec(ctx, sql, args...); err != nil {
panic(err) panic(err)
} }

View File

@ -11,7 +11,7 @@ import (
const contextLocaleKey = "numerus-locale" const contextLocaleKey = "numerus-locale"
func Locale(db *Db, next http.Handler) http.Handler { func Locale(db *Db, next http.Handler) http.Handler {
availableLanguages := getAvailableLanguages(db) availableLanguages := mustGetAvailableLanguages(db)
var matcher = language.NewMatcher(availableLanguages) var matcher = language.NewMatcher(availableLanguages)
locales := map[language.Tag]*gotext.Locale{} locales := map[language.Tag]*gotext.Locale{}
@ -46,7 +46,7 @@ func getLocale(r *http.Request) *gotext.Locale {
return r.Context().Value(contextLocaleKey).(*gotext.Locale) return r.Context().Value(contextLocaleKey).(*gotext.Locale)
} }
func getAvailableLanguages(db *Db) []language.Tag { func mustGetAvailableLanguages(db *Db) []language.Tag {
rows, err := db.Query(context.Background(), "select lang_tag from language where selectable") rows, err := db.Query(context.Background(), "select lang_tag from language where selectable")
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -41,7 +41,7 @@ func LoginHandler() http.Handler {
} }
if r.Method == "POST" { if r.Method == "POST" {
conn := getConn(r) conn := getConn(r)
cookie := conn.Text(r.Context(), "", "select login($1, $2, $3)", page.Email, page.Password, remoteAddr(r)) cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", page.Email, page.Password, remoteAddr(r))
if cookie != "" { if cookie != "" {
http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour)) http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour))
http.Redirect(w, r, "/", http.StatusSeeOther) http.Redirect(w, r, "/", http.StatusSeeOther)
@ -52,7 +52,7 @@ func LoginHandler() http.Handler {
} else { } else {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
renderTemplate(w, r, "login.html", page) mustRenderTemplate(w, r, "login.html", page)
}) })
} }
@ -61,7 +61,7 @@ func LogoutHandler() http.Handler {
user := getUser(r) user := getUser(r)
if user.LoggedIn { if user.LoggedIn {
conn := getConn(r) conn := getConn(r)
conn.Exec(r.Context(), "select logout()") conn.MustExec(r.Context(), "select logout()")
http.SetCookie(w, createSessionCookie("", -24*time.Hour)) http.SetCookie(w, createSessionCookie("", -24*time.Hour))
} }
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
@ -94,10 +94,7 @@ func CheckLogin(db *Db, next http.Handler) http.Handler {
ctx = context.WithValue(ctx, ContextCookieKey, cookie.Value) ctx = context.WithValue(ctx, ContextCookieKey, cookie.Value)
} }
conn, err := db.Acquire(ctx) conn := db.MustAcquire(ctx)
if err != nil {
panic(err)
}
defer conn.Release() defer conn.Release()
ctx = context.WithValue(ctx, ContextConnKey, conn) ctx = context.WithValue(ctx, ContextConnKey, conn)

View File

@ -29,7 +29,7 @@ func ProfileHandler() http.Handler {
conn := getConn(r) conn := getConn(r)
page := ProfilePage{ page := ProfilePage{
Email: user.Email, Email: user.Email,
Languages: getLanguageOptions(r.Context(), conn), Languages: mustGetLanguageOptions(r.Context(), conn),
} }
if r.Method == "POST" { if r.Method == "POST" {
r.ParseForm() r.ParseForm()
@ -38,17 +38,17 @@ func ProfileHandler() http.Handler {
page.Password = r.FormValue("password") page.Password = r.FormValue("password")
page.PasswordConfirm = r.FormValue("password_confirm") page.PasswordConfirm = r.FormValue("password_confirm")
page.Language = r.FormValue("language") page.Language = r.FormValue("language")
conn.Exec(r.Context(), "update user_profile set name = $1, email = $2, lang_tag = $3", page.Name, page.Email, page.Language); conn.MustExec(r.Context(), "update user_profile set name = $1, email = $2, lang_tag = $3", page.Name, page.Email, page.Language)
} else { } else {
if err := conn.QueryRow(r.Context(), "select name, lang_tag from user_profile").Scan(&page.Name, &page.Language); err != nil { if err := conn.QueryRow(r.Context(), "select name, lang_tag from user_profile").Scan(&page.Name, &page.Language); err != nil {
panic(nil) panic(nil)
} }
} }
renderTemplate(w, r, "profile.html", page) mustRenderTemplate(w, r, "profile.html", page)
}) })
} }
func getLanguageOptions(ctx context.Context, conn *Conn) []LanguageOption { func mustGetLanguageOptions(ctx context.Context, conn *Conn) []LanguageOption {
rows, err := conn.Query(ctx, "select lang_tag, endonym from language where selectable") rows, err := conn.Query(ctx, "select lang_tag, endonym from language where selectable")
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -13,7 +13,7 @@ func NewRouter(db *Db) http.Handler {
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
user := getUser(r) user := getUser(r)
if user.LoggedIn { if user.LoggedIn {
renderTemplate(w, r, "index.html", nil) mustRenderTemplate(w, r, "index.html", nil)
} else { } else {
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
} }

View File

@ -6,7 +6,7 @@ import (
"net/http" "net/http"
) )
func renderTemplate(wr io.Writer, r *http.Request, filename string, data interface{}) { func mustRenderTemplate(wr io.Writer, r *http.Request, filename string, data interface{}) {
locale := getLocale(r) locale := getLocale(r)
t := template.New(filename) t := template.New(filename)
t.Funcs(template.FuncMap{ t.Funcs(template.FuncMap{