diff --git a/internal/app/config.go b/internal/app/config.go index a953fa3..b22e05b 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -7,6 +7,30 @@ import ( "log/slog" "os" "strconv" + "time" +) + +const ( + DefaultHost = "0.0.0.0" + DefaultPort = 8080 + DefaultStorage = "./storage" + DefaultMaxMB = 512 + ServerTimeout = 10 * time.Minute + ShutdownTimeout = 10 * time.Second + + UploadChunkSize = 8 << 20 + MaxRequestOverhead = 10 << 20 + PermUserRWX = 0o700 + MegaByte = 1 << 20 + ChunkSafetyMargin = 2 + + SlugLength = 22 + KeyLength = 16 + + CleanupInterval = 1 * time.Hour + TempExpiry = 4 * time.Hour + MinRetention = 24 * time.Hour + MaxRetention = 365 * 24 * time.Hour ) type Config struct { @@ -21,18 +45,11 @@ type App struct { Logger *slog.Logger } -const ( - defaultHost = "0.0.0.0" - defaultPort = 8080 - defaultStorage = "./storage" - defaultMaxMB = 512 -) - func LoadConfig() Config { - hostEnv := getEnv("SAFEBIN_HOST", defaultHost) - portEnv := getEnvInt("SAFEBIN_PORT", defaultPort) - storageEnv := getEnv("SAFEBIN_STORAGE", defaultStorage) - maxMBEnv := int64(getEnvInt("SAFEBIN_MAX_MB", defaultMaxMB)) + hostEnv := getEnv("SAFEBIN_HOST", DefaultHost) + portEnv := getEnvInt("SAFEBIN_PORT", DefaultPort) + storageEnv := getEnv("SAFEBIN_STORAGE", DefaultStorage) + maxMBEnv := int64(getEnvInt("SAFEBIN_MAX_MB", DefaultMaxMB)) var host string var port int @@ -56,7 +73,6 @@ func getEnv(key, fallback string) string { if value, ok := os.LookupEnv(key); ok { return value } - return fallback } @@ -67,7 +83,6 @@ func getEnvInt(key string, fallback int) int { return i } } - return fallback } diff --git a/internal/app/download.go b/internal/app/download.go new file mode 100644 index 0000000..4951ed0 --- /dev/null +++ b/internal/app/download.go @@ -0,0 +1,77 @@ +package app + +import ( + "encoding/base64" + "fmt" + "mime" + "net/http" + "os" + "path/filepath" + + "github.com/skidoodle/safebin/internal/crypto" +) + +func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request) { + slug := request.PathValue("slug") + if len(slug) < SlugLength { + app.SendError(writer, request, http.StatusBadRequest) + return + } + + keyBase64 := slug[:SlugLength] + ext := slug[SlugLength:] + + key, err := base64.RawURLEncoding.DecodeString(keyBase64) + if err != nil || len(key) != KeyLength { + app.SendError(writer, request, http.StatusUnauthorized) + return + } + + id := crypto.GetID(key, ext) + path := filepath.Join(app.Conf.StorageDir, id) + + info, err := os.Stat(path) + if err != nil { + app.SendError(writer, request, http.StatusNotFound) + return + } + + file, err := os.Open(path) + + if err != nil { + app.Logger.Error("Failed to open file", "path", path, "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close file", "err", closeErr) + } + }() + + streamer, err := crypto.NewGCMStreamer(key) + + if err != nil { + app.Logger.Error("Failed to create crypto streamer", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + decryptor := crypto.NewDecryptor(file, streamer.AEAD, info.Size()) + + contentType := mime.TypeByExtension(ext) + if contentType == "" { + contentType = "application/octet-stream" + } + + csp := "default-src 'none'; img-src 'self' data:; media-src 'self' data:; " + + "style-src 'unsafe-inline'; sandbox allow-forms allow-scripts allow-downloads allow-same-origin" + + writer.Header().Set("Content-Type", contentType) + writer.Header().Set("Content-Security-Policy", csp) + writer.Header().Set("X-Content-Type-Options", "nosniff") + writer.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", slug)) + + http.ServeContent(writer, request, slug, info.ModTime(), decryptor) +} diff --git a/internal/app/handlers.go b/internal/app/handlers.go deleted file mode 100644 index 7974762..0000000 --- a/internal/app/handlers.go +++ /dev/null @@ -1,417 +0,0 @@ -package app - -import ( - "encoding/base64" - "errors" - "fmt" - "io" - "mime" - "net/http" - "os" - "path/filepath" - "regexp" - "strconv" - - "github.com/skidoodle/safebin/internal/crypto" -) - -const ( - uploadChunkSize = 8 << 20 - maxRequestOverhead = 10 << 20 - permUserRWX = 0o700 - slugLength = 22 - keyLength = 16 - megaByte = 1 << 20 - chunkSafetyMargin = 2 -) - -var reUploadID = regexp.MustCompile(`^[a-zA-Z0-9]{10,50}$`) - -func (app *App) HandleHome(writer http.ResponseWriter, request *http.Request) { - err := app.Tmpl.ExecuteTemplate(writer, "base", map[string]any{ - "MaxMB": app.Conf.MaxMB, - "Host": request.Host, - }) - - if err != nil { - app.Logger.Error("Template error", "err", err) - } -} - -func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) { - limit := (app.Conf.MaxMB * megaByte) + megaByte - request.Body = http.MaxBytesReader(writer, request.Body, limit) - - file, header, err := request.FormFile("file") - - if err != nil { - if err.Error() == "http: request body too large" { - app.SendError(writer, request, http.StatusRequestEntityTooLarge) - return - } - - app.SendError(writer, request, http.StatusBadRequest) - - return - } - - defer func() { - if closeErr := file.Close(); closeErr != nil { - app.Logger.Error("Failed to close upload file", "err", closeErr) - } - }() - - tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, "tmp"), "up_*") - - if err != nil { - app.Logger.Error("Failed to create temp file", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - tmpPath := tmp.Name() - - defer func() { - if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) { - app.Logger.Error("Failed to remove temp file", "err", removeErr) - } - }() - - defer func() { - if closeErr := tmp.Close(); closeErr != nil { - app.Logger.Error("Failed to close temp file", "err", closeErr) - } - }() - - if _, err := io.Copy(tmp, file); err != nil { - app.Logger.Error("Failed to write temp file", "err", err) - app.SendError(writer, request, http.StatusRequestEntityTooLarge) - - return - } - - app.FinalizeFile(writer, request, tmp, header.Filename) -} - -func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { - request.Body = http.MaxBytesReader(writer, request.Body, maxRequestOverhead) - - uid := request.FormValue("upload_id") - - idx, err := strconv.Atoi(request.FormValue("index")) - if err != nil { - app.SendError(writer, request, http.StatusBadRequest) - return - } - - maxChunks := int((app.Conf.MaxMB*megaByte)/uploadChunkSize) + chunkSafetyMargin - - if !reUploadID.MatchString(uid) || idx > maxChunks || idx < 0 { - app.SendError(writer, request, http.StatusBadRequest) - return - } - - file, _, err := request.FormFile("chunk") - - if err != nil { - if err.Error() == "http: request body too large" { - app.SendError(writer, request, http.StatusRequestEntityTooLarge) - return - } - - app.SendError(writer, request, http.StatusBadRequest) - - return - } - - defer func() { - if closeErr := file.Close(); closeErr != nil { - app.Logger.Error("Failed to close chunk file", "err", closeErr) - } - }() - - if err := app.saveChunk(uid, idx, file); err != nil { - app.Logger.Error("Failed to save chunk", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - } -} - -func (app *App) saveChunk(uid string, idx int, src io.Reader) error { - dir := filepath.Join(app.Conf.StorageDir, "tmp", uid) - - if err := os.MkdirAll(dir, permUserRWX); err != nil { - return fmt.Errorf("create chunk dir: %w", err) - } - - dest, err := os.Create(filepath.Join(dir, strconv.Itoa(idx))) - if err != nil { - return fmt.Errorf("create chunk file: %w", err) - } - - defer func() { - if closeErr := dest.Close(); closeErr != nil { - app.Logger.Error("Failed to close chunk dest", "err", closeErr) - } - }() - - if _, err := io.Copy(dest, src); err != nil { - return fmt.Errorf("copy chunk: %w", err) - } - - return nil -} - -func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) { - uid := request.FormValue("upload_id") - - total, err := strconv.Atoi(request.FormValue("total")) - if err != nil { - app.SendError(writer, request, http.StatusBadRequest) - return - } - - maxChunks := int((app.Conf.MaxMB*megaByte)/uploadChunkSize) + chunkSafetyMargin - - if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 { - app.SendError(writer, request, http.StatusBadRequest) - return - } - - mergedPath, err := app.mergeChunks(uid, total) - - if err != nil { - app.Logger.Error("Merge failed", "err", err) - - if errors.Is(err, io.ErrShortWrite) { - app.SendError(writer, request, http.StatusRequestEntityTooLarge) - } else { - app.SendError(writer, request, http.StatusInternalServerError) - } - - return - } - - defer func() { - if removeErr := os.Remove(mergedPath); removeErr != nil && !os.IsNotExist(removeErr) { - app.Logger.Error("Failed to remove merged file", "err", removeErr) - } - }() - - mergedRead, err := os.Open(mergedPath) - - if err != nil { - app.Logger.Error("Failed to open merged file", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - defer func() { - if closeErr := mergedRead.Close(); closeErr != nil { - app.Logger.Error("Failed to close merged reader", "err", closeErr) - } - }() - - app.FinalizeFile(writer, request, mergedRead, request.FormValue("filename")) - - if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, "tmp", uid)); err != nil { - app.Logger.Error("Failed to remove chunk dir", "err", err) - } -} - -func (app *App) mergeChunks(uid string, total int) (string, error) { - tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", "m_"+uid) - - merged, err := os.Create(tmpPath) - if err != nil { - return "", fmt.Errorf("create merge file: %w", err) - } - - defer func() { - if closeErr := merged.Close(); closeErr != nil { - app.Logger.Error("Failed to close merged file", "err", closeErr) - } - }() - - limit := app.Conf.MaxMB * megaByte - var written int64 - - for i := range total { - partPath := filepath.Join(app.Conf.StorageDir, "tmp", uid, strconv.Itoa(i)) - - part, err := os.Open(partPath) - if err != nil { - return "", fmt.Errorf("open chunk %d: %w", i, err) - } - - n, err := io.Copy(merged, part) - - if closeErr := part.Close(); closeErr != nil { - app.Logger.Error("Failed to close chunk part", "err", closeErr) - } - - if err != nil { - return "", fmt.Errorf("append chunk %d: %w", i, err) - } - - written += n - if written > limit { - return "", io.ErrShortWrite - } - } - - return tmpPath, nil -} - -func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request) { - slug := request.PathValue("slug") - if len(slug) < slugLength { - app.SendError(writer, request, http.StatusBadRequest) - return - } - - keyBase64 := slug[:slugLength] - ext := slug[slugLength:] - - key, err := base64.RawURLEncoding.DecodeString(keyBase64) - if err != nil || len(key) != keyLength { - app.SendError(writer, request, http.StatusUnauthorized) - return - } - - id := crypto.GetID(key, ext) - path := filepath.Join(app.Conf.StorageDir, id) - - info, err := os.Stat(path) - if err != nil { - app.SendError(writer, request, http.StatusNotFound) - return - } - - file, err := os.Open(path) - - if err != nil { - app.Logger.Error("Failed to open file", "path", path, "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - defer func() { - if closeErr := file.Close(); closeErr != nil { - app.Logger.Error("Failed to close file", "err", closeErr) - } - }() - - streamer, err := crypto.NewGCMStreamer(key) - - if err != nil { - app.Logger.Error("Failed to create crypto streamer", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - decryptor := crypto.NewDecryptor(file, streamer.AEAD, info.Size()) - - contentType := mime.TypeByExtension(ext) - if contentType == "" { - contentType = "application/octet-stream" - } - - csp := "default-src 'none'; img-src 'self' data:; media-src 'self' data:; " + - "style-src 'unsafe-inline'; sandbox allow-forms allow-scripts allow-downloads allow-same-origin" - - writer.Header().Set("Content-Type", contentType) - writer.Header().Set("Content-Security-Policy", csp) - writer.Header().Set("X-Content-Type-Options", "nosniff") - writer.Header().Set("Content-Disposition", fmt.Sprintf("inline; filename=%q", slug)) - - http.ServeContent(writer, request, slug, info.ModTime(), decryptor) -} - -func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, src *os.File, filename string) { - if _, err := src.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - key, err := crypto.DeriveKey(src) - - if err != nil { - app.Logger.Error("Key derivation failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - ext := filepath.Ext(filename) - id := crypto.GetID(key, ext) - finalPath := filepath.Join(app.Conf.StorageDir, id) - - if _, err := os.Stat(finalPath); err == nil { - app.RespondWithLink(writer, request, key, filename) - return - } - - if _, err := src.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - if err := app.encryptAndSave(src, key, finalPath); err != nil { - app.Logger.Error("Encryption failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - - return - } - - app.RespondWithLink(writer, request, key, filename) -} - -func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { - out, err := os.Create(finalPath + ".tmp") - if err != nil { - return fmt.Errorf("create final file: %w", err) - } - - var closed bool - - defer func() { - if !closed { - if closeErr := out.Close(); closeErr != nil { - app.Logger.Error("Failed to close final file", "err", closeErr) - } - } - - if removeErr := os.Remove(finalPath + ".tmp"); removeErr != nil && !os.IsNotExist(removeErr) { - app.Logger.Error("Failed to remove temp final file", "err", removeErr) - } - }() - - streamer, err := crypto.NewGCMStreamer(key) - if err != nil { - return fmt.Errorf("create streamer: %w", err) - } - - if err := streamer.EncryptStream(out, src); err != nil { - return fmt.Errorf("encrypt stream: %w", err) - } - - if err := out.Close(); err != nil { - return fmt.Errorf("close final file: %w", err) - } - - closed = true - - if err := os.Rename(finalPath+".tmp", finalPath); err != nil { - return fmt.Errorf("rename final file: %w", err) - } - - return nil -} diff --git a/internal/app/server.go b/internal/app/server.go index 28b2d22..c10c8bf 100644 --- a/internal/app/server.go +++ b/internal/app/server.go @@ -21,6 +21,17 @@ func (app *App) Routes() *http.ServeMux { return mux } +func (app *App) HandleHome(writer http.ResponseWriter, request *http.Request) { + err := app.Tmpl.ExecuteTemplate(writer, "base", map[string]any{ + "MaxMB": app.Conf.MaxMB, + "Host": request.Host, + }) + + if err != nil { + app.Logger.Error("Template error", "err", err) + } +} + func (app *App) RespondWithLink(writer http.ResponseWriter, request *http.Request, key []byte, originalName string) { keySlug := base64.RawURLEncoding.EncodeToString(key) ext := filepath.Ext(originalName) @@ -42,12 +53,10 @@ func (app *App) RespondWithLink(writer http.ResponseWriter, request *http.Reques if _, err := fmt.Fprintf(writer, html, link); err != nil { app.Logger.Error("Failed to write response", "err", err) } - return } scheme := "https" - if request.TLS == nil { scheme = "http" } @@ -72,7 +81,6 @@ func (app *App) SendError(writer http.ResponseWriter, request *http.Request, cod if _, err := fmt.Fprintf(writer, html, code); err != nil { app.Logger.Error("Failed to write error response", "err", err) } - return } diff --git a/internal/app/storage.go b/internal/app/storage.go index 493a881..d37e1ae 100644 --- a/internal/app/storage.go +++ b/internal/app/storage.go @@ -2,22 +2,19 @@ package app import ( "context" + "fmt" + "io" "math" "os" "path/filepath" + "strconv" "time" -) -const ( - cleanupInterval = 1 * time.Hour - tempExpiry = 4 * time.Hour - minRetention = 24 * time.Hour - maxRetention = 365 * 24 * time.Hour - bytesInMB = 1 << 20 + "github.com/skidoodle/safebin/internal/crypto" ) func (app *App) StartCleanupTask(ctx context.Context) { - ticker := time.NewTicker(cleanupInterval) + ticker := time.NewTicker(CleanupInterval) for { select { @@ -31,6 +28,117 @@ func (app *App) StartCleanupTask(ctx context.Context) { } } +func (app *App) saveChunk(uid string, idx int, src io.Reader) error { + dir := filepath.Join(app.Conf.StorageDir, "tmp", uid) + + if err := os.MkdirAll(dir, PermUserRWX); err != nil { + return fmt.Errorf("create chunk dir: %w", err) + } + + dest, err := os.Create(filepath.Join(dir, strconv.Itoa(idx))) + if err != nil { + return fmt.Errorf("create chunk file: %w", err) + } + + defer func() { + if closeErr := dest.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk dest", "err", closeErr) + } + }() + + if _, err := io.Copy(dest, src); err != nil { + return fmt.Errorf("copy chunk: %w", err) + } + + return nil +} + +func (app *App) mergeChunks(uid string, total int) (string, error) { + tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", "m_"+uid) + + merged, err := os.Create(tmpPath) + if err != nil { + return "", fmt.Errorf("create merge file: %w", err) + } + + defer func() { + if closeErr := merged.Close(); closeErr != nil { + app.Logger.Error("Failed to close merged file", "err", closeErr) + } + }() + + limit := app.Conf.MaxMB * MegaByte + var written int64 + + for i := range total { + partPath := filepath.Join(app.Conf.StorageDir, "tmp", uid, strconv.Itoa(i)) + + part, err := os.Open(partPath) + if err != nil { + return "", fmt.Errorf("open chunk %d: %w", i, err) + } + + n, err := io.Copy(merged, part) + + if closeErr := part.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk part", "err", closeErr) + } + + if err != nil { + return "", fmt.Errorf("append chunk %d: %w", i, err) + } + + written += n + if written > limit { + return "", io.ErrShortWrite + } + } + + return tmpPath, nil +} + +func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { + out, err := os.Create(finalPath + ".tmp") + if err != nil { + return fmt.Errorf("create final file: %w", err) + } + + var closed bool + + defer func() { + if !closed { + if closeErr := out.Close(); closeErr != nil { + app.Logger.Error("Failed to close final file", "err", closeErr) + } + } + + if removeErr := os.Remove(finalPath + ".tmp"); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove temp final file", "err", removeErr) + } + }() + + streamer, err := crypto.NewGCMStreamer(key) + if err != nil { + return fmt.Errorf("create streamer: %w", err) + } + + if err := streamer.EncryptStream(out, src); err != nil { + return fmt.Errorf("encrypt stream: %w", err) + } + + if err := out.Close(); err != nil { + return fmt.Errorf("close final file: %w", err) + } + + closed = true + + if err := os.Rename(finalPath+".tmp", finalPath); err != nil { + return fmt.Errorf("rename final file: %w", err) + } + + return nil +} + func (app *App) CleanStorage(path string) { entries, err := os.ReadDir(path) if err != nil { @@ -67,7 +175,7 @@ func (app *App) CleanTemp(path string) { continue } - if time.Since(info.ModTime()) > tempExpiry { + if time.Since(info.ModTime()) > TempExpiry { if err := os.RemoveAll(filepath.Join(path, entry.Name())); err != nil { app.Logger.Error("Failed to remove expired temp file", "path", entry.Name(), "err", err) } @@ -76,13 +184,13 @@ func (app *App) CleanTemp(path string) { } func CalculateRetention(fileSize, maxMB int64) time.Duration { - ratio := math.Max(0, math.Min(1, float64(fileSize)/float64(maxMB*bytesInMB))) + ratio := math.Max(0, math.Min(1, float64(fileSize)/float64(maxMB*MegaByte))) invRatio := 1.0 - ratio - retention := float64(maxRetention) * (invRatio * invRatio * invRatio) + retention := float64(MaxRetention) * (invRatio * invRatio * invRatio) - if retention < float64(minRetention) { - return minRetention + if retention < float64(MinRetention) { + return MinRetention } return time.Duration(retention) diff --git a/internal/app/upload.go b/internal/app/upload.go new file mode 100644 index 0000000..91be669 --- /dev/null +++ b/internal/app/upload.go @@ -0,0 +1,205 @@ +package app + +import ( + "errors" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "strconv" + + "github.com/skidoodle/safebin/internal/crypto" +) + +var reUploadID = regexp.MustCompile(`^[a-zA-Z0-9]{10,50}$`) + +func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) { + limit := (app.Conf.MaxMB * MegaByte) + MegaByte + request.Body = http.MaxBytesReader(writer, request.Body, limit) + + file, header, err := request.FormFile("file") + + if err != nil { + if err.Error() == "http: request body too large" { + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + return + } + + app.SendError(writer, request, http.StatusBadRequest) + return + } + + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close upload file", "err", closeErr) + } + }() + + tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, "tmp"), "up_*") + + if err != nil { + app.Logger.Error("Failed to create temp file", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + tmpPath := tmp.Name() + + defer func() { + if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove temp file", "err", removeErr) + } + }() + + defer func() { + if closeErr := tmp.Close(); closeErr != nil { + app.Logger.Error("Failed to close temp file", "err", closeErr) + } + }() + + if _, err := io.Copy(tmp, file); err != nil { + app.Logger.Error("Failed to write temp file", "err", err) + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + return + } + + app.FinalizeFile(writer, request, tmp, header.Filename) +} + +func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { + request.Body = http.MaxBytesReader(writer, request.Body, MaxRequestOverhead) + + uid := request.FormValue("upload_id") + + idx, err := strconv.Atoi(request.FormValue("index")) + if err != nil { + app.SendError(writer, request, http.StatusBadRequest) + return + } + + maxChunks := int((app.Conf.MaxMB*MegaByte)/UploadChunkSize) + ChunkSafetyMargin + + if !reUploadID.MatchString(uid) || idx > maxChunks || idx < 0 { + app.SendError(writer, request, http.StatusBadRequest) + return + } + + file, _, err := request.FormFile("chunk") + + if err != nil { + if err.Error() == "http: request body too large" { + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + return + } + + app.SendError(writer, request, http.StatusBadRequest) + return + } + + defer func() { + if closeErr := file.Close(); closeErr != nil { + app.Logger.Error("Failed to close chunk file", "err", closeErr) + } + }() + + if err := app.saveChunk(uid, idx, file); err != nil { + app.Logger.Error("Failed to save chunk", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + } +} + +func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) { + uid := request.FormValue("upload_id") + + total, err := strconv.Atoi(request.FormValue("total")) + if err != nil { + app.SendError(writer, request, http.StatusBadRequest) + return + } + + maxChunks := int((app.Conf.MaxMB*MegaByte)/UploadChunkSize) + ChunkSafetyMargin + + if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 { + app.SendError(writer, request, http.StatusBadRequest) + return + } + + mergedPath, err := app.mergeChunks(uid, total) + + if err != nil { + app.Logger.Error("Merge failed", "err", err) + + if errors.Is(err, io.ErrShortWrite) { + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + } else { + app.SendError(writer, request, http.StatusInternalServerError) + } + return + } + + defer func() { + if removeErr := os.Remove(mergedPath); removeErr != nil && !os.IsNotExist(removeErr) { + app.Logger.Error("Failed to remove merged file", "err", removeErr) + } + }() + + mergedRead, err := os.Open(mergedPath) + + if err != nil { + app.Logger.Error("Failed to open merged file", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + defer func() { + if closeErr := mergedRead.Close(); closeErr != nil { + app.Logger.Error("Failed to close merged reader", "err", closeErr) + } + }() + + app.FinalizeFile(writer, request, mergedRead, request.FormValue("filename")) + + if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, "tmp", uid)); err != nil { + app.Logger.Error("Failed to remove chunk dir", "err", err) + } +} + +func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, src *os.File, filename string) { + if _, err := src.Seek(0, 0); err != nil { + app.Logger.Error("Seek failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + key, err := crypto.DeriveKey(src) + + if err != nil { + app.Logger.Error("Key derivation failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + ext := filepath.Ext(filename) + id := crypto.GetID(key, ext) + finalPath := filepath.Join(app.Conf.StorageDir, id) + + if _, err := os.Stat(finalPath); err == nil { + app.RespondWithLink(writer, request, key, filename) + return + } + + if _, err := src.Seek(0, 0); err != nil { + app.Logger.Error("Seek failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + if err := app.encryptAndSave(src, key, finalPath); err != nil { + app.Logger.Error("Encryption failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + app.RespondWithLink(writer, request, key, filename) +} diff --git a/main.go b/main.go index 90f22b9..ed1fdb3 100644 --- a/main.go +++ b/main.go @@ -10,17 +10,10 @@ import ( "os/signal" "path/filepath" "syscall" - "time" "github.com/skidoodle/safebin/internal/app" ) -const ( - permUserRWX = 0o700 - serverTimeout = 10 * time.Minute - shutdownTimeout = 10 * time.Second -) - func main() { cfg := app.LoadConfig() logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ @@ -34,7 +27,7 @@ func main() { ) tmpDir := filepath.Join(cfg.StorageDir, "tmp") - if err := os.MkdirAll(tmpDir, permUserRWX); err != nil { + if err := os.MkdirAll(tmpDir, app.PermUserRWX); err != nil { logger.Error("Failed to initialize storage directory", "err", err) os.Exit(1) } @@ -53,9 +46,9 @@ func main() { srv := &http.Server{ Addr: cfg.Addr, Handler: application.Routes(), - ReadTimeout: serverTimeout, - WriteTimeout: serverTimeout, - IdleTimeout: serverTimeout, + ReadTimeout: app.ServerTimeout, + WriteTimeout: app.ServerTimeout, + IdleTimeout: app.ServerTimeout, } go func() { @@ -70,7 +63,7 @@ func main() { <-ctx.Done() application.Logger.Info("Shutting down gracefully...") - shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), app.ShutdownTimeout) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil {