mirror of
https://github.com/skidoodle/safebin.git
synced 2026-04-28 11:17:42 +02:00
feat: implement sequential chunk reading and decryption
Signed-off-by: skidoodle <contact@albert.lol>
This commit is contained in:
+83
-46
@@ -70,56 +70,93 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) getChunkDecryptors(uid string, total int) ([]io.ReadSeeker, func(), error) {
|
func (app *App) openChunkDecryptor(uid string, idx int) (io.ReadCloser, error) {
|
||||||
files := make([]*os.File, 0, total)
|
partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(idx))
|
||||||
decryptors := make([]io.ReadSeeker, 0, total)
|
f, err := os.Open(partPath)
|
||||||
|
if err != nil {
|
||||||
closeAll := func() {
|
return nil, fmt.Errorf("open chunk %d: %w", idx, err)
|
||||||
for _, f := range files {
|
|
||||||
_ = f.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range total {
|
key := make([]byte, crypto.KeySize)
|
||||||
partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(i))
|
if _, err := io.ReadFull(f, key); err != nil {
|
||||||
f, err := os.Open(partPath)
|
_ = f.Close()
|
||||||
if err != nil {
|
return nil, fmt.Errorf("read chunk key %d: %w", idx, err)
|
||||||
closeAll()
|
|
||||||
return nil, nil, fmt.Errorf("open chunk %d: %w", i, err)
|
|
||||||
}
|
|
||||||
files = append(files, f)
|
|
||||||
|
|
||||||
key := make([]byte, crypto.KeySize)
|
|
||||||
if _, err := io.ReadFull(f, key); err != nil {
|
|
||||||
closeAll()
|
|
||||||
return nil, nil, fmt.Errorf("read chunk key %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := f.Stat()
|
|
||||||
if err != nil {
|
|
||||||
closeAll()
|
|
||||||
return nil, nil, fmt.Errorf("stat chunk %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bodySize := info.Size() - int64(crypto.KeySize)
|
|
||||||
if bodySize < 0 {
|
|
||||||
closeAll()
|
|
||||||
return nil, nil, fmt.Errorf("invalid chunk size %d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize)
|
|
||||||
|
|
||||||
streamer, err := crypto.NewGCMStreamer(key)
|
|
||||||
if err != nil {
|
|
||||||
closeAll()
|
|
||||||
return nil, nil, fmt.Errorf("create streamer %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize)
|
|
||||||
decryptors = append(decryptors, decryptor)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return decryptors, closeAll, nil
|
info, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
_ = f.Close()
|
||||||
|
return nil, fmt.Errorf("stat chunk %d: %w", idx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodySize := info.Size() - int64(crypto.KeySize)
|
||||||
|
if bodySize < 0 {
|
||||||
|
_ = f.Close()
|
||||||
|
return nil, fmt.Errorf("invalid chunk size %d", idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize)
|
||||||
|
|
||||||
|
streamer, err := crypto.NewGCMStreamer(key)
|
||||||
|
if err != nil {
|
||||||
|
_ = f.Close()
|
||||||
|
return nil, fmt.Errorf("create streamer %d: %w", idx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize)
|
||||||
|
|
||||||
|
return &chunkReadCloser{Decryptor: decryptor, f: f}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type chunkReadCloser struct {
|
||||||
|
*crypto.Decryptor
|
||||||
|
f *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *chunkReadCloser) Close() error {
|
||||||
|
return c.f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type SequentialChunkReader struct {
|
||||||
|
app *App
|
||||||
|
uid string
|
||||||
|
total int
|
||||||
|
currentIdx int
|
||||||
|
currentRC io.ReadCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SequentialChunkReader) Read(p []byte) (n int, err error) {
|
||||||
|
if s.currentRC == nil {
|
||||||
|
if s.currentIdx >= s.total {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
rc, err := s.app.openChunkDecryptor(s.uid, s.currentIdx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
s.currentRC = rc
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = s.currentRC.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
_ = s.currentRC.Close()
|
||||||
|
s.currentRC = nil
|
||||||
|
s.currentIdx++
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return s.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SequentialChunkReader) Close() error {
|
||||||
|
if s.currentRC != nil {
|
||||||
|
return s.currentRC.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
|
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ func TestSaveChunk_EncryptsData(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetChunkDecryptors_RestoresData(t *testing.T) {
|
func TestSequentialChunkReader_RestoresData(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
app := &App{
|
app := &App{
|
||||||
Conf: Config{StorageDir: tmpDir},
|
Conf: Config{StorageDir: tmpDir},
|
||||||
@@ -190,29 +190,24 @@ func TestGetChunkDecryptors_RestoresData(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
decryptors, closeFn, err := app.getChunkDecryptors(uid, 2)
|
reader := &SequentialChunkReader{
|
||||||
if err != nil {
|
app: app,
|
||||||
t.Fatalf("getChunkDecryptors failed: %v", err)
|
uid: uid,
|
||||||
|
total: 2,
|
||||||
}
|
}
|
||||||
defer closeFn()
|
defer func() {
|
||||||
|
if err := reader.Close(); err != nil {
|
||||||
|
t.Errorf("Failed to close reader: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if len(decryptors) != 2 {
|
restored, err := io.ReadAll(reader)
|
||||||
t.Fatalf("Expected 2 decryptors, got %d", len(decryptors))
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf1, err := io.ReadAll(decryptors[0])
|
expected := append(data1, data2...)
|
||||||
if err != nil {
|
if !bytes.Equal(restored, expected) {
|
||||||
t.Fatalf("Failed to read decryptor 1: %v", err)
|
t.Errorf("Restored data mismatch.\nWant: %s\nGot: %s", expected, restored)
|
||||||
}
|
|
||||||
if !bytes.Equal(buf1, data1) {
|
|
||||||
t.Errorf("Chunk 1 mismatch. Want %s, got %s", data1, buf1)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf2, err := io.ReadAll(decryptors[1])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read decryptor 2: %v", err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(buf2, data2) {
|
|
||||||
t.Errorf("Chunk 2 mismatch. Want %s, got %s", data2, buf2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+58
-38
@@ -3,12 +3,14 @@ package app
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/skidoodle/safebin/internal/crypto"
|
"github.com/skidoodle/safebin/internal/crypto"
|
||||||
)
|
)
|
||||||
@@ -19,20 +21,36 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
|
|||||||
limit := (app.Conf.MaxMB * MegaByte) + MegaByte
|
limit := (app.Conf.MaxMB * MegaByte) + MegaByte
|
||||||
request.Body = http.MaxBytesReader(writer, request.Body, limit)
|
request.Body = http.MaxBytesReader(writer, request.Body, limit)
|
||||||
|
|
||||||
file, header, err := request.FormFile("file")
|
mr, err := request.MultipartReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "http: request body too large" {
|
|
||||||
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if closeErr := file.Close(); closeErr != nil {
|
var filename string
|
||||||
app.Logger.Error("Failed to close upload file", "err", closeErr)
|
var partReader io.Reader
|
||||||
|
|
||||||
|
for {
|
||||||
|
part, err := mr.NextPart()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}()
|
if err != nil {
|
||||||
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if part.FormName() == "file" {
|
||||||
|
filename = part.FileName()
|
||||||
|
partReader = part
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if partReader == nil {
|
||||||
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*")
|
tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -58,11 +76,10 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
|
|||||||
|
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
|
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(io.MultiWriter(hasher, pw), file)
|
_, err := io.Copy(io.MultiWriter(hasher, pw), partReader)
|
||||||
_ = pw.CloseWithError(err)
|
_ = pw.CloseWithError(err)
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}()
|
}()
|
||||||
@@ -87,8 +104,12 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := <-errChan; err != nil {
|
if err := <-errChan; err != nil {
|
||||||
app.Logger.Error("Failed to read/hash upload", "err", err)
|
if errors.Is(err, http.ErrMissingBoundary) || strings.Contains(err.Error(), "request body too large") {
|
||||||
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
||||||
|
} else {
|
||||||
|
app.Logger.Error("Failed to read/hash upload", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,7 +124,7 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
|
|||||||
info, _ := tmp.Stat()
|
info, _ := tmp.Stat()
|
||||||
decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size())
|
decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size())
|
||||||
|
|
||||||
app.finalizeUpload(writer, request, decryptor, convergentKey, header.Filename)
|
app.finalizeUpload(writer, request, decryptor, convergentKey, filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
|
func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
|
||||||
@@ -159,42 +180,41 @@ func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
decryptors, closeAll, err := app.getChunkDecryptors(uid, total)
|
|
||||||
if err != nil {
|
|
||||||
app.Logger.Error("Failed to open chunks", "err", err)
|
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
defer func() {
|
||||||
closeAll()
|
|
||||||
if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil {
|
if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil {
|
||||||
app.Logger.Error("Failed to remove chunk dir", "err", err)
|
app.Logger.Error("Failed to remove chunk dir", "err", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
readers := make([]io.Reader, len(decryptors))
|
|
||||||
for i, d := range decryptors {
|
|
||||||
readers[i] = d
|
|
||||||
}
|
|
||||||
|
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
if _, err := io.Copy(hasher, io.MultiReader(readers...)); err != nil {
|
for i := range total {
|
||||||
app.Logger.Error("Failed to hash chunks", "err", err)
|
rc, err := app.openChunkDecryptor(uid, i)
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
if err != nil {
|
||||||
return
|
app.Logger.Error("Failed to open chunk for hashing", "index", i, "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(hasher, rc); err != nil {
|
||||||
|
_ = rc.Close()
|
||||||
|
app.Logger.Error("Failed to hash chunk", "index", i, "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = rc.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
convergentKey := hasher.Sum(nil)[:crypto.KeySize]
|
convergentKey := hasher.Sum(nil)[:crypto.KeySize]
|
||||||
|
|
||||||
for _, d := range decryptors {
|
multiSrc := &SequentialChunkReader{
|
||||||
if _, err := d.Seek(0, io.SeekStart); err != nil {
|
app: app,
|
||||||
app.Logger.Error("Failed to reset chunk decryptor", "err", err)
|
uid: uid,
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
total: total,
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
multiSrc := io.MultiReader(readers...)
|
if err := multiSrc.Close(); err != nil {
|
||||||
|
app.Logger.Error("Failed to close sequential reader", "uid", uid, "err", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename"))
|
app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename"))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user