camper/pkg/media/public.go

96 lines
2.2 KiB
Go

/*
* SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
* SPDX-License-Identifier: AGPL-3.0-only
*/
package media
import (
"net/http"
"os"
"path"
"path/filepath"
"strings"
"dev.tandem.ws/tandem/camper/pkg/auth"
"dev.tandem.ws/tandem/camper/pkg/database"
"dev.tandem.ws/tandem/camper/pkg/hex"
httplib "dev.tandem.ws/tandem/camper/pkg/http"
)
type PublicHandler struct {
mediaDir string
fileHandler http.Handler
}
func NewPublicHandler(mediaDir string) (*PublicHandler, error) {
if err := os.MkdirAll(mediaDir, 0755); err != nil {
return nil, err
}
handler := &PublicHandler{
mediaDir: mediaDir,
fileHandler: http.FileServer(http.Dir(mediaDir)),
}
return handler, nil
}
func (h *PublicHandler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var head string
head, r.URL.Path = httplib.ShiftPath(r.URL.Path)
if !hashValid(head) {
http.NotFound(w, r)
return
}
switch r.Method {
case http.MethodGet:
h.serveMedia(w, r, conn, strings.ToLower(head))
default:
httplib.MethodNotAllowed(w, r, http.MethodGet)
}
})
}
func (h *PublicHandler) serveMedia(w http.ResponseWriter, r *http.Request, conn *database.Conn, hash string) {
mediaPath := h.mediaPath(hash)
var err error
if _, err = os.Stat(mediaPath); err != nil {
bytes, err := conn.GetBytes(r.Context(), "select bytes from media_content where content_hash = decode($1, 'hex')", hash)
if err != nil {
if database.ErrorIsNotFound(err) {
http.NotFound(w, r)
return
}
panic(err)
}
if err = os.MkdirAll(path.Dir(mediaPath), 0755); err != nil {
panic(err)
}
if err = os.WriteFile(mediaPath, bytes, 0644); err != nil {
panic(err)
}
}
r.URL.Path, err = filepath.Rel(h.mediaDir, mediaPath)
if err != nil {
panic(err)
}
h.fileHandler.ServeHTTP(w, r)
}
func (h *PublicHandler) mediaPath(hash string) string {
return filepath.Join(h.mediaDir, hash[:2], hash[2:])
}
func hashValid(s string) bool {
if len(s) != 64 {
return false
}
for i := 0; i < 64; i += 2 {
if !hex.Valid(s[i], s[i+1]) {
return false
}
}
return true
}