diff --git a/Dockerfile b/Dockerfile index 4385b8d..bdd4ef2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=$BUILDPLATFORM golang:1.25.5 AS builder +FROM --platform=$BUILDPLATFORM golang:1.25.6 AS builder WORKDIR /app diff --git a/go.mod b/go.mod index 31ed70f..9580abc 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/skidoodle/safebin -go 1.25.5 +go 1.25.6 diff --git a/internal/app/config.go b/internal/app/config.go index 82d7ff7..a953fa3 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -21,36 +21,54 @@ type App struct { Logger *slog.Logger } -func LoadConfig() Config { - h := getEnv("SAFEBIN_HOST", "0.0.0.0") - p := getEnvInt("SAFEBIN_PORT", 8080) - s := getEnv("SAFEBIN_STORAGE", "./storage") - mDefault := int64(getEnvInt("SAFEBIN_MAX_MB", 512)) +const ( + defaultHost = "0.0.0.0" + defaultPort = 8080 + defaultStorage = "./storage" + defaultMaxMB = 512 +) - var m int64 - flag.StringVar(&h, "h", h, "Bind address") - flag.IntVar(&p, "p", p, "Port") - flag.StringVar(&s, "s", s, "Storage directory") - flag.Int64Var(&m, "m", mDefault, "Max file size in MB") +func LoadConfig() Config { + hostEnv := getEnv("SAFEBIN_HOST", defaultHost) + portEnv := getEnvInt("SAFEBIN_PORT", defaultPort) + storageEnv := getEnv("SAFEBIN_STORAGE", defaultStorage) + maxMBEnv := int64(getEnvInt("SAFEBIN_MAX_MB", defaultMaxMB)) + + var host string + var port int + var storage string + var maxMB int64 + + flag.StringVar(&host, "h", hostEnv, "Bind address") + flag.IntVar(&port, "p", portEnv, "Port") + flag.StringVar(&storage, "s", storageEnv, "Storage directory") + flag.Int64Var(&maxMB, "m", maxMBEnv, "Max file size in MB") flag.Parse() - return Config{Addr: fmt.Sprintf("%s:%d", h, p), StorageDir: s, MaxMB: m} -} - -func getEnv(k, f string) string { - if v, ok := os.LookupEnv(k); ok { - return v + return Config{ + Addr: fmt.Sprintf("%s:%d", host, port), + StorageDir: storage, + MaxMB: maxMB, } - return f } -func getEnvInt(k string, f int) int { - if v, ok := os.LookupEnv(k); ok { - if i, err := strconv.Atoi(v); err == nil { +func getEnv(key, fallback string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + + return fallback +} + +func getEnvInt(key string, fallback int) int { + if value, ok := os.LookupEnv(key); ok { + i, err := strconv.Atoi(value) + if err == nil { return i } } - return f + + return fallback } func ParseTemplates() *template.Template { diff --git a/internal/app/handlers.go b/internal/app/handlers.go index 9c93b40..7974762 100644 --- a/internal/app/handlers.go +++ b/internal/app/handlers.go @@ -2,6 +2,7 @@ package app import ( "encoding/base64" + "errors" "fmt" "io" "mime" @@ -14,187 +15,268 @@ import ( "github.com/skidoodle/safebin/internal/crypto" ) +const ( + uploadChunkSize = 8 << 20 + maxRequestOverhead = 10 << 20 + permUserRWX = 0o700 + slugLength = 22 + keyLength = 16 + megaByte = 1 << 20 + chunkSafetyMargin = 2 +) + var reUploadID = regexp.MustCompile(`^[a-zA-Z0-9]{10,50}$`) -func (app *App) HandleHome(w http.ResponseWriter, r *http.Request) { - err := app.Tmpl.ExecuteTemplate(w, "base", map[string]any{ +func (app *App) HandleHome(writer http.ResponseWriter, request *http.Request) { + err := app.Tmpl.ExecuteTemplate(writer, "base", map[string]any{ "MaxMB": app.Conf.MaxMB, - "Host": r.Host, + "Host": request.Host, }) + if err != nil { app.Logger.Error("Template error", "err", err) } } -func (app *App) HandleUpload(w http.ResponseWriter, r *http.Request) { - limit := (app.Conf.MaxMB << 20) + (1 << 20) - r.Body = http.MaxBytesReader(w, r.Body, limit) +func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) { + limit := (app.Conf.MaxMB * megaByte) + megaByte + request.Body = http.MaxBytesReader(writer, request.Body, limit) + + file, header, err := request.FormFile("file") - file, header, err := r.FormFile("file") if err != nil { if err.Error() == "http: request body too large" { - app.SendError(w, r, http.StatusRequestEntityTooLarge) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - app.SendError(w, r, http.StatusBadRequest) + + app.SendError(writer, request, http.StatusBadRequest) + return } - defer file.Close() - tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", fmt.Sprintf("up_%d", os.Getpid())) - tmp, err := os.Create(tmpPath) + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close upload file", "err", closeErr) + } + }() + + tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, "tmp"), "up_*") + if err != nil { app.Logger.Error("Failed to create temp file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } - defer os.Remove(tmpPath) - defer tmp.Close() + + tmpPath := tmp.Name() + + defer func() { + if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove temp file", "err", removeErr) + } + }() + + defer func() { + if closeErr := tmp.Close(); closeErr != nil { + app.Logger.Error("Failed to close temp file", "err", closeErr) + } + }() if _, err := io.Copy(tmp, file); err != nil { app.Logger.Error("Failed to write temp file", "err", err) - app.SendError(w, r, http.StatusRequestEntityTooLarge) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + return } - app.FinalizeFile(w, r, tmp, header.Filename) + app.FinalizeFile(writer, request, tmp, header.Filename) } -func (app *App) HandleChunk(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 10<<20) +func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { + request.Body = http.MaxBytesReader(writer, request.Body, maxRequestOverhead) - uid := r.FormValue("upload_id") - idx, err := strconv.Atoi(r.FormValue("index")) + uid := request.FormValue("upload_id") + + idx, err := strconv.Atoi(request.FormValue("index")) if err != nil { - app.SendError(w, r, http.StatusBadRequest) + app.SendError(writer, request, http.StatusBadRequest) return } - const chunkSize = 8 << 20 - maxChunks := int((app.Conf.MaxMB<<20)/chunkSize) + 2 + maxChunks := int((app.Conf.MaxMB*megaByte)/uploadChunkSize) + chunkSafetyMargin if !reUploadID.MatchString(uid) || idx > maxChunks || idx < 0 { - app.SendError(w, r, http.StatusBadRequest) + app.SendError(writer, request, http.StatusBadRequest) return } - file, _, err := r.FormFile("chunk") + file, _, err := request.FormFile("chunk") + if err != nil { if err.Error() == "http: request body too large" { - app.SendError(w, r, http.StatusRequestEntityTooLarge) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - app.SendError(w, r, http.StatusBadRequest) + + app.SendError(writer, request, http.StatusBadRequest) + return } - defer file.Close() + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk file", "err", closeErr) + } + }() + + if err := app.saveChunk(uid, idx, file); err != nil { + app.Logger.Error("Failed to save chunk", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + } +} + +func (app *App) saveChunk(uid string, idx int, src io.Reader) error { dir := filepath.Join(app.Conf.StorageDir, "tmp", uid) - if err := os.MkdirAll(dir, 0700); err != nil { - app.Logger.Error("Failed to create chunk dir", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + + if err := os.MkdirAll(dir, permUserRWX); err != nil { + return fmt.Errorf("create chunk dir: %w", err) } dest, err := os.Create(filepath.Join(dir, strconv.Itoa(idx))) if err != nil { - app.Logger.Error("Failed to create chunk file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + return fmt.Errorf("create chunk file: %w", err) + } + + defer func() { + if closeErr := dest.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk dest", "err", closeErr) + } + }() + + if _, err := io.Copy(dest, src); err != nil { + return fmt.Errorf("copy chunk: %w", err) + } + + return nil +} + +func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) { + uid := request.FormValue("upload_id") + + total, err := strconv.Atoi(request.FormValue("total")) + if err != nil { + app.SendError(writer, request, http.StatusBadRequest) return } - defer dest.Close() - if _, err := io.Copy(dest, file); err != nil { - app.Logger.Error("Failed to save chunk", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + maxChunks := int((app.Conf.MaxMB*megaByte)/uploadChunkSize) + chunkSafetyMargin + + if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 { + app.SendError(writer, request, http.StatusBadRequest) return } + + mergedPath, err := app.mergeChunks(uid, total) + + if err != nil { + app.Logger.Error("Merge failed", "err", err) + + if errors.Is(err, io.ErrShortWrite) { + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + } else { + app.SendError(writer, request, http.StatusInternalServerError) + } + + return + } + + defer func() { + if removeErr := os.Remove(mergedPath); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove merged file", "err", removeErr) + } + }() + + mergedRead, err := os.Open(mergedPath) + + if err != nil { + app.Logger.Error("Failed to open merged file", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + + return + } + + defer func() { + if closeErr := mergedRead.Close(); closeErr != nil { + app.Logger.Error("Failed to close merged reader", "err", closeErr) + } + }() + + app.FinalizeFile(writer, request, mergedRead, request.FormValue("filename")) + + if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, "tmp", uid)); err != nil { + app.Logger.Error("Failed to remove chunk dir", "err", err) + } } -func (app *App) HandleFinish(w http.ResponseWriter, r *http.Request) { - uid := r.FormValue("upload_id") - total, err := strconv.Atoi(r.FormValue("total")) - if err != nil { - app.SendError(w, r, http.StatusBadRequest) - return - } - - const chunkSize = 8 << 20 - maxChunks := int((app.Conf.MaxMB<<20)/chunkSize) + 2 - - if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 { - app.SendError(w, r, http.StatusBadRequest) - return - } - +func (app *App) mergeChunks(uid string, total int) (string, error) { tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", "m_"+uid) + merged, err := os.Create(tmpPath) if err != nil { - app.Logger.Error("Failed to create merge file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return "", fmt.Errorf("create merge file: %w", err) } - defer os.Remove(tmpPath) - defer merged.Close() - limit := app.Conf.MaxMB << 20 + defer func() { + if closeErr := merged.Close(); closeErr != nil { + app.Logger.Error("Failed to close merged file", "err", closeErr) + } + }() + + limit := app.Conf.MaxMB * megaByte var written int64 for i := range total { partPath := filepath.Join(app.Conf.StorageDir, "tmp", uid, strconv.Itoa(i)) + part, err := os.Open(partPath) if err != nil { - app.Logger.Error("Missing chunk during merge", "uid", uid, "index", i, "err", err) - app.SendError(w, r, http.StatusBadRequest) - return + return "", fmt.Errorf("open chunk %d: %w", i, err) } n, err := io.Copy(merged, part) - part.Close() + + if closeErr := part.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk part", "err", closeErr) + } + if err != nil { - app.Logger.Error("Failed to append chunk", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return "", fmt.Errorf("append chunk %d: %w", i, err) } written += n if written > limit { - app.SendError(w, r, http.StatusRequestEntityTooLarge) - return + return "", io.ErrShortWrite } } - if err := merged.Close(); err != nil { - app.Logger.Error("Failed to close merged file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return - } - - mergedRead, err := os.Open(tmpPath) - if err != nil { - app.Logger.Error("Failed to open merged file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return - } - defer mergedRead.Close() - - app.FinalizeFile(w, r, mergedRead, r.FormValue("filename")) - os.RemoveAll(filepath.Join(app.Conf.StorageDir, "tmp", uid)) + return tmpPath, nil } -func (app *App) HandleGetFile(w http.ResponseWriter, r *http.Request) { - slug := r.PathValue("slug") - if len(slug) < 22 { - app.SendError(w, r, http.StatusBadRequest) +func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request) { + slug := request.PathValue("slug") + if len(slug) < slugLength { + app.SendError(writer, request, http.StatusBadRequest) return } - keyBase64 := slug[:22] - ext := slug[22:] + keyBase64 := slug[:slugLength] + ext := slug[slugLength:] key, err := base64.RawURLEncoding.DecodeString(keyBase64) - if err != nil || len(key) != 16 { - app.SendError(w, r, http.StatusUnauthorized) + if err != nil || len(key) != keyLength { + app.SendError(writer, request, http.StatusUnauthorized) return } @@ -203,105 +285,133 @@ func (app *App) HandleGetFile(w http.ResponseWriter, r *http.Request) { info, err := os.Stat(path) if err != nil { - app.SendError(w, r, http.StatusNotFound) + app.SendError(writer, request, http.StatusNotFound) return } - f, err := os.Open(path) + file, err := os.Open(path) + if err != nil { app.Logger.Error("Failed to open file", "path", path, "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } - defer f.Close() + + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close file", "err", closeErr) + } + }() streamer, err := crypto.NewGCMStreamer(key) + if err != nil { app.Logger.Error("Failed to create crypto streamer", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } - decryptor := crypto.NewDecryptor(f, streamer.AEAD, info.Size()) + decryptor := crypto.NewDecryptor(file, streamer.AEAD, info.Size()) contentType := mime.TypeByExtension(ext) if contentType == "" { contentType = "application/octet-stream" } - w.Header().Set("Content-Type", contentType) - w.Header().Set("Content-Security-Policy", "default-src 'none'; img-src 'self' data:; media-src 'self' data:; style-src 'unsafe-inline'; sandbox allow-forms allow-scripts allow-downloads allow-same-origin") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", slug)) + csp := "default-src 'none'; img-src 'self' data:; media-src 'self' data:; " + + "style-src 'unsafe-inline'; sandbox allow-forms allow-scripts allow-downloads allow-same-origin" - http.ServeContent(w, r, slug, info.ModTime(), decryptor) + writer.Header().Set("Content-Type", contentType) + writer.Header().Set("Content-Security-Policy", csp) + writer.Header().Set("X-Content-Type-Options", "nosniff") + writer.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", slug)) + + http.ServeContent(writer, request, slug, info.ModTime(), decryptor) } -func (app *App) FinalizeFile(w http.ResponseWriter, r *http.Request, src *os.File, filename string) { +func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, src *os.File, filename string) { if _, err := src.Seek(0, 0); err != nil { app.Logger.Error("Seek failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } key, err := crypto.DeriveKey(src) + if err != nil { app.Logger.Error("Key derivation failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } ext := filepath.Ext(filename) id := crypto.GetID(key, ext) - - if _, err := src.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return - } - finalPath := filepath.Join(app.Conf.StorageDir, id) if _, err := os.Stat(finalPath); err == nil { - app.RespondWithLink(w, r, key, filename) + app.RespondWithLink(writer, request, key, filename) return } - out, err := os.Create(finalPath + ".tmp") - if err != nil { - app.Logger.Error("Failed to create final file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + if _, err := src.Seek(0, 0); err != nil { + app.Logger.Error("Seek failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return } + + if err := app.encryptAndSave(src, key, finalPath); err != nil { + app.Logger.Error("Encryption failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + + return + } + + app.RespondWithLink(writer, request, key, filename) +} + +func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { + out, err := os.Create(finalPath + ".tmp") + if err != nil { + return fmt.Errorf("create final file: %w", err) + } + + var closed bool + defer func() { - out.Close() - os.Remove(finalPath + ".tmp") + if !closed { + if closeErr := out.Close(); closeErr != nil { + app.Logger.Error("Failed to close final file", "err", closeErr) + } + } + + if removeErr := os.Remove(finalPath + ".tmp"); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove temp final file", "err", removeErr) + } }() streamer, err := crypto.NewGCMStreamer(key) if err != nil { - app.Logger.Error("Failed to create streamer", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("create streamer: %w", err) } if err := streamer.EncryptStream(out, src); err != nil { - app.Logger.Error("Encryption failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("encrypt stream: %w", err) } if err := out.Close(); err != nil { - app.Logger.Error("Failed to close final file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("close final file: %w", err) } + closed = true + if err := os.Rename(finalPath+".tmp", finalPath); err != nil { - app.Logger.Error("Failed to rename final file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("rename final file: %w", err) } - app.RespondWithLink(w, r, key, filename) + return nil } diff --git a/internal/app/server.go b/internal/app/server.go index 953d410..28b2d22 100644 --- a/internal/app/server.go +++ b/internal/app/server.go @@ -9,22 +9,25 @@ import ( func (app *App) Routes() *http.ServeMux { mux := http.NewServeMux() - fs := http.FileServer(http.Dir("./web/static")) - mux.Handle("GET /static/", http.StripPrefix("/static/", fs)) + fileServer := http.FileServer(http.Dir("./web/static")) + + mux.Handle("GET /static/", http.StripPrefix("/static/", fileServer)) mux.HandleFunc("GET /{$}", app.HandleHome) mux.HandleFunc("POST /{$}", app.HandleUpload) mux.HandleFunc("POST /upload/chunk", app.HandleChunk) mux.HandleFunc("POST /upload/finish", app.HandleFinish) mux.HandleFunc("GET /{slug}", app.HandleGetFile) + return mux } -func (app *App) RespondWithLink(w http.ResponseWriter, r *http.Request, key []byte, originalName string) { +func (app *App) RespondWithLink(writer http.ResponseWriter, request *http.Request, key []byte, originalName string) { keySlug := base64.RawURLEncoding.EncodeToString(key) ext := filepath.Ext(originalName) - link := fmt.Sprintf("%s/%s%s", r.Host, keySlug, ext) - if r.Header.Get("X-Requested-With") == "XMLHttpRequest" { - fmt.Fprintf(w, ` + link := fmt.Sprintf("%s/%s%s", request.Host, keySlug, ext) + + if request.Header.Get("X-Requested-With") == "XMLHttpRequest" { + html := `