feat: implement sequential chunk reading and decryption

Signed-off-by: skidoodle <contact@albert.lol>
This commit is contained in:
2026-01-22 04:37:20 +01:00
parent 5c13d24736
commit 577c4b67f6
3 changed files with 157 additions and 105 deletions
+83 -46
View File
@@ -70,56 +70,93 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error {
return nil return nil
} }
func (app *App) getChunkDecryptors(uid string, total int) ([]io.ReadSeeker, func(), error) { func (app *App) openChunkDecryptor(uid string, idx int) (io.ReadCloser, error) {
files := make([]*os.File, 0, total) partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(idx))
decryptors := make([]io.ReadSeeker, 0, total) f, err := os.Open(partPath)
if err != nil {
closeAll := func() { return nil, fmt.Errorf("open chunk %d: %w", idx, err)
for _, f := range files {
_ = f.Close()
}
} }
for i := range total { key := make([]byte, crypto.KeySize)
partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(i)) if _, err := io.ReadFull(f, key); err != nil {
f, err := os.Open(partPath) _ = f.Close()
if err != nil { return nil, fmt.Errorf("read chunk key %d: %w", idx, err)
closeAll()
return nil, nil, fmt.Errorf("open chunk %d: %w", i, err)
}
files = append(files, f)
key := make([]byte, crypto.KeySize)
if _, err := io.ReadFull(f, key); err != nil {
closeAll()
return nil, nil, fmt.Errorf("read chunk key %d: %w", i, err)
}
info, err := f.Stat()
if err != nil {
closeAll()
return nil, nil, fmt.Errorf("stat chunk %d: %w", i, err)
}
bodySize := info.Size() - int64(crypto.KeySize)
if bodySize < 0 {
closeAll()
return nil, nil, fmt.Errorf("invalid chunk size %d", i)
}
bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize)
streamer, err := crypto.NewGCMStreamer(key)
if err != nil {
closeAll()
return nil, nil, fmt.Errorf("create streamer %d: %w", i, err)
}
decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize)
decryptors = append(decryptors, decryptor)
} }
return decryptors, closeAll, nil info, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("stat chunk %d: %w", idx, err)
}
bodySize := info.Size() - int64(crypto.KeySize)
if bodySize < 0 {
_ = f.Close()
return nil, fmt.Errorf("invalid chunk size %d", idx)
}
bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize)
streamer, err := crypto.NewGCMStreamer(key)
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("create streamer %d: %w", idx, err)
}
decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize)
return &chunkReadCloser{Decryptor: decryptor, f: f}, nil
}
type chunkReadCloser struct {
*crypto.Decryptor
f *os.File
}
func (c *chunkReadCloser) Close() error {
return c.f.Close()
}
type SequentialChunkReader struct {
app *App
uid string
total int
currentIdx int
currentRC io.ReadCloser
}
func (s *SequentialChunkReader) Read(p []byte) (n int, err error) {
if s.currentRC == nil {
if s.currentIdx >= s.total {
return 0, io.EOF
}
rc, err := s.app.openChunkDecryptor(s.uid, s.currentIdx)
if err != nil {
return 0, err
}
s.currentRC = rc
}
n, err = s.currentRC.Read(p)
if err == io.EOF {
_ = s.currentRC.Close()
s.currentRC = nil
s.currentIdx++
if n > 0 {
return n, nil
}
return s.Read(p)
}
return n, err
}
func (s *SequentialChunkReader) Close() error {
if s.currentRC != nil {
return s.currentRC.Close()
}
return nil
} }
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error { func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
+16 -21
View File
@@ -172,7 +172,7 @@ func TestSaveChunk_EncryptsData(t *testing.T) {
} }
} }
func TestGetChunkDecryptors_RestoresData(t *testing.T) { func TestSequentialChunkReader_RestoresData(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
app := &App{ app := &App{
Conf: Config{StorageDir: tmpDir}, Conf: Config{StorageDir: tmpDir},
@@ -190,29 +190,24 @@ func TestGetChunkDecryptors_RestoresData(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
decryptors, closeFn, err := app.getChunkDecryptors(uid, 2) reader := &SequentialChunkReader{
if err != nil { app: app,
t.Fatalf("getChunkDecryptors failed: %v", err) uid: uid,
total: 2,
} }
defer closeFn() defer func() {
if err := reader.Close(); err != nil {
t.Errorf("Failed to close reader: %v", err)
}
}()
if len(decryptors) != 2 { restored, err := io.ReadAll(reader)
t.Fatalf("Expected 2 decryptors, got %d", len(decryptors)) if err != nil {
t.Fatalf("ReadAll failed: %v", err)
} }
buf1, err := io.ReadAll(decryptors[0]) expected := append(data1, data2...)
if err != nil { if !bytes.Equal(restored, expected) {
t.Fatalf("Failed to read decryptor 1: %v", err) t.Errorf("Restored data mismatch.\nWant: %s\nGot: %s", expected, restored)
}
if !bytes.Equal(buf1, data1) {
t.Errorf("Chunk 1 mismatch. Want %s, got %s", data1, buf1)
}
buf2, err := io.ReadAll(decryptors[1])
if err != nil {
t.Fatalf("Failed to read decryptor 2: %v", err)
}
if !bytes.Equal(buf2, data2) {
t.Errorf("Chunk 2 mismatch. Want %s, got %s", data2, buf2)
} }
} }
+58 -38
View File
@@ -3,12 +3,14 @@ package app
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors"
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"github.com/skidoodle/safebin/internal/crypto" "github.com/skidoodle/safebin/internal/crypto"
) )
@@ -19,20 +21,36 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
limit := (app.Conf.MaxMB * MegaByte) + MegaByte limit := (app.Conf.MaxMB * MegaByte) + MegaByte
request.Body = http.MaxBytesReader(writer, request.Body, limit) request.Body = http.MaxBytesReader(writer, request.Body, limit)
file, header, err := request.FormFile("file") mr, err := request.MultipartReader()
if err != nil { if err != nil {
if err.Error() == "http: request body too large" {
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
return
}
app.SendError(writer, request, http.StatusBadRequest) app.SendError(writer, request, http.StatusBadRequest)
return return
} }
defer func() {
if closeErr := file.Close(); closeErr != nil { var filename string
app.Logger.Error("Failed to close upload file", "err", closeErr) var partReader io.Reader
for {
part, err := mr.NextPart()
if err == io.EOF {
break
} }
}() if err != nil {
app.SendError(writer, request, http.StatusBadRequest)
return
}
if part.FormName() == "file" {
filename = part.FileName()
partReader = part
break
}
}
if partReader == nil {
app.SendError(writer, request, http.StatusBadRequest)
return
}
tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*") tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*")
if err != nil { if err != nil {
@@ -58,11 +76,10 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
pr, pw := io.Pipe() pr, pw := io.Pipe()
hasher := sha256.New() hasher := sha256.New()
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
_, err := io.Copy(io.MultiWriter(hasher, pw), file) _, err := io.Copy(io.MultiWriter(hasher, pw), partReader)
_ = pw.CloseWithError(err) _ = pw.CloseWithError(err)
errChan <- err errChan <- err
}() }()
@@ -87,8 +104,12 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
} }
if err := <-errChan; err != nil { if err := <-errChan; err != nil {
app.Logger.Error("Failed to read/hash upload", "err", err) if errors.Is(err, http.ErrMissingBoundary) || strings.Contains(err.Error(), "request body too large") {
app.SendError(writer, request, http.StatusRequestEntityTooLarge) app.SendError(writer, request, http.StatusRequestEntityTooLarge)
} else {
app.Logger.Error("Failed to read/hash upload", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
}
return return
} }
@@ -103,7 +124,7 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
info, _ := tmp.Stat() info, _ := tmp.Stat()
decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size()) decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size())
app.finalizeUpload(writer, request, decryptor, convergentKey, header.Filename) app.finalizeUpload(writer, request, decryptor, convergentKey, filename)
} }
func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) { func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
@@ -159,42 +180,41 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request)
return return
} }
decryptors, closeAll, err := app.getChunkDecryptors(uid, total)
if err != nil {
app.Logger.Error("Failed to open chunks", "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
defer func() { defer func() {
closeAll()
if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil { if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil {
app.Logger.Error("Failed to remove chunk dir", "err", err) app.Logger.Error("Failed to remove chunk dir", "err", err)
} }
}() }()
readers := make([]io.Reader, len(decryptors))
for i, d := range decryptors {
readers[i] = d
}
hasher := sha256.New() hasher := sha256.New()
if _, err := io.Copy(hasher, io.MultiReader(readers...)); err != nil { for i := range total {
app.Logger.Error("Failed to hash chunks", "err", err) rc, err := app.openChunkDecryptor(uid, i)
app.SendError(writer, request, http.StatusInternalServerError) if err != nil {
return app.Logger.Error("Failed to open chunk for hashing", "index", i, "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
if _, err := io.Copy(hasher, rc); err != nil {
_ = rc.Close()
app.Logger.Error("Failed to hash chunk", "index", i, "err", err)
app.SendError(writer, request, http.StatusInternalServerError)
return
}
_ = rc.Close()
} }
convergentKey := hasher.Sum(nil)[:crypto.KeySize] convergentKey := hasher.Sum(nil)[:crypto.KeySize]
for _, d := range decryptors { multiSrc := &SequentialChunkReader{
if _, err := d.Seek(0, io.SeekStart); err != nil { app: app,
app.Logger.Error("Failed to reset chunk decryptor", "err", err) uid: uid,
app.SendError(writer, request, http.StatusInternalServerError) total: total,
return
}
} }
defer func() {
multiSrc := io.MultiReader(readers...) if err := multiSrc.Close(); err != nil {
app.Logger.Error("Failed to close sequential reader", "uid", uid, "err", err)
}
}()
app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename")) app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename"))
} }