Replace default router with github.com/julienschmidt/httprouter

I would fuck up handling URL parameters and this router has per-method
handlers, that are easier to work with, in some cases.
This commit is contained in:
jordi fita mas 2023-02-03 12:30:56 +01:00
parent 80f14d5818
commit 1ab48cfcbc
9 changed files with 307 additions and 240 deletions

1
debian/control vendored
View File

@ -8,6 +8,7 @@ Build-Depends:
gettext,
golang-any,
golang-github-jackc-pgx-v4-dev,
golang-github-julienschmidt-httprouter-dev,
golang-github-leonelquinteros-gotext-dev,
golang-golang-x-text-dev,
postgresql-all (>= 217~),

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.18
require (
github.com/jackc/pgx/v4 v4.17.2
github.com/julienschmidt/httprouter v1.3.0
github.com/leonelquinteros/gotext v1.5.1
golang.org/x/text v0.3.8
)

2
go.sum
View File

@ -62,6 +62,8 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0=
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=

View File

@ -3,11 +3,11 @@ package pkg
import (
"context"
"errors"
"github.com/julienschmidt/httprouter"
"html/template"
"net/http"
"net/url"
"strconv"
"strings"
)
const (
@ -19,44 +19,26 @@ type Company struct {
Slug string
}
func CompanyHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slug := r.URL.Path
if idx := strings.IndexByte(slug, '/'); idx >= 0 {
slug = slug[:idx]
}
conn := getConn(r)
func CompanyHandler(next http.Handler) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
company := &Company{
Slug: slug,
Slug: params[0].Value,
}
err := conn.QueryRow(r.Context(), "select company_id from company where slug = $1", slug).Scan(&company.Id)
conn := getConn(r)
err := conn.QueryRow(r.Context(), "select company_id from company where slug = $1", company.Slug).Scan(&company.Id)
if err != nil {
http.NotFound(w, r)
return
}
ctx := context.WithValue(r.Context(), ContextCompanyKey, company)
r = r.WithContext(ctx)
// Same as StripPrefix
p := strings.TrimPrefix(r.URL.Path, slug)
rp := strings.TrimPrefix(r.URL.RawPath, slug)
if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) {
r2 := new(http.Request)
*r2 = *r
r2.URL = new(url.URL)
*r2.URL = *r.URL
if p == "" {
r2.URL.Path = "/"
} else {
r2.URL.Path = p
}
r2.URL.RawPath = rp
r2.URL.Path = params[1].Value
next.ServeHTTP(w, r2)
} else {
http.NotFound(w, r)
}
})
}
func getCompany(r *http.Request) *Company {
@ -124,21 +106,30 @@ func (form *taxDetailsForm) mustFillFromDatabase(ctx context.Context, conn *Conn
type TaxDetailsPage struct {
DetailsForm *taxDetailsForm
NewTaxForm *newTaxForm
NewTaxForm *taxForm
Taxes []*Tax
}
func CompanyTaxDetailsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func GetCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
mustRenderTaxDetailsForm(w, r, newTaxDetailsFormFromDatabase(r))
}
func newTaxDetailsFormFromDatabase(r *http.Request) *taxDetailsForm {
locale := getLocale(r)
conn := getConn(r)
page := &TaxDetailsPage{
DetailsForm: newTaxDetailsForm(r.Context(), conn, locale),
NewTaxForm: newNewTaxForm(locale),
}
form := newTaxDetailsForm(r.Context(), conn, locale)
company := mustGetCompany(r)
if r.Method == "POST" {
if err := page.DetailsForm.Parse(r); err != nil {
form.mustFillFromDatabase(r.Context(), conn, company)
return form
}
func HandleCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
conn := getConn(r)
form := newTaxDetailsForm(r.Context(), conn, locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
@ -146,19 +137,39 @@ func CompanyTaxDetailsHandler() http.Handler {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if ok := page.DetailsForm.Validate(r.Context(), conn); ok {
form := page.DetailsForm
if ok := form.Validate(r.Context(), conn); !ok {
w.WriteHeader(http.StatusUnprocessableEntity)
mustRenderTaxDetailsForm(w, r, form)
return
}
company := mustGetCompany(r)
conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id)
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther)
return
}
w.WriteHeader(http.StatusUnprocessableEntity)
} else {
page.DetailsForm.mustFillFromDatabase(r.Context(), conn, company)
func mustRenderTaxDetailsForm(w http.ResponseWriter, r *http.Request, form *taxDetailsForm) {
locale := getLocale(r)
page := &TaxDetailsPage{
DetailsForm: form,
NewTaxForm: newTaxForm(locale),
}
mustRenderTexDetailsPage(w, r, page)
}
func mustRenderTaxForm(w http.ResponseWriter, r *http.Request, form *taxForm) {
page := &TaxDetailsPage{
DetailsForm: newTaxDetailsFormFromDatabase(r),
NewTaxForm: form,
}
mustRenderTexDetailsPage(w, r, page)
}
func mustRenderTexDetailsPage(w http.ResponseWriter, r *http.Request, page *TaxDetailsPage) {
conn := getConn(r)
company := mustGetCompany(r)
page.Taxes = mustGetTaxes(r.Context(), conn, company)
mustRenderAppTemplate(w, r, "tax-details.gohtml", page)
})
}
func mustGetCompany(r *http.Request) *Company {
@ -192,14 +203,14 @@ func mustGetTaxes(ctx context.Context, conn *Conn, company *Company) []*Tax {
return taxes
}
type newTaxForm struct {
type taxForm struct {
locale *Locale
Name *InputField
Rate *InputField
}
func newNewTaxForm(locale *Locale) *newTaxForm {
return &newTaxForm{
func newTaxForm(locale *Locale) *taxForm {
return &taxForm{
locale: locale,
Name: &InputField{
Name: "tax_name",
@ -220,7 +231,7 @@ func newNewTaxForm(locale *Locale) *newTaxForm {
}
}
func (form *newTaxForm) Parse(r *http.Request) error {
func (form *taxForm) Parse(r *http.Request) error {
if err := r.ParseForm(); err != nil {
return err
}
@ -229,7 +240,7 @@ func (form *newTaxForm) Parse(r *http.Request) error {
return nil
}
func (form *newTaxForm) Validate() bool {
func (form *taxForm) Validate() bool {
validator := newFormValidator()
validator.CheckRequiredInput(form.Name, gettext("Tax name can not be empty.", form.locale))
if validator.CheckRequiredInput(form.Rate, gettext("Tax rate can not be empty.", form.locale)) {
@ -238,24 +249,25 @@ func (form *newTaxForm) Validate() bool {
return validator.AllOK()
}
func CompanyTaxHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
param := r.URL.Path
if idx := strings.LastIndexByte(param, '/'); idx >= 0 {
param = param[idx+1:]
func HandleDeleteCompanyTax(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
taxId, err := strconv.Atoi(params[0].Value)
if err != nil {
http.NotFound(w, r)
return
}
conn := getConn(r)
company := mustGetCompany(r)
if taxId, err := strconv.Atoi(param); err == nil {
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
conn := getConn(r)
company := mustGetCompany(r)
conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId)
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther)
} else {
}
func HandleAddCompanyTax(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
form := newNewTaxForm(locale)
form := newTaxForm(locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -264,18 +276,13 @@ func CompanyTaxHandler() http.Handler {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if form.Validate() {
if !form.Validate() {
w.WriteHeader(http.StatusUnprocessableEntity)
mustRenderTaxForm(w, r, form)
return
}
conn := getConn(r)
company := mustGetCompany(r)
conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer())
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther)
} else {
w.WriteHeader(http.StatusUnprocessableEntity)
}
page := &TaxDetailsPage{
DetailsForm: newTaxDetailsForm(r.Context(), conn, locale).mustFillFromDatabase(r.Context(), conn, company),
NewTaxForm: form,
Taxes: mustGetTaxes(r.Context(), conn, company),
}
mustRenderAppTemplate(w, r, "tax-details.gohtml", page)
}
})
}

View File

@ -2,6 +2,7 @@ package pkg
import (
"context"
"github.com/julienschmidt/httprouter"
"html/template"
"net/http"
)
@ -16,11 +17,28 @@ type ContactsIndexPage struct {
Contacts []*ContactEntry
}
func ContactsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func IndexContacts(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r)
company := getCompany(r)
if r.Method == "POST" {
page := &ContactsIndexPage{
Contacts: mustGetContactEntries(r.Context(), conn, company),
}
mustRenderAppTemplate(w, r, "contacts-index.gohtml", page)
}
func GetNewContactForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
conn := getConn(r)
form := newContactForm(r.Context(), conn, locale)
mustRenderContactForm(w, r, form)
}
func mustRenderContactForm(w http.ResponseWriter, r *http.Request, form *contactForm) {
mustRenderAppTemplate(w, r, "contacts-new.gohtml", form)
}
func HandleAddContact(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r)
locale := getLocale(r)
form := newContactForm(r.Context(), conn, locale)
if err := form.Parse(r); err != nil {
@ -31,19 +49,13 @@ func ContactsHandler() http.Handler {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if form.Validate(r.Context(), conn) {
if !form.Validate(r.Context(), conn) {
mustRenderContactForm(w, r, form)
return
}
company := getCompany(r)
conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, province, city, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)
http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther)
} else {
mustRenderAppTemplate(w, r, "contacts-new.gohtml", form)
}
} else {
page := &ContactsIndexPage{
Contacts: mustGetContactEntries(r.Context(), conn, company),
}
mustRenderAppTemplate(w, r, "contacts-index.gohtml", page)
}
})
}
func mustGetContactEntries(ctx context.Context, conn *Conn, company *Company) []*ContactEntry {
@ -217,12 +229,3 @@ func (form *contactForm) Validate(ctx context.Context, conn *Conn) bool {
validator.CheckValidSelectOption(form.Country, gettext("Selected country is not valid.", form.locale))
return validator.AllOK()
}
func NewContactHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
locale := getLocale(r)
conn := getConn(r)
form := newContactForm(r.Context(), conn, locale)
mustRenderAppTemplate(w, r, "contacts-new.gohtml", form)
})
}

View File

@ -22,7 +22,7 @@ func NewLocale(lang language.Tag) *Locale {
}
}
func SetLocale(db *Db, next http.Handler) http.Handler {
func LocaleSetter(db *Db, next http.Handler) http.Handler {
availableLanguages := mustGetAvailableLanguages(db)
var matcher = language.NewMatcher(availableLanguages)

View File

@ -3,6 +3,7 @@ package pkg
import (
"context"
"errors"
"github.com/julienschmidt/httprouter"
"html/template"
"net"
"net/http"
@ -72,8 +73,19 @@ func (form *loginForm) Validate() bool {
return validator.AllOK()
}
func LoginHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func GetLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user := getUser(r)
if user.LoggedIn {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
locale := getLocale(r)
form := newLoginForm(locale)
w.WriteHeader(http.StatusOK)
mustRenderLoginForm(w, r, form)
}
func HandleLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user := getUser(r)
if user.LoggedIn {
http.Redirect(w, r, "/", http.StatusSeeOther)
@ -81,7 +93,6 @@ func LoginHandler() http.Handler {
}
locale := getLocale(r)
form := newLoginForm(locale)
if r.Method == "POST" {
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -99,13 +110,14 @@ func LoginHandler() http.Handler {
} else {
w.WriteHeader(http.StatusUnprocessableEntity)
}
}
mustRenderWebTemplate(w, r, "login.gohtml", form)
})
mustRenderLoginForm(w, r, form)
}
func LogoutHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func mustRenderLoginForm(w http.ResponseWriter, r *http.Request, form *loginForm) {
mustRenderWebTemplate(w, r, "login.gohtml", form)
}
func HandleLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
@ -114,7 +126,6 @@ func LogoutHandler() http.Handler {
conn.MustExec(r.Context(), "select logout()")
http.SetCookie(w, createSessionCookie("", -24*time.Hour))
http.Redirect(w, r, "/login", http.StatusSeeOther)
})
}
func remoteAddr(r *http.Request) string {
@ -148,7 +159,7 @@ type AppUser struct {
CsrfToken string
}
func CheckLogin(db *Db, next http.Handler) http.Handler {
func LoginChecker(db *Db, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context()
if cookie, err := r.Cookie(sessionCookie); err == nil {
@ -195,13 +206,13 @@ func getConn(r *http.Request) *Conn {
return r.Context().Value(ContextConnKey).(*Conn)
}
func Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func Authenticated(next httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
user := getUser(r)
if user.LoggedIn {
next.ServeHTTP(w, r)
next(w, r, params)
} else {
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
})
}
}

View File

@ -2,6 +2,7 @@ package pkg
import (
"context"
"github.com/julienschmidt/httprouter"
"html/template"
"net/http"
)
@ -93,13 +94,23 @@ func (form *profileForm) Validate() bool {
return validator.AllOK()
}
func ProfileHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func GetProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user := getUser(r)
conn := getConn(r)
locale := getLocale(r)
form := newProfileForm(r.Context(), conn, locale)
if r.Method == "POST" {
form.Name.Val = conn.MustGetText(r.Context(), "", "select name from user_profile")
form.Email.Val = user.Email
form.Language.Selected = user.Language.String()
w.WriteHeader(http.StatusOK)
mustRenderProfileForm(w, r, form)
}
func HandleProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r)
locale := getLocale(r)
form := newProfileForm(r.Context(), conn, locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -108,7 +119,11 @@ func ProfileHandler() http.Handler {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if ok := form.Validate(); ok {
if ok := form.Validate(); !ok {
w.WriteHeader(http.StatusUnprocessableEntity)
mustRenderProfileForm(w, r, form)
return
}
//goland:noinspection SqlWithoutWhere
cookie := conn.MustGetText(r.Context(), "", "update user_profile set name = $1, email = $2, lang_tag = $3 returning build_cookie()", form.Name, form.Email, form.Language)
setSessionCookie(w, cookie)
@ -117,14 +132,8 @@ func ProfileHandler() http.Handler {
}
company := getCompany(r)
http.Redirect(w, r, "/company/"+company.Slug+"/profile", http.StatusSeeOther)
return
}
w.WriteHeader(http.StatusUnprocessableEntity)
} else {
form.Name.Val = conn.MustGetText(r.Context(), "", "select name from user_profile")
form.Email.Val = user.Email
form.Language.Selected = user.Language.String()
}
func mustRenderProfileForm(w http.ResponseWriter, r *http.Request, form *profileForm) {
mustRenderAppTemplate(w, r, "profile.gohtml", form)
})
}

