diff --git a/pkg/app/app.go b/pkg/app/app.go index a22ddb7..514fcc4 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -7,28 +7,21 @@ package app import ( "net/http" - "strings" "golang.org/x/text/language" "dev.tandem.ws/tandem/camper/pkg/auth" - "dev.tandem.ws/tandem/camper/pkg/campsite" + "dev.tandem.ws/tandem/camper/pkg/company" "dev.tandem.ws/tandem/camper/pkg/database" httplib "dev.tandem.ws/tandem/camper/pkg/http" "dev.tandem.ws/tandem/camper/pkg/locale" - "dev.tandem.ws/tandem/camper/pkg/template" ) -func methodNotAllowed(w http.ResponseWriter, _ *http.Request, allowed ...string) { - w.Header().Set("Allow", strings.Join(allowed, ", ")) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - type App struct { db *database.DB fileHandler http.Handler profile *profileHandler - campsite *campsite.Handler + company *company.Handler locales locale.Locales defaultLocale *locale.Locale languageMatcher language.Matcher @@ -45,7 +38,7 @@ func New(db *database.DB, avatarsDir string) (http.Handler, error) { db: db, fileHandler: static, profile: profile, - campsite: campsite.NewHandler(), + company: company.NewHandler(), locales: locales, defaultLocale: locales[language.Catalan], languageMatcher: language.NewMatcher(locales.Tags()), @@ -86,7 +79,7 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { case http.MethodPost: handleLogin(w, r, user, conn) default: - methodNotAllowed(w, r, http.MethodPost, http.MethodGet) + httplib.MethodNotAllowed(w, r, http.MethodPost, http.MethodGet) } } else { if !user.LoggedIn { @@ -98,14 +91,14 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch head { case "me": h.profile.Handler(user, conn).ServeHTTP(w, r) - case "campsites": - h.campsite.Handler(user, conn).ServeHTTP(w, r) + case "company": + h.company.Handler(user, conn).ServeHTTP(w, r) case "": switch r.Method { case http.MethodGet: - h.serveDashboard(w, r, user) + redirectToMainCompany(w, r, conn) default: - methodNotAllowed(w, r, http.MethodGet) + httplib.MethodNotAllowed(w, r, http.MethodGet) } default: http.NotFound(w, r) @@ -114,6 +107,10 @@ func (h *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (h *App) serveDashboard(w http.ResponseWriter, r *http.Request, user *auth.User) { - template.MustRender(w, r, user, "dashboard.gohtml", nil) +func redirectToMainCompany(w http.ResponseWriter, r *http.Request, conn *database.Conn) { + co, err := auth.QueryMainCompany(r.Context(), conn) + if err != nil { + panic(err) + } + httplib.Relocate(w, r, co.URL(), http.StatusFound) } diff --git a/pkg/app/login.go b/pkg/app/login.go index a7bd26a..6cd6834 100644 --- a/pkg/app/login.go +++ b/pkg/app/login.go @@ -61,7 +61,7 @@ func (f *loginForm) Valid(l *locale.Locale) bool { } func (f *loginForm) MustRender(w http.ResponseWriter, r *http.Request, user *auth.User) { - template.MustRender(w, r, user, "login.gohtml", f) + template.MustRender(w, r, user, nil, "login.gohtml", f) } func serveLoginForm(w http.ResponseWriter, r *http.Request, user *auth.User, redirectPath string) { diff --git a/pkg/app/user.go b/pkg/app/user.go index a0248ad..da28d8d 100644 --- a/pkg/app/user.go +++ b/pkg/app/user.go @@ -91,14 +91,14 @@ func (h *profileHandler) Handler(user *auth.User, conn *database.Conn) http.Hand case http.MethodGet: h.serveAvatar(w, r, user) default: - methodNotAllowed(w, r, http.MethodGet) + httplib.MethodNotAllowed(w, r, http.MethodGet) } case "session": switch r.Method { case http.MethodDelete: handleLogout(w, r, user, conn) default: - methodNotAllowed(w, r, http.MethodDelete) + httplib.MethodNotAllowed(w, r, http.MethodDelete) } case "": switch r.Method { @@ -107,7 +107,7 @@ func (h *profileHandler) Handler(user *auth.User, conn *database.Conn) http.Hand case http.MethodPut: h.updateProfile(w, r, user, conn) default: - methodNotAllowed(w, r, http.MethodGet, http.MethodPut) + httplib.MethodNotAllowed(w, r, http.MethodGet, http.MethodPut) } default: http.NotFound(w, r) @@ -249,7 +249,7 @@ func (f *profileForm) Valid(l *locale.Locale) bool { } func (f *profileForm) MustRender(w http.ResponseWriter, r *http.Request, user *auth.User) { - template.MustRender(w, r, user, "profile.gohtml", f) + template.MustRender(w, r, user, nil, "profile.gohtml", f) } func (f *profileForm) HasAvatarFile() bool { diff --git a/pkg/auth/company.go b/pkg/auth/company.go new file mode 100644 index 0000000..d653645 --- /dev/null +++ b/pkg/auth/company.go @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package auth + +import ( + "context" + + "dev.tandem.ws/tandem/camper/pkg/database" +) + +type Company struct { + ID int + CurrencySymbol string + DecimalDigits int + Slug string +} + +func QueryMainCompany(ctx context.Context, conn *database.Conn) (*Company, error) { + slug, err := conn.GetText(ctx, "select slug::text from company order by company_id limit 1") + if err != nil { + return nil, err + } + return QueryBySlug(ctx, conn, slug) +} + +func QueryBySlug(ctx context.Context, conn *database.Conn, slug string) (*Company, error) { + company := &Company{ + Slug: slug, + } + if err := conn.QueryRow(ctx, ` + select company_id + , currency_symbol + , decimal_digits + from company + join currency using (currency_code) + where slug = $1 + `, company.Slug).Scan( + &company.ID, + &company.CurrencySymbol, + &company.DecimalDigits, + ); err != nil { + return nil, err + } + return company, nil +} + +func (c *Company) URL() string { + if c == nil { + return "" + } + return "/company/" + c.Slug +} diff --git a/pkg/campsite/handler.go b/pkg/campsite/http.go similarity index 79% rename from pkg/campsite/handler.go rename to pkg/campsite/http.go index ca7463f..dce90b6 100644 --- a/pkg/campsite/handler.go +++ b/pkg/campsite/http.go @@ -23,14 +23,14 @@ func NewHandler() *Handler { } } -func (h *Handler) Handler(user *auth.User, conn *database.Conn) http.HandlerFunc { +func (h *Handler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var head string head, r.URL.Path = httplib.ShiftPath(r.URL.Path) switch head { case "types": - h.types.Handler(user, conn).ServeHTTP(w, r) + h.types.Handler(user, company, conn).ServeHTTP(w, r) default: http.NotFound(w, r) } diff --git a/pkg/campsite/type.go b/pkg/campsite/type.go index 5900d04..060a258 100644 --- a/pkg/campsite/type.go +++ b/pkg/campsite/type.go @@ -17,16 +17,21 @@ import ( type typeHandler struct { } -func (h *typeHandler) Handler(user *auth.User, conn *database.Conn) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (h *typeHandler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var head string head, r.URL.Path = httplib.ShiftPath(r.URL.Path) switch head { case "new": - template.MustRender(w, r, user, "campsite/type/new.gohtml", nil) + switch r.Method { + case http.MethodGet: + template.MustRender(w, r, user, company, "campsite/type/new.gohtml", nil) + default: + httplib.MethodNotAllowed(w, r, http.MethodGet) + } default: http.NotFound(w, r) } - } + }) } diff --git a/pkg/company/http.go b/pkg/company/http.go new file mode 100644 index 0000000..2758d46 --- /dev/null +++ b/pkg/company/http.go @@ -0,0 +1,64 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package company + +import ( + "dev.tandem.ws/tandem/camper/pkg/auth" + "dev.tandem.ws/tandem/camper/pkg/campsite" + "dev.tandem.ws/tandem/camper/pkg/database" + httplib "dev.tandem.ws/tandem/camper/pkg/http" + "dev.tandem.ws/tandem/camper/pkg/template" + "dev.tandem.ws/tandem/camper/pkg/uuid" + "net/http" +) + +type Handler struct { + campsite *campsite.Handler +} + +func NewHandler() *Handler { + return &Handler{ + campsite: campsite.NewHandler(), + } +} + +func (h *Handler) Handler(user *auth.User, conn *database.Conn) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var slug string + slug, r.URL.Path = httplib.ShiftPath(r.URL.Path) + if !uuid.Valid(slug) { + http.NotFound(w, r) + return + } + company, err := auth.QueryBySlug(r.Context(), conn, slug) + if database.ErrorIsNotFound(err) { + http.NotFound(w, r) + return + } else if err != nil { + panic(err) + } + + var head string + head, r.URL.Path = httplib.ShiftPath(r.URL.Path) + switch head { + case "campsites": + h.campsite.Handler(user, company, conn).ServeHTTP(w, r) + case "": + switch r.Method { + case http.MethodGet: + serveDashboard(w, r, user, company) + default: + httplib.MethodNotAllowed(w, r, http.MethodGet) + } + default: + http.NotFound(w, r) + } + }) +} + +func serveDashboard(w http.ResponseWriter, r *http.Request, user *auth.User, company *auth.Company) { + template.MustRender(w, r, user, company, "dashboard.gohtml", nil) +} diff --git a/pkg/database/db.go b/pkg/database/db.go index 956d942..fbeba6b 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -7,6 +7,7 @@ package database import ( "context" + "errors" "log" "github.com/jackc/pgconn" @@ -14,6 +15,10 @@ import ( "github.com/jackc/pgx/v4/pgxpool" ) +func ErrorIsNotFound(err error) bool { + return errors.Is(err, pgx.ErrNoRows) +} + func New(ctx context.Context, connString string) (*DB, error) { config, err := pgxpool.ParseConfig(connString) if err != nil { @@ -66,10 +71,18 @@ func (c *Conn) MustExec(ctx context.Context, sql string, args ...interface{}) pg return tag } -func (c *Conn) MustGetText(ctx context.Context, sql string, args ...interface{}) string { +func (c *Conn) GetText(ctx context.Context, sql string, args ...interface{}) (string, error) { var result string if err := c.QueryRow(ctx, sql, args...).Scan(&result); err != nil { + return "", err + } + return result, nil +} + +func (c *Conn) MustGetText(ctx context.Context, sql string, args ...interface{}) string { + if result, err := c.GetText(ctx, sql, args...); err == nil { + return result + } else { panic(err) } - return result } diff --git a/pkg/http/request.go b/pkg/http/request.go index acae37e..f4894b4 100644 --- a/pkg/http/request.go +++ b/pkg/http/request.go @@ -37,3 +37,8 @@ func ShiftPath(p string) (head, tail string) { return p[1:i], p[i:] } } + +func MethodNotAllowed(w http.ResponseWriter, _ *http.Request, allowed ...string) { + w.Header().Set("Allow", strings.Join(allowed, ", ")) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} diff --git a/pkg/template/render.go b/pkg/template/render.go index fb74ceb..e21139d 100644 --- a/pkg/template/render.go +++ b/pkg/template/render.go @@ -19,19 +19,20 @@ func templateFile(name string) string { return "web/templates/" + name } -func MustRender(w io.Writer, r *http.Request, user *auth.User, filename string, data interface{}) { +func MustRender(w io.Writer, r *http.Request, user *auth.User, company *auth.Company, filename string, data interface{}) { layout := "layout.gohtml" if httplib.IsHTMxRequest(r) { layout = "htmx.gohtml" } - mustRenderLayout(w, user, layout, filename, data) + mustRenderLayout(w, user, company, layout, filename, data) } -func mustRenderLayout(w io.Writer, user *auth.User, layout string, filename string, data interface{}) { +func mustRenderLayout(w io.Writer, user *auth.User, company *auth.Company, layout string, filename string, data interface{}) { t := template.New(filename) t.Funcs(template.FuncMap{ - "gettext": user.Locale.Get, - "pgettext": user.Locale.GetC, + "gettext": user.Locale.Get, + "pgettext": user.Locale.GetC, + "companyURL": company.URL, "currentLocale": func() string { return user.Locale.Language.String() }, diff --git a/pkg/uuid/uuid.go b/pkg/uuid/uuid.go new file mode 100644 index 0000000..3a06145 --- /dev/null +++ b/pkg/uuid/uuid.go @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package uuid + +func Valid(s string) bool { + if len(s) != 36 { + return false + } + // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { + return false + } + for _, x := range [16]int{ + 0, 2, 4, 6, + 9, 11, + 14, 16, + 19, 21, + 24, 26, 28, 30, 32, 34} { + if !validHex(s[x], s[x+1]) { + return false + } + } + return true +} + +// xvalues returns the value of a byte as a hexadecimal digit or 255. +var xvalues = [256]byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255, + 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, +} + +func validHex(x1, x2 byte) bool { + return xvalues[x1] != 255 && xvalues[x2] != 255 +} diff --git a/pkg/uuid/uuid_test.go b/pkg/uuid/uuid_test.go new file mode 100644 index 0000000..d57174f --- /dev/null +++ b/pkg/uuid/uuid_test.go @@ -0,0 +1,38 @@ +/* + * SPDX-FileCopyrightText: 2023 jordi fita mas + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package uuid + +import ( + "strings" + "testing" +) + +type test struct { + in string + isUuid bool +} + +var tests = []test{ + {"f47ac10b-58cc-0372-8567-0e02b2c3d479", true}, + {"2bc1be74-169d-4300-a239-49a1196a045d", true}, + {"12bc1be74-169d-4300-a239-49a1196a045d", false}, + {"2bc1be74-169d-4300-a239-49a1196a045", false}, + {"2bc1be74-1x9d-4300-a239-49a1196a045d", false}, + {"2bc1be74-169d-4300-a239-49a1196ag45d", false}, +} + +func testValid(t *testing.T, in string, isUuid bool) { + if ok := Valid(in); ok != isUuid { + t.Errorf("Valid(%s) got %v expected %v", in, ok, isUuid) + } +} + +func TestUUID(t *testing.T) { + for _, tt := range tests { + testValid(t, tt.in, tt.isUuid) + testValid(t, strings.ToUpper(tt.in), tt.isUuid) + } +} diff --git a/web/templates/layout.gohtml b/web/templates/layout.gohtml index e31badb..96d0b0a 100644 --- a/web/templates/layout.gohtml +++ b/web/templates/layout.gohtml @@ -36,6 +36,15 @@ {{- end }} +{{ if isLoggedIn -}} + +{{- end }}
{{- template "content" . }}