From 835e52dbcb8055e599871dc3c02c14d9e113d7a1 Mon Sep 17 00:00:00 2001 From: jordi fita mas Date: Mon, 17 Jul 2023 11:46:11 +0200 Subject: [PATCH] =?UTF-8?q?Return=20HTTP=C2=A0404=20instead=20of=20500=20f?= =?UTF-8?q?or=20invalid=20UUID=20values=20in=20URL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since most of PL/pgSQL functions accept a `uuid` domain, we get an error if the value is not valid, forcing us to return an HTTP 500, as we can not detect that the error was due to that. Instead, i now validate that the slug is indeed a valid UUID before attempting to send it to the database, returning the correct HTTP error code and avoiding useless calls to the database. I based the validation function of Parse() from Google’s uuid package[0] because this function is an order or magnitude faster in benchmarks: goos: linux goarch: amd64 pkg: dev.tandem.ws/tandem/numerus/pkg cpu: Intel(R) Core(TM) i5-6200U CPU @ 2.30GHz BenchmarkValidUuid-4 36946050 29.37 ns/op BenchmarkValidUuid_Re-4 3633169 306.70 ns/op The regular expression used for the benchmark was: var re = regexp.MustCompile("^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[8|9|aA|bB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$") And the input parameter for both functions was the following valid UUID, because most of the time the passed UUID will be valid: "f47ac10b-58cc-0372-8567-0e02b2c3d479" I did not use the uuid package, even though it is in Debian’s repository, because i only need to check whether the value is valid, not convert it to a byte array. As far as i know, that package can not do that. [0]: https://github.com/google/uuid --- pkg/contacts.go | 19 ++++++++++++++++++- pkg/expenses.go | 30 ++++++++++++++++++++++++++++-- pkg/invoices.go | 37 +++++++++++++++++++++++++++++++++---- pkg/products.go | 16 ++++++++++++++++ pkg/quote.go | 31 ++++++++++++++++++++++++++++--- pkg/uuid.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ pkg/uuid_test.go | 33 +++++++++++++++++++++++++++++++++ 7 files changed, 202 insertions(+), 10 deletions(-) create mode 100644 pkg/uuid.go create mode 100644 pkg/uuid_test.go diff --git a/pkg/contacts.go b/pkg/contacts.go index f132b99..0cd478a 100644 --- a/pkg/contacts.go +++ b/pkg/contacts.go @@ -53,6 +53,10 @@ func GetContactForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa mustRenderNewContactForm(w, r, form) return } + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if !form.MustFillFromDatabase(r.Context(), conn, slug) { http.NotFound(w, r) return @@ -120,7 +124,12 @@ func HandleUpdateContact(w http.ResponseWriter, r *http.Request, params httprout mustRenderEditContactForm(w, r, params[0].Value, form) return } - slug := conn.MustGetText(r.Context(), "", "select edit_contact($1, $2, $3, $4, $5, $6, $7, $8, $9)", params[0].Value, form.Name, form.Phone, form.Email, form.Web, form.TaxDetails(), form.IBAN, form.BIC, form.Tags) + slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } + slug = conn.MustGetText(r.Context(), "", "select edit_contact($1, $2, $3, $4, $5, $6, $7, $8, $9)", slug, form.Name, form.Phone, form.Email, form.Web, form.TaxDetails(), form.IBAN, form.BIC, form.Tags) if slug == "" { http.NotFound(w, r) } @@ -486,6 +495,10 @@ func ServeEditContactTags(w http.ResponseWriter, r *http.Request, params httprou locale := getLocale(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/contacts/"+slug+"/tags"), slug, locale) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), `select tags from contact where slug = $1`, form.Slug).Scan(form.Tags)) { http.NotFound(w, r) @@ -499,6 +512,10 @@ func HandleUpdateContactTags(w http.ResponseWriter, r *http.Request, params http conn := getConn(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/contacts/"+slug+"/tags/edit"), slug, locale) if err := form.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/expenses.go b/pkg/expenses.go index b428ab7..d882164 100644 --- a/pkg/expenses.go +++ b/pkg/expenses.go @@ -155,6 +155,10 @@ func ServeExpenseForm(w http.ResponseWriter, r *http.Request, params httprouter. mustRenderNewExpenseForm(w, r, form) return } + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if !form.MustFillFromDatabase(r.Context(), conn, slug) { http.NotFound(w, r) return @@ -364,14 +368,19 @@ func HandleUpdateExpense(w http.ResponseWriter, r *http.Request, params httprout http.Error(w, err.Error(), http.StatusForbidden) return } + slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if r.FormValue("quick") == "status" { - slug := conn.MustGetText(r.Context(), "", "update expense set expense_status = $1 where slug = $2 returning slug", form.ExpenseStatus, params[0].Value) + slug = conn.MustGetText(r.Context(), "", "update expense set expense_status = $1 where slug = $2 returning slug", form.ExpenseStatus, slug) if slug == "" { http.NotFound(w, r) + return } htmxRedirect(w, r, companyURI(mustGetCompany(r), "/expenses")) } else { - slug := params[0].Value if !form.Validate() { if !IsHTMxRequest(r) { w.WriteHeader(http.StatusUnprocessableEntity) @@ -520,6 +529,10 @@ func ServeEditExpenseTags(w http.ResponseWriter, r *http.Request, params httprou locale := getLocale(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/expenses/"+slug+"/tags"), slug, locale) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), `select tags from expense where slug = $1`, form.Slug).Scan(form.Tags)) { http.NotFound(w, r) @@ -533,6 +546,10 @@ func HandleUpdateExpenseTags(w http.ResponseWriter, r *http.Request, params http conn := getConn(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/expenses/"+slug+"/tags/edit"), slug, locale) if err := form.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -544,12 +561,17 @@ func HandleUpdateExpenseTags(w http.ResponseWriter, r *http.Request, params http } if conn.MustGetText(r.Context(), "", "update expense set tags = $1 where slug = $2 returning slug", form.Tags, form.Slug) == "" { http.NotFound(w, r) + return } mustRenderStandaloneTemplate(w, r, "tags/view.gohtml", form) } func ServeExpenseAttachment(w http.ResponseWriter, r *http.Request, params httprouter.Params) { slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } conn := getConn(r) var contentType string var content []byte @@ -571,6 +593,10 @@ func ServeExpenseAttachment(w http.ResponseWriter, r *http.Request, params httpr func HandleEditExpenseAction(w http.ResponseWriter, r *http.Request, params httprouter.Params) { slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } actionUri := fmt.Sprintf("/invoices/%s/edit", slug) handleExpenseAction(w, r, actionUri, func(w http.ResponseWriter, r *http.Request, form *expenseForm) { mustRenderEditExpenseForm(w, r, slug, form) diff --git a/pkg/invoices.go b/pkg/invoices.go index 78e17a2..59689de 100644 --- a/pkg/invoices.go +++ b/pkg/invoices.go @@ -257,10 +257,10 @@ func ServeInvoice(w http.ResponseWriter, r *http.Request, params httprouter.Para case "new": locale := getLocale(r) form := newInvoiceForm(r.Context(), conn, locale, company) - if invoiceToDuplicate := r.URL.Query().Get("duplicate"); invoiceToDuplicate != "" { + if invoiceToDuplicate := r.URL.Query().Get("duplicate"); ValidUuid(invoiceToDuplicate) { form.MustFillFromDatabase(r.Context(), conn, invoiceToDuplicate) form.InvoiceStatus.Selected = []string{"created"} - } else if quoteToInvoice := r.URL.Query().Get("quote"); quoteToInvoice != "" { + } else if quoteToInvoice := r.URL.Query().Get("quote"); ValidUuid(quoteToInvoice) { form.MustFillFromQuote(r.Context(), conn, quoteToInvoice) } form.Date.Val = time.Now().Format("2006-01-02") @@ -289,6 +289,10 @@ func ServeInvoice(w http.ResponseWriter, r *http.Request, params httprouter.Para pdf = true slug = slug[:len(slug)-len(".pdf")] } + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } inv := mustGetInvoice(r.Context(), conn, company, slug) if inv == nil { http.NotFound(w, r) @@ -1151,14 +1155,19 @@ func HandleUpdateInvoice(w http.ResponseWriter, r *http.Request, params httprout http.Error(w, err.Error(), http.StatusForbidden) return } + slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if r.FormValue("quick") == "status" { - slug := conn.MustGetText(r.Context(), "", "update invoice set invoice_status = $1 where slug = $2 returning slug", form.InvoiceStatus, params[0].Value) + slug = conn.MustGetText(r.Context(), "", "update invoice set invoice_status = $1 where slug = $2 returning slug", form.InvoiceStatus, slug) if slug == "" { http.NotFound(w, r) + return } htmxRedirect(w, r, companyURI(mustGetCompany(r), "/invoices")) } else { - slug := params[0].Value if !form.Validate() { if !IsHTMxRequest(r) { w.WriteHeader(http.StatusUnprocessableEntity) @@ -1194,6 +1203,10 @@ func ServeEditInvoice(w http.ResponseWriter, r *http.Request, params httprouter. conn := getConn(r) company := mustGetCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } locale := getLocale(r) form := newInvoiceForm(r.Context(), conn, locale, company) if !form.MustFillFromDatabase(r.Context(), conn, slug) { @@ -1225,6 +1238,10 @@ func mustRenderEditInvoiceForm(w http.ResponseWriter, r *http.Request, slug stri func HandleEditInvoiceAction(w http.ResponseWriter, r *http.Request, params httprouter.Params) { slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } actionUri := fmt.Sprintf("/invoices/%s/edit", slug) handleInvoiceAction(w, r, actionUri, func(w http.ResponseWriter, r *http.Request, form *invoiceForm) { conn := getConn(r) @@ -1319,6 +1336,10 @@ func ServeEditInvoiceTags(w http.ResponseWriter, r *http.Request, params httprou locale := getLocale(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/invoices/"+slug+"/tags"), slug, locale) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), `select tags from invoice where slug = $1`, form.Slug).Scan(form.Tags)) { http.NotFound(w, r) @@ -1332,6 +1353,10 @@ func HandleUpdateInvoiceTags(w http.ResponseWriter, r *http.Request, params http conn := getConn(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/invoices/"+slug+"/tags/edit"), slug, locale) if err := form.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -1349,6 +1374,10 @@ func HandleUpdateInvoiceTags(w http.ResponseWriter, r *http.Request, params http func ServeInvoiceAttachment(w http.ResponseWriter, r *http.Request, params httprouter.Params) { slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } conn := getConn(r) var contentType string var content []byte diff --git a/pkg/products.go b/pkg/products.go index 9caafa2..5913456 100644 --- a/pkg/products.go +++ b/pkg/products.go @@ -50,6 +50,10 @@ func GetProductForm(w http.ResponseWriter, r *http.Request, params httprouter.Pa mustRenderNewProductForm(w, r, form) return } + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if !form.MustFillFromDatabase(r.Context(), conn, slug) { http.NotFound(w, r) return @@ -136,6 +140,10 @@ func HandleUpdateProduct(w http.ResponseWriter, r *http.Request, params httprout return } slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if !form.Validate() { if !IsHTMxRequest(r) { w.WriteHeader(http.StatusUnprocessableEntity) @@ -363,6 +371,10 @@ func ServeEditProductTags(w http.ResponseWriter, r *http.Request, params httprou locale := getLocale(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/products/"+slug+"/tags"), slug, locale) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), `select tags from product where slug = $1`, form.Slug).Scan(form.Tags)) { http.NotFound(w, r) @@ -376,6 +388,10 @@ func HandleUpdateProductTags(w http.ResponseWriter, r *http.Request, params http conn := getConn(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/products/"+slug+"/tags/edit"), slug, locale) if err := form.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/quote.go b/pkg/quote.go index 20775eb..dc0190d 100644 --- a/pkg/quote.go +++ b/pkg/quote.go @@ -258,7 +258,7 @@ func ServeQuote(w http.ResponseWriter, r *http.Request, params httprouter.Params case "new": locale := getLocale(r) form := newQuoteForm(r.Context(), conn, locale, company) - if quoteToDuplicate := r.URL.Query().Get("duplicate"); quoteToDuplicate != "" { + if quoteToDuplicate := r.URL.Query().Get("duplicate"); ValidUuid(quoteToDuplicate) { form.MustFillFromDatabase(r.Context(), conn, quoteToDuplicate) form.QuoteStatus.Selected = []string{"created"} } @@ -288,6 +288,10 @@ func ServeQuote(w http.ResponseWriter, r *http.Request, params httprouter.Params pdf = true slug = slug[:len(slug)-len(".pdf")] } + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } quo := mustGetQuote(r.Context(), conn, company, slug) if quo == nil { http.NotFound(w, r) @@ -1036,14 +1040,19 @@ func HandleUpdateQuote(w http.ResponseWriter, r *http.Request, params httprouter http.Error(w, err.Error(), http.StatusForbidden) return } + slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } if r.FormValue("quick") == "status" { - slug := conn.MustGetText(r.Context(), "", "update quote set quote_status = $1 where slug = $2 returning slug", form.QuoteStatus, params[0].Value) + slug = conn.MustGetText(r.Context(), "", "update quote set quote_status = $1 where slug = $2 returning slug", form.QuoteStatus, slug) if slug == "" { http.NotFound(w, r) + return } htmxRedirect(w, r, companyURI(mustGetCompany(r), "/quotes")) } else { - slug := params[0].Value if !form.Validate() { if !IsHTMxRequest(r) { w.WriteHeader(http.StatusUnprocessableEntity) @@ -1064,6 +1073,10 @@ func ServeEditQuote(w http.ResponseWriter, r *http.Request, params httprouter.Pa conn := getConn(r) company := mustGetCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } locale := getLocale(r) form := newQuoteForm(r.Context(), conn, locale, company) if !form.MustFillFromDatabase(r.Context(), conn, slug) { @@ -1095,6 +1108,10 @@ func mustRenderEditQuoteForm(w http.ResponseWriter, r *http.Request, slug string func HandleEditQuoteAction(w http.ResponseWriter, r *http.Request, params httprouter.Params) { slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } actionUri := fmt.Sprintf("/quotes/%s/edit", slug) handleQuoteAction(w, r, actionUri, func(w http.ResponseWriter, r *http.Request, form *quoteForm) { conn := getConn(r) @@ -1164,6 +1181,10 @@ func ServeEditQuoteTags(w http.ResponseWriter, r *http.Request, params httproute locale := getLocale(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/quotes/"+slug+"/tags"), slug, locale) if notFoundErrorOrPanic(conn.QueryRow(r.Context(), `select tags from quote where slug = $1`, form.Slug).Scan(form.Tags)) { http.NotFound(w, r) @@ -1177,6 +1198,10 @@ func HandleUpdateQuoteTags(w http.ResponseWriter, r *http.Request, params httpro conn := getConn(r) company := getCompany(r) slug := params[0].Value + if !ValidUuid(slug) { + http.NotFound(w, r) + return + } form := newTagsForm(companyURI(company, "/quotes/"+slug+"/tags/edit"), slug, locale) if err := form.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/uuid.go b/pkg/uuid.go new file mode 100644 index 0000000..db99f37 --- /dev/null +++ b/pkg/uuid.go @@ -0,0 +1,46 @@ +package pkg + +func ValidUuid(s string) bool { + if len(s) != 36 { + return false + } + // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { + return false + } + for _, x := range [16]int{ + 0, 2, 4, 6, + 9, 11, + 14, 16, + 19, 21, + 24, 26, 28, 30, 32, 34} { + if !validHex(s[x], s[x+1]) { + return false + } + } + return true +} + +// xvalues returns the value of a byte as a hexadecimal digit or 255. +var xvalues = [256]byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255, + 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, +} + +func validHex(x1, x2 byte) bool { + return xvalues[x1] != 255 && xvalues[x2] != 255 +} diff --git a/pkg/uuid_test.go b/pkg/uuid_test.go new file mode 100644 index 0000000..c2390ed --- /dev/null +++ b/pkg/uuid_test.go @@ -0,0 +1,33 @@ +package pkg + +import ( + "strings" + "testing" +) + +type test struct { + in string + isUuid bool +} + +var tests = []test{ + {"f47ac10b-58cc-0372-8567-0e02b2c3d479", true}, + {"2bc1be74-169d-4300-a239-49a1196a045d", true}, + {"12bc1be74-169d-4300-a239-49a1196a045d", false}, + {"2bc1be74-169d-4300-a239-49a1196a045", false}, + {"2bc1be74-1x9d-4300-a239-49a1196a045d", false}, + {"2bc1be74-169d-4300-a239-49a1196ag45d", false}, +} + +func testValidUuid(t *testing.T, in string, isUuid bool) { + if ok := ValidUuid(in); ok != isUuid { + t.Errorf("ValidUuid(%s) got %v expected %v", in, ok, isUuid) + } +} + +func TestUUID(t *testing.T) { + for _, tt := range tests { + testValidUuid(t, tt.in, tt.isUuid) + testValidUuid(t, strings.ToUpper(tt.in), tt.isUuid) + } +}