View File

@ -2,39 +2,72 @@ package pkg
import (
"net/http"
"github.com/julienschmidt/httprouter"
)
func NewRouter(db *Db) http.Handler {
companyRouter := http.NewServeMux()
companyRouter.Handle("/tax-details", CompanyTaxDetailsHandler())
companyRouter.Handle("/tax/", CompanyTaxHandler())
companyRouter.Handle("/tax", CompanyTaxHandler())
companyRouter.Handle("/profile", ProfileHandler())
companyRouter.Handle("/contacts/new", NewContactHandler())
companyRouter.Handle("/contacts", ContactsHandler())
companyRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
companyRouter := httprouter.New()
companyRouter.GET("/profile", GetProfileForm)
companyRouter.POST("/profile", HandleProfileForm)
companyRouter.GET("/tax-details", GetCompanyTaxDetailsForm)
companyRouter.POST("/tax-details", HandleCompanyTaxDetailsForm)
companyRouter.POST("/tax", HandleAddCompanyTax)
companyRouter.DELETE("/tax/:taxId", HandleDeleteCompanyTax)
companyRouter.GET("/contacts", IndexContacts)
companyRouter.POST("/contacts", HandleAddContact)
companyRouter.GET("/contacts/new", GetNewContactForm)
companyRouter.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
mustRenderAppTemplate(w, r, "dashboard.gohtml", nil)
})
router := http.NewServeMux()
router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static"))))
router.Handle("/login", LoginHandler())
router.Handle("/logout", Authenticated(LogoutHandler()))
router.Handle("/company/", Authenticated(http.StripPrefix("/company/", CompanyHandler(companyRouter))))
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
router := httprouter.New()
router.ServeFiles("/static/*filepath", http.Dir("web/static"))
router.GET("/login", GetLoginForm)
router.POST("/login", HandleLoginForm)
router.POST("/logout", Authenticated(HandleLogout))
companyHandler := Authenticated(CompanyHandler(companyRouter))
router.GET("/company/:slug/*rest", companyHandler)
router.POST("/company/:slug/*rest", companyHandler)
router.PUT("/company/:slug/*rest", companyHandler)
router.DELETE("/company/:slug/*rest", companyHandler)
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user := getUser(r)
if user.LoggedIn {
conn := getConn(r)
slug := conn.MustGetText(r.Context(), "", "select slug::text from company order by company_id limit 1")
http.Redirect(w, r, "/company/"+slug, http.StatusFound)
http.Redirect(w, r, "/company/"+slug+"/", http.StatusFound)
} else {
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
})
var handler http.Handler = router
handler = SetLocale(db, handler)
handler = CheckLogin(db, handler)
handler = MethodOverrider(handler)
handler = LocaleSetter(db, handler)
handler = LoginChecker(db, handler)
handler = Recoverer(handler)
handler = Logger(handler)
return handler
}
func MethodOverrider(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
override := r.FormValue("_method")
if override == http.MethodDelete || override == http.MethodPut {
r2 := new(http.Request)
*r2 = *r
r2.Method = override
r = r2
}
}
next.ServeHTTP(w, r)
})
}