Add the where company_id filter to accounts and payments queries

I actually did not forget them, and i did not add them on purpose,
mistakenly believing that PostgreSQL’s row-level policies would project
only rows from the current company.  That is actually how Camper works,
but that’s because we use the request’s domain name to select the
company; here we use the path, and the row-level policy would return
rows from all companies the user belongs to.
This commit is contained in:
jordi fita mas 2024-08-15 02:59:46 +02:00
parent f95936c523
commit 7f21a2131e
2 changed files with 15 additions and 11 deletions

View File

@ -19,9 +19,10 @@ const (
func servePaymentAccountIndex(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func servePaymentAccountIndex(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r) conn := getConn(r)
company := mustGetCompany(r)
locale := getLocale(r) locale := getLocale(r)
page := NewPaymentAccountIndexPage(r.Context(), conn, locale) page := NewPaymentAccountIndexPage(r.Context(), conn, company, locale)
page.MustRender(w, r) page.MustRender(w, r)
} }
@ -29,9 +30,9 @@ type PaymentAccountIndexPage struct {
Accounts []*PaymentAccountEntry Accounts []*PaymentAccountEntry
} }
func NewPaymentAccountIndexPage(ctx context.Context, conn *Conn, locale *Locale) *PaymentAccountIndexPage { func NewPaymentAccountIndexPage(ctx context.Context, conn *Conn, company *Company, locale *Locale) *PaymentAccountIndexPage {
return &PaymentAccountIndexPage{ return &PaymentAccountIndexPage{
Accounts: mustCollectPaymentAccountEntries(ctx, conn, locale), Accounts: mustCollectPaymentAccountEntries(ctx, conn, company, locale),
} }
} }
@ -50,7 +51,7 @@ type PaymentAccountEntry struct {
ExpirationDate string ExpirationDate string
} }
func mustCollectPaymentAccountEntries(ctx context.Context, conn *Conn, locale *Locale) []*PaymentAccountEntry { func mustCollectPaymentAccountEntries(ctx context.Context, conn *Conn, company *Company, locale *Locale) []*PaymentAccountEntry {
rows := conn.MustQuery(ctx, ` rows := conn.MustQuery(ctx, `
select payment_account_id select payment_account_id
, slug , slug
@ -65,8 +66,9 @@ func mustCollectPaymentAccountEntries(ctx context.Context, conn *Conn, locale *L
left join payment_account_card using (payment_account_id, payment_account_type) left join payment_account_card using (payment_account_id, payment_account_type)
join payment_account_type using (payment_account_type) join payment_account_type using (payment_account_type)
left join payment_account_type_i18n as i18n on payment_account_type.payment_account_type = i18n.payment_account_type and i18n.lang_tag = $1 left join payment_account_type_i18n as i18n on payment_account_type.payment_account_type = i18n.payment_account_type and i18n.lang_tag = $1
where company_id = $2
order by payment_account_id order by payment_account_id
`, locale.Language.String()) `, locale.Language.String(), company.Id)
defer rows.Close() defer rows.Close()
var entries []*PaymentAccountEntry var entries []*PaymentAccountEntry

View File

@ -12,9 +12,10 @@ import (
func servePaymentIndex(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { func servePaymentIndex(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r) conn := getConn(r)
company := mustGetCompany(r)
locale := getLocale(r) locale := getLocale(r)
page := NewPaymentIndexPage(r.Context(), conn, locale) page := NewPaymentIndexPage(r.Context(), conn, company, locale)
page.MustRender(w, r) page.MustRender(w, r)
} }
@ -22,9 +23,9 @@ type PaymentIndexPage struct {
Payments []*PaymentEntry Payments []*PaymentEntry
} }
func NewPaymentIndexPage(ctx context.Context, conn *Conn, locale *Locale) *PaymentIndexPage { func NewPaymentIndexPage(ctx context.Context, conn *Conn, company *Company, locale *Locale) *PaymentIndexPage {
return &PaymentIndexPage{ return &PaymentIndexPage{
Payments: mustCollectPaymentEntries(ctx, conn, locale), Payments: mustCollectPaymentEntries(ctx, conn, company, locale),
} }
} }
@ -44,7 +45,7 @@ type PaymentEntry struct {
StatusLabel string StatusLabel string
} }
func mustCollectPaymentEntries(ctx context.Context, conn *Conn, locale *Locale) []*PaymentEntry { func mustCollectPaymentEntries(ctx context.Context, conn *Conn, company *Company, locale *Locale) []*PaymentEntry {
rows := conn.MustQuery(ctx, ` rows := conn.MustQuery(ctx, `
select payment_id select payment_id
, payment.slug , payment.slug
@ -59,8 +60,9 @@ func mustCollectPaymentEntries(ctx context.Context, conn *Conn, locale *Locale)
join payment_status_i18n psi18n on payment.payment_status = psi18n.payment_status and psi18n.lang_tag = $1 join payment_status_i18n psi18n on payment.payment_status = psi18n.payment_status and psi18n.lang_tag = $1
join currency using (currency_code) join currency using (currency_code)
left join payment_attachment as attachment using (payment_id) left join payment_attachment as attachment using (payment_id)
where company_id = $2
order by payment_date desc, total desc order by payment_date desc, total desc
`, locale.Language) `, locale.Language, company.Id)
defer rows.Close() defer rows.Close()
var entries []*PaymentEntry var entries []*PaymentEntry
@ -132,7 +134,7 @@ func newPaymentForm(ctx context.Context, conn *Conn, locale *Locale, company *Co
Name: "payment_account", Name: "payment_account",
Label: pgettext("input", "Account", locale), Label: pgettext("input", "Account", locale),
Required: true, Required: true,
Options: MustGetOptions(ctx, conn, "select payment_account_id::text, name from payment_account order by name"), Options: MustGetOptions(ctx, conn, "select payment_account_id::text, name from payment_account where company_id = $1 order by name", company.Id),
}, },
Amount: &InputField{ Amount: &InputField{
Name: "amount", Name: "amount",