Replace default router with github.com/julienschmidt/httprouter

I would fuck up handling URL parameters and this router has per-method
handlers, that are easier to work with, in some cases.
This commit is contained in:
jordi fita mas 2023-02-03 12:30:56 +01:00
parent 80f14d5818
commit 1ab48cfcbc
9 changed files with 307 additions and 240 deletions

1
debian/control vendored
View File

@ -8,6 +8,7 @@ Build-Depends:
gettext, gettext,
golang-any, golang-any,
golang-github-jackc-pgx-v4-dev, golang-github-jackc-pgx-v4-dev,
golang-github-julienschmidt-httprouter-dev,
golang-github-leonelquinteros-gotext-dev, golang-github-leonelquinteros-gotext-dev,
golang-golang-x-text-dev, golang-golang-x-text-dev,
postgresql-all (>= 217~), postgresql-all (>= 217~),

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.18
require ( require (
github.com/jackc/pgx/v4 v4.17.2 github.com/jackc/pgx/v4 v4.17.2
github.com/julienschmidt/httprouter v1.3.0
github.com/leonelquinteros/gotext v1.5.1 github.com/leonelquinteros/gotext v1.5.1
golang.org/x/text v0.3.8 golang.org/x/text v0.3.8
) )

2
go.sum
View File

@ -62,6 +62,8 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0=
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=

View File

