From 722dbaa6aac9fd8f1849da5e03145d6983aad2ef Mon Sep 17 00:00:00 2001 From: skidoodle Date: Sun, 18 Jan 2026 23:39:53 +0100 Subject: [PATCH] feat: implement encrypted chunked storage and convergent encryption Signed-off-by: skidoodle --- internal/app/server_test.go | 99 ++++++++++++++++++++++++++++++++++++ internal/app/storage.go | 55 ++++++++++++++++++-- internal/app/storage_test.go | 85 +++++++++++++++++++++++++++++++ internal/app/upload.go | 99 ++++++++++++++++++++---------------- 4 files changed, 289 insertions(+), 49 deletions(-) diff --git a/internal/app/server_test.go b/internal/app/server_test.go index 7f774df..651be2a 100644 --- a/internal/app/server_test.go +++ b/internal/app/server_test.go @@ -2,6 +2,7 @@ package app import ( "bytes" + "encoding/base64" "fmt" "io" "log/slog" @@ -12,6 +13,8 @@ import ( "path/filepath" "strings" "testing" + + "github.com/skidoodle/safebin/internal/crypto" ) func setupTestApp(t *testing.T) (*App, string) { @@ -176,6 +179,102 @@ func TestIntegration_ChunkedUpload(t *testing.T) { } } +func TestIntegration_ChunkedUpload_VerifyEncryption(t *testing.T) { + app, storageDir := setupTestApp(t) + server := httptest.NewServer(app.Routes()) + defer server.Close() + + uploadID := "securechunk123" + plaintext := []byte("This is a secret message that should be encrypted") + + uploadChunk(t, server.URL, uploadID, 0, plaintext) + + chunkPath := filepath.Join(storageDir, TempDirName, uploadID, "0") + encryptedData, err := os.ReadFile(chunkPath) + if err != nil { + t.Fatalf("Failed to read chunk file: %v", err) + } + + if bytes.Contains(encryptedData, plaintext) { + t.Fatal("Chunk file contains plaintext data!") + } + + if len(encryptedData) <= crypto.KeySize { + t.Fatalf("Chunk file too small: %d bytes", len(encryptedData)) + } + + key := encryptedData[:crypto.KeySize] + ciphertext := encryptedData[crypto.KeySize:] + + streamer, err := crypto.NewGCMStreamer(key) + if err != nil { + t.Fatalf("Failed to create streamer: %v", err) + } + + r := bytes.NewReader(ciphertext) + d := crypto.NewDecryptor(r, streamer.AEAD, int64(len(ciphertext))) + + decrypted, err := io.ReadAll(d) + if err != nil { + t.Fatalf("Failed to decrypt chunk: %v", err) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("Decrypted data mismatch.\nWant: %s\nGot: %s", plaintext, decrypted) + } +} + +func TestIntegration_Upload_VerifyEncryption(t *testing.T) { + app, storageDir := setupTestApp(t) + server := httptest.NewServer(app.Routes()) + defer server.Close() + + plaintext := []byte("Sensitive Data For Full Upload") + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, _ := writer.CreateFormFile("file", "secret.txt") + part.Write(plaintext) + writer.Close() + + req, _ := http.NewRequest("POST", server.URL+"/", body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + respBytes, _ := io.ReadAll(resp.Body) + slug := filepath.Base(strings.TrimSpace(string(respBytes))) + + if len(slug) < SlugLength { + t.Fatalf("Invalid slug: %s", slug) + } + keyBase64 := slug[:SlugLength] + key, _ := base64.RawURLEncoding.DecodeString(keyBase64) + ext := filepath.Ext("secret.txt") + id := crypto.GetID(key, ext) + + finalPath := filepath.Join(storageDir, id) + finalData, err := os.ReadFile(finalPath) + if err != nil { + t.Fatalf("Failed to read final file: %v", err) + } + + if bytes.Contains(finalData, plaintext) { + t.Fatal("Final file contains plaintext!") + } + + streamer, _ := crypto.NewGCMStreamer(key) + d := crypto.NewDecryptor(bytes.NewReader(finalData), streamer.AEAD, int64(len(finalData))) + decrypted, _ := io.ReadAll(d) + + if !bytes.Equal(decrypted, plaintext) { + t.Error("Final file decryption failed") + } +} + func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) diff --git a/internal/app/storage.go b/internal/app/storage.go index d4ae71a..018a483 100644 --- a/internal/app/storage.go +++ b/internal/app/storage.go @@ -2,6 +2,7 @@ package app import ( "context" + "crypto/rand" "encoding/json" "fmt" "io" @@ -48,15 +49,30 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error { } }() - if _, err := io.Copy(dest, src); err != nil { - return fmt.Errorf("copy chunk: %w", err) + key := make([]byte, crypto.KeySize) + if _, err := rand.Read(key); err != nil { + return fmt.Errorf("generate chunk key: %w", err) + } + + if _, err := dest.Write(key); err != nil { + return fmt.Errorf("write chunk key: %w", err) + } + + streamer, err := crypto.NewGCMStreamer(key) + if err != nil { + return fmt.Errorf("create streamer: %w", err) + } + + if err := streamer.EncryptStream(dest, src); err != nil { + return fmt.Errorf("encrypt chunk: %w", err) } return nil } -func (app *App) openChunkFiles(uid string, total int) ([]*os.File, error) { +func (app *App) getChunkDecryptors(uid string, total int) ([]io.ReadSeeker, func(), error) { files := make([]*os.File, 0, total) + decryptors := make([]io.ReadSeeker, 0, total) closeAll := func() { for _, f := range files { @@ -69,12 +85,41 @@ func (app *App) openChunkFiles(uid string, total int) ([]*os.File, error) { f, err := os.Open(partPath) if err != nil { closeAll() - return nil, fmt.Errorf("open chunk %d: %w", i, err) + return nil, nil, fmt.Errorf("open chunk %d: %w", i, err) } files = append(files, f) + + key := make([]byte, crypto.KeySize) + if _, err := io.ReadFull(f, key); err != nil { + closeAll() + return nil, nil, fmt.Errorf("read chunk key %d: %w", i, err) + } + + info, err := f.Stat() + if err != nil { + closeAll() + return nil, nil, fmt.Errorf("stat chunk %d: %w", i, err) + } + + bodySize := info.Size() - int64(crypto.KeySize) + if bodySize < 0 { + closeAll() + return nil, nil, fmt.Errorf("invalid chunk size %d", i) + } + + bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize) + + streamer, err := crypto.NewGCMStreamer(key) + if err != nil { + closeAll() + return nil, nil, fmt.Errorf("create streamer %d: %w", i, err) + } + + decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize) + decryptors = append(decryptors, decryptor) } - return files, nil + return decryptors, closeAll, nil } func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { diff --git a/internal/app/storage_test.go b/internal/app/storage_test.go index 80c5c9e..5906458 100644 --- a/internal/app/storage_test.go +++ b/internal/app/storage_test.go @@ -1,12 +1,16 @@ package app import ( + "bytes" + "crypto/rand" "encoding/json" + "io" "os" "path/filepath" "testing" "time" + "github.com/skidoodle/safebin/internal/crypto" "go.etcd.io/bbolt" ) @@ -131,3 +135,84 @@ func TestCleanup_ExpiredStorage(t *testing.T) { t.Fatalf("DB View failed: %v", err) } } + +func TestSaveChunk_EncryptsData(t *testing.T) { + tmpDir := t.TempDir() + app := &App{ + Conf: Config{StorageDir: tmpDir}, + Logger: discardLogger(), + } + + uid := "test-encrypt-chunk" + plaintext := make([]byte, 1024) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + if err := app.saveChunk(uid, 0, bytes.NewReader(plaintext)); err != nil { + t.Fatalf("saveChunk failed: %v", err) + } + + path := filepath.Join(tmpDir, TempDirName, uid, "0") + fileData, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + if bytes.Equal(fileData, plaintext) { + t.Fatal("Chunk stored as plaintext!") + } + if bytes.Contains(fileData, plaintext) { + t.Fatal("Chunk contains plaintext!") + } + + expectedSize := crypto.KeySize + len(plaintext) + 16 + if len(fileData) != expectedSize { + t.Errorf("Unexpected file size. Want %d, got %d", expectedSize, len(fileData)) + } +} + +func TestGetChunkDecryptors_RestoresData(t *testing.T) { + tmpDir := t.TempDir() + app := &App{ + Conf: Config{StorageDir: tmpDir}, + Logger: discardLogger(), + } + + uid := "test-restore" + data1 := []byte("chunk one data") + data2 := []byte("chunk two data") + + if err := app.saveChunk(uid, 0, bytes.NewReader(data1)); err != nil { + t.Fatal(err) + } + if err := app.saveChunk(uid, 1, bytes.NewReader(data2)); err != nil { + t.Fatal(err) + } + + decryptors, closeFn, err := app.getChunkDecryptors(uid, 2) + if err != nil { + t.Fatalf("getChunkDecryptors failed: %v", err) + } + defer closeFn() + + if len(decryptors) != 2 { + t.Fatalf("Expected 2 decryptors, got %d", len(decryptors)) + } + + buf1, err := io.ReadAll(decryptors[0]) + if err != nil { + t.Fatalf("Failed to read decryptor 1: %v", err) + } + if !bytes.Equal(buf1, data1) { + t.Errorf("Chunk 1 mismatch. Want %s, got %s", data1, buf1) + } + + buf2, err := io.ReadAll(decryptors[1]) + if err != nil { + t.Fatalf("Failed to read decryptor 2: %v", err) + } + if !bytes.Equal(buf2, data2) { + t.Errorf("Chunk 2 mismatch. Want %s, got %s", data2, buf2) + } +} diff --git a/internal/app/upload.go b/internal/app/upload.go index dce2dad..978ee9d 100644 --- a/internal/app/upload.go +++ b/internal/app/upload.go @@ -1,6 +1,8 @@ package app import ( + "crypto/rand" + "crypto/sha256" "io" "net/http" "os" @@ -18,17 +20,14 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) request.Body = http.MaxBytesReader(writer, request.Body, limit) file, header, err := request.FormFile("file") - if err != nil { if err.Error() == "http: request body too large" { app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - app.SendError(writer, request, http.StatusBadRequest) return } - defer func() { if closeErr := file.Close(); closeErr != nil { app.Logger.Error("Failed to close upload file", "err", closeErr) @@ -36,45 +35,60 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) }() tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*") - if err != nil { app.Logger.Error("Failed to create temp file", "err", err) app.SendError(writer, request, http.StatusInternalServerError) return } - tmpPath := tmp.Name() defer func() { + _ = tmp.Close() if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) { app.Logger.Error("Failed to remove temp file", "err", removeErr) } }() - defer func() { - if closeErr := tmp.Close(); closeErr != nil { - app.Logger.Error("Failed to close temp file", "err", closeErr) - } + ephemeralKey := make([]byte, crypto.KeySize) + if _, err := rand.Read(ephemeralKey); err != nil { + app.Logger.Error("Failed to generate ephemeral key", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + pr, pw := io.Pipe() + hasher := sha256.New() + + errChan := make(chan error, 1) + + go func() { + _, err := io.Copy(io.MultiWriter(hasher, pw), file) + _ = pw.CloseWithError(err) + errChan <- err }() - if _, err := io.Copy(tmp, file); err != nil { - app.Logger.Error("Failed to write temp file", "err", err) + defer pr.Close() + + streamer, err := crypto.NewGCMStreamer(ephemeralKey) + if err != nil { + app.Logger.Error("Failed to create streamer", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + if err := streamer.EncryptStream(tmp, pr); err != nil { + app.Logger.Error("Failed to encrypt stream", "err", err) + app.SendError(writer, request, http.StatusInternalServerError) + return + } + + if err := <-errChan; err != nil { + app.Logger.Error("Failed to read/hash upload", "err", err) app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - if _, err := tmp.Seek(0, 0); err != nil { - app.Logger.Error("Seek failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return - } - - key, err := crypto.DeriveKey(tmp) - if err != nil { - app.Logger.Error("Key derivation failed", "err", err) - app.SendError(writer, request, http.StatusInternalServerError) - return - } + convergentKey := hasher.Sum(nil)[:crypto.KeySize] if _, err := tmp.Seek(0, 0); err != nil { app.Logger.Error("Seek failed", "err", err) @@ -82,14 +96,16 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request) return } - app.finalizeUpload(writer, request, tmp, key, header.Filename) + info, _ := tmp.Stat() + decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size()) + + app.finalizeUpload(writer, request, decryptor, convergentKey, header.Filename) } func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { request.Body = http.MaxBytesReader(writer, request.Body, MaxRequestOverhead) uid := request.FormValue("upload_id") - idx, err := strconv.Atoi(request.FormValue("index")) if err != nil { app.SendError(writer, request, http.StatusBadRequest) @@ -104,17 +120,14 @@ func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { } file, _, err := request.FormFile("chunk") - if err != nil { if err.Error() == "http: request body too large" { app.SendError(writer, request, http.StatusRequestEntityTooLarge) return } - app.SendError(writer, request, http.StatusBadRequest) return } - defer func() { if closeErr := file.Close(); closeErr != nil { app.Logger.Error("Failed to close chunk file", "err", closeErr) @@ -129,7 +142,6 @@ func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) { uid := request.FormValue("upload_id") - total, err := strconv.Atoi(request.FormValue("total")) if err != nil { app.SendError(writer, request, http.StatusBadRequest) @@ -143,37 +155,36 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) return } - files, err := app.openChunkFiles(uid, total) + decryptors, closeAll, err := app.getChunkDecryptors(uid, total) if err != nil { app.Logger.Error("Failed to open chunks", "err", err) app.SendError(writer, request, http.StatusInternalServerError) return } - defer func() { - for _, f := range files { - _ = f.Close() - } + closeAll() if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil { app.Logger.Error("Failed to remove chunk dir", "err", err) } }() - readers := make([]io.Reader, len(files)) - for i, f := range files { - readers[i] = f + readers := make([]io.Reader, len(decryptors)) + for i, d := range decryptors { + readers[i] = d } - key, err := crypto.DeriveKey(io.MultiReader(readers...)) - if err != nil { - app.Logger.Error("Key derivation failed", "err", err) + hasher := sha256.New() + if _, err := io.Copy(hasher, io.MultiReader(readers...)); err != nil { + app.Logger.Error("Failed to hash chunks", "err", err) app.SendError(writer, request, http.StatusInternalServerError) return } - for _, f := range files { - if _, err := f.Seek(0, 0); err != nil { - app.Logger.Error("Failed to reset chunk", "err", err) + convergentKey := hasher.Sum(nil)[:crypto.KeySize] + + for _, d := range decryptors { + if _, err := d.Seek(0, io.SeekStart); err != nil { + app.Logger.Error("Failed to reset chunk decryptor", "err", err) app.SendError(writer, request, http.StatusInternalServerError) return } @@ -181,7 +192,7 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) multiSrc := io.MultiReader(readers...) - app.finalizeUpload(writer, request, multiSrc, key, request.FormValue("filename")) + app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename")) } func (app *App) finalizeUpload(writer http.ResponseWriter, request *http.Request, src io.Reader, key []byte, filename string) {