From aca72673015b5b484920b8dfb255f882cdb21249 Mon Sep 17 00:00:00 2001 From: skidoodle Date: Sat, 17 Jan 2026 22:58:38 +0100 Subject: [PATCH] refactor: internals Signed-off-by: skidoodle --- Dockerfile | 2 +- go.mod | 2 +- internal/app/config.go | 60 +++--- internal/app/handlers.go | 380 ++++++++++++++++++++++++-------------- internal/app/server.go | 50 +++-- internal/app/storage.go | 77 ++++++-- internal/crypto/crypto.go | 58 +++--- internal/crypto/reader.go | 61 +++--- main.go | 28 ++- 9 files changed, 476 insertions(+), 242 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4385b8d..bdd4ef2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=$BUILDPLATFORM golang:1.25.5 AS builder +FROM --platform=$BUILDPLATFORM golang:1.25.6 AS builder WORKDIR /app diff --git a/go.mod b/go.mod index 31ed70f..9580abc 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/skidoodle/safebin -go 1.25.5 +go 1.25.6 diff --git a/internal/app/config.go b/internal/app/config.go index 82d7ff7..a953fa3 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -21,36 +21,54 @@ type App struct { Logger *slog.Logger } -func LoadConfig() Config { - h := getEnv("SAFEBIN_HOST", "0.0.0.0") - p := getEnvInt("SAFEBIN_PORT", 8080) - s := getEnv("SAFEBIN_STORAGE", "./storage") - mDefault := int64(getEnvInt("SAFEBIN_MAX_MB", 512)) +const ( + defaultHost = "0.0.0.0" + defaultPort = 8080 + defaultStorage = "./storage" + defaultMaxMB = 512 +) - var m int64 - flag.StringVar(&h, "h", h, "Bind address") - flag.IntVar(&p, "p", p, "Port") - flag.StringVar(&s, "s", s, "Storage directory") - flag.Int64Var(&m, "m", mDefault, "Max file size in MB") +func LoadConfig() Config { + hostEnv := getEnv("SAFEBIN_HOST", defaultHost) + portEnv := getEnvInt("SAFEBIN_PORT", defaultPort) + storageEnv := getEnv("SAFEBIN_STORAGE", defaultStorage) + maxMBEnv := int64(getEnvInt("SAFEBIN_MAX_MB", defaultMaxMB)) + + var host string + var port int + var storage string + var maxMB int64 + + flag.StringVar(&host, "h", hostEnv, "Bind address") + flag.IntVar(&port, "p", portEnv, "Port") + flag.StringVar(&storage, "s", storageEnv, "Storage directory") + flag.Int64Var(&maxMB, "m", maxMBEnv, "Max file size in MB") flag.Parse() - return Config{Addr: fmt.Sprintf("%s:%d", h, p), StorageDir: s, MaxMB: m} -} - -func getEnv(k, f string) string { - if v, ok := os.LookupEnv(k); ok { - return v + return Config{ + Addr: fmt.Sprintf("%s:%d", host, port), + StorageDir: storage, + MaxMB: maxMB, } - return f } -func getEnvInt(k string, f int) int { - if v, ok := os.LookupEnv(k); ok { - if i, err := strconv.Atoi(v); err == nil { +func getEnv(key, fallback string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + + return fallback +} + +func getEnvInt(key string, fallback int) int { + if value, ok := os.LookupEnv(key); ok { + i, err := strconv.Atoi(value) + if err == nil { return i } } - return f + + return fallback } func ParseTemplates() *template.Template { diff --git a/internal/app/handlers.go b/internal/app/handlers.go index 9c93b40..7974762 100644 --- a/internal/app/handlers.go +++ b/internal/app/handlers.go @@ -2,6 +2,7 @@ package app import ( "encoding/base64" + "errors" "fmt" "io" "mime" @@ -14,187 +15,268 @@ import ( "github.com/skidoodle/safebin/internal/crypto" ) +const ( + uploadChunkSize = 8 << 20 + maxRequestOverhead = 10 << 20 + permUserRWX = 0o700 + slugLength = 22 + keyLength = 16 + megaByte = 1 << 20 + chunkSafetyMargin = 2 +) + var reUploadID = regexp.MustCompile(`^[a-zA-Z0-9]{10,50}$`) -func (app *App) HandleHome(w http.ResponseWriter, r *http.Request) { - err := app.Tmpl.ExecuteTemplate(w, "base", map[string]any{ +func (app *App) HandleHome(writer http.ResponseWriter, request *http.Request) { + err := app.Tmpl.ExecuteTemplate(writer, "base", map[string]any{ "MaxMB": app.Conf.MaxMB, - "Host": r.Host, + "Host": request.Host, }) + if err != nil { app.Logger.Error("Template error", "err", err) } } -func (app *App) HandleUpload(w http.ResponseWriter, r *http.Request) { - limit := (app.Conf.MaxMB << 20) + (1 << 20) - r.Body = http.MaxBytesReader(w, r.Body, limit) +func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) { + limit := (app.Conf.MaxMB * megaByte) + megaByte + request.Body = http.MaxBytesReader(writer, request.Body, limit) + + file, header, err := request.FormFile("file") - file, header, err := r.FormFile("file") if err != nil { if err.Error() == "http: request body too large" { - app.SendError(w, r, http.StatusRequestEntityTooLarge) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - app.SendError(w, r, http.StatusBadRequest) + + app.SendError(writer, request, http.StatusBadRequest) + return } - defer file.Close() - tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", fmt.Sprintf("up_%d", os.Getpid())) - tmp, err := os.Create(tmpPath) + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close upload file", "err", closeErr) + } + }() + + tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, "tmp"), "up_*") + if err != nil { app.Logger.Error("Failed to create temp file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } - defer os.Remove(tmpPath) - defer tmp.Close() + + tmpPath := tmp.Name() + + defer func() { + if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove temp file", "err", removeErr) + } + }() + + defer func() { + if closeErr := tmp.Close(); closeErr != nil { + app.Logger.Error("Failed to close temp file", "err", closeErr) + } + }() if _, err := io.Copy(tmp, file); err != nil { app.Logger.Error("Failed to write temp file", "err", err) - app.SendError(w, r, http.StatusRequestEntityTooLarge) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + return } - app.FinalizeFile(w, r, tmp, header.Filename) + app.FinalizeFile(writer, request, tmp, header.Filename) } -func (app *App) HandleChunk(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, 10<<20) +func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { + request.Body = http.MaxBytesReader(writer, request.Body, maxRequestOverhead) - uid := r.FormValue("upload_id") - idx, err := strconv.Atoi(r.FormValue("index")) + uid := request.FormValue("upload_id") + + idx, err := strconv.Atoi(request.FormValue("index")) if err != nil { - app.SendError(w, r, http.StatusBadRequest) + app.SendError(writer, request, http.StatusBadRequest) return } - const chunkSize = 8 << 20 - maxChunks := int((app.Conf.MaxMB<<20)/chunkSize) + 2 + maxChunks := int((app.Conf.MaxMB*megaByte)/uploadChunkSize) + chunkSafetyMargin if !reUploadID.MatchString(uid) || idx > maxChunks || idx < 0 { - app.SendError(w, r, http.StatusBadRequest) + app.SendError(writer, request, http.StatusBadRequest) return } - file, _, err := r.FormFile("chunk") + file, _, err := request.FormFile("chunk") + if err != nil { if err.Error() == "http: request body too large" { - app.SendError(w, r, http.StatusRequestEntityTooLarge) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - app.SendError(w, r, http.StatusBadRequest) + + app.SendError(writer, request, http.StatusBadRequest) + return } - defer file.Close() + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk file", "err", closeErr) + } + }() + + if err := app.saveChunk(uid, idx, file); err != nil { + app.Logger.Error("Failed to save chunk", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + } +} + +func (app *App) saveChunk(uid string, idx int, src io.Reader) error { dir := filepath.Join(app.Conf.StorageDir, "tmp", uid) - if err := os.MkdirAll(dir, 0700); err != nil { - app.Logger.Error("Failed to create chunk dir", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + + if err := os.MkdirAll(dir, permUserRWX); err != nil { + return fmt.Errorf("create chunk dir: %w", err) } dest, err := os.Create(filepath.Join(dir, strconv.Itoa(idx))) if err != nil { - app.Logger.Error("Failed to create chunk file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + return fmt.Errorf("create chunk file: %w", err) + } + + defer func() { + if closeErr := dest.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk dest", "err", closeErr) + } + }() + + if _, err := io.Copy(dest, src); err != nil { + return fmt.Errorf("copy chunk: %w", err) + } + + return nil +} + +func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) { + uid := request.FormValue("upload_id") + + total, err := strconv.Atoi(request.FormValue("total")) + if err != nil { + app.SendError(writer, request, http.StatusBadRequest) return } - defer dest.Close() - if _, err := io.Copy(dest, file); err != nil { - app.Logger.Error("Failed to save chunk", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + maxChunks := int((app.Conf.MaxMB*megaByte)/uploadChunkSize) + chunkSafetyMargin + + if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 { + app.SendError(writer, request, http.StatusBadRequest) return } + + mergedPath, err := app.mergeChunks(uid, total) + + if err != nil { + app.Logger.Error("Merge failed", "err", err) + + if errors.Is(err, io.ErrShortWrite) { + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + } else { + app.SendError(writer, request, http.StatusInternalServerError) + } + + return + } + + defer func() { + if removeErr := os.Remove(mergedPath); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove merged file", "err", removeErr) + } + }() + + mergedRead, err := os.Open(mergedPath) + + if err != nil { + app.Logger.Error("Failed to open merged file", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + + return + } + + defer func() { + if closeErr := mergedRead.Close(); closeErr != nil { + app.Logger.Error("Failed to close merged reader", "err", closeErr) + } + }() + + app.FinalizeFile(writer, request, mergedRead, request.FormValue("filename")) + + if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, "tmp", uid)); err != nil { + app.Logger.Error("Failed to remove chunk dir", "err", err) + } } -func (app *App) HandleFinish(w http.ResponseWriter, r *http.Request) { - uid := r.FormValue("upload_id") - total, err := strconv.Atoi(r.FormValue("total")) - if err != nil { - app.SendError(w, r, http.StatusBadRequest) - return - } - - const chunkSize = 8 << 20 - maxChunks := int((app.Conf.MaxMB<<20)/chunkSize) + 2 - - if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 { - app.SendError(w, r, http.StatusBadRequest) - return - } - +func (app *App) mergeChunks(uid string, total int) (string, error) { tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", "m_"+uid) + merged, err := os.Create(tmpPath) if err != nil { - app.Logger.Error("Failed to create merge file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return "", fmt.Errorf("create merge file: %w", err) } - defer os.Remove(tmpPath) - defer merged.Close() - limit := app.Conf.MaxMB << 20 + defer func() { + if closeErr := merged.Close(); closeErr != nil { + app.Logger.Error("Failed to close merged file", "err", closeErr) + } + }() + + limit := app.Conf.MaxMB * megaByte var written int64 for i := range total { partPath := filepath.Join(app.Conf.StorageDir, "tmp", uid, strconv.Itoa(i)) + part, err := os.Open(partPath) if err != nil { - app.Logger.Error("Missing chunk during merge", "uid", uid, "index", i, "err", err) - app.SendError(w, r, http.StatusBadRequest) - return + return "", fmt.Errorf("open chunk %d: %w", i, err) } n, err := io.Copy(merged, part) - part.Close() + + if closeErr := part.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk part", "err", closeErr) + } + if err != nil { - app.Logger.Error("Failed to append chunk", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return "", fmt.Errorf("append chunk %d: %w", i, err) } written += n if written > limit { - app.SendError(w, r, http.StatusRequestEntityTooLarge) - return + return "", io.ErrShortWrite } } - if err := merged.Close(); err != nil { - app.Logger.Error("Failed to close merged file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return - } - - mergedRead, err := os.Open(tmpPath) - if err != nil { - app.Logger.Error("Failed to open merged file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return - } - defer mergedRead.Close() - - app.FinalizeFile(w, r, mergedRead, r.FormValue("filename")) - os.RemoveAll(filepath.Join(app.Conf.StorageDir, "tmp", uid)) + return tmpPath, nil } -func (app *App) HandleGetFile(w http.ResponseWriter, r *http.Request) { - slug := r.PathValue("slug") - if len(slug) < 22 { - app.SendError(w, r, http.StatusBadRequest) +func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request) { + slug := request.PathValue("slug") + if len(slug) < slugLength { + app.SendError(writer, request, http.StatusBadRequest) return } - keyBase64 := slug[:22] - ext := slug[22:] + keyBase64 := slug[:slugLength] + ext := slug[slugLength:] key, err := base64.RawURLEncoding.DecodeString(keyBase64) - if err != nil || len(key) != 16 { - app.SendError(w, r, http.StatusUnauthorized) + if err != nil || len(key) != keyLength { + app.SendError(writer, request, http.StatusUnauthorized) return } @@ -203,105 +285,133 @@ func (app *App) HandleGetFile(w http.ResponseWriter, r *http.Request) { info, err := os.Stat(path) if err != nil { - app.SendError(w, r, http.StatusNotFound) + app.SendError(writer, request, http.StatusNotFound) return } - f, err := os.Open(path) + file, err := os.Open(path) + if err != nil { app.Logger.Error("Failed to open file", "path", path, "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } - defer f.Close() + + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close file", "err", closeErr) + } + }() streamer, err := crypto.NewGCMStreamer(key) + if err != nil { app.Logger.Error("Failed to create crypto streamer", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } - decryptor := crypto.NewDecryptor(f, streamer.AEAD, info.Size()) + decryptor := crypto.NewDecryptor(file, streamer.AEAD, info.Size()) contentType := mime.TypeByExtension(ext) if contentType == "" { contentType = "application/octet-stream" } - w.Header().Set("Content-Type", contentType) - w.Header().Set("Content-Security-Policy", "default-src 'none'; img-src 'self' data:; media-src 'self' data:; style-src 'unsafe-inline'; sandbox allow-forms allow-scripts allow-downloads allow-same-origin") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", slug)) + csp := "default-src 'none'; img-src 'self' data:; media-src 'self' data:; " + + "style-src 'unsafe-inline'; sandbox allow-forms allow-scripts allow-downloads allow-same-origin" - http.ServeContent(w, r, slug, info.ModTime(), decryptor) + writer.Header().Set("Content-Type", contentType) + writer.Header().Set("Content-Security-Policy", csp) + writer.Header().Set("X-Content-Type-Options", "nosniff") + writer.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", slug)) + + http.ServeContent(writer, request, slug, info.ModTime(), decryptor) } -func (app *App) FinalizeFile(w http.ResponseWriter, r *http.Request, src *os.File, filename string) { +func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, src *os.File, filename string) { if _, err := src.Seek(0, 0); err != nil { app.Logger.Error("Seek failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } key, err := crypto.DeriveKey(src) + if err != nil { app.Logger.Error("Key derivation failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + app.SendError(writer, request, http.StatusInternalServerError) + return } ext := filepath.Ext(filename) id := crypto.GetID(key, ext) - - if _, err := src.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return - } - finalPath := filepath.Join(app.Conf.StorageDir, id) if _, err := os.Stat(finalPath); err == nil { - app.RespondWithLink(w, r, key, filename) + app.RespondWithLink(writer, request, key, filename) return } - out, err := os.Create(finalPath + ".tmp") - if err != nil { - app.Logger.Error("Failed to create final file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) + if _, err := src.Seek(0, 0); err != nil { + app.Logger.Error("Seek failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return } + + if err := app.encryptAndSave(src, key, finalPath); err != nil { + app.Logger.Error("Encryption failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + + return + } + + app.RespondWithLink(writer, request, key, filename) +} + +func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { + out, err := os.Create(finalPath + ".tmp") + if err != nil { + return fmt.Errorf("create final file: %w", err) + } + + var closed bool + defer func() { - out.Close() - os.Remove(finalPath + ".tmp") + if !closed { + if closeErr := out.Close(); closeErr != nil { + app.Logger.Error("Failed to close final file", "err", closeErr) + } + } + + if removeErr := os.Remove(finalPath + ".tmp"); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove temp final file", "err", removeErr) + } }() streamer, err := crypto.NewGCMStreamer(key) if err != nil { - app.Logger.Error("Failed to create streamer", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("create streamer: %w", err) } if err := streamer.EncryptStream(out, src); err != nil { - app.Logger.Error("Encryption failed", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("encrypt stream: %w", err) } if err := out.Close(); err != nil { - app.Logger.Error("Failed to close final file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("close final file: %w", err) } + closed = true + if err := os.Rename(finalPath+".tmp", finalPath); err != nil { - app.Logger.Error("Failed to rename final file", "err", err) - app.SendError(w, r, http.StatusInternalServerError) - return + return fmt.Errorf("rename final file: %w", err) } - app.RespondWithLink(w, r, key, filename) + return nil } diff --git a/internal/app/server.go b/internal/app/server.go index 953d410..28b2d22 100644 --- a/internal/app/server.go +++ b/internal/app/server.go @@ -9,22 +9,25 @@ import ( func (app *App) Routes() *http.ServeMux { mux := http.NewServeMux() - fs := http.FileServer(http.Dir("./web/static")) - mux.Handle("GET /static/", http.StripPrefix("/static/", fs)) + fileServer := http.FileServer(http.Dir("./web/static")) + + mux.Handle("GET /static/", http.StripPrefix("/static/", fileServer)) mux.HandleFunc("GET /{$}", app.HandleHome) mux.HandleFunc("POST /{$}", app.HandleUpload) mux.HandleFunc("POST /upload/chunk", app.HandleChunk) mux.HandleFunc("POST /upload/finish", app.HandleFinish) mux.HandleFunc("GET /{slug}", app.HandleGetFile) + return mux } -func (app *App) RespondWithLink(w http.ResponseWriter, r *http.Request, key []byte, originalName string) { +func (app *App) RespondWithLink(writer http.ResponseWriter, request *http.Request, key []byte, originalName string) { keySlug := base64.RawURLEncoding.EncodeToString(key) ext := filepath.Ext(originalName) - link := fmt.Sprintf("%s/%s%s", r.Host, keySlug, ext) - if r.Header.Get("X-Requested-With") == "XMLHttpRequest" { - fmt.Fprintf(w, ` + link := fmt.Sprintf("%s/%s%s", request.Host, keySlug, ext) + + if request.Header.Get("X-Requested-With") == "XMLHttpRequest" { + html := `
Upload Complete:
@@ -34,27 +37,44 @@ func (app *App) RespondWithLink(w http.ResponseWriter, r *http.Request, key []by
-
`, link) +
` + + if _, err := fmt.Fprintf(writer, html, link); err != nil { + app.Logger.Error("Failed to write response", "err", err) + } + return } + scheme := "https" - if r.TLS == nil { + + if request.TLS == nil { scheme = "http" } - fmt.Fprintf(w, "%s://%s\n", scheme, link) + + if _, err := fmt.Fprintf(writer, "%s://%s\n", scheme, link); err != nil { + app.Logger.Error("Failed to write response", "err", err) + } } -func (app *App) SendError(w http.ResponseWriter, r *http.Request, code int) { - if r.Header.Get("X-Requested-With") == "XMLHttpRequest" { - w.WriteHeader(code) - fmt.Fprintf(w, ` +func (app *App) SendError(writer http.ResponseWriter, request *http.Request, code int) { + if request.Header.Get("X-Requested-With") == "XMLHttpRequest" { + writer.WriteHeader(code) + + html := `
Error %d
-
`, code) + ` + + if _, err := fmt.Fprintf(writer, html, code); err != nil { + app.Logger.Error("Failed to write error response", "err", err) + } + return } - http.Error(w, http.StatusText(code), code) + + http.Error(writer, http.StatusText(code), code) } diff --git a/internal/app/storage.go b/internal/app/storage.go index 6f30153..493a881 100644 --- a/internal/app/storage.go +++ b/internal/app/storage.go @@ -8,43 +8,82 @@ import ( "time" ) +const ( + cleanupInterval = 1 * time.Hour + tempExpiry = 4 * time.Hour + minRetention = 24 * time.Hour + maxRetention = 365 * 24 * time.Hour + bytesInMB = 1 << 20 +) + func (app *App) StartCleanupTask(ctx context.Context) { - ticker := time.NewTicker(1 * time.Hour) + ticker := time.NewTicker(cleanupInterval) + for { select { case <-ctx.Done(): + ticker.Stop() return case <-ticker.C: - app.CleanDir(app.Conf.StorageDir, false) - app.CleanDir(filepath.Join(app.Conf.StorageDir, "tmp"), true) + app.CleanStorage(app.Conf.StorageDir) + app.CleanTemp(filepath.Join(app.Conf.StorageDir, "tmp")) } } } -func (app *App) CleanDir(path string, isTmp bool) { - entries, _ := os.ReadDir(path) +func (app *App) CleanStorage(path string) { + entries, err := os.ReadDir(path) + if err != nil { + app.Logger.Error("Failed to read storage dir", "err", err) + return + } + for _, entry := range entries { - info, _ := entry.Info() - expiry := 4 * time.Hour - if !isTmp { - expiry = CalculateRetention(info.Size(), app.Conf.MaxMB) + info, err := entry.Info() + if err != nil { + continue } + expiry := CalculateRetention(info.Size(), app.Conf.MaxMB) + if time.Since(info.ModTime()) > expiry { - os.RemoveAll(filepath.Join(path, entry.Name())) + if err := os.RemoveAll(filepath.Join(path, entry.Name())); err != nil { + app.Logger.Error("Failed to remove expired file", "path", entry.Name(), "err", err) + } } } } -func CalculateRetention(fileSize int64, maxMB int64) time.Duration { - const ( - minAge = 24 * time.Hour - maxAge = 365 * 24 * time.Hour - ) - ratio := math.Max(0, math.Min(1, float64(fileSize)/float64(maxMB<<20))) - retention := float64(maxAge) * math.Pow(1.0-ratio, 3) - if retention < float64(minAge) { - return minAge +func (app *App) CleanTemp(path string) { + entries, err := os.ReadDir(path) + if err != nil { + app.Logger.Error("Failed to read temp dir", "err", err) + return } + + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + continue + } + + if time.Since(info.ModTime()) > tempExpiry { + if err := os.RemoveAll(filepath.Join(path, entry.Name())); err != nil { + app.Logger.Error("Failed to remove expired temp file", "path", entry.Name(), "err", err) + } + } + } +} + +func CalculateRetention(fileSize, maxMB int64) time.Duration { + ratio := math.Max(0, math.Min(1, float64(fileSize)/float64(maxMB*bytesInMB))) + + invRatio := 1.0 - ratio + retention := float64(maxRetention) * (invRatio * invRatio * invRatio) + + if retention < float64(minRetention) { + return minRetention + } + return time.Duration(retention) } diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index 1c59bf4..a6fc9d8 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -6,27 +6,34 @@ import ( "crypto/sha256" "encoding/base64" "encoding/binary" + "errors" + "fmt" "io" ) const ( GCMChunkSize = 64 * 1024 NonceSize = 12 + KeySize = 16 + IDSize = 9 ) -func DeriveKey(r io.Reader) ([]byte, error) { - h := sha256.New() - if _, err := io.Copy(h, r); err != nil { - return nil, err +func DeriveKey(reader io.Reader) ([]byte, error) { + hasher := sha256.New() + + if _, err := io.Copy(hasher, reader); err != nil { + return nil, fmt.Errorf("failed to copy to hasher: %w", err) } - return h.Sum(nil)[:16], nil + + return hasher.Sum(nil)[:KeySize], nil } func GetID(key []byte, ext string) string { - h := sha256.New() - h.Write(key) - h.Write([]byte(ext)) - return base64.RawURLEncoding.EncodeToString(h.Sum(nil)[:9]) + hasher := sha256.New() + hasher.Write(key) + hasher.Write([]byte(ext)) + + return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)[:IDSize]) } type GCMStreamer struct { @@ -34,37 +41,46 @@ type GCMStreamer struct { } func NewGCMStreamer(key []byte) (*GCMStreamer, error) { - b, err := aes.NewCipher(key) + block, err := aes.NewCipher(key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create cipher: %w", err) } - g, err := cipher.NewGCM(b) + + gcm, err := cipher.NewGCM(block) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create GCM: %w", err) } - return &GCMStreamer{AEAD: g}, nil + + return &GCMStreamer{AEAD: gcm}, nil } func (g *GCMStreamer) EncryptStream(dst io.Writer, src io.Reader) error { buf := make([]byte, GCMChunkSize) - var chunkIdx uint64 = 0 + var chunkIdx uint64 + for { - n, err := io.ReadFull(src, buf) - if n > 0 { + bytesRead, err := io.ReadFull(src, buf) + if bytesRead > 0 { nonce := make([]byte, NonceSize) binary.BigEndian.PutUint64(nonce[4:], chunkIdx) - ciphertext := g.AEAD.Seal(nil, nonce, buf[:n], nil) + + ciphertext := g.AEAD.Seal(nil, nonce, buf[:bytesRead], nil) + if _, werr := dst.Write(ciphertext); werr != nil { - return werr + return fmt.Errorf("failed to write ciphertext: %w", werr) } + chunkIdx++ } - if err == io.EOF || err == io.ErrUnexpectedEOF { + + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { break } + if err != nil { - return err + return fmt.Errorf("failed to read source: %w", err) } } + return nil } diff --git a/internal/crypto/reader.go b/internal/crypto/reader.go index 7aad6f5..c0dd702 100644 --- a/internal/crypto/reader.go +++ b/internal/crypto/reader.go @@ -4,34 +4,41 @@ import ( "crypto/cipher" "encoding/binary" "errors" + "fmt" "io" ) +var ErrInvalidWhence = errors.New("invalid whence") +var ErrNegativeBias = errors.New("negative bias") + type Decryptor struct { - rs io.ReadSeeker - aead cipher.AEAD - size int64 - offset int64 + readSeeker io.ReadSeeker + aead cipher.AEAD + size int64 + offset int64 } -func NewDecryptor(rs io.ReadSeeker, aead cipher.AEAD, encryptedSize int64) *Decryptor { +func NewDecryptor(readSeeker io.ReadSeeker, aead cipher.AEAD, encryptedSize int64) *Decryptor { overhead := int64(aead.Overhead()) - fullBlocks := encryptedSize / (GCMChunkSize + overhead) - remainder := encryptedSize % (GCMChunkSize + overhead) + chunkWithOverhead := int64(GCMChunkSize) + overhead - plainSize := (fullBlocks * GCMChunkSize) + fullBlocks := encryptedSize / chunkWithOverhead + remainder := encryptedSize % chunkWithOverhead + + plainSize := fullBlocks * GCMChunkSize if remainder > overhead { plainSize += (remainder - overhead) } return &Decryptor{ - rs: rs, - aead: aead, - size: plainSize, + readSeeker: readSeeker, + aead: aead, + size: plainSize, + offset: 0, } } -func (d *Decryptor) Read(p []byte) (int, error) { +func (d *Decryptor) Read(buf []byte) (int, error) { if d.offset >= d.size { return 0, io.EOF } @@ -40,25 +47,29 @@ func (d *Decryptor) Read(p []byte) (int, error) { overhang := d.offset % GCMChunkSize overhead := int64(d.aead.Overhead()) - actualChunkSize := int64(GCMChunkSize + overhead) + actualChunkSize := int64(GCMChunkSize) + overhead - _, err := d.rs.Seek(chunkIdx*actualChunkSize, io.SeekStart) + _, err := d.readSeeker.Seek(chunkIdx*actualChunkSize, io.SeekStart) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to seek: %w", err) } encrypted := make([]byte, actualChunkSize) - n, err := io.ReadFull(d.rs, encrypted) - if err != nil && err != io.ErrUnexpectedEOF { - return 0, err + + bytesRead, err := io.ReadFull(d.readSeeker, encrypted) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { + return 0, fmt.Errorf("failed to read encrypted data: %w", err) } nonce := make([]byte, NonceSize) + if chunkIdx < 0 { + return 0, fmt.Errorf("invalid chunk index") + } binary.BigEndian.PutUint64(nonce[4:], uint64(chunkIdx)) - plaintext, err := d.aead.Open(nil, nonce, encrypted[:n], nil) + plaintext, err := d.aead.Open(nil, nonce, encrypted[:bytesRead], nil) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to decrypt: %w", err) } if overhang >= int64(len(plaintext)) { @@ -66,7 +77,7 @@ func (d *Decryptor) Read(p []byte) (int, error) { } available := plaintext[overhang:] - nCopied := copy(p, available) + nCopied := copy(buf, available) d.offset += int64(nCopied) return nCopied, nil @@ -74,6 +85,7 @@ func (d *Decryptor) Read(p []byte) (int, error) { func (d *Decryptor) Seek(offset int64, whence int) (int64, error) { var abs int64 + switch whence { case io.SeekStart: abs = offset @@ -82,11 +94,14 @@ func (d *Decryptor) Seek(offset int64, whence int) (int64, error) { case io.SeekEnd: abs = d.size + offset default: - return 0, errors.New("invalid whence") + return 0, ErrInvalidWhence } + if abs < 0 { - return 0, errors.New("negative bias") + return 0, ErrNegativeBias } + d.offset = abs + return abs, nil } diff --git a/main.go b/main.go index 3f592d6..90f22b9 100644 --- a/main.go +++ b/main.go @@ -2,27 +2,39 @@ package main import ( "context" + "errors" "fmt" "log/slog" "net/http" "os" "os/signal" + "path/filepath" "syscall" "time" "github.com/skidoodle/safebin/internal/app" ) +const ( + permUserRWX = 0o700 + serverTimeout = 10 * time.Minute + shutdownTimeout = 10 * time.Second +) + func main() { cfg := app.LoadConfig() - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelDebug, + AddSource: true, + })) logger.Info("Initializing Safebin Server", "storage_dir", cfg.StorageDir, "max_file_size", fmt.Sprintf("%dMB", cfg.MaxMB), ) - if err := os.MkdirAll(fmt.Sprintf("%s/tmp", cfg.StorageDir), 0700); err != nil { + tmpDir := filepath.Join(cfg.StorageDir, "tmp") + if err := os.MkdirAll(tmpDir, permUserRWX); err != nil { logger.Error("Failed to initialize storage directory", "err", err) os.Exit(1) } @@ -41,13 +53,15 @@ func main() { srv := &http.Server{ Addr: cfg.Addr, Handler: application.Routes(), - ReadTimeout: 10 * time.Minute, - WriteTimeout: 10 * time.Minute, + ReadTimeout: serverTimeout, + WriteTimeout: serverTimeout, + IdleTimeout: serverTimeout, } go func() { application.Logger.Info("Server is ready and listening", "addr", cfg.Addr) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { application.Logger.Error("Server failed to start", "err", err) os.Exit(1) } @@ -56,10 +70,12 @@ func main() { <-ctx.Done() application.Logger.Info("Shutting down gracefully...") - shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { application.Logger.Error("Forced shutdown", "err", err) } + application.Logger.Info("Server stopped") }