/*
 * SPDX-FileCopyrightText: 2023 jordi fita mas <jfita@peritasoft.com>
 * SPDX-License-Identifier: AGPL-3.0-only
 */

package media

import (
	"net/http"
	"os"
	"path"
	"path/filepath"
	"strings"

	"dev.tandem.ws/tandem/camper/pkg/auth"
	"dev.tandem.ws/tandem/camper/pkg/database"
	"dev.tandem.ws/tandem/camper/pkg/hex"
	httplib "dev.tandem.ws/tandem/camper/pkg/http"
)

type PublicHandler struct {
	mediaDir    string
	fileHandler http.Handler
}

func NewPublicHandler(mediaDir string) (*PublicHandler, error) {
	if err := os.MkdirAll(mediaDir, 0755); err != nil {
		return nil, err
	}
	handler := &PublicHandler{
		mediaDir:    mediaDir,
		fileHandler: http.FileServer(http.Dir(mediaDir)),
	}
	return handler, nil
}

func (h *PublicHandler) Handler(user *auth.User, company *auth.Company, conn *database.Conn) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		var head string
		head, r.URL.Path = httplib.ShiftPath(r.URL.Path)

		if !hashValid(head) {
			http.NotFound(w, r)
			return
		}
		switch r.Method {
		case http.MethodGet:
			h.serveMedia(w, r, conn, strings.ToLower(head))
		default:
			httplib.MethodNotAllowed(w, r, http.MethodGet)
		}
	})
}

func (h *PublicHandler) serveMedia(w http.ResponseWriter, r *http.Request, conn *database.Conn, hash string) {
	mediaPath := h.mediaPath(hash)
	var err error
	if _, err = os.Stat(mediaPath); err != nil {
		bytes, err := conn.GetBytes(r.Context(), "select bytes from media_content where content_hash = decode($1, 'hex')", hash)
		if err != nil {
			if database.ErrorIsNotFound(err) {
				http.NotFound(w, r)
				return
			}
			panic(err)
		}
		if err = os.MkdirAll(path.Dir(mediaPath), 0755); err != nil {
			panic(err)
		}
		if err = os.WriteFile(mediaPath, bytes, 0644); err != nil {
			panic(err)
		}
	}
	r.URL.Path, err = filepath.Rel(h.mediaDir, mediaPath)
	if err != nil {
		panic(err)
	}
	h.fileHandler.ServeHTTP(w, r)
}

func (h *PublicHandler) mediaPath(hash string) string {
	return filepath.Join(h.mediaDir, hash[:2], hash[2:])
}

func hashValid(s string) bool {
	if len(s) != 64 {
		return false
	}
	for i := 0; i < 64; i += 2 {
		if !hex.Valid(s[i], s[i+1]) {
			return false
		}
	}
	return true
}