Refactor checking for pgx.ErrNoRows in a function

This commit is contained in:
jordi fita mas 2023-02-14 12:46:11 +01:00
parent 13fa1d6b89
commit 4db0a8fb5a
3 changed files with 20 additions and 28 deletions

View File

@ -2,7 +2,6 @@ package pkg
import ( import (
"context" "context"
"github.com/jackc/pgx/v4"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"html/template" "html/template"
"net/http" "net/http"
@ -38,14 +37,9 @@ func GetContactForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa
mustRenderNewContactForm(w, r, form) mustRenderNewContactForm(w, r, form)
return return
} }
err := conn.QueryRow(r.Context(), "select business_name, substr(vatin::text, 3), trade_name, phone, email, web, address, city, province, postal_code, country_code from contact where slug = $1", slug).Scan(form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), "select business_name, substr(vatin::text, 3), trade_name, phone, email, web, address, city, province, postal_code, country_code from contact where slug = $1", slug).Scan(form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)) {
if err != nil {
if err == pgx.ErrNoRows {
http.NotFound(w, r) http.NotFound(w, r)
return return
} else {
panic(err)
}
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
mustRenderEditContactForm(w, r, form) mustRenderEditContactForm(w, r, form)

View File

@ -50,6 +50,16 @@ func NewDatabase(ctx context.Context, connString string) (*Db, error) {
return &Db{pool}, nil return &Db{pool}, nil
} }
func notFoundErrorOrPanic(err error) bool {
if err == pgx.ErrNoRows {
return true
}
if err != nil {
panic(err)
}
return false
}
func (db *Db) Acquire(ctx context.Context) (*Conn, error) { func (db *Db) Acquire(ctx context.Context) (*Conn, error) {
conn, err := db.Pool.Acquire(ctx) conn, err := db.Pool.Acquire(ctx)
if err != nil { if err != nil {
@ -80,12 +90,9 @@ func (c *Conn) MustBegin(ctx context.Context) *Tx {
func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string { func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string {
var result string var result string
if err := c.Conn.QueryRow(ctx, sql, args...).Scan(&result); err != nil { if notFoundErrorOrPanic(c.Conn.QueryRow(ctx, sql, args...).Scan(&result)) {
if err == pgx.ErrNoRows {
return def return def
} }
panic(err)
}
return result return result
} }
@ -138,12 +145,9 @@ func (tx *Tx) MustGetInteger(ctx context.Context, sql string, args ...interface{
func (tx *Tx) MustGetIntegerOrDefault(ctx context.Context, def int, sql string, args ...interface{}) int { func (tx *Tx) MustGetIntegerOrDefault(ctx context.Context, def int, sql string, args ...interface{}) int {
var result int var result int
if err := tx.QueryRow(ctx, sql, args...).Scan(&result); err != nil { if notFoundErrorOrPanic(tx.QueryRow(ctx, sql, args...).Scan(&result)) {
if err == pgx.ErrNoRows {
return def return def
} }
panic(err)
}
return result return result
} }

View File

@ -3,7 +3,6 @@ package pkg
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/jackc/pgx/v4"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"html/template" "html/template"
"math" "math"
@ -42,14 +41,9 @@ func GetProductForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa
return return
} }
var productId int var productId int
err := conn.QueryRow(r.Context(), "select product_id, product.name, product.description, to_price(price, decimal_digits) from product join company using (company_id) join currency using (currency_code) where product.slug = $1", slug).Scan(&productId, form.Name, form.Description, form.Price) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), "select product_id, product.name, product.description, to_price(price, decimal_digits) from product join company using (company_id) join currency using (currency_code) where product.slug = $1", slug).Scan(&productId, form.Name, form.Description, form.Price)) {
if err != nil {
if err == pgx.ErrNoRows {
http.NotFound(w, r) http.NotFound(w, r)
return return
} else {
panic(err)
}
} }
rows, err := conn.Query(r.Context(), "select tax_id from product_tax where product_id = $1", productId) rows, err := conn.Query(r.Context(), "select tax_id from product_tax where product_id = $1", productId)
if err != nil { if err != nil {