package pkg import ( "context" "errors" "github.com/julienschmidt/httprouter" "html/template" "net" "net/http" "time" "golang.org/x/text/language" ) const ( ContextUserKey = "numerus-user" ContextCookieKey = "numerus-cookie" ContextConnKey = "numerus-database" sessionCookie = "numerus-session" defaultRole = "guest" csrfTokenField = "csfrToken" ) type loginForm struct { locale *Locale Errors []error Email *InputField Password *InputField } func newLoginForm(locale *Locale) *loginForm { return &loginForm{ locale: locale, Email: &InputField{ Name: "email", Label: pgettext("input", "Email", locale), Type: "email", Required: true, Attributes: []template.HTMLAttr{ `autofocus="autofocus"`, `autocomplete="username"`, `autocapitalize="none"`, }, }, Password: &InputField{ Name: "password", Label: pgettext("input", "Password", locale), Type: "password", Required: true, Attributes: []template.HTMLAttr{ `autocomplete="current-password"`, }, }, } } func (form *loginForm) Parse(r *http.Request) error { err := r.ParseForm() if err != nil { return err } form.Email.FillValue(r) form.Password.FillValue(r) return nil } func (form *loginForm) Validate() bool { validator := newFormValidator() if validator.CheckRequiredInput(form.Email, gettext("Email can not be empty.", form.locale)) { validator.CheckValidEmailInput(form.Email, gettext("This value is not a valid email. It should be like name@domain.com.", form.locale)) } validator.CheckRequiredInput(form.Password, gettext("Password can not be empty.", form.locale)) return validator.AllOK() } func GetLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { user := getUser(r) if user.LoggedIn { http.Redirect(w, r, "/", http.StatusSeeOther) return } locale := getLocale(r) form := newLoginForm(locale) w.WriteHeader(http.StatusOK) mustRenderLoginForm(w, r, form) } func HandleLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { user := getUser(r) if user.LoggedIn { http.Redirect(w, r, "/", http.StatusSeeOther) return } locale := getLocale(r) form := newLoginForm(locale) if err := form.Parse(r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } if form.Validate() { conn := getConn(r) cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", form.Email, form.Password, remoteAddr(r)) if cookie != "" { setSessionCookie(w, cookie) http.Redirect(w, r, "/", http.StatusSeeOther) return } form.Errors = append(form.Errors, errors.New(gettext("Invalid user or password.", locale))) w.WriteHeader(http.StatusUnauthorized) } else { w.WriteHeader(http.StatusUnprocessableEntity) } mustRenderLoginForm(w, r, form) } func mustRenderLoginForm(w http.ResponseWriter, r *http.Request, form *loginForm) { mustRenderWebTemplate(w, r, "login.gohtml", form) } func HandleLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { if err := verifyCsrfTokenValid(r); err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } conn := getConn(r) conn.MustExec(r.Context(), "select logout()") http.SetCookie(w, createSessionCookie("", -24*time.Hour)) http.Redirect(w, r, "/login", http.StatusSeeOther) } func remoteAddr(r *http.Request) string { address := r.Header.Get("X-Forwarded-For") if address == "" { address, _, _ = net.SplitHostPort(r.RemoteAddr) } return address } func setSessionCookie(w http.ResponseWriter, cookie string) { http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour)) } func createSessionCookie(value string, duration time.Duration) *http.Cookie { return &http.Cookie{ Name: sessionCookie, Value: value, Path: "/", Expires: time.Now().Add(duration), HttpOnly: true, SameSite: http.SameSiteLaxMode, } } type AppUser struct { Email string LoggedIn bool Role string Language language.Tag CsrfToken string } func LoginChecker(db *Db, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var ctx = r.Context() if cookie, err := r.Cookie(sessionCookie); err == nil { ctx = context.WithValue(ctx, ContextCookieKey, cookie.Value) } conn := db.MustAcquire(ctx) defer conn.Release() ctx = context.WithValue(ctx, ContextConnKey, conn) user := &AppUser{ Email: "", LoggedIn: false, Role: defaultRole, } row := conn.QueryRow(ctx, "select coalesce(email, ''), role, lang_tag, csrf_token from user_profile") var langTag string if err := row.Scan(&user.Email, &user.Role, &langTag, &user.CsrfToken); err != nil { panic(err) } user.LoggedIn = user.Email != "" user.Language, _ = language.Parse(langTag) ctx = context.WithValue(ctx, ContextUserKey, user) next.ServeHTTP(w, r.WithContext(ctx)) }) } func verifyCsrfTokenValid(r *http.Request) error { user := getUser(r) token := r.FormValue(csrfTokenField) if user.CsrfToken == token { return nil } locale := getLocale(r) return errors.New(locale.Get("Cross-site request forgery detected.")) } func getUser(r *http.Request) *AppUser { return r.Context().Value(ContextUserKey).(*AppUser) } func getConn(r *http.Request) *Conn { return r.Context().Value(ContextConnKey).(*Conn) } func Authenticated(next httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { user := getUser(r) if user.LoggedIn { next(w, r, params) } else { http.Redirect(w, r, "/login", http.StatusSeeOther) } } }