@ -3,11 +3,11 @@ package pkg
import ( import (
"context" "context"
"errors" "errors"
"github.com/julienschmidt/httprouter"
"html/template" "html/template"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings"
) )
const ( const (
@ -19,44 +19,26 @@ type Company struct {
Slug string Slug string
} }
func CompanyHandler(next http.Handler) http.Handler { func CompanyHandler(next http.Handler) httprouter.Handle {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
slug := r.URL.Path
if idx := strings.IndexByte(slug, '/'); idx >= 0 {
slug = slug[:idx]
}
conn := getConn(r)
company := &Company{ company := &Company{
Slug: slug, Slug: params[0].Value,
} }
err := conn.QueryRow(r.Context(), "select company_id from company where slug = $1", slug).Scan(&company.Id) conn := getConn(r)
err := conn.QueryRow(r.Context(), "select company_id from company where slug = $1", company.Slug).Scan(&company.Id)
if err != nil { if err != nil {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
ctx := context.WithValue(r.Context(), ContextCompanyKey, company) ctx := context.WithValue(r.Context(), ContextCompanyKey, company)
r = r.WithContext(ctx) r = r.WithContext(ctx)
r2 := new(http.Request)
// Same as StripPrefix *r2 = *r
p := strings.TrimPrefix(r.URL.Path, slug) r2.URL = new(url.URL)
rp := strings.TrimPrefix(r.URL.RawPath, slug) *r2.URL = *r.URL
if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) { r2.URL.Path = params[1].Value
r2 := new(http.Request) next.ServeHTTP(w, r2)
*r2 = *r }
r2.URL = new(url.URL)
*r2.URL = *r.URL
if p == "" {
r2.URL.Path = "/"
} else {
r2.URL.Path = p
}
r2.URL.RawPath = rp
next.ServeHTTP(w, r2)
} else {
http.NotFound(w, r)
}
})
} }
func getCompany(r *http.Request) *Company { func getCompany(r *http.Request) *Company {
@ -124,41 +106,70 @@ func (form *taxDetailsForm) mustFillFromDatabase(ctx context.Context, conn *Conn
type TaxDetailsPage struct { type TaxDetailsPage struct {
DetailsForm *taxDetailsForm DetailsForm *taxDetailsForm
NewTaxForm *newTaxForm NewTaxForm *taxForm
Taxes []*Tax Taxes []*Tax
} }
func CompanyTaxDetailsHandler() http.Handler { func GetCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mustRenderTaxDetailsForm(w, r, newTaxDetailsFormFromDatabase(r))
locale := getLocale(r) }
conn := getConn(r)
page := &TaxDetailsPage{ func newTaxDetailsFormFromDatabase(r *http.Request) *taxDetailsForm {
DetailsForm: newTaxDetailsForm(r.Context(), conn, locale), locale := getLocale(r)
NewTaxForm: newNewTaxForm(locale), conn := getConn(r)
} form := newTaxDetailsForm(r.Context(), conn, locale)
company := mustGetCompany(r)
if r.Method == "POST" { company := mustGetCompany(r)
if err := page.DetailsForm.Parse(r); err != nil { form.mustFillFromDatabase(r.Context(), conn, company)
http.Error(w, err.Error(), http.StatusBadRequest)
return return form
} }
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden) func HandleCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return locale := getLocale(r)
} conn := getConn(r)
if ok := page.DetailsForm.Validate(r.Context(), conn); ok { form := newTaxDetailsForm(r.Context(), conn, locale)
form := page.DetailsForm if err := form.Parse(r); err != nil {
conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id) http.Error(w, err.Error(), http.StatusBadRequest)
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) return
return }
} if err := verifyCsrfTokenValid(r); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity) http.Error(w, err.Error(), http.StatusForbidden)
} else { return
page.DetailsForm.mustFillFromDatabase(r.Context(), conn, company) }
} if ok := form.Validate(r.Context(), conn); !ok {
page.Taxes = mustGetTaxes(r.Context(), conn, company) w.WriteHeader(http.StatusUnprocessableEntity)
mustRenderAppTemplate(w, r, "tax-details.gohtml", page) mustRenderTaxDetailsForm(w, r, form)
}) return
}
company := mustGetCompany(r)
conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id)
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther)
return
}
func mustRenderTaxDetailsForm(w http.ResponseWriter, r *http.Request, form *taxDetailsForm) {
locale := getLocale(r)
page := &TaxDetailsPage{
DetailsForm: form,
NewTaxForm: newTaxForm(locale),
}
mustRenderTexDetailsPage(w, r, page)
}
func mustRenderTaxForm(w http.ResponseWriter, r *http.Request, form *taxForm) {
page := &TaxDetailsPage{
DetailsForm: newTaxDetailsFormFromDatabase(r),
NewTaxForm: form,
}
mustRenderTexDetailsPage(w, r, page)
}
func mustRenderTexDetailsPage(w http.ResponseWriter, r *http.Request, page *TaxDetailsPage) {
conn := getConn(r)
company := mustGetCompany(r)
page.Taxes = mustGetTaxes(r.Context(), conn, company)
mustRenderAppTemplate(w, r, "tax-details.gohtml", page)
} }
func mustGetCompany(r *http.Request) *Company { func mustGetCompany(r *http.Request) *Company {
@ -192,14 +203,14 @@ func mustGetTaxes(ctx context.Context, conn *Conn, company *Company) []*Tax {
return taxes return taxes
} }
type newTaxForm struct { type taxForm struct {
locale *Locale locale *Locale
Name *InputField Name *InputField
Rate *InputField Rate *InputField
} }
func newNewTaxForm(locale *Locale) *newTaxForm { func newTaxForm(locale *Locale) *taxForm {
return &newTaxForm{ return &taxForm{
locale: locale, locale: locale,
Name: &InputField{ Name: &InputField{
Name: "tax_name", Name: "tax_name",
@ -220,7 +231,7 @@ func newNewTaxForm(locale *Locale) *newTaxForm {
} }
} }
func (form *newTaxForm) Parse(r *http.Request) error { func (form *taxForm) Parse(r *http.Request) error {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return err return err
} }
@ -229,7 +240,7 @@ func (form *newTaxForm) Parse(r *http.Request) error {
return nil return nil
} }
func (form *newTaxForm) Validate() bool { func (form *taxForm) Validate() bool {
validator := newFormValidator() validator := newFormValidator()
validator.CheckRequiredInput(form.Name, gettext("Tax name can not be empty.", form.locale)) validator.CheckRequiredInput(form.Name, gettext("Tax name can not be empty.", form.locale))
if validator.CheckRequiredInput(form.Rate, gettext("Tax rate can not be empty.", form.locale)) { if validator.CheckRequiredInput(form.Rate, gettext("Tax rate can not be empty.", form.locale)) {
@ -238,44 +249,40 @@ func (form *newTaxForm) Validate() bool {
return validator.AllOK() return validator.AllOK()
} }
func CompanyTaxHandler() http.Handler { func HandleDeleteCompanyTax(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { taxId, err := strconv.Atoi(params[0].Value)
param := r.URL.Path if err != nil {
if idx := strings.LastIndexByte(param, '/'); idx >= 0 { http.NotFound(w, r)
param = param[idx+1:] return
} }
conn := getConn(r) if err := verifyCsrfTokenValid(r); err != nil {
company := mustGetCompany(r) http.Error(w, err.Error(), http.StatusForbidden)
if taxId, err := strconv.Atoi(param); err == nil { return
if err := verifyCsrfTokenValid(r); err != nil { }
http.Error(w, err.Error(), http.StatusForbidden) conn := getConn(r)
return company := mustGetCompany(r)
} conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId)
conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId) http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther)
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) }
} else {
locale := getLocale(r) func HandleAddCompanyTax(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
form := newNewTaxForm(locale) locale := getLocale(r)
if err := form.Parse(r); err != nil { form := newTaxForm(locale)
http.Error(w, err.Error(), http.StatusBadRequest) if err := form.Parse(r); err != nil {
return http.Error(w, err.Error(), http.StatusBadRequest)
} return
if err := verifyCsrfTokenValid(r); err != nil { }
http.Error(w, err.Error(), http.StatusForbidden) if err := verifyCsrfTokenValid(r); err != nil {
return http.Error(w, err.Error(), http.StatusForbidden)
} return
if form.Validate() { }
conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer()) if !form.Validate() {
http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) w.WriteHeader(http.StatusUnprocessableEntity)
} else { mustRenderTaxForm(w, r, form)
w.WriteHeader(http.StatusUnprocessableEntity) return
} }
page := &TaxDetailsPage{ conn := getConn(r)
DetailsForm: newTaxDetailsForm(r.Context(), conn, locale).mustFillFromDatabase(r.Context(), conn, company), company := mustGetCompany(r)
NewTaxForm: form, conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer())
Taxes: mustGetTaxes(r.Context(), conn, company), http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther)
}
mustRenderAppTemplate(w, r, "tax-details.gohtml", page)
}
})
} }

View File

@ -2,6 +2,7 @@ package pkg
import ( import (
"context" "context"
"github.com/julienschmidt/httprouter"
"html/template" "html/template"
"net/http" "net/http"
) )
@ -16,34 +17,45 @@ type ContactsIndexPage struct {
Contacts []*ContactEntry Contacts []*ContactEntry
} }
func ContactsHandler() http.Handler { func IndexContacts(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn := getConn(r)
conn := getConn(r) company := getCompany(r)
company := getCompany(r) page := &ContactsIndexPage{
if r.Method == "POST" { Contacts: mustGetContactEntries(r.Context(), conn, company),
locale := getLocale(r) }
form := newContactForm(r.Context(), conn, locale) mustRenderAppTemplate(w, r, "contacts-index.gohtml", page)
if err := form.Parse(r); err != nil { }
http.Error(w, err.Error(), http.StatusBadRequest)
return func GetNewContactForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
} locale := getLocale(r)
if err := verifyCsrfTokenValid(r); err != nil { conn := getConn(r)
http.Error(w, err.Error(), http.StatusForbidden) form := newContactForm(r.Context(), conn, locale)
return mustRenderContactForm(w, r, form)
} }
if form.Validate(r.Context(), conn) {
conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, province, city, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country) func mustRenderContactForm(w http.ResponseWriter, r *http.Request, form *contactForm) {
http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther) mustRenderAppTemplate(w, r, "contacts-new.gohtml", form)
} else { }
mustRenderAppTemplate(w, r, "contacts-new.gohtml", form)
} func HandleAddContact(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
} else { conn := getConn(r)
page := &ContactsIndexPage{ locale := getLocale(r)
Contacts: mustGetContactEntries(r.Context(), conn, company), form := newContactForm(r.Context(), conn, locale)
} if err := form.Parse(r); err != nil {
mustRenderAppTemplate(w, r, "contacts-index.gohtml", page) http.Error(w, err.Error(), http.StatusBadRequest)
} return
}) }
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if !form.Validate(r.Context(), conn) {
mustRenderContactForm(w, r, form)
return
}
company := getCompany(r)
conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, province, city, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country)
http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther)
} }
func mustGetContactEntries(ctx context.Context, conn *Conn, company *Company) []*ContactEntry { func mustGetContactEntries(ctx context.Context, conn *Conn, company *Company) []*ContactEntry {
@ -217,12 +229,3 @@ func (form *contactForm) Validate(ctx context.Context, conn *Conn) bool {
validator.CheckValidSelectOption(form.Country, gettext("Selected country is not valid.", form.locale)) validator.CheckValidSelectOption(form.Country, gettext("Selected country is not valid.", form.locale))
return validator.AllOK() return validator.AllOK()
} }
func NewContactHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
locale := getLocale(r)
conn := getConn(r)
form := newContactForm(r.Context(), conn, locale)
mustRenderAppTemplate(w, r, "contacts-new.gohtml", form)
})
}

