diff --git a/internal/app/storage.go b/internal/app/storage.go index f707787..f1be775 100644 --- a/internal/app/storage.go +++ b/internal/app/storage.go @@ -55,48 +55,26 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error { return nil } -func (app *App) mergeChunks(uid string, total int) (string, error) { - tmpPath := filepath.Join(app.Conf.StorageDir, TempDirName, "m_"+uid) +func (app *App) openChunkFiles(uid string, total int) ([]*os.File, error) { + files := make([]*os.File, 0, total) - merged, err := os.Create(tmpPath) - if err != nil { - return "", fmt.Errorf("create merge file: %w", err) - } - - defer func() { - if closeErr := merged.Close(); closeErr != nil { - app.Logger.Error("Failed to close merged file", "err", closeErr) + closeAll := func() { + for _, f := range files { + _ = f.Close() } - }() - - limit := app.Conf.MaxMB * MegaByte - var written int64 + } for i := range total { partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(i)) - - part, err := os.Open(partPath) + f, err := os.Open(partPath) if err != nil { - return "", fmt.Errorf("open chunk %d: %w", i, err) - } - - n, err := io.Copy(merged, part) - - if closeErr := part.Close(); closeErr != nil { - app.Logger.Error("Failed to close chunk part", "err", closeErr) - } - - if err != nil { - return "", fmt.Errorf("append chunk %d: %w", i, err) - } - - written += n - if written > limit { - return "", io.ErrShortWrite + closeAll() + return nil, fmt.Errorf("open chunk %d: %w", i, err) } + files = append(files, f) } - return tmpPath, nil + return files, nil } func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { diff --git a/internal/app/storage_test.go b/internal/app/storage_test.go index 24463c4..3600501 100644 --- a/internal/app/storage_test.go +++ b/internal/app/storage_test.go @@ -10,46 +10,6 @@ import ( "go.etcd.io/bbolt" ) -func TestCleanup_AbandonedMerge(t *testing.T) { - tmpDir := t.TempDir() - tmpStorage := filepath.Join(tmpDir, TempDirName) - if err := os.MkdirAll(tmpStorage, 0700); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - - db, err := InitDB(tmpDir) - if err != nil { - t.Fatalf("InitDB failed: %v", err) - } - defer func() { - if err := db.Close(); err != nil { - t.Errorf("Failed to close DB: %v", err) - } - }() - - app := &App{ - Conf: Config{StorageDir: tmpDir}, - Logger: discardLogger(), - DB: db, - } - - abandonedFile := filepath.Join(tmpStorage, "m_abandoned_upload_id") - if err := os.WriteFile(abandonedFile, []byte("partial data"), 0600); err != nil { - t.Fatal(err) - } - - oldTime := time.Now().Add(-TempExpiry - time.Hour) - if err := os.Chtimes(abandonedFile, oldTime, oldTime); err != nil { - t.Fatal(err) - } - - app.CleanTemp(tmpStorage) - - if _, err := os.Stat(abandonedFile); !os.IsNotExist(err) { - t.Error("Cleanup failed to remove abandoned merge file from crashed session") - } -} - func TestCleanup_AbandonedChunks(t *testing.T) { tmpDir := t.TempDir() tmpStorage := filepath.Join(tmpDir, TempDirName) diff --git a/internal/app/upload.go b/internal/app/upload.go index bc710a7..dce2dad 100644 --- a/internal/app/upload.go +++ b/internal/app/upload.go @@ -1,7 +1,6 @@ package app import ( - "errors" "io" "net/http" "os" @@ -64,7 +63,26 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) return } - app.FinalizeFile(writer, request, tmp, header.Filename) + if _, err := tmp.Seek(0, 0); err != nil { + app.Logger.Error("Seek failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + key, err := crypto.DeriveKey(tmp) + if err != nil { + app.Logger.Error("Key derivation failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + if _, err := tmp.Seek(0, 0); err != nil { + app.Logger.Error("Seek failed", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + app.finalizeUpload(writer, request, tmp, key, header.Filename) } func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { @@ -125,61 +143,48 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) return } - mergedPath, err := app.mergeChunks(uid, total) - + files, err := app.openChunkFiles(uid, total) if err != nil { - app.Logger.Error("Merge failed", "err", err) - - if errors.Is(err, io.ErrShortWrite) { - app.SendError(writer, request, http.StatusRequestEntityTooLarge) - } else { - app.SendError(writer, request, http.StatusInternalServerError) - } - return - } - - defer func() { - if removeErr := os.Remove(mergedPath); removeErr != nil && !os.IsNotExist(removeErr) { - app.Logger.Error("Failed to remove merged file", "err", removeErr) - } - }() - - mergedRead, err := os.Open(mergedPath) - - if err != nil { - app.Logger.Error("Failed to open merged file", "err", err) + app.Logger.Error("Failed to open chunks", "err", err) app.SendError(writer, request, http.StatusInternalServerError) return } defer func() { - if closeErr := mergedRead.Close(); closeErr != nil { - app.Logger.Error("Failed to close merged reader", "err", closeErr) + for _, f := range files { + _ = f.Close() + } + if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil { + app.Logger.Error("Failed to remove chunk dir", "err", err) } }() - app.FinalizeFile(writer, request, mergedRead, request.FormValue("filename")) - - if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil { - app.Logger.Error("Failed to remove chunk dir", "err", err) - } -} - -func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, src *os.File, filename string) { - if _, err := src.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return + readers := make([]io.Reader, len(files)) + for i, f := range files { + readers[i] = f } - key, err := crypto.DeriveKey(src) - + key, err := crypto.DeriveKey(io.MultiReader(readers...)) if err != nil { app.Logger.Error("Key derivation failed", "err", err) app.SendError(writer, request, http.StatusInternalServerError) return } + for _, f := range files { + if _, err := f.Seek(0, 0); err != nil { + app.Logger.Error("Failed to reset chunk", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + } + + multiSrc := io.MultiReader(readers...) + + app.finalizeUpload(writer, request, multiSrc, key, request.FormValue("filename")) +} + +func (app *App) finalizeUpload(writer http.ResponseWriter, request *http.Request, src io.Reader, key []byte, filename string) { ext := filepath.Ext(filename) id := crypto.GetID(key, ext) finalPath := filepath.Join(app.Conf.StorageDir, id) @@ -192,12 +197,6 @@ func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, return } - if _, err := src.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return - } - if err := app.encryptAndSave(src, key, finalPath); err != nil { app.Logger.Error("Encryption failed", "err", err) app.SendError(writer, request, http.StatusInternalServerError)