package pkg import ( "context" "log" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" ) type Db struct { *pgxpool.Pool } func NewDatabase(ctx context.Context, connString string) (*Db, error) { config, err := pgxpool.ParseConfig(connString) if err != nil { log.Fatal(err) } config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { if _, err := conn.Exec(ctx, "SET search_path TO numerus, public"); err != nil { return err } return registerPgTypes(ctx, conn) } config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { cookie := "" if value, ok := ctx.Value(ContextCookieKey).(string); ok { cookie = value } if _, err := conn.Exec(ctx, "select set_cookie($1)", cookie); err != nil { log.Printf("ERROR - Failed to set role: %v", err) return false } return true } config.AfterRelease = func(conn *pgx.Conn) bool { if _, err := conn.Exec(context.Background(), "RESET ROLE"); err != nil { log.Printf("ERROR - Failed to reset role: %v", err) return false } return true } pool, err := pgxpool.ConnectConfig(ctx, config) if err != nil { return nil, err } return &Db{pool}, nil } func notFoundErrorOrPanic(err error) bool { if err == pgx.ErrNoRows { return true } if err != nil { panic(err) } return false } func (db *Db) Acquire(ctx context.Context) (*Conn, error) { conn, err := db.Pool.Acquire(ctx) if err != nil { return nil, err } return &Conn{conn}, nil } func (db *Db) MustAcquire(ctx context.Context) *Conn { conn, err := db.Acquire(ctx) if err != nil { panic(err) } return conn } type Conn struct { *pgxpool.Conn } func (c *Conn) MustBegin(ctx context.Context) *Tx { tx, err := c.Begin(ctx) if err != nil { panic(err) } return &Tx{tx} } func (c *Conn) MustGetText(ctx context.Context, def string, sql string, args ...interface{}) string { var result string if notFoundErrorOrPanic(c.Conn.QueryRow(ctx, sql, args...).Scan(&result)) { return def } return result } func (c *Conn) MustGetBool(ctx context.Context, sql string, args ...interface{}) bool { var result bool if err := c.Conn.QueryRow(ctx, sql, args...).Scan(&result); err != nil { panic(err) } return result } func (c *Conn) MustExec(ctx context.Context, sql string, args ...interface{}) { if _, err := c.Conn.Exec(ctx, sql, args...); err != nil { panic(err) } } func (c *Conn) MustQuery(ctx context.Context, sql string, args ...interface{}) pgx.Rows { rows, err := c.Conn.Query(ctx, sql, args...) if err != nil { panic(err) } return rows } type Tx struct { pgx.Tx } func (tx *Tx) MustCommit(ctx context.Context) { if err := tx.Commit(ctx); err != nil { panic(err) } } func (tx *Tx) MustRollback(ctx context.Context) { if err := tx.Rollback(ctx); err != nil { panic(err) } } func (tx *Tx) MustExec(ctx context.Context, sql string, args ...interface{}) { if _, err := tx.Exec(ctx, sql, args...); err != nil { panic(err) } } func (tx *Tx) MustGetText(ctx context.Context, sql string, args ...interface{}) string { var result string if err := tx.QueryRow(ctx, sql, args...).Scan(&result); err != nil { panic(err) } return result } func (tx *Tx) MustGetInteger(ctx context.Context, sql string, args ...interface{}) int { var result int if err := tx.QueryRow(ctx, sql, args...).Scan(&result); err != nil { panic(err) } return result } func (tx *Tx) MustGetIntegerOrDefault(ctx context.Context, def int, sql string, args ...interface{}) int { var result int if notFoundErrorOrPanic(tx.QueryRow(ctx, sql, args...).Scan(&result)) { return def } return result } func (tx *Tx) MustCopyFrom(ctx context.Context, tableName string, columns []string, length int, next func(int) ([]interface{}, error)) int64 { copied, err := tx.CopyFrom(ctx, pgx.Identifier{tableName}, columns, pgx.CopyFromSlice(length, next)) if err != nil { panic(err) } return copied }