View File

@ -22,7 +22,7 @@ func NewLocale(lang language.Tag) *Locale {
} }
} }
func SetLocale(db *Db, next http.Handler) http.Handler { func LocaleSetter(db *Db, next http.Handler) http.Handler {
availableLanguages := mustGetAvailableLanguages(db) availableLanguages := mustGetAvailableLanguages(db)
var matcher = language.NewMatcher(availableLanguages) var matcher = language.NewMatcher(availableLanguages)

View File

@ -3,6 +3,7 @@ package pkg
import ( import (
"context" "context"
"errors" "errors"
"github.com/julienschmidt/httprouter"
"html/template" "html/template"
"net" "net"
"net/http" "net/http"
@ -72,49 +73,59 @@ func (form *loginForm) Validate() bool {
return validator.AllOK() return validator.AllOK()
} }
func LoginHandler() http.Handler { func GetLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := getUser(r)
user := getUser(r) if user.LoggedIn {
if user.LoggedIn { http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
locale := getLocale(r)
form := newLoginForm(locale)
w.WriteHeader(http.StatusOK)
mustRenderLoginForm(w, r, form)
}
func HandleLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user := getUser(r)
if user.LoggedIn {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
locale := getLocale(r)
form := newLoginForm(locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if form.Validate() {
conn := getConn(r)
cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", form.Email, form.Password, remoteAddr(r))
if cookie != "" {
setSessionCookie(w, cookie)
http.Redirect(w, r, "/", http.StatusSeeOther) http.Redirect(w, r, "/", http.StatusSeeOther)
return return
} }
locale := getLocale(r) form.Errors = append(form.Errors, errors.New(gettext("Invalid user or password.", locale)))
form := newLoginForm(locale) w.WriteHeader(http.StatusUnauthorized)
if r.Method == "POST" { } else {
if err := form.Parse(r); err != nil { w.WriteHeader(http.StatusUnprocessableEntity)
http.Error(w, err.Error(), http.StatusBadRequest) }
return mustRenderLoginForm(w, r, form)
}
if form.Validate() {
conn := getConn(r)
cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", form.Email, form.Password, remoteAddr(r))
if cookie != "" {
setSessionCookie(w, cookie)
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
form.Errors = append(form.Errors, errors.New(gettext("Invalid user or password.", locale)))
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusUnprocessableEntity)
}
}
mustRenderWebTemplate(w, r, "login.gohtml", form)
})
} }
func LogoutHandler() http.Handler { func mustRenderLoginForm(w http.ResponseWriter, r *http.Request, form *loginForm) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mustRenderWebTemplate(w, r, "login.gohtml", form)
if err := verifyCsrfTokenValid(r); err != nil { }
http.Error(w, err.Error(), http.StatusForbidden)
return func HandleLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
} if err := verifyCsrfTokenValid(r); err != nil {
conn := getConn(r) http.Error(w, err.Error(), http.StatusForbidden)
conn.MustExec(r.Context(), "select logout()") return
http.SetCookie(w, createSessionCookie("", -24*time.Hour)) }
http.Redirect(w, r, "/login", http.StatusSeeOther) conn := getConn(r)
}) conn.MustExec(r.Context(), "select logout()")
http.SetCookie(w, createSessionCookie("", -24*time.Hour))
http.Redirect(w, r, "/login", http.StatusSeeOther)
} }
func remoteAddr(r *http.Request) string { func remoteAddr(r *http.Request) string {
@ -148,7 +159,7 @@ type AppUser struct {
CsrfToken string CsrfToken string
} }
func CheckLogin(db *Db, next http.Handler) http.Handler { func LoginChecker(db *Db, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context() var ctx = r.Context()
if cookie, err := r.Cookie(sessionCookie); err == nil { if cookie, err := r.Cookie(sessionCookie); err == nil {
@ -195,13 +206,13 @@ func getConn(r *http.Request) *Conn {
return r.Context().Value(ContextConnKey).(*Conn) return r.Context().Value(ContextConnKey).(*Conn)
} }
func Authenticated(next http.Handler) http.Handler { func Authenticated(next httprouter.Handle) httprouter.Handle {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
user := getUser(r) user := getUser(r)
if user.LoggedIn { if user.LoggedIn {
next.ServeHTTP(w, r) next(w, r, params)
} else { } else {
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
} }
}) }
} }

View File

@ -2,6 +2,7 @@ package pkg
import ( import (
"context" "context"
"github.com/julienschmidt/httprouter"
"html/template" "html/template"
"net/http" "net/http"
) )
@ -93,38 +94,46 @@ func (form *profileForm) Validate() bool {
return validator.AllOK() return validator.AllOK()
} }
func ProfileHandler() http.Handler { func GetProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := getUser(r)
user := getUser(r) conn := getConn(r)
conn := getConn(r) locale := getLocale(r)
locale := getLocale(r) form := newProfileForm(r.Context(), conn, locale)
form := newProfileForm(r.Context(), conn, locale) form.Name.Val = conn.MustGetText(r.Context(), "", "select name from user_profile")
if r.Method == "POST" { form.Email.Val = user.Email
if err := form.Parse(r); err != nil { form.Language.Selected = user.Language.String()
http.Error(w, err.Error(), http.StatusBadRequest)
return w.WriteHeader(http.StatusOK)
} mustRenderProfileForm(w, r, form)
if err := verifyCsrfTokenValid(r); err != nil { }
http.Error(w, err.Error(), http.StatusForbidden)
return func HandleProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
} conn := getConn(r)
if ok := form.Validate(); ok { locale := getLocale(r)
//goland:noinspection SqlWithoutWhere form := newProfileForm(r.Context(), conn, locale)
cookie := conn.MustGetText(r.Context(), "", "update user_profile set name = $1, email = $2, lang_tag = $3 returning build_cookie()", form.Name, form.Email, form.Language) if err := form.Parse(r); err != nil {
setSessionCookie(w, cookie) http.Error(w, err.Error(), http.StatusBadRequest)
if form.Password.Val != "" { return
conn.MustExec(r.Context(), "select change_password($1)", form.Password) }
} if err := verifyCsrfTokenValid(r); err != nil {
company := getCompany(r) http.Error(w, err.Error(), http.StatusForbidden)
http.Redirect(w, r, "/company/"+company.Slug+"/profile", http.StatusSeeOther) return
return }
} if ok := form.Validate(); !ok {
w.WriteHeader(http.StatusUnprocessableEntity) w.WriteHeader(http.StatusUnprocessableEntity)
} else { mustRenderProfileForm(w, r, form)
form.Name.Val = conn.MustGetText(r.Context(), "", "select name from user_profile") return
form.Email.Val = user.Email }
form.Language.Selected = user.Language.String() //goland:noinspection SqlWithoutWhere
} cookie := conn.MustGetText(r.Context(), "", "update user_profile set name = $1, email = $2, lang_tag = $3 returning build_cookie()", form.Name, form.Email, form.Language)
mustRenderAppTemplate(w, r, "profile.gohtml", form) setSessionCookie(w, cookie)
}) if form.Password.Val != "" {
conn.MustExec(r.Context(), "select change_password($1)", form.Password)
}
company := getCompany(r)
http.Redirect(w, r, "/company/"+company.Slug+"/profile", http.StatusSeeOther)
}
func mustRenderProfileForm(w http.ResponseWriter, r *http.Request, form *profileForm) {
mustRenderAppTemplate(w, r, "profile.gohtml", form)
} }

