Refactor checking for pgx.ErrNoRows in a function
This commit is contained in:
parent
13fa1d6b89
commit
4db0a8fb5a
|
@ -2,7 +2,6 @@ package pkg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -38,14 +37,9 @@ func GetContactForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa
|
||||||
mustRenderNewContactForm(w, r, form)
|
mustRenderNewContactForm(w, r, form)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := conn.QueryRow(r.Context(), "select business_name, substr(vatin::text, 3), trade_name, phone, email, web, address, city, province, postal_code, country_code from contact where slug = $1", slug).Scan(form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)
|
if notFoundErrorOrPanic(conn.QueryRow(r.Context(), "select business_name, substr(vatin::text, 3), trade_name, phone, email, web, address, city, province, postal_code, country_code from contact where slug = $1", slug).Scan(form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)) {
|
||||||
if err != nil {
|
|
||||||
if err == pgx.ErrNoRows {
|
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
mustRenderEditContactForm(w, r, form)
|
mustRenderEditContactForm(w, r, form)
|
||||||
|
|
20
pkg/db.go
20
pkg/db.go
|
@ -50,6 +50,16 @@ func NewDatabase(ctx context.Context, connString string) (*Db, error) {
|
||||||
return &Db{pool}, nil
|
return &Db{pool}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func notFoundErrorOrPanic(err error) bool {
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (db *Db) Acquire(ctx context.Context) (*Conn, error) {
|
func (db *Db) Acquire(ctx context.Context) (*Conn, error) {
|
||||||
conn, err := db.Pool.Acquire(ctx)
|
conn, err := db.Pool.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -80,12 +90,9 @@ func (c *Conn) MustBegin(ctx context.Context) *Tx {
|
||||||
|
|
||||||
func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string {
|
func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string {
|
||||||
var result string
|
var result string
|
||||||
if err := c.Conn.QueryRow(ctx, sql, args...).Scan(&result); err != nil {
|
if notFoundErrorOrPanic(c.Conn.QueryRow(ctx, sql, args...).Scan(&result)) {
|
||||||
if err == pgx.ErrNoRows {
|
|
||||||
return def
|
return def
|
||||||
}
|
}
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
@ -138,12 +145,9 @@ func (tx *Tx) MustGetInteger(ctx context.Context, sql string, args ...interface{
|
||||||
|
|
||||||
func (tx *Tx) MustGetIntegerOrDefault(ctx context.Context, def int, sql string, args ...interface{}) int {
|
func (tx *Tx) MustGetIntegerOrDefault(ctx context.Context, def int, sql string, args ...interface{}) int {
|
||||||
var result int
|
var result int
|
||||||
if err := tx.QueryRow(ctx, sql, args...).Scan(&result); err != nil {
|
if notFoundErrorOrPanic(tx.QueryRow(ctx, sql, args...).Scan(&result)) {
|
||||||
if err == pgx.ErrNoRows {
|
|
||||||
return def
|
return def
|
||||||
}
|
}
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package pkg
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"html/template"
|
"html/template"
|
||||||
"math"
|
"math"
|
||||||
|
@ -42,14 +41,9 @@ func GetProductForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var productId int
|
var productId int
|
||||||
err := conn.QueryRow(r.Context(), "select product_id, product.name, product.description, to_price(price, decimal_digits) from product join company using (company_id) join currency using (currency_code) where product.slug = $1", slug).Scan(&productId, form.Name, form.Description, form.Price)
|
if notFoundErrorOrPanic(conn.QueryRow(r.Context(), "select product_id, product.name, product.description, to_price(price, decimal_digits) from product join company using (company_id) join currency using (currency_code) where product.slug = $1", slug).Scan(&productId, form.Name, form.Description, form.Price)) {
|
||||||
if err != nil {
|
|
||||||
if err == pgx.ErrNoRows {
|
|
||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
rows, err := conn.Query(r.Context(), "select tax_id from product_tax where product_id = $1", productId)
|
rows, err := conn.Query(r.Context(), "select tax_id from product_tax where product_id = $1", productId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue