Add the company’s slug in the URL before company-dependent handlers

I really doubt that they are going to use more than a single company,
but the application is based on Numerus, that **does** have multiple
company, and followed the same architecture and philosophy: use the URL
to choose the company to manage, even if the user has a single company.

The reason i use the slug instead of the ID is because i do not want to
make the ID public in case the application is really used by employees
of many unrelated companies: they need not need to guess how many
companies there are based on the ID.

I validate this slug to be a valid UUID instead of relaying on the
query’s empty result because casting a string with a malformed value to
UUID results in an error other than data not found.  Not with that
select, but it would fail with a function parameter, and i want to add
that UUID check to all functions that do use slugs.

I based uuid.Valid function on Parse() from Google’s uuid package[0]
instead of using regular expression, as it was my first idea, because
that function is an order of magnitude faster in benchmarks:

  goos: linux
  goarch: amd64
  pkg: dev.tandem.ws/tandem/numerus/pkg
  cpu: Intel(R) Core(TM) i5-6200U CPU @ 2.30GHz
  BenchmarkValidUuid-4            36946050                29.37 ns/op
  BenchmarkValidUuid_Re-4          3633169               306.70 ns/op

The regular expression used for the benchmark was:

  var re = regexp.MustCompile("^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[8|9|aA|bB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$")

And the input parameter for both functions was the following valid UUID,
because most of the time the passed UUID will be valid:

  "f47ac10b-58cc-0372-8567-0e02b2c3d479"

I did not use the uuid package as is, even though it is in Debian’s
repository, because i only need to check whether the value is valid,
not convert it to a byte array.  As far as i know, that package can not
do that.

Adding the Company struct into auth was not my intention, as it makes
little sense name-wise, but i need to have the Company when rendering
templates and the company package has templates to render, thus using
the company package for the Company struct would create a dependency
loop between template and company.  I’ve chosen the auth package only
because User is also there; User and Company are very much related in
this application, but not enough to include the company inside the user,
or vice versa, as the User comes from the cookie while the company from
the URL.

Finally, had to move methodNotAllowed to the http package, as an
exported function, because it is used now from other packages, namely
campsite.

[0]: https://github.com/google/uuid
This commit is contained in:
jordi fita mas 2023-07-31 18:51:50 +02:00
parent 5a2c8fea41
commit f65110824e
13 changed files with 273 additions and 35 deletions

View File

@ -7,28 +7,21 @@ package app
import (
"net/http"
"strings"
"golang.org/x/text/language"
"dev.tandem.ws/tandem/camper/pkg/auth"
"dev.tandem.ws/tandem/camper/pkg/campsite"
"dev.tandem.ws/tandem/camper/pkg/company"
"dev.tandem.ws/tandem/camper/pkg/database"
httplib "dev.tandem.ws/tandem/camper/pkg/http"
"dev.tandem.ws/tandem/camper/pkg/locale"
"dev.tandem.ws/tandem/camper/pkg/template"
)
func methodNotAllowed(w http.ResponseWriter, _ *http.Request, allowed ...string) {
w.Header().Set("Allow", strings.Join(allowed, ", "))
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
type App struct {
db *database.DB
fileHandler http.Handler
profile *profileHandler
campsite *campsite.Handler
company *company.Handler
locales locale.Locales
defaultLocale *locale.Locale
languageMatcher language.Matcher
@ -45,7 +38,7 @@ func New(db *database.DB, avatarsDir string) (http.Handler, error) {
db: db,
fileHandler: static,
profile: profile,
campsite: campsite.NewHandler(),
company: company.NewHandler(),
locales: locales,
defaultLocale: locales[language.Catalan],
languageMatcher: language.NewMatcher(locales.Tags()),
@ -86,7 +79,7 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
handleLogin(w, r, user, conn)
default:
methodNotAllowed(w, r, http.MethodPost, http.MethodGet)
httplib.MethodNotAllowed(w, r, http.MethodPost, http.MethodGet)
}
} else {
if !user.LoggedIn {
@ -98,14 +91,14 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch head {
case "me":
h.profile.Handler(user, conn).ServeHTTP(w, r)
case "campsites":
h.campsite.Handler(user, conn).ServeHTTP(w, r)
case "company":
h.company.Handler(user, conn).ServeHTTP(w, r)
case "":
switch r.Method {
case http.MethodGet:
h.serveDashboard(w, r, user)
redirectToMainCompany(w, r, conn)
default:
methodNotAllowed(w, r, http.MethodGet)
httplib.MethodNotAllowed(w, r, http.MethodGet)
}
default:
http.NotFound(w, r)
@ -114,6 +107,10 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (h *App) serveDashboard(w http.ResponseWriter, r *http.Request, user *auth.User) {
template.MustRender(w, r, user, "dashboard.gohtml", nil)
func redirectToMainCompany(w http.ResponseWriter, r *http.Request, conn *database.Conn) {
co, err := auth.QueryMainCompany(r.Context(), conn)
if err != nil {
panic(err)
}
httplib.Relocate(w, r, co.URL(), http.StatusFound)
}

View File

@ -61,7 +61,7 @@ func (f *loginForm) Valid(l *locale.Locale) bool {
}
func (f *loginForm) MustRender(w http.ResponseWriter, r *http.Request, user *auth.User) {
template.MustRender(w, r, user, "login.gohtml", f)
template.MustRender(w, r, user, nil, "login.gohtml", f)
}
func serveLoginForm(w http.ResponseWriter, r *http.Request, user *auth.User, redirectPath string) {

View File

@ -91,14 +91,14 @@ func (h *profileHandler) Handler(user *auth.User, conn *database.Conn) http.Hand
case http.MethodGet:
h.serveAvatar(w, r, user)
default:
methodNotAllowed(w, r, http.MethodGet)
httplib.MethodNotAllowed(w, r, http.MethodGet)
}
case "session":
switch r.Method {
case http.MethodDelete:
handleLogout(w, r, user, conn)
default:
methodNotAllowed(w, r, http.MethodDelete)
httplib.MethodNotAllowed(w, r, http.MethodDelete)
}
case "":
switch r.Method {
@ -107,7 +107,7 @@ func (h *profileHandler) Handler(user *auth.User, conn *database.Conn) http.Hand
case http.MethodPut:
h.updateProfile(w, r, user, conn)
default:
methodNotAllowed(w, r, http.MethodGet, http.MethodPut)
httplib.MethodNotAllowed(w, r, http.MethodGet, http.MethodPut)
}
default:
http.NotFound(w, r)
@ -249,7 +249,7 @@ func (f *profileForm) Valid(l *locale.Locale) bool {
}
func (f *profileForm) MustRender(w http.ResponseWriter, r *http.Request, user *auth.User) {
template.MustRender(w, r, user, "profile.gohtml", f)
template.MustRender(w, r, user, nil, "profile.gohtml", f)
}
func (f *profileForm) HasAvatarFile() bool {

55
pkg/auth/company.go Normal file
View File

@ -0,0 +1,55 @@
/*
* SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
* SPDX-License-Identifier: AGPL-3.0-only
*/
package auth
import (
"context"
"dev.tandem.ws/tandem/camper/pkg/database"
)
type Company struct {
ID int
CurrencySymbol string
DecimalDigits int
Slug string
}
func QueryMainCompany(ctx context.Context, conn *database.Conn) (*Company, error) {
slug, err := conn.GetText(ctx, "select slug::text from company order by company_id limit 1")
if err != nil {
return nil, err
}
return QueryBySlug(ctx, conn, slug)
}
func QueryBySlug(ctx context.Context, conn *database.Conn, slug string) (*Company, error) {
company := &Company{
Slug: slug,
}
if err := conn.QueryRow(ctx, `
select company_id
, currency_symbol
, decimal_digits
from company
join currency using (currency_code)
where slug = $1
`, company.Slug).Scan(
&company.ID,
&company.CurrencySymbol,
&company.DecimalDigits,
); err != nil {
return nil, err
}
return company, nil
}
func (c *Company) URL() string {
if c == nil {
return ""
}
return "/company/" + c.Slug
}

View File

@ -23,14 +23,14 @@ func NewHandler() *Handler {
}
}
func (h *Handler) Handler(user *auth.User, conn *database.Conn) http.HandlerFunc {
func (h *Handler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var head string
head, r.URL.Path = httplib.ShiftPath(r.URL.Path)
switch head {
case "types":
h.types.Handler(user, conn).ServeHTTP(w, r)
h.types.Handler(user, company, conn).ServeHTTP(w, r)
default:
http.NotFound(w, r)
}

View File

@ -17,16 +17,21 @@ import (
type typeHandler struct {
}
func (h *typeHandler) Handler(user *auth.User, conn *database.Conn) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
func (h *typeHandler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var head string
head, r.URL.Path = httplib.ShiftPath(r.URL.Path)
switch head {
case "new":
template.MustRender(w, r, user, "campsite/type/new.gohtml", nil)
switch r.Method {
case http.MethodGet:
template.MustRender(w, r, user, company, "campsite/type/new.gohtml", nil)
default:
httplib.MethodNotAllowed(w, r, http.MethodGet)
}
default:
http.NotFound(w, r)
}
}
})
}

64
pkg/company/http.go Normal file
View File

@ -0,0 +1,64 @@
/*
* SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
* SPDX-License-Identifier: AGPL-3.0-only
*/
package company
import (
"dev.tandem.ws/tandem/camper/pkg/auth"
"dev.tandem.ws/tandem/camper/pkg/campsite"
"dev.tandem.ws/tandem/camper/pkg/database"
httplib "dev.tandem.ws/tandem/camper/pkg/http"
"dev.tandem.ws/tandem/camper/pkg/template"
"dev.tandem.ws/tandem/camper/pkg/uuid"
"net/http"
)
type Handler struct {
campsite *campsite.Handler
}
func NewHandler() *Handler {
return &Handler{
campsite: campsite.NewHandler(),
}
}
func (h *Handler) Handler(user *auth.User, conn *database.Conn) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var slug string
slug, r.URL.Path = httplib.ShiftPath(r.URL.Path)
if !uuid.Valid(slug) {
http.NotFound(w, r)
return
}
company, err := auth.QueryBySlug(r.Context(), conn, slug)
if database.ErrorIsNotFound(err) {
http.NotFound(w, r)
return
} else if err != nil {
panic(err)
}
var head string
head, r.URL.Path = httplib.ShiftPath(r.URL.Path)
switch head {
case "campsites":
h.campsite.Handler(user, company, conn).ServeHTTP(w, r)
case "":
switch r.Method {
case http.MethodGet:
serveDashboard(w, r, user, company)
default:
httplib.MethodNotAllowed(w, r, http.MethodGet)
}
default:
http.NotFound(w, r)
}
})
}
func serveDashboard(w http.ResponseWriter, r *http.Request, user *auth.User, company *auth.Company) {
template.MustRender(w, r, user, company, "dashboard.gohtml", nil)
}

View File

@ -7,6 +7,7 @@ package database
import (
"context"
"errors"
"log"
"github.com/jackc/pgconn"
@ -14,6 +15,10 @@ import (
"github.com/jackc/pgx/v4/pgxpool"
)
func ErrorIsNotFound(err error) bool {
return errors.Is(err, pgx.ErrNoRows)
}
func New(ctx context.Context, connString string) (*DB, error) {
config, err := pgxpool.ParseConfig(connString)
if err != nil {
@ -66,10 +71,18 @@ func (c *Conn) MustExec(ctx context.Context, sql string, args ...interface{}) pg
return tag
}
func (c *Conn) MustGetText(ctx context.Context, sql string, args ...interface{}) string {
func (c *Conn) GetText(ctx context.Context, sql string, args ...interface{}) (string, error) {
var result string
if err := c.QueryRow(ctx, sql, args...).Scan(&result); err != nil {
return "", err
}
return result, nil
}
func (c *Conn) MustGetText(ctx context.Context, sql string, args ...interface{}) string {
if result, err := c.GetText(ctx, sql, args...); err == nil {
return result
} else {
panic(err)
}
return result
}

View File

@ -37,3 +37,8 @@ func ShiftPath(p string) (head, tail string) {
return p[1:i], p[i:]
}
}
func MethodNotAllowed(w http.ResponseWriter, _ *http.Request, allowed ...string) {
w.Header().Set("Allow", strings.Join(allowed, ", "))
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

View File

@ -19,19 +19,20 @@ func templateFile(name string) string {
return "web/templates/" + name
}
func MustRender(w io.Writer, r *http.Request, user *auth.User, filename string, data interface{}) {
func MustRender(w io.Writer, r *http.Request, user *auth.User, company *auth.Company, filename string, data interface{}) {
layout := "layout.gohtml"
if httplib.IsHTMxRequest(r) {
layout = "htmx.gohtml"
}
mustRenderLayout(w, user, layout, filename, data)
mustRenderLayout(w, user, company, layout, filename, data)
}
func mustRenderLayout(w io.Writer, user *auth.User, layout string, filename string, data interface{}) {
func mustRenderLayout(w io.Writer, user *auth.User, company *auth.Company, layout string, filename string, data interface{}) {
t := template.New(filename)
t.Funcs(template.FuncMap{
"gettext": user.Locale.Get,
"pgettext": user.Locale.GetC,
"companyURL": company.URL,
"currentLocale": func() string {
return user.Locale.Language.String()
},

51
pkg/uuid/uuid.go Normal file
View File

@ -0,0 +1,51 @@
/*
* SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
* SPDX-License-Identifier: AGPL-3.0-only
*/
package uuid
func Valid(s string) bool {
if len(s) != 36 {
return false
}
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return false
}
for _, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
if !validHex(s[x], s[x+1]) {
return false
}
}
return true
}
// xvalues returns the value of a byte as a hexadecimal digit or 255.
var xvalues = [256]byte{
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
}
func validHex(x1, x2 byte) bool {
return xvalues[x1] != 255 && xvalues[x2] != 255
}

38
pkg/uuid/uuid_test.go Normal file
View File

@ -0,0 +1,38 @@
/*
* SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
* SPDX-License-Identifier: AGPL-3.0-only
*/
package uuid
import (
"strings"
"testing"
)
type test struct {
in string
isUuid bool
}
var tests = []test{
{"f47ac10b-58cc-0372-8567-0e02b2c3d479", true},
{"2bc1be74-169d-4300-a239-49a1196a045d", true},
{"12bc1be74-169d-4300-a239-49a1196a045d", false},
{"2bc1be74-169d-4300-a239-49a1196a045", false},
{"2bc1be74-1x9d-4300-a239-49a1196a045d", false},
{"2bc1be74-169d-4300-a239-49a1196ag45d", false},
}
func testValid(t *testing.T, in string, isUuid bool) {
if ok := Valid(in); ok != isUuid {
t.Errorf("Valid(%s) got %v expected %v", in, ok, isUuid)
}
}
func TestUUID(t *testing.T) {
for _, tt := range tests {
testValid(t, tt.in, tt.isUuid)
testValid(t, strings.ToUpper(tt.in), tt.isUuid)
}
}

View File

@ -36,6 +36,15 @@
</nav>
{{- end }}
</header>
{{ if isLoggedIn -}}
<nav>
<ul role="menu">
<li role="presentation">
<a role="menuitem" href="{{ companyURL }}/">{{( pgettext "Dashboard" "title" )}}</a>
</li>
</ul>
</nav>
{{- end }}
<main id="main">
{{- template "content" . }}
</main>