camper/pkg/app/media.go

96 lines
2.2 KiB
Go

/*
* SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
* SPDX-License-Identifier: AGPL-3.0-only
*/
package app
import (
"net/http"
"os"
"path"
"path/filepath"
"strings"
"dev.tandem.ws/tandem/camper/pkg/auth"
"dev.tandem.ws/tandem/camper/pkg/database"
"dev.tandem.ws/tandem/camper/pkg/hex"
httplib "dev.tandem.ws/tandem/camper/pkg/http"
)
type mediaHandler struct {
mediaDir string
fileHandler http.Handler
}
func newMediaHandler(mediaDir string) (*mediaHandler, error) {
if err := os.MkdirAll(mediaDir, 0755); err != nil {
return nil, err
}
handler := &mediaHandler{
mediaDir: mediaDir,
fileHandler: http.FileServer(http.Dir(mediaDir)),
}
return handler, nil
}
func (h *mediaHandler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var head string
head, r.URL.Path = httplib.ShiftPath(r.URL.Path)
if !mediaHashValid(head) {
http.NotFound(w, r)
return
}
switch r.Method {
case http.MethodGet:
h.serveMedia(w, r, company, conn, strings.ToLower(head))
default:
httplib.MethodNotAllowed(w, r, http.MethodGet)
}
}
}
func (h *mediaHandler) serveMedia(w http.ResponseWriter, r *http.Request, company *auth.Company, conn *database.Conn, hash string) {
mediaPath := h.mediaPath(hash)
var err error
if _, err = os.Stat(mediaPath); err != nil {
bytes, err := conn.GetBytes(r.Context(), "select content from media where company_id = $1 and hash = decode($2, 'hex')", company.ID, hash)
if err != nil {
if database.ErrorIsNotFound(err) {
http.NotFound(w, r)
return
}
panic(err)
}
if err = os.MkdirAll(path.Dir(mediaPath), 0755); err != nil {
panic(err)
}
if err = os.WriteFile(mediaPath, bytes, 0644); err != nil {
panic(err)
}
}
r.URL.Path, err = filepath.Rel(h.mediaDir, mediaPath)
if err != nil {
panic(err)
}
h.fileHandler.ServeHTTP(w, r)
}
func (h *mediaHandler) mediaPath(hash string) string {
return filepath.Join(h.mediaDir, hash[:2], hash[2:])
}
func mediaHashValid(s string) bool {
if len(s) != 64 {
return false
}
for i := 0; i < 64; i += 2 {
if !hex.Valid(s[i], s[i+1]) {
return false
}
}
return true
}