Refactor checking for pgx.ErrNoRows in a function
This commit is contained in:
parent
13fa1d6b89
commit
4db0a8fb5a
|
@ -2,7 +2,6 @@ package pkg
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
@ -38,14 +37,9 @@ func GetContactForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa
|
|||
mustRenderNewContactForm(w, r, form)
|
||||
return
|
||||
}
|
||||
err := conn.QueryRow(r.Context(), "select business_name, substr(vatin::text, 3), trade_name, phone, email, web, address, city, province, postal_code, country_code from contact where slug = $1", slug).Scan(form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
if notFoundErrorOrPanic(conn.QueryRow(r.Context(), "select business_name, substr(vatin::text, 3), trade_name, phone, email, web, address, city, province, postal_code, country_code from contact where slug = $1", slug).Scan(form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)) {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
mustRenderEditContactForm(w, r, form)
|
||||
|
|
24
pkg/db.go
24
pkg/db.go
|
@ -50,6 +50,16 @@ func NewDatabase(ctx context.Context, connString string) (*Db, error) {
|
|||
return &Db{pool}, nil
|
||||
}
|
||||
|
||||
func notFoundErrorOrPanic(err error) bool {
|
||||
if err == pgx.ErrNoRows {
|
||||
return true
|
||||
}
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (db *Db) Acquire(ctx context.Context) (*Conn, error) {
|
||||
conn, err := db.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
|
@ -80,11 +90,8 @@ func (c *Conn) MustBegin(ctx context.Context) *Tx {
|
|||
|
||||
func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string {
|
||||
var result string
|
||||
if err := c.Conn.QueryRow(ctx, sql, args...).Scan(&result); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return def
|
||||
}
|
||||
panic(err)
|
||||
if notFoundErrorOrPanic(c.Conn.QueryRow(ctx, sql, args...).Scan(&result)) {
|
||||
return def
|
||||
}
|
||||
|
||||
return result
|
||||
|
@ -138,11 +145,8 @@ func (tx *Tx) MustGetInteger(ctx context.Context, sql string, args ...interface{
|
|||
|
||||
func (tx *Tx) MustGetIntegerOrDefault(ctx context.Context, def int, sql string, args ...interface{}) int {
|
||||
var result int
|
||||
if err := tx.QueryRow(ctx, sql, args...).Scan(&result); err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return def
|
||||
}
|
||||
panic(err)
|
||||
if notFoundErrorOrPanic(tx.QueryRow(ctx, sql, args...).Scan(&result)) {
|
||||
return def
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package pkg
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"html/template"
|
||||
"math"
|
||||
|
@ -42,14 +41,9 @@ func GetProductForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa
|
|||
return
|
||||
}
|
||||
var productId int
|
||||
err := conn.QueryRow(r.Context(), "select product_id, product.name, product.description, to_price(price, decimal_digits) from product join company using (company_id) join currency using (currency_code) where product.slug = $1", slug).Scan(&productId, form.Name, form.Description, form.Price)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
if notFoundErrorOrPanic(conn.QueryRow(r.Context(), "select product_id, product.name, product.description, to_price(price, decimal_digits) from product join company using (company_id) join currency using (currency_code) where product.slug = $1", slug).Scan(&productId, form.Name, form.Description, form.Price)) {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
rows, err := conn.Query(r.Context(), "select tax_id from product_tax where product_id = $1", productId)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue