From 2799fdb8db416f67eeac63339d7478da3434ae2e Mon Sep 17 00:00:00 2001 From: jordi fita mas Date: Sat, 4 Feb 2023 10:43:42 +0100 Subject: [PATCH] Add companyURI for Go code too, not just templates --- pkg/company.go | 7 +++---- pkg/contacts.go | 9 ++++----- pkg/profile.go | 4 ++-- pkg/router.go | 6 ++++-- pkg/template.go | 12 ++++++++---- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/pkg/company.go b/pkg/company.go index 6acdd78..80016b5 100644 --- a/pkg/company.go +++ b/pkg/company.go @@ -144,7 +144,7 @@ func HandleCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httpr } company := mustGetCompany(r) conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id) - http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) + http.Redirect(w, r, companyURI(company, "/tax-details"), http.StatusSeeOther) return } @@ -260,9 +260,8 @@ func HandleDeleteCompanyTax(w http.ResponseWriter, r *http.Request, params httpr return } conn := getConn(r) - company := mustGetCompany(r) conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId) - http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) + http.Redirect(w, r, companyURI(mustGetCompany(r), "/tax-details"), http.StatusSeeOther) } func HandleAddCompanyTax(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { @@ -284,5 +283,5 @@ func HandleAddCompanyTax(w http.ResponseWriter, r *http.Request, _ httprouter.Pa conn := getConn(r) company := mustGetCompany(r) conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer()) - http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) + http.Redirect(w, r, companyURI(company, "/tax-details"), http.StatusSeeOther) } diff --git a/pkg/contacts.go b/pkg/contacts.go index 2fa1595..5e30bd1 100644 --- a/pkg/contacts.go +++ b/pkg/contacts.go @@ -21,7 +21,7 @@ type ContactsIndexPage struct { func IndexContacts(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { conn := getConn(r) - company := getCompany(r) + company := mustGetCompany(r) page := &ContactsIndexPage{ Contacts: mustGetContactEntries(r.Context(), conn, company), } @@ -75,9 +75,9 @@ func HandleAddContact(w http.ResponseWriter, r *http.Request, _ httprouter.Param mustRenderNewContactForm(w, r, form) return } - company := getCompany(r) + company := mustGetCompany(r) conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, city, province, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country) - http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther) + http.Redirect(w, r, companyURI(company, "/contacts"), http.StatusSeeOther) } func HandleUpdateContact(w http.ResponseWriter, r *http.Request, params httprouter.Params) { @@ -100,8 +100,7 @@ func HandleUpdateContact(w http.ResponseWriter, r *http.Request, params httprout if slug == "" { http.NotFound(w, r) } - company := getCompany(r) - http.Redirect(w, r, "/company/"+company.Slug+"/contacts/"+slug, http.StatusSeeOther) + http.Redirect(w, r, companyURI(mustGetCompany(r), "/contacts/"+slug), http.StatusSeeOther) } func mustGetContactEntries(ctx context.Context, conn *Conn, company *Company) []*ContactEntry { diff --git a/pkg/profile.go b/pkg/profile.go index 1c1b314..4cb8ed1 100644 --- a/pkg/profile.go +++ b/pkg/profile.go @@ -130,8 +130,8 @@ func HandleProfileForm(w http.ResponseWriter, r *http.Request, _ httprouter.Para if form.Password.Val != "" { conn.MustExec(r.Context(), "select change_password($1)", form.Password) } - company := getCompany(r) - http.Redirect(w, r, "/company/"+company.Slug+"/profile", http.StatusSeeOther) + company := mustGetCompany(r) + http.Redirect(w, r, companyURI(company, "/profile"), http.StatusSeeOther) } func mustRenderProfileForm(w http.ResponseWriter, r *http.Request, form *profileForm) { diff --git a/pkg/router.go b/pkg/router.go index ab7807c..0300cac 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -38,8 +38,10 @@ func NewRouter(db *Db) http.Handler { user := getUser(r) if user.LoggedIn { conn := getConn(r) - slug := conn.MustGetText(r.Context(), "", "select slug::text from company order by company_id limit 1") - http.Redirect(w, r, "/company/"+slug+"/", http.StatusFound) + company := &Company{ + Slug: conn.MustGetText(r.Context(), "", "select slug::text from company order by company_id limit 1"), + } + http.Redirect(w, r, companyURI(company, "/"), http.StatusFound) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } diff --git a/pkg/template.go b/pkg/template.go index ae28dcd..f127f3c 100644 --- a/pkg/template.go +++ b/pkg/template.go @@ -25,10 +25,7 @@ func mustRenderTemplate(wr io.Writer, r *http.Request, layout string, filename s return locale.Language.String() }, "companyURI": func(uri string) string { - if company == nil { - return uri - } - return "/company/" + company.Slug + uri + return companyURI(company, uri) }, "csrfToken": func() template.HTML { return template.HTML(fmt.Sprintf(``, csrfTokenField, user.CsrfToken)) @@ -56,6 +53,13 @@ func mustRenderTemplate(wr io.Writer, r *http.Request, layout string, filename s } } +func companyURI(company *Company, uri string) string { + if company == nil { + return uri + } + return "/company/" + company.Slug + uri +} + func overrideMethodField(method string) template.HTML { return template.HTML(fmt.Sprintf(``, overrideMethodName, method)) }