numerus/pkg/company.go

866 lines
24 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pkg
import (
"context"
"errors"
"github.com/julienschmidt/httprouter"
"html/template"
"math"
"net/http"
"net/url"
"strconv"
)
const (
ContextCompanyKey = "numerus-company"
)
type Company struct {
Id int
CurrencySymbol string
DecimalDigits int
Slug string
}
func CompanyHandler(next http.Handler) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
company := &Company{
Slug: params[0].Value,
}
conn := getConn(r)
err := conn.QueryRow(r.Context(), "select company_id, currency_symbol, decimal_digits from company join currency using (currency_code) where slug = $1", company.Slug).Scan(&company.Id, &company.CurrencySymbol, &company.DecimalDigits)
if err != nil {
http.NotFound(w, r)
return
}
ctx := context.WithValue(r.Context(), ContextCompanyKey, company)
r = r.WithContext(ctx)
r2 := new(http.Request)
*r2 = *r
r2.URL = new(url.URL)
*r2.URL = *r.URL
r2.URL.Path = params[1].Value
next.ServeHTTP(w, r2)
}
}
func (c Company) MinCents() float64 {
var r float64
r = 1
for i := 0; i < c.DecimalDigits; i++ {
r /= 10.0
}
return r
}
func getCompany(r *http.Request) *Company {
company := r.Context().Value(ContextCompanyKey)
if company == nil {
return nil
}
return company.(*Company)
}
type CurrencyOption struct {
Code string
Symbol string
}
type CountryOption struct {
Code string
Name string
}
type Tax struct {
Id int
Name string
Class string
Rate int
}
type PaymentMethod struct {
Id int
Name string
Instructions string
}
type taxDetailsForm struct {
locale *Locale
TradeName *InputField
BusinessName *InputField
VATIN *InputField
Phone *InputField
Email *InputField
Web *InputField
Address *InputField
City *InputField
Province *InputField
PostalCode *InputField
Country *SelectField
Currency *SelectField
}
func newTaxDetailsForm(ctx context.Context, conn *Conn, locale *Locale) *taxDetailsForm {
return &taxDetailsForm{
locale: locale,
TradeName: &InputField{
Name: "trade_name",
Label: pgettext("input", "Trade name", locale),
Type: "text",
},
Phone: &InputField{
Name: "phone",
Label: pgettext("input", "Phone", locale),
Type: "tel",
Required: true,
Attributes: []template.HTMLAttr{
`autocomplete="tel"`,
},
},
Email: &InputField{
Name: "email",
Label: pgettext("input", "Email", locale),
Type: "email",
Required: true,
Attributes: []template.HTMLAttr{
`autocomplete="email"`,
},
},
Web: &InputField{
Name: "web",
Label: pgettext("input", "Web", locale),
Type: "url",
Attributes: []template.HTMLAttr{
`autocomplete="url"`,
},
},
BusinessName: &InputField{
Name: "business_name",
Label: pgettext("input", "Business name", locale),
Type: "text",
Required: true,
Attributes: []template.HTMLAttr{
`autocomplete="organization"`,
`minlength="2"`,
},
},
VATIN: &InputField{
Name: "vatin",
Label: pgettext("input", "VAT number", locale),
Type: "text",
Required: true,
},
Address: &InputField{
Name: "address",
Label: pgettext("input", "Address", locale),
Type: "text",
Required: true,
Attributes: []template.HTMLAttr{
`autocomplete="address-line1"`,
},
},
City: &InputField{
Name: "city",
Label: pgettext("input", "City", locale),
Type: "text",
Required: true,
},
Province: &InputField{
Name: "province",
Label: pgettext("input", "Province", locale),
Type: "text",
Required: true,
},
PostalCode: &InputField{
Name: "postal_code",
Label: pgettext("input", "Postal code", locale),
Type: "text",
Required: true,
Attributes: []template.HTMLAttr{
`autocomplete="postal-code"`,
},
},
Country: &SelectField{
Name: "country",
Label: pgettext("input", "Country", locale),
Options: mustGetCountryOptions(ctx, conn, locale),
Required: true,
Selected: []string{"ES"},
Attributes: []template.HTMLAttr{
`autocomplete="country"`,
},
},
Currency: &SelectField{
Name: "currency",
Label: pgettext("input", "Currency", locale),
Options: MustGetOptions(ctx, conn, "select currency_code, currency_symbol from currency order by currency_code"),
Required: true,
Selected: []string{"EUR"},
},
}
}
func (form *taxDetailsForm) Parse(r *http.Request) error {
if err := r.ParseForm(); err != nil {
return err
}
form.TradeName.FillValue(r)
form.BusinessName.FillValue(r)
form.VATIN.FillValue(r)
form.Phone.FillValue(r)
form.Email.FillValue(r)
form.Web.FillValue(r)
form.Address.FillValue(r)
form.City.FillValue(r)
form.Province.FillValue(r)
form.PostalCode.FillValue(r)
form.Country.FillValue(r)
form.Currency.FillValue(r)
return nil
}
func (form *taxDetailsForm) Validate(ctx context.Context, conn *Conn) bool {
validator := newFormValidator()
country := ""
if validator.CheckValidSelectOption(form.Country, gettext("Selected country is not valid.", form.locale)) {
country = form.Country.Selected[0]
}
validator.CheckRequiredInput(form.BusinessName, gettext("Business name can not be empty.", form.locale))
validator.CheckInputMinLength(form.BusinessName, 2, gettext("Business name must have at least two letters.", form.locale))
if validator.CheckRequiredInput(form.VATIN, gettext("VAT number can not be empty.", form.locale)) {
validator.CheckValidVATINInput(ctx, conn, form.VATIN, country, gettext("This value is not a valid VAT number.", form.locale))
}
if validator.CheckRequiredInput(form.Phone, gettext("Phone can not be empty.", form.locale)) {
validator.CheckValidPhoneInput(ctx, conn, form.Phone, country, gettext("This value is not a valid phone number.", form.locale))
}
if validator.CheckRequiredInput(form.Email, gettext("Email can not be empty.", form.locale)) {
validator.CheckValidEmailInput(form.Email, gettext("This value is not a valid email. It should be like name@domain.com.", form.locale))
}
if form.Web.Val != "" {
validator.CheckValidURL(form.Web, gettext("This value is not a valid web address. It should be like https://domain.com/.", form.locale))
}
validator.CheckRequiredInput(form.Address, gettext("Address can not be empty.", form.locale))
validator.CheckRequiredInput(form.City, gettext("City can not be empty.", form.locale))
validator.CheckRequiredInput(form.Province, gettext("Province can not be empty.", form.locale))
if validator.CheckRequiredInput(form.PostalCode, gettext("Postal code can not be empty.", form.locale)) {
validator.CheckValidPostalCode(ctx, conn, form.PostalCode, country, gettext("This value is not a valid postal code.", form.locale))
}
validator.CheckValidSelectOption(form.Currency, gettext("Selected currency is not valid.", form.locale))
return validator.AllOK()
}
func (form *taxDetailsForm) mustFillFromDatabase(ctx context.Context, conn *Conn, company *Company) *taxDetailsForm {
err := conn.QueryRow(ctx, `
select business_name
, substr(vatin::text, 3)
, trade_name
, phone
, email
, web
, address
, city
, province
, postal_code
, country_code
, currency_code
from company
where company.company_id = $1`, company.Id).Scan(
form.BusinessName,
form.VATIN,
form.TradeName,
form.Phone,
form.Email,
form.Web,
form.Address,
form.City,
form.Province,
form.PostalCode,
form.Country,
form.Currency,
)
if err != nil {
panic(err)
}
return form
}
type TaxDetailsPage struct {
DetailsForm *taxDetailsForm
}
func (page *TaxDetailsPage) MustRender(w http.ResponseWriter, r *http.Request) {
mustRenderMainTemplate(w, r, "company/tax-details.gohtml", page)
}
func GetCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
mustRenderTaxDetailsForm(w, r, newTaxDetailsFormFromDatabase(r))
}
func newTaxDetailsFormFromDatabase(r *http.Request) *taxDetailsForm {
locale := getLocale(r)
conn := getConn(r)
form := newTaxDetailsForm(r.Context(), conn, locale)
company := mustGetCompany(r)
form.mustFillFromDatabase(r.Context(), conn, company)
return form
}
func HandleCompanyTaxDetailsForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
conn := getConn(r)
form := newTaxDetailsForm(r.Context(), conn, locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if ok := form.Validate(r.Context(), conn); !ok {
w.WriteHeader(http.StatusUnprocessableEntity)
mustRenderTaxDetailsForm(w, r, form)
return
}
company := mustGetCompany(r)
conn.MustExec(r.Context(), `
update company
set business_name = $1
, vatin = ($11 || $2)::vatin
, trade_name = $3
, phone = parse_packed_phone_number($4, $11)
, email = $5
, web = $6
, address = $7
, city = $8
, province = $9
, postal_code = $10
, country_code = $11
, currency_code = $12
where company_id = $13
`,
form.BusinessName,
form.VATIN,
form.TradeName,
form.Phone,
form.Email,
form.Web,
form.Address,
form.City,
form.Province,
form.PostalCode,
form.Country,
form.Currency,
company.Id)
htmxRedirect(w, r, companyURI(company, "/tax-details"))
}
func mustRenderTaxDetailsForm(w http.ResponseWriter, r *http.Request, form *taxDetailsForm) {
page := &TaxDetailsPage{
DetailsForm: form,
}
page.MustRender(w, r)
}
func mustGetCompany(r *http.Request) *Company {
company := getCompany(r)
if company == nil {
panic(errors.New("company: required but not found"))
}
return company
}
func serveCompanyInvoicingForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r)
company := mustGetCompany(r)
locale := getLocale(r)
form := newInvoicingFormFromDatabase(r.Context(), conn, company, locale)
form.MustRender(w, r)
}
type InvoicingForm struct {
locale *Locale
InvoiceNumberFormat *InputField
NextInvoiceNumber *InputField
QuoteNumberFormat *InputField
NextQuoteNumber *InputField
LegalDisclaimer *InputField
}
func newInvoicingForm(locale *Locale) *InvoicingForm {
return &InvoicingForm{
locale: locale,
InvoiceNumberFormat: &InputField{
Name: "invoice_number_format",
Label: pgettext("input", "Invoice number format", locale),
Type: "text",
Required: true,
},
NextInvoiceNumber: &InputField{
Name: "next_invoice_number",
Label: pgettext("input", "Next invoice number", locale),
Type: "number",
Required: true,
Attributes: []template.HTMLAttr{
"min=1",
},
},
QuoteNumberFormat: &InputField{
Name: "quote_number_format",
Label: pgettext("input", "Quotation number format", locale),
Type: "text",
Required: true,
},
NextQuoteNumber: &InputField{
Name: "next_quotation_number",
Label: pgettext("input", "Next quotation number", locale),
Type: "number",
Required: true,
Attributes: []template.HTMLAttr{
"min=1",
},
},
LegalDisclaimer: &InputField{
Name: "legal_disclaimer",
Label: pgettext("input", "Legal disclaimer", locale),
Type: "textarea",
},
}
}
func newInvoicingFormFromDatabase(ctx context.Context, conn *Conn, company *Company, locale *Locale) *InvoicingForm {
form := newInvoicingForm(locale)
form.mustFillFromDatabase(ctx, conn, company)
return form
}
func (form *InvoicingForm) mustFillFromDatabase(ctx context.Context, conn *Conn, company *Company) {
err := conn.QueryRow(ctx, `
select invoice_number_format
, quote_number_format
, legal_disclaimer
, coalesce(invoice_number_counter.currval, 0) + 1
, coalesce(quote_number_counter.currval, 0) + 1
from company
left join invoice_number_counter
on invoice_number_counter.company_id = company.company_id
and invoice_number_counter.year = date_part('year', current_date)
left join quote_number_counter
on quote_number_counter.company_id = company.company_id
and quote_number_counter.year = date_part('year', current_date)
where company.company_id = $1`, company.Id).Scan(
form.InvoiceNumberFormat,
form.QuoteNumberFormat,
form.LegalDisclaimer,
form.NextInvoiceNumber,
form.NextQuoteNumber,
)
if err != nil {
panic(err)
}
}
func (form *InvoicingForm) MustRender(w http.ResponseWriter, r *http.Request) {
mustRenderMainTemplate(w, r, "company/invoicing.gohtml", form)
}
func (form *InvoicingForm) Parse(r *http.Request) error {
if err := r.ParseForm(); err != nil {
return err
}
form.InvoiceNumberFormat.FillValue(r)
form.NextInvoiceNumber.FillValue(r)
form.QuoteNumberFormat.FillValue(r)
form.NextQuoteNumber.FillValue(r)
form.LegalDisclaimer.FillValue(r)
return nil
}
func (form *InvoicingForm) Validate() bool {
validator := newFormValidator()
validator.CheckRequiredInput(form.InvoiceNumberFormat, gettext("Invoice number format can not be empty.", form.locale))
validator.CheckValidInteger(form.NextInvoiceNumber, 1, math.MaxInt32, gettext("Next invoice number must be a number greater than zero.", form.locale))
validator.CheckRequiredInput(form.QuoteNumberFormat, gettext("Quotation number format can not be empty.", form.locale))
validator.CheckValidInteger(form.NextQuoteNumber, 1, math.MaxInt32, gettext("Next quotation number must be a number greater than zero.", form.locale))
return validator.AllOK()
}
func handleCompanyInvoicingForm(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
conn := getConn(r)
form := newInvoicingForm(locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if ok := form.Validate(); !ok {
w.WriteHeader(http.StatusUnprocessableEntity)
form.MustRender(w, r)
return
}
company := mustGetCompany(r)
tx := conn.MustBegin(r.Context())
defer tx.MustRollback(r.Context())
tx.MustExec(r.Context(), `
update company
set invoice_number_format = $1
, quote_number_format = $2
, legal_disclaimer = $3
where company_id = $4
`,
form.InvoiceNumberFormat,
form.QuoteNumberFormat,
form.LegalDisclaimer,
company.Id)
tx.MustExec(r.Context(), `
insert into invoice_number_counter (company_id, year, currval)
values ($1, date_part('year', current_date), $2)
on conflict (company_id, year) do update
set currval = excluded.currval
`,
company.Id,
form.NextInvoiceNumber.Integer()-1)
tx.MustExec(r.Context(), `
insert into quote_number_counter (company_id, year, currval)
values ($1, date_part('year', current_date), $2)
on conflict (company_id, year) do update
set currval = excluded.currval
`,
company.Id,
form.NextQuoteNumber.Integer()-1)
tx.MustCommit(r.Context())
htmxRedirect(w, r, companyURI(company, "/invoicing"))
}
func serveCompanyTaxes(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r)
company := mustGetCompany(r)
locale := getLocale(r)
page := newTaxesPage(r.Context(), conn, company, locale)
page.MustRender(w, r)
}
type TaxesPage struct {
Taxes []*Tax
Form *taxForm
}
func newTaxesPage(ctx context.Context, conn *Conn, company *Company, locale *Locale) *TaxesPage {
form := newTaxForm(ctx, conn, company, locale)
return newTaxesPageWithForm(ctx, conn, company, form)
}
func newTaxesPageWithForm(ctx context.Context, conn *Conn, company *Company, form *taxForm) *TaxesPage {
return &TaxesPage{
Taxes: mustCollectTaxes(ctx, conn, company),
Form: form,
}
}
func (page *TaxesPage) MustRender(w http.ResponseWriter, r *http.Request) {
mustRenderMainTemplate(w, r, "company/taxes.gohtml", page)
}
func mustCollectTaxes(ctx context.Context, conn *Conn, company *Company) []*Tax {
rows, err := conn.Query(ctx, "select tax_id, tax.name, tax_class.name, (rate * 100)::integer from tax join tax_class using (tax_class_id) where tax.company_id = $1 order by rate, tax.name", company.Id)
if err != nil {
panic(err)
}
defer rows.Close()
var taxes []*Tax
for rows.Next() {
tax := &Tax{}
err = rows.Scan(&tax.Id, &tax.Name, &tax.Class, &tax.Rate)
if err != nil {
panic(err)
}
taxes = append(taxes, tax)
}
if rows.Err() != nil {
panic(rows.Err())
}
return taxes
}
type taxForm struct {
locale *Locale
Name *InputField
Class *SelectField
Rate *InputField
}
func newTaxForm(ctx context.Context, conn *Conn, company *Company, locale *Locale) *taxForm {
return &taxForm{
locale: locale,
Name: &InputField{
Name: "tax_name",
Label: pgettext("input", "Tax name", locale),
Type: "text",
Required: true,
},
Class: &SelectField{
Name: "tax_class",
Label: pgettext("input", "Tax Class", locale),
Options: MustGetOptions(ctx, conn, "select tax_class_id::text, name from tax_class where company_id = $1 order by name", company.Id),
Required: true,
EmptyLabel: gettext("Select a tax class", locale),
},
Rate: &InputField{
Name: "tax_rate",
Label: pgettext("input", "Rate (%)", locale),
Type: "number",
Required: true,
Attributes: []template.HTMLAttr{
"min=-99",
"max=99",
},
},
}
}
func (form *taxForm) Parse(r *http.Request) error {
if err := r.ParseForm(); err != nil {
return err
}
form.Name.FillValue(r)
form.Class.FillValue(r)
form.Rate.FillValue(r)
return nil
}
func (form *taxForm) Validate() bool {
validator := newFormValidator()
validator.CheckRequiredInput(form.Name, gettext("Tax name can not be empty.", form.locale))
validator.CheckValidSelectOption(form.Class, gettext("Selected tax class is not valid.", form.locale))
if validator.CheckRequiredInput(form.Rate, gettext("Tax rate can not be empty.", form.locale)) {
validator.CheckValidInteger(form.Rate, -99, 99, gettext("Tax rate must be an integer between -99 and 99.", form.locale))
}
return validator.AllOK()
}
func HandleAddCompanyTax(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
conn := getConn(r)
company := mustGetCompany(r)
form := newTaxForm(r.Context(), conn, company, locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if !form.Validate() {
w.WriteHeader(http.StatusUnprocessableEntity)
page := newTaxesPageWithForm(r.Context(), conn, company, form)
page.MustRender(w, r)
return
}
conn.MustExec(r.Context(), "insert into tax (company_id, tax_class_id, name, rate) values ($1, $2, $3, $4 / 100::decimal)", company.Id, form.Class, form.Name, form.Rate.Integer())
htmxRedirect(w, r, companyURI(company, "/taxes"))
}
func HandleDeleteCompanyTax(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
taxId, err := strconv.Atoi(params[0].Value)
if err != nil {
http.NotFound(w, r)
return
}
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
conn := getConn(r)
conn.MustExec(r.Context(), "delete from tax where tax_id = $1", taxId)
company := mustGetCompany(r)
htmxRedirect(w, r, companyURI(company, "/taxes"))
}
func servePaymentMethods(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
conn := getConn(r)
company := mustGetCompany(r)
locale := getLocale(r)
page := newPaymentMethodsPage(r.Context(), conn, company, locale)
page.MustRender(w, r)
}
type PaymentMethodsPage struct {
PaymentMethods []*PaymentMethod
Form *paymentMethodForm
}
func newPaymentMethodsPage(ctx context.Context, conn *Conn, company *Company, locale *Locale) *PaymentMethodsPage {
form := newPaymentMethodForm(locale)
return newPaymentMethodsPageWithForm(ctx, conn, company, form)
}
func newPaymentMethodsPageWithForm(ctx context.Context, conn *Conn, company *Company, form *paymentMethodForm) *PaymentMethodsPage {
return &PaymentMethodsPage{
PaymentMethods: mustCollectPaymentMethods(ctx, conn, company),
Form: form,
}
}
func (page *PaymentMethodsPage) MustRender(w http.ResponseWriter, r *http.Request) {
mustRenderMainTemplate(w, r, "company/payment_methods.gohtml", page)
}
func mustCollectPaymentMethods(ctx context.Context, conn *Conn, company *Company) []*PaymentMethod {
rows, err := conn.Query(ctx, "select payment_method_id, name, instructions from payment_method where company_id = $1 order by name", company.Id)
if err != nil {
panic(err)
}
defer rows.Close()
var methods []*PaymentMethod
for rows.Next() {
method := &PaymentMethod{}
err = rows.Scan(&method.Id, &method.Name, &method.Instructions)
if err != nil {
panic(err)
}
methods = append(methods, method)
}
if rows.Err() != nil {
panic(rows.Err())
}
return methods
}
type paymentMethodForm struct {
locale *Locale
Name *InputField
Instructions *InputField
}
func newPaymentMethodForm(locale *Locale) *paymentMethodForm {
return &paymentMethodForm{
locale: locale,
Name: &InputField{
Name: "method_name",
Label: pgettext("input", "Payment method name", locale),
Type: "text",
Required: true,
},
Instructions: &InputField{
Name: "method_instructions",
Label: pgettext("input", "Instructions", locale),
Type: "textarea",
Required: true,
},
}
}
func (form *paymentMethodForm) Parse(r *http.Request) error {
if err := r.ParseForm(); err != nil {
return err
}
form.Name.FillValue(r)
form.Instructions.FillValue(r)
return nil
}
func (form *paymentMethodForm) Validate() bool {
validator := newFormValidator()
validator.CheckRequiredInput(form.Name, gettext("Payment method name can not be empty.", form.locale))
validator.CheckRequiredInput(form.Instructions, gettext("Payment instructions can not be empty.", form.locale))
return validator.AllOK()
}
func HandleAddPaymentMethod(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
locale := getLocale(r)
conn := getConn(r)
company := mustGetCompany(r)
form := newPaymentMethodForm(locale)
if err := form.Parse(r); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if !form.Validate() {
w.WriteHeader(http.StatusUnprocessableEntity)
page := newPaymentMethodsPageWithForm(r.Context(), conn, company, form)
page.MustRender(w, r)
return
}
conn.MustExec(r.Context(), "insert into payment_method (company_id, name, instructions) values ($1, $2, $3)", company.Id, form.Name, form.Instructions)
htmxRedirect(w, r, companyURI(company, "/payment-methods"))
}
func HandleDeletePaymentMethod(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
paymentMethodId, err := strconv.Atoi(params[0].Value)
if err != nil {
http.NotFound(w, r)
return
}
if err := verifyCsrfTokenValid(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
conn := getConn(r)
conn.MustExec(r.Context(), "delete from payment_method where payment_method_id = $1", paymentMethodId)
company := mustGetCompany(r)
htmxRedirect(w, r, companyURI(company, "/payment-methods"))
}
func GetCompanySwitcher(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
page := &CompanySwitchPage{
Companies: mustCollectUserCompanies(r.Context(), getConn(r)),
}
page.MustRender(w, r)
}
type CompanySwitchPage struct {
Companies []*UserCompany
}
type UserCompany struct {
Name string
Slug string
}
func (page *CompanySwitchPage) MustRender(w http.ResponseWriter, r *http.Request) {
mustRenderModalTemplate(w, r, "company/switch.gohtml", page)
}
func mustCollectUserCompanies(ctx context.Context, conn *Conn) []*UserCompany {
rows, err := conn.Query(ctx, "select business_name::text, slug::text from company order by business_name")
if err != nil {
panic(err)
}
defer rows.Close()
var companies []*UserCompany
for rows.Next() {
company := &UserCompany{}
err = rows.Scan(&company.Name, &company.Slug)
if err != nil {
panic(err)
}
companies = append(companies, company)
}
if rows.Err() != nil {
panic(rows.Err())
}
return companies
}