I need the actual remote address to add fail2ban rules for it, but i also to not want everyone to be able to fake X-Forward-For HTTP headers. Which can contain multiple ip addresses, by the way, so i have to get only the first one, as the others will be the proxies that the request has been (re)forwarded to.
229 lines
5.7 KiB
Go
229 lines
5.7 KiB
Go
package pkg
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"github.com/julienschmidt/httprouter"
|
|
"html/template"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/text/language"
|
|
)
|
|
|
|
const (
|
|
ContextUserKey = "numerus-user"
|
|
ContextCookieKey = "numerus-cookie"
|
|
ContextConnKey = "numerus-database"
|
|
sessionCookie = "numerus-session"
|
|
defaultRole = "guest"
|
|
csrfTokenField = "csfrToken"
|
|
)
|
|
|
|
type loginForm struct {
|
|
locale *Locale
|
|
Errors []error
|
|
Email *InputField
|
|
Password *InputField
|
|
}
|
|
|
|
func newLoginForm(locale *Locale) *loginForm {
|
|
return &loginForm{
|
|
locale: locale,
|
|
Email: &InputField{
|
|
Name: "email",
|
|
Label: pgettext("input", "Email", locale),
|
|
Type: "email",
|
|
Required: true,
|
|
Attributes: []template.HTMLAttr{
|
|
`autofocus="autofocus"`,
|
|
`autocomplete="username"`,
|
|
`autocapitalize="none"`,
|
|
},
|
|
},
|
|
Password: &InputField{
|
|
Name: "password",
|
|
Label: pgettext("input", "Password", locale),
|
|
Type: "password",
|
|
Required: true,
|
|
Attributes: []template.HTMLAttr{
|
|
`autocomplete="current-password"`,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (form *loginForm) Parse(r *http.Request) error {
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
form.Email.FillValue(r)
|
|
form.Password.FillValue(r)
|
|
return nil
|
|
}
|
|
|
|
func (form *loginForm) Validate() bool {
|
|
validator := newFormValidator()
|
|
if validator.CheckRequiredInput(form.Email, gettext("Email can not be empty.", form.locale)) {
|
|
validator.CheckValidEmailInput(form.Email, gettext("This value is not a valid email. It should be like name@domain.com.", form.locale))
|
|
}
|
|
validator.CheckRequiredInput(form.Password, gettext("Password can not be empty.", form.locale))
|
|
return validator.AllOK()
|
|
}
|
|
|
|
func GetLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
|
user := getUser(r)
|
|
if user.LoggedIn {
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
locale := getLocale(r)
|
|
form := newLoginForm(locale)
|
|
w.WriteHeader(http.StatusOK)
|
|
mustRenderLoginForm(w, r, form)
|
|
}
|
|
|
|
func HandleLoginForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
|
user := getUser(r)
|
|
if user.LoggedIn {
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
locale := getLocale(r)
|
|
form := newLoginForm(locale)
|
|
if err := form.Parse(r); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
if form.Validate() {
|
|
conn := getConn(r)
|
|
cookie := conn.MustGetText(r.Context(), "", "select login($1, $2, $3)", form.Email, form.Password, remoteAddr(r))
|
|
if cookie != "" {
|
|
setSessionCookie(w, cookie)
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
form.Errors = append(form.Errors, errors.New(gettext("Invalid user or password.", locale)))
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
} else {
|
|
w.WriteHeader(http.StatusUnprocessableEntity)
|
|
}
|
|
mustRenderLoginForm(w, r, form)
|
|
}
|
|
|
|
func mustRenderLoginForm(w http.ResponseWriter, r *http.Request, form *loginForm) {
|
|
mustRenderWebTemplate(w, r, "login.gohtml", form)
|
|
}
|
|
|
|
func HandleLogout(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
|
if err := verifyCsrfTokenValid(r); err != nil {
|
|
http.Error(w, err.Error(), http.StatusForbidden)
|
|
return
|
|
}
|
|
conn := getConn(r)
|
|
conn.MustExec(r.Context(), "select logout()")
|
|
http.SetCookie(w, createSessionCookie("", -24*time.Hour))
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
}
|
|
|
|
func remoteAddr(r *http.Request) string {
|
|
address, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
if address != "localhost" && address != "127.0.0.1" && address != "::1" {
|
|
return address
|
|
}
|
|
forwarded := r.Header.Get("X-Forwarded-For")
|
|
if forwarded == "" {
|
|
return address
|
|
}
|
|
ips := strings.Split(forwarded, ", ")
|
|
forwarded = ips[0]
|
|
if forwarded == "" {
|
|
return address
|
|
}
|
|
return forwarded
|
|
}
|
|
|
|
func setSessionCookie(w http.ResponseWriter, cookie string) {
|
|
http.SetCookie(w, createSessionCookie(cookie, 8766*24*time.Hour))
|
|
}
|
|
|
|
func createSessionCookie(value string, duration time.Duration) *http.Cookie {
|
|
return &http.Cookie{
|
|
Name: sessionCookie,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(duration),
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
}
|
|
|
|
type AppUser struct {
|
|
Email string
|
|
LoggedIn bool
|
|
Role string
|
|
Language language.Tag
|
|
CsrfToken string
|
|
}
|
|
|
|
func LoginChecker(db *Db, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var ctx = r.Context()
|
|
if cookie, err := r.Cookie(sessionCookie); err == nil {
|
|
ctx = context.WithValue(ctx, ContextCookieKey, cookie.Value)
|
|
}
|
|
|
|
conn := db.MustAcquire(ctx)
|
|
defer conn.Release()
|
|
ctx = context.WithValue(ctx, ContextConnKey, conn)
|
|
|
|
user := &AppUser{
|
|
Email: "",
|
|
LoggedIn: false,
|
|
Role: defaultRole,
|
|
}
|
|
row := conn.QueryRow(ctx, "select coalesce(email, ''), role, lang_tag, csrf_token from user_profile")
|
|
var langTag string
|
|
if err := row.Scan(&user.Email, &user.Role, &langTag, &user.CsrfToken); err != nil {
|
|
panic(err)
|
|
}
|
|
user.LoggedIn = user.Email != ""
|
|
user.Language, _ = language.Parse(langTag)
|
|
ctx = context.WithValue(ctx, ContextUserKey, user)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func verifyCsrfTokenValid(r *http.Request) error {
|
|
user := getUser(r)
|
|
token := r.FormValue(csrfTokenField)
|
|
if user.CsrfToken == token {
|
|
return nil
|
|
}
|
|
locale := getLocale(r)
|
|
return errors.New(locale.Get("Cross-site request forgery detected."))
|
|
}
|
|
|
|
func getUser(r *http.Request) *AppUser {
|
|
return r.Context().Value(ContextUserKey).(*AppUser)
|
|
}
|
|
|
|
func getConn(r *http.Request) *Conn {
|
|
return r.Context().Value(ContextConnKey).(*Conn)
|
|
}
|
|
|
|
func Authenticated(next httprouter.Handle) httprouter.Handle {
|
|
return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
|
|
user := getUser(r)
|
|
if user.LoggedIn {
|
|
next(w, r, params)
|
|
} else {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
}
|
|
}
|
|
}
|