Add direct middleware mode

This commit is contained in:
2025-11-18 11:35:19 +01:00
parent d6759fb1b8
commit 24d7ecec8f
8 changed files with 119 additions and 52 deletions
+47 -29
View File
@@ -30,11 +30,12 @@ type Config struct {
DefaultImageQuality int
ClientTimeout time.Duration
LogLevel slog.Level
BaseURL string
}
type App struct {
Config *Config
Cache *ristretto.Cache
Cache *ristretto.Cache[string, CacheEntry]
Client *http.Client
Logger *slog.Logger
}
@@ -50,10 +51,11 @@ func main() {
MaxAllowedSize: getEnvInt64("MAX_ALLOWED_SIZE", 1024*1024*50),
DefaultImageQuality: getEnvInt("DEFAULT_IMAGE_QUALITY", 80),
ClientTimeout: getEnvDuration("CLIENT_TIMEOUT", 2*time.Minute),
BaseURL: getEnvString("BASE_URL", ""),
LogLevel: logLevel,
}
cache, err := ristretto.NewCache(&ristretto.Config{
cache, err := ristretto.NewCache(&ristretto.Config[string, CacheEntry]{
NumCounters: 1e7, // Number of keys to track frequency of (10M).
MaxCost: 1 << 30, // Maximum cost of cache (1GB).
BufferItems: 64, // Number of keys per Get buffer.
@@ -111,13 +113,19 @@ func main() {
func (app *App) handleProxy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
logger := ctx.Value(loggerKey).(*slog.Logger)
mediaURL := r.URL.Path[1:]
if strings.HasPrefix(mediaURL, "https:/") && !strings.HasPrefix(mediaURL, "https://") {
mediaURL = "https://" + mediaURL[6:]
}
if strings.HasPrefix(mediaURL, "http:/") && !strings.HasPrefix(mediaURL, "http://") {
mediaURL = "http://" + mediaURL[5:]
var mediaURL string
if app.Config.BaseURL != "" {
baseURL := strings.TrimSuffix(app.Config.BaseURL, "/")
mediaURL = "https://" + baseURL + r.URL.Path
} else {
mediaURL = r.URL.Path[1:]
if strings.HasPrefix(mediaURL, "https:/") && !strings.HasPrefix(mediaURL, "https://") {
mediaURL = "https://" + mediaURL[6:]
}
if strings.HasPrefix(mediaURL, "http:/") && !strings.HasPrefix(mediaURL, "http://") {
mediaURL = "http://" + mediaURL[5:]
}
}
logger = logger.With("media_url", mediaURL)
@@ -134,19 +142,17 @@ func (app *App) handleProxy(w http.ResponseWriter, r *http.Request) {
return
}
if !isAllowedDomain(parsedURL.Host, app.Config.AllowedDomains) {
if app.Config.BaseURL == "" && !isAllowedDomain(parsedURL.Host, app.Config.AllowedDomains) {
logger.Warn("Domain not allowed", "domain", parsedURL.Host)
http.Error(w, "Domain not allowed", http.StatusForbidden)
return
}
if value, found := app.Cache.Get(mediaURL); found {
if cachedEntry, ok := value.(CacheEntry); ok {
logger.Debug("Serving image from cache")
w.Header().Set("Content-Type", cachedEntry.ContentType)
w.Write(cachedEntry.Data)
return
}
if cachedEntry, found := app.Cache.Get(mediaURL); found {
logger.Debug("Serving from cache")
w.Header().Set("Content-Type", cachedEntry.ContentType)
w.Write(cachedEntry.Data)
return
}
logger.Debug("Cache miss, performing HEAD request to origin")
@@ -156,12 +162,13 @@ func (app *App) handleProxy(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Could not fetch media metadata", http.StatusInternalServerError)
return
}
defer headResp.Body.Close()
if headResp.StatusCode != http.StatusOK {
logger.Warn("Origin server returned non-200 status for HEAD request", "status", headResp.StatusCode)
http.Error(w, fmt.Sprintf("Media source returned status: %d", headResp.StatusCode), headResp.StatusCode)
return
}
defer headResp.Body.Close()
headerContentType := headResp.Header.Get("Content-Type")
if !isAllowedType(headerContentType, true) {
@@ -174,19 +181,18 @@ func (app *App) handleProxy(w http.ResponseWriter, r *http.Request) {
switch mediaTypeCategory {
case "image":
logger.Debug("Delegating to image handler")
app.handleImage(w, r)
app.handleImage(w, r, mediaURL)
case "video", "audio":
logger.Debug("Delegating to stream handler")
app.handleStream(w, r)
app.handleStream(w, r, mediaURL)
default:
logger.Warn("Media type passed initial checks but is not an image, video, or audio", "category", mediaTypeCategory)
http.Error(w, "Unsupported media type", http.StatusUnsupportedMediaType)
}
}
func (app *App) handleImage(w http.ResponseWriter, r *http.Request) {
func (app *App) handleImage(w http.ResponseWriter, r *http.Request, mediaURL string) {
logger := r.Context().Value(loggerKey).(*slog.Logger)
mediaURL := r.URL.Path[1:]
resp, err := app.Client.Get(mediaURL)
if err != nil {
@@ -196,11 +202,16 @@ func (app *App) handleImage(w http.ResponseWriter, r *http.Request) {
}
defer resp.Body.Close()
r.Body = http.MaxBytesReader(w, r.Body, app.Config.MaxAllowedSize)
mediaData, err := io.ReadAll(resp.Body)
limitedReader := &io.LimitedReader{R: resp.Body, N: app.Config.MaxAllowedSize}
mediaData, err := io.ReadAll(limitedReader)
if err != nil {
logger.Error("Could not read image data", "error", err)
http.Error(w, "Could not read image data", http.StatusRequestEntityTooLarge)
http.Error(w, "Could not read image data", http.StatusInternalServerError)
return
}
if limitedReader.N == 0 {
logger.Error("Image exceeds max allowed size", "limit", app.Config.MaxAllowedSize)
http.Error(w, "Image exceeds max allowed size", http.StatusRequestEntityTooLarge)
return
}
@@ -233,25 +244,32 @@ func (app *App) handleImage(w http.ResponseWriter, r *http.Request) {
entryToCache = CacheEntry{ContentType: "image/webp", Data: optimizedImage}
}
app.Cache.SetWithTTL(mediaURL, entryToCache, 1, app.Config.CacheTTL)
app.Cache.SetWithTTL(mediaURL, entryToCache, int64(len(entryToCache.Data)), app.Config.CacheTTL)
app.Cache.Wait()
w.Header().Set("Content-Type", entryToCache.ContentType)
w.Write(entryToCache.Data)
}
func (app *App) handleStream(w http.ResponseWriter, r *http.Request) {
func (app *App) handleStream(w http.ResponseWriter, r *http.Request, mediaURL string) {
logger := r.Context().Value(loggerKey).(*slog.Logger)
mediaURL := r.URL.Path[1:]
originReq, err := http.NewRequestWithContext(r.Context(), r.Method, mediaURL, r.Body)
originReq, err := http.NewRequestWithContext(r.Context(), r.Method, mediaURL, nil)
if err != nil {
logger.Error("Failed to create origin request", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
originReq.Header = r.Header.Clone()
if rangeHeader := r.Header.Get("Range"); rangeHeader != "" {
originReq.Header.Set("Range", rangeHeader)
}
if acceptHeader := r.Header.Get("Accept"); acceptHeader != "" {
originReq.Header.Set("Accept", acceptHeader)
}
if userAgentHeader := r.Header.Get("User-Agent"); userAgentHeader != "" {
originReq.Header.Set("User-Agent", userAgentHeader)
}
originResp, err := app.Client.Do(originReq)
if err != nil {