From 1ab48cfcbcf9338baa6b1efa9c61d2b62776fb36 Mon Sep 17 00:00:00 2001 From: jordi fita mas Date: Fri, 3 Feb 2023 12:30:56 +0100 Subject: [PATCH] Replace default router with github.com/julienschmidt/httprouter I would fuck up handling URL parameters and this router has per-method handlers, that are easier to work with, in some cases. --- debian/control | 1 + go.mod | 1 + go.sum | 2 + pkg/company.go | 223 +++++++++++++++++++++++++----------------------- pkg/contacts.go | 77 +++++++++-------- pkg/locale.go | 2 +- pkg/login.go | 97 +++++++++++---------- pkg/profile.go | 77 +++++++++-------- pkg/router.go | 67 +++++++++++---- 9 files changed, 307 insertions(+), 240 deletions(-) diff --git a/debian/control b/debian/control index f2be840..0c8afac 100644 --- a/debian/control +++ b/debian/control @@ -8,6 +8,7 @@ Build-Depends: gettext, golang-any, golang-github-jackc-pgx-v4-dev, + golang-github-julienschmidt-httprouter-dev, golang-github-leonelquinteros-gotext-dev, golang-golang-x-text-dev, postgresql-all (>= 217~), diff --git a/go.mod b/go.mod index 755b0c4..bb8435a 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/jackc/pgx/v4 v4.17.2 + github.com/julienschmidt/httprouter v1.3.0 github.com/leonelquinteros/gotext v1.5.1 golang.org/x/text v0.3.8 ) diff --git a/go.sum b/go.sum index eaef7b4..4a7c508 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,8 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/pkg/company.go b/pkg/company.go index 52af1d6..6acdd78 100644 --- a/pkg/company.go +++ b/pkg/company.go @@ -3,11 +3,11 @@ package pkg import ( "context" "errors" + "github.com/julienschmidt/httprouter" "html/template" "net/http" "net/url" "strconv" - "strings" ) const ( @@ -19,44 +19,26 @@ type Company struct { Slug string } -func CompanyHandler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slug := r.URL.Path - if idx := strings.IndexByte(slug, '/'); idx >= 0 { - slug = slug[:idx] - } - - conn := getConn(r) +func CompanyHandler(next http.Handler) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { company := &Company{ - Slug: slug, + Slug: params[0].Value, } - err := conn.QueryRow(r.Context(), "select company_id from company where slug = $1", slug).Scan(&company.Id) + conn := getConn(r) + err := conn.QueryRow(r.Context(), "select company_id from company where slug = $1", company.Slug).Scan(&company.Id) if err != nil { http.NotFound(w, r) return } ctx := context.WithValue(r.Context(), ContextCompanyKey, company) r = r.WithContext(ctx) - - // Same as StripPrefix - p := strings.TrimPrefix(r.URL.Path, slug) - rp := strings.TrimPrefix(r.URL.RawPath, slug) - if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) { - r2 := new(http.Request) - *r2 = *r - r2.URL = new(url.URL) - *r2.URL = *r.URL - if p == "" { - r2.URL.Path = "/" - } else { - r2.URL.Path = p - } - r2.URL.RawPath = rp - next.ServeHTTP(w, r2) - } else { - http.NotFound(w, r) - } - }) + r2 := new(http.Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = params[1].Value + next.ServeHTTP(w, r2) + } } func getCompany(r *http.Request) *Company { @@ -124,41 +106,70 @@ func (form *taxDetailsForm) mustFillFromDatabase(ctx context.Context, conn *Conn type TaxDetailsPage struct { DetailsForm *taxDetailsForm - NewTaxForm *newTaxForm + NewTaxForm *taxForm Taxes []*Tax } -func CompanyTaxDetailsHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - locale := getLocale(r) - conn := getConn(r) - page := &TaxDetailsPage{ - DetailsForm: newTaxDetailsForm(r.Context(), conn, locale), - NewTaxForm: newNewTaxForm(locale), - } - company := mustGetCompany(r) - if r.Method == "POST" { - if err := page.DetailsForm.Parse(r); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := verifyCsrfTokenValid(r); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - if ok := page.DetailsForm.Validate(r.Context(), conn); ok { - form := page.DetailsForm - conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id) - http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) - return - } - w.WriteHeader(http.StatusUnprocessableEntity) - } else { - page.DetailsForm.mustFillFromDatabase(r.Context(), conn, company) - } - page.Taxes = mustGetTaxes(r.Context(), conn, company) - mustRenderAppTemplate(w, r, "tax-details.gohtml", page) - }) +func GetCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + mustRenderTaxDetailsForm(w, r, newTaxDetailsFormFromDatabase(r)) +} + +func newTaxDetailsFormFromDatabase(r *http.Request) *taxDetailsForm { + locale := getLocale(r) + conn := getConn(r) + form := newTaxDetailsForm(r.Context(), conn, locale) + + company := mustGetCompany(r) + form.mustFillFromDatabase(r.Context(), conn, company) + + return form +} + +func HandleCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + locale := getLocale(r) + conn := getConn(r) + form := newTaxDetailsForm(r.Context(), conn, locale) + if err := form.Parse(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + if ok := form.Validate(r.Context(), conn); !ok { + w.WriteHeader(http.StatusUnprocessableEntity) + mustRenderTaxDetailsForm(w, r, form) + return + } + company := mustGetCompany(r) + conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id) + http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) + return +} + +func mustRenderTaxDetailsForm(w http.ResponseWriter, r *http.Request, form *taxDetailsForm) { + locale := getLocale(r) + page := &TaxDetailsPage{ + DetailsForm: form, + NewTaxForm: newTaxForm(locale), + } + mustRenderTexDetailsPage(w, r, page) +} + +func mustRenderTaxForm(w http.ResponseWriter, r *http.Request, form *taxForm) { + page := &TaxDetailsPage{ + DetailsForm: newTaxDetailsFormFromDatabase(r), + NewTaxForm: form, + } + mustRenderTexDetailsPage(w, r, page) +} + +func mustRenderTexDetailsPage(w http.ResponseWriter, r *http.Request, page *TaxDetailsPage) { + conn := getConn(r) + company := mustGetCompany(r) + page.Taxes = mustGetTaxes(r.Context(), conn, company) + mustRenderAppTemplate(w, r, "tax-details.gohtml", page) } func mustGetCompany(r *http.Request) *Company { @@ -192,14 +203,14 @@ func mustGetTaxes(ctx context.Context, conn *Conn, company *Company) []*Tax { return taxes } -type newTaxForm struct { +type taxForm struct { locale *Locale Name *InputField Rate *InputField } -func newNewTaxForm(locale *Locale) *newTaxForm { - return &newTaxForm{ +func newTaxForm(locale *Locale) *taxForm { + return &taxForm{ locale: locale, Name: &InputField{ Name: "tax_name", @@ -220,7 +231,7 @@ func newNewTaxForm(locale *Locale) *newTaxForm { } } -func (form *newTaxForm) Parse(r *http.Request) error { +func (form *taxForm) Parse(r *http.Request) error { if err := r.ParseForm(); err != nil { return err } @@ -229,7 +240,7 @@ func (form *newTaxForm) Parse(r *http.Request) error { return nil } -func (form *newTaxForm) Validate() bool { +func (form *taxForm) Validate() bool { validator := newFormValidator() validator.CheckRequiredInput(form.Name, gettext("Tax name can not be empty.", form.locale)) if validator.CheckRequiredInput(form.Rate, gettext("Tax rate can not be empty.", form.locale)) { @@ -238,44 +249,40 @@ func (form *newTaxForm) Validate() bool { return validator.AllOK() } -func CompanyTaxHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - param := r.URL.Path - if idx := strings.LastIndexByte(param, '/'); idx >= 0 { - param = param[idx+1:] - } - conn := getConn(r) - company := mustGetCompany(r) - if taxId, err := strconv.Atoi(param); err == nil { - if err := verifyCsrfTokenValid(r); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId) - http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) - } else { - locale := getLocale(r) - form := newNewTaxForm(locale) - if err := form.Parse(r); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := verifyCsrfTokenValid(r); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - if form.Validate() { - conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer()) - http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) - } else { - w.WriteHeader(http.StatusUnprocessableEntity) - } - page := &TaxDetailsPage{ - DetailsForm: newTaxDetailsForm(r.Context(), conn, locale).mustFillFromDatabase(r.Context(), conn, company), - NewTaxForm: form, - Taxes: mustGetTaxes(r.Context(), conn, company), - } - mustRenderAppTemplate(w, r, "tax-details.gohtml", page) - } - }) +func HandleDeleteCompanyTax(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + taxId, err := strconv.Atoi(params[0].Value) + if err != nil { + http.NotFound(w, r) + return + } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + conn := getConn(r) + company := mustGetCompany(r) + conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId) + http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) +} + +func HandleAddCompanyTax(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + locale := getLocale(r) + form := newTaxForm(locale) + if err := form.Parse(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + if !form.Validate() { + w.WriteHeader(http.StatusUnprocessableEntity) + mustRenderTaxForm(w, r, form) + return + } + conn := getConn(r) + company := mustGetCompany(r) + conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer()) + http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) } diff --git a/pkg/contacts.go b/pkg/contacts.go index c38d75a..dc9a3fe 100644 --- a/pkg/contacts.go +++ b/pkg/contacts.go @@ -2,6 +2,7 @@ package pkg import ( "context" + "github.com/julienschmidt/httprouter" "html/template" "net/http" ) @@ -16,34 +17,45 @@ type ContactsIndexPage struct { Contacts []*ContactEntry } -func ContactsHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn := getConn(r) - company := getCompany(r) - if r.Method == "POST" { - locale := getLocale(r) - form := newContactForm(r.Context(), conn, locale) - if err := form.Parse(r); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := verifyCsrfTokenValid(r); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - if form.Validate(r.Context(), conn) { - conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, province, city, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country) - http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther) - } else { - mustRenderAppTemplate(w, r, "contacts-new.gohtml", form) - } - } else { - page := &ContactsIndexPage{ - Contacts: mustGetContactEntries(r.Context(), conn, company), - } - mustRenderAppTemplate(w, r, "contacts-index.gohtml", page) - } - }) +func IndexContacts(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + conn := getConn(r) + company := getCompany(r) + page := &ContactsIndexPage{ + Contacts: mustGetContactEntries(r.Context(), conn, company), + } + mustRenderAppTemplate(w, r, "contacts-index.gohtml", page) +} + +func GetNewContactForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + locale := getLocale(r) + conn := getConn(r) + form := newContactForm(r.Context(), conn, locale) + mustRenderContactForm(w, r, form) +} + +func mustRenderContactForm(w http.ResponseWriter, r *http.Request, form *contactForm) { + mustRenderAppTemplate(w, r, "contacts-new.gohtml", form) +} + +func HandleAddContact(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + conn := getConn(r) + locale := getLocale(r) + form := newContactForm(r.Context(), conn, locale) + if err := form.Parse(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + if !form.Validate(r.Context(), conn) { + mustRenderContactForm(w, r, form) + return + } + company := getCompany(r) + conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, province, city, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country) + http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther) } func mustGetContactEntries(ctx context.Context, conn *Conn, company *Company) []*ContactEntry { @@ -217,12 +229,3 @@ func (form *contactForm) Validate(ctx context.Context, conn *Conn) bool { validator.CheckValidSelectOption(form.Country, gettext("Selected country is not valid.", form.locale)) return validator.AllOK() } - -func NewContactHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - locale := getLocale(r) - conn := getConn(r) - form := newContactForm(r.Context(), conn, locale) - mustRenderAppTemplate(w, r, "contacts-new.gohtml", form) - }) -} diff --git a/pkg/locale.go b/pkg/locale.go index eea6d09..7363147 100644 --- a/pkg/locale.go +++ b/pkg/locale.go @@ -22,7 +22,7 @@ func NewLocale(lang language.Tag) *Locale { } } -func SetLocale(db *Db, next http.Handler) http.Handler { +func LocaleSetter(db *Db, next http.Handler) http.Handler { availableLanguages := mustGetAvailableLanguages(db) var matcher = language.NewMatcher(availableLanguages) diff --git a/pkg/login.go b/pkg/login.go index bd54b21..66ad121 100644 --- a/pkg/login.go +++ b/pkg/login.go @@ -3,6 +3,7 @@ package pkg import ( "context" "errors" + "github.com/julienschmidt/httprouter" "html/template" "net" "net/http" @@ -72,49 +73,59 @@ func (form *loginForm) Validate() bool { return validator.AllOK() } -func LoginHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := getUser(r) - if user.LoggedIn { +func GetLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + user := getUser(r) + if user.LoggedIn { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + locale := getLocale(r) + form := newLoginForm(locale) + w.WriteHeader(http.StatusOK) + mustRenderLoginForm(w, r, form) +} + +func HandleLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + user := getUser(r) + if user.LoggedIn { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + locale := getLocale(r) + form := newLoginForm(locale) + if err := form.Parse(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if form.Validate() { + conn := getConn(r) + cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", form.Email, form.Password, remoteAddr(r)) + if cookie != "" { + setSessionCookie(w, cookie) http.Redirect(w, r, "/", http.StatusSeeOther) return } - locale := getLocale(r) - form := newLoginForm(locale) - if r.Method == "POST" { - if err := form.Parse(r); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if form.Validate() { - conn := getConn(r) - cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", form.Email, form.Password, remoteAddr(r)) - if cookie != "" { - setSessionCookie(w, cookie) - http.Redirect(w, r, "/", http.StatusSeeOther) - return - } - form.Errors = append(form.Errors, errors.New(gettext("Invalid user or password.", locale))) - w.WriteHeader(http.StatusUnauthorized) - } else { - w.WriteHeader(http.StatusUnprocessableEntity) - } - } - mustRenderWebTemplate(w, r, "login.gohtml", form) - }) + form.Errors = append(form.Errors, errors.New(gettext("Invalid user or password.", locale))) + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusUnprocessableEntity) + } + mustRenderLoginForm(w, r, form) } -func LogoutHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := verifyCsrfTokenValid(r); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - conn := getConn(r) - conn.MustExec(r.Context(), "select logout()") - http.SetCookie(w, createSessionCookie("", -24*time.Hour)) - http.Redirect(w, r, "/login", http.StatusSeeOther) - }) +func mustRenderLoginForm(w http.ResponseWriter, r *http.Request, form *loginForm) { + mustRenderWebTemplate(w, r, "login.gohtml", form) +} + +func HandleLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + conn := getConn(r) + conn.MustExec(r.Context(), "select logout()") + http.SetCookie(w, createSessionCookie("", -24*time.Hour)) + http.Redirect(w, r, "/login", http.StatusSeeOther) } func remoteAddr(r *http.Request) string { @@ -148,7 +159,7 @@ type AppUser struct { CsrfToken string } -func CheckLogin(db *Db, next http.Handler) http.Handler { +func LoginChecker(db *Db, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx = r.Context() if cookie, err := r.Cookie(sessionCookie); err == nil { @@ -195,13 +206,13 @@ func getConn(r *http.Request) *Conn { return r.Context().Value(ContextConnKey).(*Conn) } -func Authenticated(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func Authenticated(next httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { user := getUser(r) if user.LoggedIn { - next.ServeHTTP(w, r) + next(w, r, params) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } - }) + } } diff --git a/pkg/profile.go b/pkg/profile.go index 0bd472b..1c1b314 100644 --- a/pkg/profile.go +++ b/pkg/profile.go @@ -2,6 +2,7 @@ package pkg import ( "context" + "github.com/julienschmidt/httprouter" "html/template" "net/http" ) @@ -93,38 +94,46 @@ func (form *profileForm) Validate() bool { return validator.AllOK() } -func ProfileHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := getUser(r) - conn := getConn(r) - locale := getLocale(r) - form := newProfileForm(r.Context(), conn, locale) - if r.Method == "POST" { - if err := form.Parse(r); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if err := verifyCsrfTokenValid(r); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - if ok := form.Validate(); ok { - //goland:noinspection SqlWithoutWhere - cookie := conn.MustGetText(r.Context(), "", "update user_profile set name = $1, email = $2, lang_tag = $3 returning build_cookie()", form.Name, form.Email, form.Language) - setSessionCookie(w, cookie) - if form.Password.Val != "" { - conn.MustExec(r.Context(), "select change_password($1)", form.Password) - } - company := getCompany(r) - http.Redirect(w, r, "/company/"+company.Slug+"/profile", http.StatusSeeOther) - return - } - w.WriteHeader(http.StatusUnprocessableEntity) - } else { - form.Name.Val = conn.MustGetText(r.Context(), "", "select name from user_profile") - form.Email.Val = user.Email - form.Language.Selected = user.Language.String() - } - mustRenderAppTemplate(w, r, "profile.gohtml", form) - }) +func GetProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + user := getUser(r) + conn := getConn(r) + locale := getLocale(r) + form := newProfileForm(r.Context(), conn, locale) + form.Name.Val = conn.MustGetText(r.Context(), "", "select name from user_profile") + form.Email.Val = user.Email + form.Language.Selected = user.Language.String() + + w.WriteHeader(http.StatusOK) + mustRenderProfileForm(w, r, form) +} + +func HandleProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + conn := getConn(r) + locale := getLocale(r) + form := newProfileForm(r.Context(), conn, locale) + if err := form.Parse(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + if ok := form.Validate(); !ok { + w.WriteHeader(http.StatusUnprocessableEntity) + mustRenderProfileForm(w, r, form) + return + } + //goland:noinspection SqlWithoutWhere + cookie := conn.MustGetText(r.Context(), "", "update user_profile set name = $1, email = $2, lang_tag = $3 returning build_cookie()", form.Name, form.Email, form.Language) + setSessionCookie(w, cookie) + if form.Password.Val != "" { + conn.MustExec(r.Context(), "select change_password($1)", form.Password) + } + company := getCompany(r) + http.Redirect(w, r, "/company/"+company.Slug+"/profile", http.StatusSeeOther) +} + +func mustRenderProfileForm(w http.ResponseWriter, r *http.Request, form *profileForm) { + mustRenderAppTemplate(w, r, "profile.gohtml", form) } diff --git a/pkg/router.go b/pkg/router.go index ac0c2dc..468b3a6 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -2,39 +2,72 @@ package pkg import ( "net/http" + + "github.com/julienschmidt/httprouter" ) func NewRouter(db *Db) http.Handler { - companyRouter := http.NewServeMux() - companyRouter.Handle("/tax-details", CompanyTaxDetailsHandler()) - companyRouter.Handle("/tax/", CompanyTaxHandler()) - companyRouter.Handle("/tax", CompanyTaxHandler()) - companyRouter.Handle("/profile", ProfileHandler()) - companyRouter.Handle("/contacts/new", NewContactHandler()) - companyRouter.Handle("/contacts", ContactsHandler()) - companyRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + companyRouter := httprouter.New() + companyRouter.GET("/profile", GetProfileForm) + companyRouter.POST("/profile", HandleProfileForm) + companyRouter.GET("/tax-details", GetCompanyTaxDetailsForm) + companyRouter.POST("/tax-details", HandleCompanyTaxDetailsForm) + companyRouter.POST("/tax", HandleAddCompanyTax) + companyRouter.DELETE("/tax/:taxId", HandleDeleteCompanyTax) + companyRouter.GET("/contacts", IndexContacts) + companyRouter.POST("/contacts", HandleAddContact) + companyRouter.GET("/contacts/new", GetNewContactForm) + companyRouter.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { mustRenderAppTemplate(w, r, "dashboard.gohtml", nil) }) - router := http.NewServeMux() - router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static")))) - router.Handle("/login", LoginHandler()) - router.Handle("/logout", Authenticated(LogoutHandler())) - router.Handle("/company/", Authenticated(http.StripPrefix("/company/", CompanyHandler(companyRouter)))) - router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + router := httprouter.New() + router.ServeFiles("/static/*filepath", http.Dir("web/static")) + router.GET("/login", GetLoginForm) + router.POST("/login", HandleLoginForm) + router.POST("/logout", Authenticated(HandleLogout)) + + companyHandler := Authenticated(CompanyHandler(companyRouter)) + router.GET("/company/:slug/*rest", companyHandler) + router.POST("/company/:slug/*rest", companyHandler) + router.PUT("/company/:slug/*rest", companyHandler) + router.DELETE("/company/:slug/*rest", companyHandler) + + router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { user := getUser(r) if user.LoggedIn { conn := getConn(r) slug := conn.MustGetText(r.Context(), "", "select slug::text from company order by company_id limit 1") - http.Redirect(w, r, "/company/"+slug, http.StatusFound) + http.Redirect(w, r, "/company/"+slug+"/", http.StatusFound) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } }) + var handler http.Handler = router - handler = SetLocale(db, handler) - handler = CheckLogin(db, handler) + handler = MethodOverrider(handler) + handler = LocaleSetter(db, handler) + handler = LoginChecker(db, handler) handler = Recoverer(handler) handler = Logger(handler) return handler } + +func MethodOverrider(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + override := r.FormValue("_method") + if override == http.MethodDelete || override == http.MethodPut { + r2 := new(http.Request) + *r2 = *r + r2.Method = override + r = r2 + } + } + next.ServeHTTP(w, r) + }) +}