diff --git a/deploy/user_profile.sql b/deploy/user_profile.sql index 10fb691..b7788a1 100644 --- a/deploy/user_profile.sql +++ b/deploy/user_profile.sql @@ -16,6 +16,7 @@ select user_id , name , role , lang_tag + , left(cookie, 10) as csrf_token from auth."user" where email = current_user_email() and cookie = current_user_cookie() @@ -27,6 +28,7 @@ select 0 , '' , 'guest'::name , 'und' + , '' where not exists ( select 1 from auth."user" diff --git a/pkg/company.go b/pkg/company.go index c2e19d8..52af1d6 100644 --- a/pkg/company.go +++ b/pkg/company.go @@ -142,6 +142,10 @@ func CompanyTaxDetailsHandler() http.Handler { http.Error(w, err.Error(), http.StatusBadRequest) return } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } if ok := page.DetailsForm.Validate(r.Context(), conn); ok { form := page.DetailsForm conn.MustExec(r.Context(), "update company set business_name = $1, vatin = ($11 || $2)::vatin, trade_name = $3, phone = parse_packed_phone_number($4, $11), email = $5, web = $6, address = $7, city = $8, province = $9, postal_code = $10, country_code = $11, currency_code = $12 where company_id = $13", form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country, form.Currency, company.Id) @@ -243,6 +247,10 @@ func CompanyTaxHandler() http.Handler { conn := getConn(r) company := mustGetCompany(r) if taxId, err := strconv.Atoi(param); err == nil { + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId) http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) } else { @@ -252,6 +260,10 @@ func CompanyTaxHandler() http.Handler { http.Error(w, err.Error(), http.StatusBadRequest) return } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } if form.Validate() { conn.MustExec(r.Context(), "insert into tax (company_id, name, rate) values ($1, $2, $3 / 100::decimal)", company.Id, form.Name, form.Rate.Integer()) http.Redirect(w, r, "/company/"+company.Slug+"/tax-details", http.StatusSeeOther) diff --git a/pkg/contacts.go b/pkg/contacts.go index 861e0eb..c38d75a 100644 --- a/pkg/contacts.go +++ b/pkg/contacts.go @@ -27,6 +27,10 @@ func ContactsHandler() http.Handler { http.Error(w, err.Error(), http.StatusBadRequest) return } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } if form.Validate(r.Context(), conn) { conn.MustExec(r.Context(), "insert into contact (company_id, business_name, vatin, trade_name, phone, email, web, address, province, city, postal_code, country_code) values ($1, $2, ($12 || $3)::vatin, $4, parse_packed_phone_number($5, $12), $6, $7, $8, $9, $10, $11, $12)", company.Id, form.BusinessName, form.VATIN, form.TradeName, form.Phone, form.Email, form.Web, form.Address, form.City, form.Province, form.PostalCode, form.Country) http.Redirect(w, r, "/company/"+company.Slug+"/contacts", http.StatusSeeOther) diff --git a/pkg/login.go b/pkg/login.go index 4960957..bd54b21 100644 --- a/pkg/login.go +++ b/pkg/login.go @@ -17,6 +17,7 @@ const ( ContextConnKey = "numerus-database" sessionCookie = "numerus-session" defaultRole = "guest" + csrfTokenField = "csfrToken" ) type loginForm struct { @@ -105,6 +106,10 @@ func LoginHandler() http.Handler { func LogoutHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } conn := getConn(r) conn.MustExec(r.Context(), "select logout()") http.SetCookie(w, createSessionCookie("", -24*time.Hour)) @@ -136,10 +141,11 @@ func createSessionCookie(value string, duration time.Duration) *http.Cookie { } type AppUser struct { - Email string - LoggedIn bool - Role string - Language language.Tag + Email string + LoggedIn bool + Role string + Language language.Tag + CsrfToken string } func CheckLogin(db *Db, next http.Handler) http.Handler { @@ -158,9 +164,9 @@ func CheckLogin(db *Db, next http.Handler) http.Handler { LoggedIn: false, Role: defaultRole, } - row := conn.QueryRow(ctx, "select coalesce(email, ''), role, lang_tag from user_profile") + row := conn.QueryRow(ctx, "select coalesce(email, ''), role, lang_tag, csrf_token from user_profile") var langTag string - if err := row.Scan(&user.Email, &user.Role, &langTag); err != nil { + if err := row.Scan(&user.Email, &user.Role, &langTag, &user.CsrfToken); err != nil { panic(err) } user.LoggedIn = user.Email != "" @@ -171,6 +177,16 @@ func CheckLogin(db *Db, next http.Handler) http.Handler { }) } +func verifyCsrfTokenValid(r *http.Request) error { + user := getUser(r) + token := r.FormValue(csrfTokenField) + if user.CsrfToken == token { + return nil + } + locale := getLocale(r) + return errors.New(locale.Get("Cross-site request forgery detected.")) +} + func getUser(r *http.Request) *AppUser { return r.Context().Value(ContextUserKey).(*AppUser) } diff --git a/pkg/profile.go b/pkg/profile.go index 3d88301..0bd472b 100644 --- a/pkg/profile.go +++ b/pkg/profile.go @@ -104,6 +104,10 @@ func ProfileHandler() http.Handler { http.Error(w, err.Error(), http.StatusBadRequest) return } + if err := verifyCsrfTokenValid(r); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } if ok := form.Validate(); ok { //goland:noinspection SqlWithoutWhere cookie := conn.MustGetText(r.Context(), "", "update user_profile set name = $1, email = $2, lang_tag = $3 returning build_cookie()", form.Name, form.Email, form.Language) diff --git a/pkg/template.go b/pkg/template.go index f9a5e83..7698976 100644 --- a/pkg/template.go +++ b/pkg/template.go @@ -1,6 +1,7 @@ package pkg import ( + "fmt" "html/template" "io" "net/http" @@ -13,6 +14,7 @@ func templateFile(name string) string { func mustRenderTemplate(wr io.Writer, r *http.Request, layout string, filename string, data interface{}) { locale := getLocale(r) company := getCompany(r) + user := getUser(r) t := template.New(filename) t.Funcs(template.FuncMap{ "gettext": locale.Get, @@ -26,6 +28,9 @@ func mustRenderTemplate(wr io.Writer, r *http.Request, layout string, filename s } return "/company/" + company.Slug + uri }, + "csrfToken": func() template.HTML { + return template.HTML(fmt.Sprintf(``, csrfTokenField, user.CsrfToken)) + }, "addInputAttr": func(attr string, field *InputField) *InputField { field.Attributes = append(field.Attributes, template.HTMLAttr(attr)) return field diff --git a/test/user_profile.sql b/test/user_profile.sql index b40f597..ad02071 100644 --- a/test/user_profile.sql +++ b/test/user_profile.sql @@ -5,7 +5,7 @@ reset client_min_messages; begin; -select plan(47); +select plan(53); set search_path to numerus, auth, public; @@ -50,6 +50,13 @@ select column_privs_are('user_profile', 'lang_tag', 'invoicer', array['SELECT', select column_privs_are('user_profile', 'lang_tag', 'admin', array['SELECT', 'UPDATE']); select column_privs_are('user_profile', 'lang_tag', 'authenticator', array[]::text[]); +select has_column('user_profile', 'csrf_token'); +select col_type_is('user_profile', 'csrf_token', 'text'); +select column_privs_are('user_profile', 'csrf_token', 'guest', array ['SELECT']); +select column_privs_are('user_profile', 'csrf_token', 'invoicer', array['SELECT']); +select column_privs_are('user_profile', 'csrf_token', 'admin', array['SELECT']); +select column_privs_are('user_profile', 'csrf_token', 'authenticator', array[]::text[]); + set client_min_messages to warning; truncate auth."user" cascade; @@ -62,14 +69,14 @@ values (1, 'demo@tandem.blog', 'Demo', 'test', 'invoicer', '44facbb30d8a419dfd4b ; prepare profile as -select user_id, email, name, role, lang_tag +select user_id, email, name, role, lang_tag, csrf_token from user_profile; select set_config('request.user.cookie', '', false); select results_eq( 'profile', - $$ values (0, null::email, '', 'guest'::name, 'und') $$, + $$ values (0, null::email, '', 'guest'::name, 'und', '') $$, 'Should be set up with the guest user when no user logged in yet.' ); @@ -77,7 +84,7 @@ select set_cookie( '44facbb30d8a419dfd4bfbc44a4b5539d4970148dfc84bed0e/demo@tand select results_eq( 'profile', - $$ values (1, 'demo@tandem.blog'::email, 'Demo', 'invoicer'::name, 'ca') $$, + $$ values (1, 'demo@tandem.blog'::email, 'Demo', 'invoicer'::name, 'ca', '44facbb30d') $$, 'Should only see the profile of the first user' ); @@ -104,7 +111,7 @@ select throws_ok( select results_eq( 'profile', - $$ values (1, 'demo+update@tandem.blog'::email, 'Demo Update', 'invoicer'::name, 'es') $$, + $$ values (1, 'demo+update@tandem.blog'::email, 'Demo Update', 'invoicer'::name, 'es', '44facbb30d') $$, 'Should see the changed profile of the first user' ); @@ -114,7 +121,7 @@ select set_cookie( '12af4c88b528c2ad4222e3740496ecbc58e76e26f087657524/admin@tan select results_eq( 'profile', - $$ values (5, 'admin@tandem.blog'::email, 'Admin', 'admin'::name, 'es') $$, + $$ values (5, 'admin@tandem.blog'::email, 'Admin', 'admin'::name, 'es', '12af4c88b5') $$, 'Should only see the profile of the second user' ); @@ -141,7 +148,7 @@ select throws_ok( select results_eq( 'profile', - $$ values (5, 'admin+update@tandem.blog'::email, 'Admin Update', 'admin'::name, 'ca') $$, + $$ values (5, 'admin+update@tandem.blog'::email, 'Admin Update', 'admin'::name, 'ca', '12af4c88b5') $$, 'Should see the changed profile of the first user' ); diff --git a/verify/user_profile.sql b/verify/user_profile.sql index 6425b20..08216c7 100644 --- a/verify/user_profile.sql +++ b/verify/user_profile.sql @@ -8,6 +8,7 @@ select , name , role , lang_tag +, csrf_token from numerus.user_profile where false; diff --git a/web/template/app.gohtml b/web/template/app.gohtml index 313e584..b9ccf99 100644 --- a/web/template/app.gohtml +++ b/web/template/app.gohtml @@ -28,6 +28,7 @@
  • + {{ csrfToken }}