diff --git a/internal/app/storage.go b/internal/app/storage.go index 018a483..732266d 100644 --- a/internal/app/storage.go +++ b/internal/app/storage.go @@ -70,56 +70,93 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error { return nil } -func (app *App) getChunkDecryptors(uid string, total int) ([]io.ReadSeeker, func(), error) { - files := make([]*os.File, 0, total) - decryptors := make([]io.ReadSeeker, 0, total) - - closeAll := func() { - for _, f := range files { - _ = f.Close() - } +func (app *App) openChunkDecryptor(uid string, idx int) (io.ReadCloser, error) { + partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(idx)) + f, err := os.Open(partPath) + if err != nil { + return nil, fmt.Errorf("open chunk %d: %w", idx, err) } - for i := range total { - partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(i)) - f, err := os.Open(partPath) - if err != nil { - closeAll() - return nil, nil, fmt.Errorf("open chunk %d: %w", i, err) - } - files = append(files, f) - - key := make([]byte, crypto.KeySize) - if _, err := io.ReadFull(f, key); err != nil { - closeAll() - return nil, nil, fmt.Errorf("read chunk key %d: %w", i, err) - } - - info, err := f.Stat() - if err != nil { - closeAll() - return nil, nil, fmt.Errorf("stat chunk %d: %w", i, err) - } - - bodySize := info.Size() - int64(crypto.KeySize) - if bodySize < 0 { - closeAll() - return nil, nil, fmt.Errorf("invalid chunk size %d", i) - } - - bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize) - - streamer, err := crypto.NewGCMStreamer(key) - if err != nil { - closeAll() - return nil, nil, fmt.Errorf("create streamer %d: %w", i, err) - } - - decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize) - decryptors = append(decryptors, decryptor) + key := make([]byte, crypto.KeySize) + if _, err := io.ReadFull(f, key); err != nil { + _ = f.Close() + return nil, fmt.Errorf("read chunk key %d: %w", idx, err) } - return decryptors, closeAll, nil + info, err := f.Stat() + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("stat chunk %d: %w", idx, err) + } + + bodySize := info.Size() - int64(crypto.KeySize) + if bodySize < 0 { + _ = f.Close() + return nil, fmt.Errorf("invalid chunk size %d", idx) + } + + bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize) + + streamer, err := crypto.NewGCMStreamer(key) + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("create streamer %d: %w", idx, err) + } + + decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize) + + return &chunkReadCloser{Decryptor: decryptor, f: f}, nil +} + +type chunkReadCloser struct { + *crypto.Decryptor + f *os.File +} + +func (c *chunkReadCloser) Close() error { + return c.f.Close() +} + +type SequentialChunkReader struct { + app *App + uid string + total int + currentIdx int + currentRC io.ReadCloser +} + +func (s *SequentialChunkReader) Read(p []byte) (n int, err error) { + if s.currentRC == nil { + if s.currentIdx >= s.total { + return 0, io.EOF + } + rc, err := s.app.openChunkDecryptor(s.uid, s.currentIdx) + if err != nil { + return 0, err + } + s.currentRC = rc + } + + n, err = s.currentRC.Read(p) + if err == io.EOF { + _ = s.currentRC.Close() + s.currentRC = nil + s.currentIdx++ + + if n > 0 { + return n, nil + } + return s.Read(p) + } + + return n, err +} + +func (s *SequentialChunkReader) Close() error { + if s.currentRC != nil { + return s.currentRC.Close() + } + return nil } func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { diff --git a/internal/app/storage_test.go b/internal/app/storage_test.go index 5906458..dde058d 100644 --- a/internal/app/storage_test.go +++ b/internal/app/storage_test.go @@ -172,7 +172,7 @@ func TestSaveChunk_EncryptsData(t *testing.T) { } } -func TestGetChunkDecryptors_RestoresData(t *testing.T) { +func TestSequentialChunkReader_RestoresData(t *testing.T) { tmpDir := t.TempDir() app := &App{ Conf: Config{StorageDir: tmpDir}, @@ -190,29 +190,24 @@ func TestGetChunkDecryptors_RestoresData(t *testing.T) { t.Fatal(err) } - decryptors, closeFn, err := app.getChunkDecryptors(uid, 2) - if err != nil { - t.Fatalf("getChunkDecryptors failed: %v", err) + reader := &SequentialChunkReader{ + app: app, + uid: uid, + total: 2, } - defer closeFn() + defer func() { + if err := reader.Close(); err != nil { + t.Errorf("Failed to close reader: %v", err) + } + }() - if len(decryptors) != 2 { - t.Fatalf("Expected 2 decryptors, got %d", len(decryptors)) + restored, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll failed: %v", err) } - buf1, err := io.ReadAll(decryptors[0]) - if err != nil { - t.Fatalf("Failed to read decryptor 1: %v", err) - } - if !bytes.Equal(buf1, data1) { - t.Errorf("Chunk 1 mismatch. Want %s, got %s", data1, buf1) - } - - buf2, err := io.ReadAll(decryptors[1]) - if err != nil { - t.Fatalf("Failed to read decryptor 2: %v", err) - } - if !bytes.Equal(buf2, data2) { - t.Errorf("Chunk 2 mismatch. Want %s, got %s", data2, buf2) + expected := append(data1, data2...) + if !bytes.Equal(restored, expected) { + t.Errorf("Restored data mismatch.\nWant: %s\nGot: %s", expected, restored) } } diff --git a/internal/app/upload.go b/internal/app/upload.go index 6beb196..5460c34 100644 --- a/internal/app/upload.go +++ b/internal/app/upload.go @@ -3,12 +3,14 @@ package app import ( "crypto/rand" "crypto/sha256" + "errors" "io" "net/http" "os" "path/filepath" "regexp" "strconv" + "strings" "github.com/skidoodle/safebin/internal/crypto" ) @@ -19,20 +21,36 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) limit := (app.Conf.MaxMB * MegaByte) + MegaByte request.Body = http.MaxBytesReader(writer, request.Body, limit) - file, header, err := request.FormFile("file") + mr, err := request.MultipartReader() if err != nil { - if err.Error() == "http: request body too large" { - app.SendError(writer, request, http.StatusRequestEntityTooLarge) - return - } app.SendError(writer, request, http.StatusBadRequest) return } - defer func() { - if closeErr := file.Close(); closeErr != nil { - app.Logger.Error("Failed to close upload file", "err", closeErr) + + var filename string + var partReader io.Reader + + for { + part, err := mr.NextPart() + if err == io.EOF { + break } - }() + if err != nil { + app.SendError(writer, request, http.StatusBadRequest) + return + } + + if part.FormName() == "file" { + filename = part.FileName() + partReader = part + break + } + } + + if partReader == nil { + app.SendError(writer, request, http.StatusBadRequest) + return + } tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*") if err != nil { @@ -58,11 +76,10 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) pr, pw := io.Pipe() hasher := sha256.New() - errChan := make(chan error, 1) go func() { - _, err := io.Copy(io.MultiWriter(hasher, pw), file) + _, err := io.Copy(io.MultiWriter(hasher, pw), partReader) _ = pw.CloseWithError(err) errChan <- err }() @@ -87,8 +104,12 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) } if err := <-errChan; err != nil { - app.Logger.Error("Failed to read/hash upload", "err", err) - app.SendError(writer, request, http.StatusRequestEntityTooLarge) + if errors.Is(err, http.ErrMissingBoundary) || strings.Contains(err.Error(), "request body too large") { + app.SendError(writer, request, http.StatusRequestEntityTooLarge) + } else { + app.Logger.Error("Failed to read/hash upload", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + } return } @@ -103,7 +124,7 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) info, _ := tmp.Stat() decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size()) - app.finalizeUpload(writer, request, decryptor, convergentKey, header.Filename) + app.finalizeUpload(writer, request, decryptor, convergentKey, filename) } func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { @@ -159,42 +180,41 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) return } - decryptors, closeAll, err := app.getChunkDecryptors(uid, total) - if err != nil { - app.Logger.Error("Failed to open chunks", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return - } defer func() { - closeAll() if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil { app.Logger.Error("Failed to remove chunk dir", "err", err) } }() - readers := make([]io.Reader, len(decryptors)) - for i, d := range decryptors { - readers[i] = d - } - hasher := sha256.New() - if _, err := io.Copy(hasher, io.MultiReader(readers...)); err != nil { - app.Logger.Error("Failed to hash chunks", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return + for i := range total { + rc, err := app.openChunkDecryptor(uid, i) + if err != nil { + app.Logger.Error("Failed to open chunk for hashing", "index", i, "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + if _, err := io.Copy(hasher, rc); err != nil { + _ = rc.Close() + app.Logger.Error("Failed to hash chunk", "index", i, "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + _ = rc.Close() } convergentKey := hasher.Sum(nil)[:crypto.KeySize] - for _, d := range decryptors { - if _, err := d.Seek(0, io.SeekStart); err != nil { - app.Logger.Error("Failed to reset chunk decryptor", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return - } + multiSrc := &SequentialChunkReader{ + app: app, + uid: uid, + total: total, } - - multiSrc := io.MultiReader(readers...) + defer func() { + if err := multiSrc.Close(); err != nil { + app.Logger.Error("Failed to close sequential reader", "uid", uid, "err", err) + } + }() app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename")) }