View File

@ -2,39 +2,72 @@ package pkg
import ( import (
"net/http" "net/http"
"github.com/julienschmidt/httprouter"
) )
func NewRouter(db *Db) http.Handler { func NewRouter(db *Db) http.Handler {
companyRouter := http.NewServeMux() companyRouter := httprouter.New()
companyRouter.Handle("/tax-details", CompanyTaxDetailsHandler()) companyRouter.GET("/profile", GetProfileForm)
companyRouter.Handle("/tax/", CompanyTaxHandler()) companyRouter.POST("/profile", HandleProfileForm)
companyRouter.Handle("/tax", CompanyTaxHandler()) companyRouter.GET("/tax-details", GetCompanyTaxDetailsForm)
companyRouter.Handle("/profile", ProfileHandler()) companyRouter.POST("/tax-details", HandleCompanyTaxDetailsForm)
companyRouter.Handle("/contacts/new", NewContactHandler()) companyRouter.POST("/tax", HandleAddCompanyTax)
companyRouter.Handle("/contacts", ContactsHandler()) companyRouter.DELETE("/tax/:taxId", HandleDeleteCompanyTax)
companyRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { companyRouter.GET("/contacts", IndexContacts)
companyRouter.POST("/contacts", HandleAddContact)
companyRouter.GET("/contacts/new", GetNewContactForm)
companyRouter.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
mustRenderAppTemplate(w, r, "dashboard.gohtml", nil) mustRenderAppTemplate(w, r, "dashboard.gohtml", nil)
}) })
router := http.NewServeMux() router := httprouter.New()
router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static")))) router.ServeFiles("/static/*filepath", http.Dir("web/static"))
router.Handle("/login", LoginHandler()) router.GET("/login", GetLoginForm)
router.Handle("/logout", Authenticated(LogoutHandler())) router.POST("/login", HandleLoginForm)
router.Handle("/company/", Authenticated(http.StripPrefix("/company/", CompanyHandler(companyRouter)))) router.POST("/logout", Authenticated(HandleLogout))
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
companyHandler := Authenticated(CompanyHandler(companyRouter))
router.GET("/company/:slug/*rest", companyHandler)
router.POST("/company/:slug/*rest", companyHandler)
router.PUT("/company/:slug/*rest", companyHandler)
router.DELETE("/company/:slug/*rest", companyHandler)
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
user := getUser(r) user := getUser(r)
if user.LoggedIn { if user.LoggedIn {
conn := getConn(r) conn := getConn(r)
slug := conn.MustGetText(r.Context(), "", "select slug::text from company order by company_id limit 1") slug := conn.MustGetText(r.Context(), "", "select slug::text from company order by company_id limit 1")
http.Redirect(w, r, "/company/"+slug, http.StatusFound) http.Redirect(w, r, "/company/"+slug+"/", http.StatusFound)
} else { } else {
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
} }
}) })
var handler http.Handler = router var handler http.Handler = router
handler = SetLocale(db, handler) handler = MethodOverrider(handler)
handler = CheckLogin(db, handler) handler = LocaleSetter(db, handler)
handler = LoginChecker(db, handler)
handler = Recoverer(handler) handler = Recoverer(handler)
handler = Logger(handler) handler = Logger(handler)
return handler return handler
} }
func MethodOverrider(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
override := r.FormValue("_method")
if override == http.MethodDelete || override == http.MethodPut {
r2 := new(http.Request)
*r2 = *r
r2.Method = override
r = r2
}
}
next.ServeHTTP(w, r)
})
}