From f65110824e9891cfc33c8843f8fca3efe4a4c3eb Mon Sep 17 00:00:00 2001 From: jordi fita mas Date: Mon, 31 Jul 2023 18:51:50 +0200 Subject: [PATCH] =?UTF-8?q?Add=20the=20company=E2=80=99s=20slug=20in=20the?= =?UTF-8?q?=20URL=20before=20company-dependent=20handlers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I really doubt that they are going to use more than a single company, but the application is based on Numerus, that **does** have multiple company, and followed the same architecture and philosophy: use the URL to choose the company to manage, even if the user has a single company. The reason i use the slug instead of the ID is because i do not want to make the ID public in case the application is really used by employees of many unrelated companies: they need not need to guess how many companies there are based on the ID. I validate this slug to be a valid UUID instead of relaying on the query’s empty result because casting a string with a malformed value to UUID results in an error other than data not found. Not with that select, but it would fail with a function parameter, and i want to add that UUID check to all functions that do use slugs. I based uuid.Valid function on Parse() from Google’s uuid package[0] instead of using regular expression, as it was my first idea, because that function is an order of magnitude faster in benchmarks: goos: linux goarch: amd64 pkg: dev.tandem.ws/tandem/numerus/pkg cpu: Intel(R) Core(TM) i5-6200U CPU @ 2.30GHz BenchmarkValidUuid-4 36946050 29.37 ns/op BenchmarkValidUuid_Re-4 3633169 306.70 ns/op The regular expression used for the benchmark was: var re = regexp.MustCompile("^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[8|9|aA|bB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$") And the input parameter for both functions was the following valid UUID, because most of the time the passed UUID will be valid: "f47ac10b-58cc-0372-8567-0e02b2c3d479" I did not use the uuid package as is, even though it is in Debian’s repository, because i only need to check whether the value is valid, not convert it to a byte array. As far as i know, that package can not do that. Adding the Company struct into auth was not my intention, as it makes little sense name-wise, but i need to have the Company when rendering templates and the company package has templates to render, thus using the company package for the Company struct would create a dependency loop between template and company. I’ve chosen the auth package only because User is also there; User and Company are very much related in this application, but not enough to include the company inside the user, or vice versa, as the User comes from the cookie while the company from the URL. Finally, had to move methodNotAllowed to the http package, as an exported function, because it is used now from other packages, namely campsite. [0]: https://github.com/google/uuid --- pkg/app/app.go | 31 ++++++-------- pkg/app/login.go | 2 +- pkg/app/user.go | 8 ++-- pkg/auth/company.go | 55 ++++++++++++++++++++++++ pkg/campsite/{handler.go => http.go} | 4 +- pkg/campsite/type.go | 13 ++++-- pkg/company/http.go | 64 ++++++++++++++++++++++++++++ pkg/database/db.go | 17 +++++++- pkg/http/request.go | 5 +++ pkg/template/render.go | 11 ++--- pkg/uuid/uuid.go | 51 ++++++++++++++++++++++ pkg/uuid/uuid_test.go | 38 +++++++++++++++++ web/templates/layout.gohtml | 9 ++++ 13 files changed, 273 insertions(+), 35 deletions(-) create mode 100644 pkg/auth/company.go rename pkg/campsite/{handler.go => http.go} (79%) create mode 100644 pkg/company/http.go create mode 100644 pkg/uuid/uuid.go create mode 100644 pkg/uuid/uuid_test.go diff --git a/pkg/app/app.go b/pkg/app/app.go index a22ddb7..514fcc4 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -7,28 +7,21 @@ package app import ( "net/http" - "strings" "golang.org/x/text/language" "dev.tandem.ws/tandem/camper/pkg/auth" - "dev.tandem.ws/tandem/camper/pkg/campsite" + "dev.tandem.ws/tandem/camper/pkg/company" "dev.tandem.ws/tandem/camper/pkg/database" httplib "dev.tandem.ws/tandem/camper/pkg/http" "dev.tandem.ws/tandem/camper/pkg/locale" - "dev.tandem.ws/tandem/camper/pkg/template" ) -func methodNotAllowed(w http.ResponseWriter, _ *http.Request, allowed ...string) { - w.Header().Set("Allow", strings.Join(allowed, ", ")) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - type App struct { db *database.DB fileHandler http.Handler profile *profileHandler - campsite *campsite.Handler + company *company.Handler locales locale.Locales defaultLocale *locale.Locale languageMatcher language.Matcher @@ -45,7 +38,7 @@ func New(db *database.DB, avatarsDir string) (http.Handler, error) { db: db, fileHandler: static, profile: profile, - campsite: campsite.NewHandler(), + company: company.NewHandler(), locales: locales, defaultLocale: locales[language.Catalan], languageMatcher: language.NewMatcher(locales.Tags()), @@ -86,7 +79,7 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { case http.MethodPost: handleLogin(w, r, user, conn) default: - methodNotAllowed(w, r, http.MethodPost, http.MethodGet) + httplib.MethodNotAllowed(w, r, http.MethodPost, http.MethodGet) } } else { if !user.LoggedIn { @@ -98,14 +91,14 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch head { case "me": h.profile.Handler(user, conn).ServeHTTP(w, r) - case "campsites": - h.campsite.Handler(user, conn).ServeHTTP(w, r) + case "company": + h.company.Handler(user, conn).ServeHTTP(w, r) case "": switch r.Method { case http.MethodGet: - h.serveDashboard(w, r, user) + redirectToMainCompany(w, r, conn) default: - methodNotAllowed(w, r, http.MethodGet) + httplib.MethodNotAllowed(w, r, http.MethodGet) } default: http.NotFound(w, r) @@ -114,6 +107,10 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (h *App) serveDashboard(w http.ResponseWriter, r *http.Request, user *auth.User) { - template.MustRender(w, r, user, "dashboard.gohtml", nil) +func redirectToMainCompany(w http.ResponseWriter, r *http.Request, conn *database.Conn) { + co, err := auth.QueryMainCompany(r.Context(), conn) + if err != nil { + panic(err) + } + httplib.Relocate(w, r, co.URL(), http.StatusFound) } diff --git a/pkg/app/login.go b/pkg/app/login.go index a7bd26a..6cd6834 100644 --- a/pkg/app/login.go +++ b/pkg/app/login.go @@ -61,7 +61,7 @@ func (f *loginForm) Valid(l *locale.Locale) bool { } func (f *loginForm) MustRender(w http.ResponseWriter, r *http.Request, user *auth.User) { - template.MustRender(w, r, user, "login.gohtml", f) + template.MustRender(w, r, user, nil, "login.gohtml", f) } func serveLoginForm(w http.ResponseWriter, r *http.Request, user *auth.User, redirectPath string) { diff --git a/pkg/app/user.go b/pkg/app/user.go index a0248ad..da28d8d 100644 --- a/pkg/app/user.go +++ b/pkg/app/user.go @@ -91,14 +91,14 @@ func (h *profileHandler) Handler(user *auth.User, conn *database.Conn) http.Hand case http.MethodGet: h.serveAvatar(w, r, user) default: - methodNotAllowed(w, r, http.MethodGet) + httplib.MethodNotAllowed(w, r, http.MethodGet) } case "session": switch r.Method { case http.MethodDelete: handleLogout(w, r, user, conn) default: - methodNotAllowed(w, r, http.MethodDelete) + httplib.MethodNotAllowed(w, r, http.MethodDelete) } case "": switch r.Method { @@ -107,7 +107,7 @@ func (h *profileHandler) Handler(user *auth.User, conn *database.Conn) http.Hand case http.MethodPut: h.updateProfile(w, r, user, conn) default: - methodNotAllowed(w, r, http.MethodGet, http.MethodPut) + httplib.MethodNotAllowed(w, r, http.MethodGet, http.MethodPut) } default: http.NotFound(w, r) @@ -249,7 +249,7 @@ func (f *profileForm) Valid(l *locale.Locale) bool { } func (f *profileForm) MustRender(w http.ResponseWriter, r *http.Request, user *auth.User) { - template.MustRender(w, r, user, "profile.gohtml", f) + template.MustRender(w, r, user, nil, "profile.gohtml", f) } func (f *profileForm) HasAvatarFile() bool { diff --git a/pkg/auth/company.go b/pkg/auth/company.go new file mode 100644 index 0000000..d653645 --- /dev/null +++ b/pkg/auth/company.go @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package auth + +import ( + "context" + + "dev.tandem.ws/tandem/camper/pkg/database" +) + +type Company struct { + ID int + CurrencySymbol string + DecimalDigits int + Slug string +} + +func QueryMainCompany(ctx context.Context, conn *database.Conn) (*Company, error) { + slug, err := conn.GetText(ctx, "select slug::text from company order by company_id limit 1") + if err != nil { + return nil, err + } + return QueryBySlug(ctx, conn, slug) +} + +func QueryBySlug(ctx context.Context, conn *database.Conn, slug string) (*Company, error) { + company := &Company{ + Slug: slug, + } + if err := conn.QueryRow(ctx, ` + select company_id + , currency_symbol + , decimal_digits + from company + join currency using (currency_code) + where slug = $1 + `, company.Slug).Scan( + &company.ID, + &company.CurrencySymbol, + &company.DecimalDigits, + ); err != nil { + return nil, err + } + return company, nil +} + +func (c *Company) URL() string { + if c == nil { + return "" + } + return "/company/" + c.Slug +} diff --git a/pkg/campsite/handler.go b/pkg/campsite/http.go similarity index 79% rename from pkg/campsite/handler.go rename to pkg/campsite/http.go index ca7463f..dce90b6 100644 --- a/pkg/campsite/handler.go +++ b/pkg/campsite/http.go @@ -23,14 +23,14 @@ func NewHandler() *Handler { } } -func (h *Handler) Handler(user *auth.User, conn *database.Conn) http.HandlerFunc { +func (h *Handler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var head string head, r.URL.Path = httplib.ShiftPath(r.URL.Path) switch head { case "types": - h.types.Handler(user, conn).ServeHTTP(w, r) + h.types.Handler(user, company, conn).ServeHTTP(w, r) default: http.NotFound(w, r) } diff --git a/pkg/campsite/type.go b/pkg/campsite/type.go index 5900d04..060a258 100644 --- a/pkg/campsite/type.go +++ b/pkg/campsite/type.go @@ -17,16 +17,21 @@ import ( type typeHandler struct { } -func (h *typeHandler) Handler(user *auth.User, conn *database.Conn) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (h *typeHandler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var head string head, r.URL.Path = httplib.ShiftPath(r.URL.Path) switch head { case "new": - template.MustRender(w, r, user, "campsite/type/new.gohtml", nil) + switch r.Method { + case http.MethodGet: + template.MustRender(w, r, user, company, "campsite/type/new.gohtml", nil) + default: + httplib.MethodNotAllowed(w, r, http.MethodGet) + } default: http.NotFound(w, r) } - } + }) } diff --git a/pkg/company/http.go b/pkg/company/http.go new file mode 100644 index 0000000..2758d46 --- /dev/null +++ b/pkg/company/http.go @@ -0,0 +1,64 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package company + +import ( + "dev.tandem.ws/tandem/camper/pkg/auth" + "dev.tandem.ws/tandem/camper/pkg/campsite" + "dev.tandem.ws/tandem/camper/pkg/database" + httplib "dev.tandem.ws/tandem/camper/pkg/http" + "dev.tandem.ws/tandem/camper/pkg/template" + "dev.tandem.ws/tandem/camper/pkg/uuid" + "net/http" +) + +type Handler struct { + campsite *campsite.Handler +} + +func NewHandler() *Handler { + return &Handler{ + campsite: campsite.NewHandler(), + } +} + +func (h *Handler) Handler(user *auth.User, conn *database.Conn) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var slug string + slug, r.URL.Path = httplib.ShiftPath(r.URL.Path) + if !uuid.Valid(slug) { + http.NotFound(w, r) + return + } + company, err := auth.QueryBySlug(r.Context(), conn, slug) + if database.ErrorIsNotFound(err) { + http.NotFound(w, r) + return + } else if err != nil { + panic(err) + } + + var head string + head, r.URL.Path = httplib.ShiftPath(r.URL.Path) + switch head { + case "campsites": + h.campsite.Handler(user, company, conn).ServeHTTP(w, r) + case "": + switch r.Method { + case http.MethodGet: + serveDashboard(w, r, user, company) + default: + httplib.MethodNotAllowed(w, r, http.MethodGet) + } + default: + http.NotFound(w, r) + } + }) +} + +func serveDashboard(w http.ResponseWriter, r *http.Request, user *auth.User, company *auth.Company) { + template.MustRender(w, r, user, company, "dashboard.gohtml", nil) +} diff --git a/pkg/database/db.go b/pkg/database/db.go index 956d942..fbeba6b 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -7,6 +7,7 @@ package database import ( "context" + "errors" "log" "github.com/jackc/pgconn" @@ -14,6 +15,10 @@ import ( "github.com/jackc/pgx/v4/pgxpool" ) +func ErrorIsNotFound(err error) bool { + return errors.Is(err, pgx.ErrNoRows) +} + func New(ctx context.Context, connString string) (*DB, error) { config, err := pgxpool.ParseConfig(connString) if err != nil { @@ -66,10 +71,18 @@ func (c *Conn) MustExec(ctx context.Context, sql string, args ...interface{}) pg return tag } -func (c *Conn) MustGetText(ctx context.Context, sql string, args ...interface{}) string { +func (c *Conn) GetText(ctx context.Context, sql string, args ...interface{}) (string, error) { var result string if err := c.QueryRow(ctx, sql, args...).Scan(&result); err != nil { + return "", err + } + return result, nil +} + +func (c *Conn) MustGetText(ctx context.Context, sql string, args ...interface{}) string { + if result, err := c.GetText(ctx, sql, args...); err == nil { + return result + } else { panic(err) } - return result } diff --git a/pkg/http/request.go b/pkg/http/request.go index acae37e..f4894b4 100644 --- a/pkg/http/request.go +++ b/pkg/http/request.go @@ -37,3 +37,8 @@ func ShiftPath(p string) (head, tail string) { return p[1:i], p[i:] } } + +func MethodNotAllowed(w http.ResponseWriter, _ *http.Request, allowed ...string) { + w.Header().Set("Allow", strings.Join(allowed, ", ")) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} diff --git a/pkg/template/render.go b/pkg/template/render.go index fb74ceb..e21139d 100644 --- a/pkg/template/render.go +++ b/pkg/template/render.go @@ -19,19 +19,20 @@ func templateFile(name string) string { return "web/templates/" + name } -func MustRender(w io.Writer, r *http.Request, user *auth.User, filename string, data interface{}) { +func MustRender(w io.Writer, r *http.Request, user *auth.User, company *auth.Company, filename string, data interface{}) { layout := "layout.gohtml" if httplib.IsHTMxRequest(r) { layout = "htmx.gohtml" } - mustRenderLayout(w, user, layout, filename, data) + mustRenderLayout(w, user, company, layout, filename, data) } -func mustRenderLayout(w io.Writer, user *auth.User, layout string, filename string, data interface{}) { +func mustRenderLayout(w io.Writer, user *auth.User, company *auth.Company, layout string, filename string, data interface{}) { t := template.New(filename) t.Funcs(template.FuncMap{ - "gettext": user.Locale.Get, - "pgettext": user.Locale.GetC, + "gettext": user.Locale.Get, + "pgettext": user.Locale.GetC, + "companyURL": company.URL, "currentLocale": func() string { return user.Locale.Language.String() }, diff --git a/pkg/uuid/uuid.go b/pkg/uuid/uuid.go new file mode 100644 index 0000000..3a06145 --- /dev/null +++ b/pkg/uuid/uuid.go @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package uuid + +func Valid(s string) bool { + if len(s) != 36 { + return false + } + // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { + return false + } + for _, x := range [16]int{ + 0, 2, 4, 6, + 9, 11, + 14, 16, + 19, 21, + 24, 26, 28, 30, 32, 34} { + if !validHex(s[x], s[x+1]) { + return false + } + } + return true +} + +// xvalues returns the value of a byte as a hexadecimal digit or 255. +var xvalues = [256]byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255, + 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, +} + +func validHex(x1, x2 byte) bool { + return xvalues[x1] != 255 && xvalues[x2] != 255 +} diff --git a/pkg/uuid/uuid_test.go b/pkg/uuid/uuid_test.go new file mode 100644 index 0000000..d57174f --- /dev/null +++ b/pkg/uuid/uuid_test.go @@ -0,0 +1,38 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package uuid + +import ( + "strings" + "testing" +) + +type test struct { + in string + isUuid bool +} + +var tests = []test{ + {"f47ac10b-58cc-0372-8567-0e02b2c3d479", true}, + {"2bc1be74-169d-4300-a239-49a1196a045d", true}, + {"12bc1be74-169d-4300-a239-49a1196a045d", false}, + {"2bc1be74-169d-4300-a239-49a1196a045", false}, + {"2bc1be74-1x9d-4300-a239-49a1196a045d", false}, + {"2bc1be74-169d-4300-a239-49a1196ag45d", false}, +} + +func testValid(t *testing.T, in string, isUuid bool) { + if ok := Valid(in); ok != isUuid { + t.Errorf("Valid(%s) got %v expected %v", in, ok, isUuid) + } +} + +func TestUUID(t *testing.T) { + for _, tt := range tests { + testValid(t, tt.in, tt.isUuid) + testValid(t, strings.ToUpper(tt.in), tt.isUuid) + } +} diff --git a/web/templates/layout.gohtml b/web/templates/layout.gohtml index e31badb..96d0b0a 100644 --- a/web/templates/layout.gohtml +++ b/web/templates/layout.gohtml @@ -36,6 +36,15 @@ {{- end }} +{{ if isLoggedIn -}} + +{{- end }}
{{- template "content" . }}