feat: implement encrypted chunked storage and convergent encryption

Signed-off-by: skidoodle <contact@albert.lol>
This commit is contained in:
2026-01-18 23:39:53 +01:00
parent 2d6a3ab216
commit 722dbaa6aa
4 changed files with 289 additions and 49 deletions
+99
View File
@@ -2,6 +2,7 @@ package app
import ( import (
"bytes" "bytes"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@@ -12,6 +13,8 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/skidoodle/safebin/internal/crypto"
) )
func setupTestApp(t *testing.T) (*App, string) { func setupTestApp(t *testing.T) (*App, string) {
@@ -176,6 +179,102 @@ func TestIntegration_ChunkedUpload(t *testing.T) {
} }
} }
func TestIntegration_ChunkedUpload_VerifyEncryption(t *testing.T) {
app, storageDir := setupTestApp(t)
server := httptest.NewServer(app.Routes())
defer server.Close()
uploadID := "securechunk123"
plaintext := []byte("This is a secret message that should be encrypted")
uploadChunk(t, server.URL, uploadID, 0, plaintext)
chunkPath := filepath.Join(storageDir, TempDirName, uploadID, "0")
encryptedData, err := os.ReadFile(chunkPath)
if err != nil {
t.Fatalf("Failed to read chunk file: %v", err)
}
if bytes.Contains(encryptedData, plaintext) {
t.Fatal("Chunk file contains plaintext data!")
}
if len(encryptedData) <= crypto.KeySize {
t.Fatalf("Chunk file too small: %d bytes", len(encryptedData))
}
key := encryptedData[:crypto.KeySize]
ciphertext := encryptedData[crypto.KeySize:]
streamer, err := crypto.NewGCMStreamer(key)
if err != nil {
t.Fatalf("Failed to create streamer: %v", err)
}
r := bytes.NewReader(ciphertext)
d := crypto.NewDecryptor(r, streamer.AEAD, int64(len(ciphertext)))
decrypted, err := io.ReadAll(d)
if err != nil {
t.Fatalf("Failed to decrypt chunk: %v", err)
}
if !bytes.Equal(decrypted, plaintext) {
t.Errorf("Decrypted data mismatch.\nWant: %s\nGot: %s", plaintext, decrypted)
}
}
func TestIntegration_Upload_VerifyEncryption(t *testing.T) {
app, storageDir := setupTestApp(t)
server := httptest.NewServer(app.Routes())
defer server.Close()
plaintext := []byte("Sensitive Data For Full Upload")
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, _ := writer.CreateFormFile("file", "secret.txt")
part.Write(plaintext)
writer.Close()
req, _ := http.NewRequest("POST", server.URL+"/", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
respBytes, _ := io.ReadAll(resp.Body)
slug := filepath.Base(strings.TrimSpace(string(respBytes)))
if len(slug) < SlugLength {
t.Fatalf("Invalid slug: %s", slug)
}
keyBase64 := slug[:SlugLength]
key, _ := base64.RawURLEncoding.DecodeString(keyBase64)
ext := filepath.Ext("secret.txt")
id := crypto.GetID(key, ext)
finalPath := filepath.Join(storageDir, id)
finalData, err := os.ReadFile(finalPath)
if err != nil {
t.Fatalf("Failed to read final file: %v", err)
}
if bytes.Contains(finalData, plaintext) {
t.Fatal("Final file contains plaintext!")
}
streamer, _ := crypto.NewGCMStreamer(key)
d := crypto.NewDecryptor(bytes.NewReader(finalData), streamer.AEAD, int64(len(finalData)))
decrypted, _ := io.ReadAll(d)
if !bytes.Equal(decrypted, plaintext) {
t.Error("Final file decryption failed")
}
}
func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) { func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) {
body := &bytes.Buffer{} body := &bytes.Buffer{}
writer := multipart.NewWriter(body) writer := multipart.NewWriter(body)
+50 -5
View File
@@ -2,6 +2,7 @@ package app
import ( import (
"context" "context"
"crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -48,15 +49,30 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error {
} }
}() }()
if _, err := io.Copy(dest, src); err != nil { key := make([]byte, crypto.KeySize)
return fmt.Errorf("copy chunk: %w", err) if _, err := rand.Read(key); err != nil {
return fmt.Errorf("generate chunk key: %w", err)
}
if _, err := dest.Write(key); err != nil {
return fmt.Errorf("write chunk key: %w", err)
}
streamer, err := crypto.NewGCMStreamer(key)
if err != nil {
return fmt.Errorf("create streamer: %w", err)
}
if err := streamer.EncryptStream(dest, src); err != nil {
return fmt.Errorf("encrypt chunk: %w", err)
} }
return nil return nil
} }
func (app *App) openChunkFiles(uid string, total int) ([]*os.File, error) { func (app *App) getChunkDecryptors(uid string, total int) ([]io.ReadSeeker, func(), error) {
files := make([]*os.File, 0, total) files := make([]*os.File, 0, total)
decryptors := make([]io.ReadSeeker, 0, total)
closeAll := func() { closeAll := func() {
for _, f := range files { for _, f := range files {
@@ -69,12 +85,41 @@ func (app *App) openChunkFiles(uid string, total int) ([]*os.File, error) {
f, err := os.Open(partPath) f, err := os.Open(partPath)
if err != nil { if err != nil {
closeAll() closeAll()
return nil, fmt.Errorf("open chunk %d: %w", i, err) return nil, nil, fmt.Errorf("open chunk %d: %w", i, err)
} }
files = append(files, f) files = append(files, f)
key := make([]byte, crypto.KeySize)
if _, err := io.ReadFull(f, key); err != nil {
closeAll()
return nil, nil, fmt.Errorf("read chunk key %d: %w", i, err)
} }
return files, nil info, err := f.Stat()
if err != nil {
closeAll()
return nil, nil, fmt.Errorf("stat chunk %d: %w", i, err)
}
bodySize := info.Size() - int64(crypto.KeySize)
if bodySize < 0 {
closeAll()
return nil, nil, fmt.Errorf("invalid chunk size %d", i)
}
bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize)
streamer, err := crypto.NewGCMStreamer(key)
if err != nil {
closeAll()
return nil, nil, fmt.Errorf("create streamer %d: %w", i, err)
}
decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize)
decryptors = append(decryptors, decryptor)
}
return decryptors, closeAll, nil
} }
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
+85
View File
@@ -1,12 +1,16 @@
package app package app
import ( import (
"bytes"
"crypto/rand"
"encoding/json" "encoding/json"
"io"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
"github.com/skidoodle/safebin/internal/crypto"
"go.etcd.io/bbolt" "go.etcd.io/bbolt"
) )
@@ -131,3 +135,84 @@ func TestCleanup_ExpiredStorage(t *testing.T) {
t.Fatalf("DB View failed: %v", err) t.Fatalf("DB View failed: %v", err)
} }
} }
func TestSaveChunk_EncryptsData(t *testing.T) {
tmpDir := t.TempDir()
app := &App{
Conf: Config{StorageDir: tmpDir},
Logger: discardLogger(),
}
uid := "test-encrypt-chunk"
plaintext := make([]byte, 1024)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
if err := app.saveChunk(uid, 0, bytes.NewReader(plaintext)); err != nil {
t.Fatalf("saveChunk failed: %v", err)
}
path := filepath.Join(tmpDir, TempDirName, uid, "0")
fileData, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if bytes.Equal(fileData, plaintext) {
t.Fatal("Chunk stored as plaintext!")
}
if bytes.Contains(fileData, plaintext) {
t.Fatal("Chunk contains plaintext!")
}
expectedSize := crypto.KeySize + len(plaintext) + 16
if len(fileData) != expectedSize {
t.Errorf("Unexpected file size. Want %d, got %d", expectedSize, len(fileData))
}
}
func TestGetChunkDecryptors_RestoresData(t *testing.T) {
tmpDir := t.TempDir()
app := &App{
Conf: Config{StorageDir: tmpDir},
Logger: discardLogger(),
}
uid := "test-restore"
data1 := []byte("chunk one data")
data2 := []byte("chunk two data")
if err := app.saveChunk(uid, 0, bytes.NewReader(data1)); err != nil {
t.Fatal(err)
}
if err := app.saveChunk(uid, 1, bytes.NewReader(data2)); err != nil {
t.Fatal(err)
}
decryptors, closeFn, err := app.getChunkDecryptors(uid, 2)
if err != nil {
t.Fatalf("getChunkDecryptors failed: %v", err)
}
defer closeFn()
if len(decryptors) != 2 {
t.Fatalf("Expected 2 decryptors, got %d", len(decryptors))
}
buf1, err := io.ReadAll(decryptors[0])
if err != nil {
t.Fatalf("Failed to read decryptor 1: %v", err)
}
if !bytes.Equal(buf1, data1) {
t.Errorf("Chunk 1 mismatch. Want %s, got %s", data1, buf1)
}
buf2, err := io.ReadAll(decryptors[1])
if err != nil {
t.Fatalf("Failed to read decryptor 2: %v", err)
}
if !bytes.Equal(buf2, data2) {
t.Errorf("Chunk 2 mismatch. Want %s, got %s", data2, buf2)
}
}
+54 -43
View File
@@ -1,6 +1,8 @@
package app package app
import ( import (
"crypto/rand"
"crypto/sha256"
"io" "io"
"net/http" "net/http"
"os" "os"
@@ -18,17 +20,14 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
request.Body = http.MaxBytesReader(writer, request.Body, limit) request.Body = http.MaxBytesReader(writer, request.Body, limit)
file, header, err := request.FormFile("file") file, header, err := request.FormFile("file")
if err != nil { if err != nil {
if err.Error() == "http: request body too large" { if err.Error() == "http: request body too large" {
app.SendError(writer, request, http.StatusRequestEntityTooLarge) app.SendError(writer, request, http.StatusRequestEntityTooLarge)
return return
} }
app.SendError(writer, request, http.StatusBadRequest) app.SendError(writer, request, http.StatusBadRequest)
return return
} }
defer func() { defer func() {
if closeErr := file.Close(); closeErr != nil { if closeErr := file.Close(); closeErr != nil {
app.Logger.Error("Failed to close upload file", "err", closeErr) app.Logger.Error("Failed to close upload file", "err", closeErr)
@@ -36,45 +35,60 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
}() }()
tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*") tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*")
if err != nil { if err != nil {
app.Logger.Error("Failed to create temp file", "err", err) app.Logger.Error("Failed to create temp file", "err", err)
app.SendError(writer, request, http.StatusInternalServerError) app.SendError(writer, request, http.StatusInternalServerError)
return return
} }
tmpPath := tmp.Name() tmpPath := tmp.Name()
defer func() { defer func() {
_ = tmp.Close()
if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) { if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) {
app.Logger.Error("Failed to remove temp file", "err", removeErr) app.Logger.Error("Failed to remove temp file", "err", removeErr)
} }
}() }()
defer func() { ephemeralKey := make([]byte, crypto.KeySize)
if closeErr := tmp.Close(); closeErr != nil { if _, err := rand.Read(ephemeralKey); err != nil {
app.Logger.Error("Failed to close temp file", "err", closeErr) app.Logger.Error("Failed to generate ephemeral key", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
} }
pr, pw := io.Pipe()
hasher := sha256.New()
errChan := make(chan error, 1)
go func() {
_, err := io.Copy(io.MultiWriter(hasher, pw), file)
_ = pw.CloseWithError(err)
errChan <- err
}() }()
if _, err := io.Copy(tmp, file); err != nil { defer pr.Close()
app.Logger.Error("Failed to write temp file", "err", err)
streamer, err := crypto.NewGCMStreamer(ephemeralKey)
if err != nil {
app.Logger.Error("Failed to create streamer", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
if err := streamer.EncryptStream(tmp, pr); err != nil {
app.Logger.Error("Failed to encrypt stream", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
if err := <-errChan; err != nil {
app.Logger.Error("Failed to read/hash upload", "err", err)
app.SendError(writer, request, http.StatusRequestEntityTooLarge) app.SendError(writer, request, http.StatusRequestEntityTooLarge)
return return
} }
if _, err := tmp.Seek(0, 0); err != nil { convergentKey := hasher.Sum(nil)[:crypto.KeySize]
app.Logger.Error("Seek failed", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
key, err := crypto.DeriveKey(tmp)
if err != nil {
app.Logger.Error("Key derivation failed", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
if _, err := tmp.Seek(0, 0); err != nil { if _, err := tmp.Seek(0, 0); err != nil {
app.Logger.Error("Seek failed", "err", err) app.Logger.Error("Seek failed", "err", err)
@@ -82,14 +96,16 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
return return
} }
app.finalizeUpload(writer, request, tmp, key, header.Filename) info, _ := tmp.Stat()
decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size())
app.finalizeUpload(writer, request, decryptor, convergentKey, header.Filename)
} }
func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
request.Body = http.MaxBytesReader(writer, request.Body, MaxRequestOverhead) request.Body = http.MaxBytesReader(writer, request.Body, MaxRequestOverhead)
uid := request.FormValue("upload_id") uid := request.FormValue("upload_id")
idx, err := strconv.Atoi(request.FormValue("index")) idx, err := strconv.Atoi(request.FormValue("index"))
if err != nil { if err != nil {
app.SendError(writer, request, http.StatusBadRequest) app.SendError(writer, request, http.StatusBadRequest)
@@ -104,17 +120,14 @@ func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
} }
file, _, err := request.FormFile("chunk") file, _, err := request.FormFile("chunk")
if err != nil { if err != nil {
if err.Error() == "http: request body too large" { if err.Error() == "http: request body too large" {
app.SendError(writer, request, http.StatusRequestEntityTooLarge) app.SendError(writer, request, http.StatusRequestEntityTooLarge)
return return
} }
app.SendError(writer, request, http.StatusBadRequest) app.SendError(writer, request, http.StatusBadRequest)
return return
} }
defer func() { defer func() {
if closeErr := file.Close(); closeErr != nil { if closeErr := file.Close(); closeErr != nil {
app.Logger.Error("Failed to close chunk file", "err", closeErr) app.Logger.Error("Failed to close chunk file", "err", closeErr)
@@ -129,7 +142,6 @@ func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) { func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) {
uid := request.FormValue("upload_id") uid := request.FormValue("upload_id")
total, err := strconv.Atoi(request.FormValue("total")) total, err := strconv.Atoi(request.FormValue("total"))
if err != nil { if err != nil {
app.SendError(writer, request, http.StatusBadRequest) app.SendError(writer, request, http.StatusBadRequest)
@@ -143,37 +155,36 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request)
return return
} }
files, err := app.openChunkFiles(uid, total) decryptors, closeAll, err := app.getChunkDecryptors(uid, total)
if err != nil { if err != nil {
app.Logger.Error("Failed to open chunks", "err", err) app.Logger.Error("Failed to open chunks", "err", err)
app.SendError(writer, request, http.StatusInternalServerError) app.SendError(writer, request, http.StatusInternalServerError)
return return
} }
defer func() { defer func() {
for _, f := range files { closeAll()
_ = f.Close()
}
if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil { if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil {
app.Logger.Error("Failed to remove chunk dir", "err", err) app.Logger.Error("Failed to remove chunk dir", "err", err)
} }
}() }()
readers := make([]io.Reader, len(files)) readers := make([]io.Reader, len(decryptors))
for i, f := range files { for i, d := range decryptors {
readers[i] = f readers[i] = d
} }
key, err := crypto.DeriveKey(io.MultiReader(readers...)) hasher := sha256.New()
if err != nil { if _, err := io.Copy(hasher, io.MultiReader(readers...)); err != nil {
app.Logger.Error("Key derivation failed", "err", err) app.Logger.Error("Failed to hash chunks", "err", err)
app.SendError(writer, request, http.StatusInternalServerError) app.SendError(writer, request, http.StatusInternalServerError)
return return
} }
for _, f := range files { convergentKey := hasher.Sum(nil)[:crypto.KeySize]
if _, err := f.Seek(0, 0); err != nil {
app.Logger.Error("Failed to reset chunk", "err", err) for _, d := range decryptors {
if _, err := d.Seek(0, io.SeekStart); err != nil {
app.Logger.Error("Failed to reset chunk decryptor", "err", err)
app.SendError(writer, request, http.StatusInternalServerError) app.SendError(writer, request, http.StatusInternalServerError)
return return
} }
@@ -181,7 +192,7 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request)
multiSrc := io.MultiReader(readers...) multiSrc := io.MultiReader(readers...)
app.finalizeUpload(writer, request, multiSrc, key, request.FormValue("filename")) app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename"))
} }
func (app *App) finalizeUpload(writer http.ResponseWriter, request *http.Request, src io.Reader, key []byte, filename string) { func (app *App) finalizeUpload(writer http.ResponseWriter, request *http.Request, src io.Reader, key []byte, filename string) {