mirror of
https://github.com/skidoodle/safebin.git
synced 2026-04-28 19:27:41 +02:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
2bcf339408
|
|||
|
2df37e9002
|
|||
|
722dbaa6aa
|
|||
|
2d6a3ab216
|
|||
|
d18ef48bd4
|
|||
|
e18be18029
|
|||
|
a69e5a52a3
|
@@ -1,3 +1,2 @@
|
|||||||
storage/*
|
storage/*
|
||||||
# Added by goreleaser init:
|
|
||||||
dist/
|
dist/
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ const (
|
|||||||
ShutdownTimeout = 10 * time.Second
|
ShutdownTimeout = 10 * time.Second
|
||||||
|
|
||||||
UploadChunkSize = 8 << 20
|
UploadChunkSize = 8 << 20
|
||||||
|
MinChunkSize = 1 << 20
|
||||||
MaxRequestOverhead = 10 << 20
|
MaxRequestOverhead = 10 << 20
|
||||||
PermUserRWX = 0o700
|
PermUserRWX = 0o700
|
||||||
MegaByte = 1 << 20
|
MegaByte = 1 << 20
|
||||||
@@ -35,8 +36,10 @@ const (
|
|||||||
MinRetention = 24 * time.Hour
|
MinRetention = 24 * time.Hour
|
||||||
MaxRetention = 365 * 24 * time.Hour
|
MaxRetention = 365 * 24 * time.Hour
|
||||||
|
|
||||||
|
DBDirName = "db"
|
||||||
DBFileName = "safebin.db"
|
DBFileName = "safebin.db"
|
||||||
DBBucketName = "files"
|
DBBucketName = "files"
|
||||||
|
DBBucketIndexName = "expiry_index"
|
||||||
TempDirName = "tmp"
|
TempDirName = "tmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+13
-2
@@ -1,6 +1,7 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,15 +16,25 @@ type FileMeta struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func InitDB(storageDir string) (*bbolt.DB, error) {
|
func InitDB(storageDir string) (*bbolt.DB, error) {
|
||||||
path := filepath.Join(storageDir, DBFileName)
|
dbDir := filepath.Join(storageDir, DBDirName)
|
||||||
|
if err := os.MkdirAll(dbDir, PermUserRWX); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(dbDir, DBFileName)
|
||||||
db, err := bbolt.Open(path, 0600, &bbolt.Options{Timeout: 1 * time.Second})
|
db, err := bbolt.Open(path, 0600, &bbolt.Options{Timeout: 1 * time.Second})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Update(func(tx *bbolt.Tx) error {
|
err = db.Update(func(tx *bbolt.Tx) error {
|
||||||
_, err := tx.CreateBucketIfNotExists([]byte(DBBucketName))
|
if _, err := tx.CreateBucketIfNotExists([]byte(DBBucketName)); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.CreateBucketIfNotExists([]byte(DBBucketIndexName)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
+14
-3
@@ -23,16 +23,18 @@ func TestInitDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
dbPath := filepath.Join(tmpDir, DBFileName)
|
dbPath := filepath.Join(tmpDir, DBDirName, DBFileName)
|
||||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||||
t.Error("Database file was not created")
|
t.Error("Database file was not created")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.View(func(tx *bbolt.Tx) error {
|
err = db.View(func(tx *bbolt.Tx) error {
|
||||||
b := tx.Bucket([]byte(DBBucketName))
|
if b := tx.Bucket([]byte(DBBucketName)); b == nil {
|
||||||
if b == nil {
|
|
||||||
t.Errorf("Bucket '%s' was not created", DBBucketName)
|
t.Errorf("Bucket '%s' was not created", DBBucketName)
|
||||||
}
|
}
|
||||||
|
if b := tx.Bucket([]byte(DBBucketIndexName)); b == nil {
|
||||||
|
t.Errorf("Bucket '%s' was not created", DBBucketIndexName)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -85,6 +87,15 @@ func TestDB_MetadataLifecycle(t *testing.T) {
|
|||||||
if meta.ExpiresAt.Before(time.Now()) {
|
if meta.ExpiresAt.Before(time.Now()) {
|
||||||
t.Error("Expiration time is in the past")
|
t.Error("Expiration time is in the past")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bIndex := tx.Bucket([]byte(DBBucketIndexName))
|
||||||
|
indexKey := []byte(meta.ExpiresAt.Format(time.RFC3339) + "_" + fileID)
|
||||||
|
if val := bIndex.Get(indexKey); val == nil {
|
||||||
|
t.Error("Index entry not found")
|
||||||
|
} else if string(val) != fileID {
|
||||||
|
t.Errorf("Index value mismatch: want %s, got %s", fileID, string(val))
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package app
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -9,6 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/skidoodle/safebin/internal/crypto"
|
"github.com/skidoodle/safebin/internal/crypto"
|
||||||
|
"go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request) {
|
func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request) {
|
||||||
@@ -28,14 +30,42 @@ func (app *App) HandleGetFile(writer http.ResponseWriter, request *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
id := crypto.GetID(key, ext)
|
id := crypto.GetID(key, ext)
|
||||||
path := filepath.Join(app.Conf.StorageDir, id)
|
|
||||||
|
|
||||||
|
var meta FileMeta
|
||||||
|
err = app.DB.View(func(tx *bbolt.Tx) error {
|
||||||
|
b := tx.Bucket([]byte(DBBucketName))
|
||||||
|
if b == nil {
|
||||||
|
return fmt.Errorf("bucket not found")
|
||||||
|
}
|
||||||
|
data := b.Get([]byte(id))
|
||||||
|
if data == nil {
|
||||||
|
return fmt.Errorf("file not found")
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, &meta)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
app.SendError(writer, request, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(app.Conf.StorageDir, id)
|
||||||
info, err := os.Stat(path)
|
info, err := os.Stat(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.SendError(writer, request, http.StatusNotFound)
|
app.SendError(writer, request, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if info.Size() != meta.Size {
|
||||||
|
app.Logger.Error("Integrity check failed: disk size mismatch",
|
||||||
|
"id", id,
|
||||||
|
"disk_bytes", info.Size(),
|
||||||
|
"expected_bytes", meta.Size,
|
||||||
|
)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
file, err := os.Open(path)
|
file, err := os.Open(path)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -72,10 +72,13 @@ func (app *App) RespondWithLink(writer http.ResponseWriter, request *http.Reques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
scheme := "https"
|
scheme := request.Header.Get("X-Forwarded-Proto")
|
||||||
|
if scheme == "" {
|
||||||
|
scheme = "https"
|
||||||
if request.TLS == nil {
|
if request.TLS == nil {
|
||||||
scheme = "http"
|
scheme = "http"
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := fmt.Fprintf(writer, "%s://%s\n", scheme, link); err != nil {
|
if _, err := fmt.Fprintf(writer, "%s://%s\n", scheme, link); err != nil {
|
||||||
app.Logger.Error("Failed to write response", "err", err)
|
app.Logger.Error("Failed to write response", "err", err)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package app
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -12,6 +13,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/skidoodle/safebin/internal/crypto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupTestApp(t *testing.T) (*App, string) {
|
func setupTestApp(t *testing.T) (*App, string) {
|
||||||
@@ -176,6 +179,113 @@ func TestIntegration_ChunkedUpload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ChunkedUpload_VerifyEncryption(t *testing.T) {
|
||||||
|
app, storageDir := setupTestApp(t)
|
||||||
|
server := httptest.NewServer(app.Routes())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
uploadID := "securechunk123"
|
||||||
|
plaintext := []byte("This is a secret message that should be encrypted")
|
||||||
|
|
||||||
|
uploadChunk(t, server.URL, uploadID, 0, plaintext)
|
||||||
|
|
||||||
|
chunkPath := filepath.Join(storageDir, TempDirName, uploadID, "0")
|
||||||
|
encryptedData, err := os.ReadFile(chunkPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read chunk file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Contains(encryptedData, plaintext) {
|
||||||
|
t.Fatal("Chunk file contains plaintext data!")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(encryptedData) <= crypto.KeySize {
|
||||||
|
t.Fatalf("Chunk file too small: %d bytes", len(encryptedData))
|
||||||
|
}
|
||||||
|
|
||||||
|
key := encryptedData[:crypto.KeySize]
|
||||||
|
ciphertext := encryptedData[crypto.KeySize:]
|
||||||
|
|
||||||
|
streamer, err := crypto.NewGCMStreamer(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create streamer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := bytes.NewReader(ciphertext)
|
||||||
|
d := crypto.NewDecryptor(r, streamer.AEAD, int64(len(ciphertext)))
|
||||||
|
|
||||||
|
decrypted, err := io.ReadAll(d)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decrypt chunk: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted, plaintext) {
|
||||||
|
t.Errorf("Decrypted data mismatch.\nWant: %s\nGot: %s", plaintext, decrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_Upload_VerifyEncryption(t *testing.T) {
|
||||||
|
app, storageDir := setupTestApp(t)
|
||||||
|
server := httptest.NewServer(app.Routes())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
plaintext := []byte("Sensitive Data For Full Upload")
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
part, err := writer.CreateFormFile("file", "secret.txt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateFormFile failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := part.Write(plaintext); err != nil {
|
||||||
|
t.Fatalf("Write failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatalf("Writer close failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/", body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
t.Errorf("Failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
respBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
slug := filepath.Base(strings.TrimSpace(string(respBytes)))
|
||||||
|
|
||||||
|
if len(slug) < SlugLength {
|
||||||
|
t.Fatalf("Invalid slug: %s", slug)
|
||||||
|
}
|
||||||
|
keyBase64 := slug[:SlugLength]
|
||||||
|
key, _ := base64.RawURLEncoding.DecodeString(keyBase64)
|
||||||
|
ext := filepath.Ext("secret.txt")
|
||||||
|
id := crypto.GetID(key, ext)
|
||||||
|
|
||||||
|
finalPath := filepath.Join(storageDir, id)
|
||||||
|
finalData, err := os.ReadFile(finalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read final file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Contains(finalData, plaintext) {
|
||||||
|
t.Fatal("Final file contains plaintext!")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamer, _ := crypto.NewGCMStreamer(key)
|
||||||
|
d := crypto.NewDecryptor(bytes.NewReader(finalData), streamer.AEAD, int64(len(finalData)))
|
||||||
|
decrypted, _ := io.ReadAll(d)
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted, plaintext) {
|
||||||
|
t.Error("Final file decryption failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) {
|
func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) {
|
||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
writer := multipart.NewWriter(body)
|
writer := multipart.NewWriter(body)
|
||||||
|
|||||||
+85
-46
@@ -2,6 +2,7 @@ package app
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -48,55 +49,77 @@ func (app *App) saveChunk(uid string, idx int, src io.Reader) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if _, err := io.Copy(dest, src); err != nil {
|
key := make([]byte, crypto.KeySize)
|
||||||
return fmt.Errorf("copy chunk: %w", err)
|
if _, err := rand.Read(key); err != nil {
|
||||||
|
return fmt.Errorf("generate chunk key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := dest.Write(key); err != nil {
|
||||||
|
return fmt.Errorf("write chunk key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
streamer, err := crypto.NewGCMStreamer(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create streamer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := streamer.EncryptStream(dest, src); err != nil {
|
||||||
|
return fmt.Errorf("encrypt chunk: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) mergeChunks(uid string, total int) (string, error) {
|
func (app *App) getChunkDecryptors(uid string, total int) ([]io.ReadSeeker, func(), error) {
|
||||||
tmpPath := filepath.Join(app.Conf.StorageDir, TempDirName, "m_"+uid)
|
files := make([]*os.File, 0, total)
|
||||||
|
decryptors := make([]io.ReadSeeker, 0, total)
|
||||||
|
|
||||||
merged, err := os.Create(tmpPath)
|
closeAll := func() {
|
||||||
if err != nil {
|
for _, f := range files {
|
||||||
return "", fmt.Errorf("create merge file: %w", err)
|
_ = f.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if closeErr := merged.Close(); closeErr != nil {
|
|
||||||
app.Logger.Error("Failed to close merged file", "err", closeErr)
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
limit := app.Conf.MaxMB * MegaByte
|
|
||||||
var written int64
|
|
||||||
|
|
||||||
for i := range total {
|
for i := range total {
|
||||||
partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(i))
|
partPath := filepath.Join(app.Conf.StorageDir, TempDirName, uid, strconv.Itoa(i))
|
||||||
|
f, err := os.Open(partPath)
|
||||||
part, err := os.Open(partPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("open chunk %d: %w", i, err)
|
closeAll()
|
||||||
}
|
return nil, nil, fmt.Errorf("open chunk %d: %w", i, err)
|
||||||
|
}
|
||||||
n, err := io.Copy(merged, part)
|
files = append(files, f)
|
||||||
|
|
||||||
if closeErr := part.Close(); closeErr != nil {
|
key := make([]byte, crypto.KeySize)
|
||||||
app.Logger.Error("Failed to close chunk part", "err", closeErr)
|
if _, err := io.ReadFull(f, key); err != nil {
|
||||||
|
closeAll()
|
||||||
|
return nil, nil, fmt.Errorf("read chunk key %d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
info, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("append chunk %d: %w", i, err)
|
closeAll()
|
||||||
|
return nil, nil, fmt.Errorf("stat chunk %d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
written += n
|
bodySize := info.Size() - int64(crypto.KeySize)
|
||||||
if written > limit {
|
if bodySize < 0 {
|
||||||
return "", io.ErrShortWrite
|
closeAll()
|
||||||
}
|
return nil, nil, fmt.Errorf("invalid chunk size %d", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tmpPath, nil
|
bodyReader := io.NewSectionReader(f, int64(crypto.KeySize), bodySize)
|
||||||
|
|
||||||
|
streamer, err := crypto.NewGCMStreamer(key)
|
||||||
|
if err != nil {
|
||||||
|
closeAll()
|
||||||
|
return nil, nil, fmt.Errorf("create streamer %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptor := crypto.NewDecryptor(bodyReader, streamer.AEAD, bodySize)
|
||||||
|
decryptors = append(decryptors, decryptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decryptors, closeAll, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
|
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
|
||||||
@@ -151,32 +174,42 @@ func (app *App) RegisterFile(id string, size int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return app.DB.Update(func(tx *bbolt.Tx) error {
|
return app.DB.Update(func(tx *bbolt.Tx) error {
|
||||||
b := tx.Bucket([]byte(DBBucketName))
|
bFiles := tx.Bucket([]byte(DBBucketName))
|
||||||
|
bIndex := tx.Bucket([]byte(DBBucketIndexName))
|
||||||
|
|
||||||
data, err := json.Marshal(meta)
|
data, err := json.Marshal(meta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return b.Put([]byte(id), data)
|
|
||||||
|
if err := bFiles.Put([]byte(id), data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexKey := []byte(meta.ExpiresAt.Format(time.RFC3339) + "_" + id)
|
||||||
|
return bIndex.Put(indexKey, []byte(id))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) CleanStorage() {
|
func (app *App) CleanStorage() {
|
||||||
now := time.Now()
|
now := time.Now().Format(time.RFC3339)
|
||||||
var toDelete []string
|
var toDeleteIDs []string
|
||||||
|
var toDeleteKeys []string
|
||||||
|
|
||||||
err := app.DB.View(func(tx *bbolt.Tx) error {
|
err := app.DB.View(func(tx *bbolt.Tx) error {
|
||||||
b := tx.Bucket([]byte(DBBucketName))
|
bIndex := tx.Bucket([]byte(DBBucketIndexName))
|
||||||
c := b.Cursor()
|
if bIndex == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c := bIndex.Cursor()
|
||||||
|
|
||||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||||
var meta FileMeta
|
if string(k) > now {
|
||||||
if err := json.Unmarshal(v, &meta); err != nil {
|
break
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if now.After(meta.ExpiresAt) {
|
toDeleteKeys = append(toDeleteKeys, string(k))
|
||||||
toDelete = append(toDelete, string(k))
|
toDeleteIDs = append(toDeleteIDs, string(v))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@@ -186,21 +219,27 @@ func (app *App) CleanStorage() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(toDelete) == 0 {
|
if len(toDeleteIDs) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = app.DB.Update(func(tx *bbolt.Tx) error {
|
err = app.DB.Update(func(tx *bbolt.Tx) error {
|
||||||
b := tx.Bucket([]byte(DBBucketName))
|
bFiles := tx.Bucket([]byte(DBBucketName))
|
||||||
for _, id := range toDelete {
|
bIndex := tx.Bucket([]byte(DBBucketIndexName))
|
||||||
|
|
||||||
|
for i, id := range toDeleteIDs {
|
||||||
path := filepath.Join(app.Conf.StorageDir, id)
|
path := filepath.Join(app.Conf.StorageDir, id)
|
||||||
if err := os.RemoveAll(path); err != nil {
|
if err := os.RemoveAll(path); err != nil {
|
||||||
app.Logger.Error("Failed to remove expired file", "path", id, "err", err)
|
app.Logger.Error("Failed to remove expired file", "path", id, "err", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := b.Delete([]byte(id)); err != nil {
|
if err := bFiles.Delete([]byte(id)); err != nil {
|
||||||
app.Logger.Error("Failed to delete metadata", "id", id, "err", err)
|
app.Logger.Error("Failed to delete metadata", "id", id, "err", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := bIndex.Delete([]byte(toDeleteKeys[i])); err != nil {
|
||||||
|
app.Logger.Error("Failed to delete index", "key", toDeleteKeys[i], "err", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|||||||
+102
-44
@@ -1,55 +1,19 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/skidoodle/safebin/internal/crypto"
|
||||||
"go.etcd.io/bbolt"
|
"go.etcd.io/bbolt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCleanup_AbandonedMerge(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
tmpStorage := filepath.Join(tmpDir, TempDirName)
|
|
||||||
if err := os.MkdirAll(tmpStorage, 0700); err != nil {
|
|
||||||
t.Fatalf("MkdirAll failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := InitDB(tmpDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("InitDB failed: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := db.Close(); err != nil {
|
|
||||||
t.Errorf("Failed to close DB: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
app := &App{
|
|
||||||
Conf: Config{StorageDir: tmpDir},
|
|
||||||
Logger: discardLogger(),
|
|
||||||
DB: db,
|
|
||||||
}
|
|
||||||
|
|
||||||
abandonedFile := filepath.Join(tmpStorage, "m_abandoned_upload_id")
|
|
||||||
if err := os.WriteFile(abandonedFile, []byte("partial data"), 0600); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
oldTime := time.Now().Add(-TempExpiry - time.Hour)
|
|
||||||
if err := os.Chtimes(abandonedFile, oldTime, oldTime); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.CleanTemp(tmpStorage)
|
|
||||||
|
|
||||||
if _, err := os.Stat(abandonedFile); !os.IsNotExist(err) {
|
|
||||||
t.Error("Cleanup failed to remove abandoned merge file from crashed session")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCleanup_AbandonedChunks(t *testing.T) {
|
func TestCleanup_AbandonedChunks(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
tmpStorage := filepath.Join(tmpDir, TempDirName)
|
tmpStorage := filepath.Join(tmpDir, TempDirName)
|
||||||
@@ -135,9 +99,16 @@ func TestCleanup_ExpiredStorage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := app.DB.Update(func(tx *bbolt.Tx) error {
|
if err := app.DB.Update(func(tx *bbolt.Tx) error {
|
||||||
b := tx.Bucket([]byte(DBBucketName))
|
bFiles := tx.Bucket([]byte(DBBucketName))
|
||||||
|
bIndex := tx.Bucket([]byte(DBBucketIndexName))
|
||||||
|
|
||||||
data, _ := json.Marshal(expiredMeta)
|
data, _ := json.Marshal(expiredMeta)
|
||||||
return b.Put([]byte(filename), data)
|
if err := bFiles.Put([]byte(filename), data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexKey := []byte(expiredMeta.ExpiresAt.Format(time.RFC3339) + "_" + filename)
|
||||||
|
return bIndex.Put(indexKey, []byte(filename))
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatalf("DB Update failed: %v", err)
|
t.Fatalf("DB Update failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -149,12 +120,99 @@ func TestCleanup_ExpiredStorage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := app.DB.View(func(tx *bbolt.Tx) error {
|
if err := app.DB.View(func(tx *bbolt.Tx) error {
|
||||||
b := tx.Bucket([]byte(DBBucketName))
|
bFiles := tx.Bucket([]byte(DBBucketName))
|
||||||
if v := b.Get([]byte(filename)); v != nil {
|
if v := bFiles.Get([]byte(filename)); v != nil {
|
||||||
t.Error("Cleanup failed to remove metadata")
|
t.Error("Cleanup failed to remove metadata")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bIndex := tx.Bucket([]byte(DBBucketIndexName))
|
||||||
|
indexKey := []byte(expiredMeta.ExpiresAt.Format(time.RFC3339) + "_" + filename)
|
||||||
|
if v := bIndex.Get(indexKey); v != nil {
|
||||||
|
t.Error("Cleanup failed to remove index entry")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.Fatalf("DB View failed: %v", err)
|
t.Fatalf("DB View failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveChunk_EncryptsData(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
app := &App{
|
||||||
|
Conf: Config{StorageDir: tmpDir},
|
||||||
|
Logger: discardLogger(),
|
||||||
|
}
|
||||||
|
|
||||||
|
uid := "test-encrypt-chunk"
|
||||||
|
plaintext := make([]byte, 1024)
|
||||||
|
if _, err := rand.Read(plaintext); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := app.saveChunk(uid, 0, bytes.NewReader(plaintext)); err != nil {
|
||||||
|
t.Fatalf("saveChunk failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(tmpDir, TempDirName, uid, "0")
|
||||||
|
fileData, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(fileData, plaintext) {
|
||||||
|
t.Fatal("Chunk stored as plaintext!")
|
||||||
|
}
|
||||||
|
if bytes.Contains(fileData, plaintext) {
|
||||||
|
t.Fatal("Chunk contains plaintext!")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedSize := crypto.KeySize + len(plaintext) + 16
|
||||||
|
if len(fileData) != expectedSize {
|
||||||
|
t.Errorf("Unexpected file size. Want %d, got %d", expectedSize, len(fileData))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetChunkDecryptors_RestoresData(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
app := &App{
|
||||||
|
Conf: Config{StorageDir: tmpDir},
|
||||||
|
Logger: discardLogger(),
|
||||||
|
}
|
||||||
|
|
||||||
|
uid := "test-restore"
|
||||||
|
data1 := []byte("chunk one data")
|
||||||
|
data2 := []byte("chunk two data")
|
||||||
|
|
||||||
|
if err := app.saveChunk(uid, 0, bytes.NewReader(data1)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := app.saveChunk(uid, 1, bytes.NewReader(data2)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptors, closeFn, err := app.getChunkDecryptors(uid, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getChunkDecryptors failed: %v", err)
|
||||||
|
}
|
||||||
|
defer closeFn()
|
||||||
|
|
||||||
|
if len(decryptors) != 2 {
|
||||||
|
t.Fatalf("Expected 2 decryptors, got %d", len(decryptors))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf1, err := io.ReadAll(decryptors[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read decryptor 1: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf1, data1) {
|
||||||
|
t.Errorf("Chunk 1 mismatch. Want %s, got %s", data1, buf1)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf2, err := io.ReadAll(decryptors[1])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read decryptor 2: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf2, data2) {
|
||||||
|
t.Errorf("Chunk 2 mismatch. Want %s, got %s", data2, buf2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+83
-69
@@ -1,7 +1,8 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -19,17 +20,14 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
|
|||||||
request.Body = http.MaxBytesReader(writer, request.Body, limit)
|
request.Body = http.MaxBytesReader(writer, request.Body, limit)
|
||||||
|
|
||||||
file, header, err := request.FormFile("file")
|
file, header, err := request.FormFile("file")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "http: request body too large" {
|
if err.Error() == "http: request body too large" {
|
||||||
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := file.Close(); closeErr != nil {
|
if closeErr := file.Close(); closeErr != nil {
|
||||||
app.Logger.Error("Failed to close upload file", "err", closeErr)
|
app.Logger.Error("Failed to close upload file", "err", closeErr)
|
||||||
@@ -37,48 +35,88 @@ func (app *App) HandleUpload(writer http.ResponseWriter, request *http.Request)
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*")
|
tmp, err := os.CreateTemp(filepath.Join(app.Conf.StorageDir, TempDirName), "up_*")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.Logger.Error("Failed to create temp file", "err", err)
|
app.Logger.Error("Failed to create temp file", "err", err)
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpPath := tmp.Name()
|
tmpPath := tmp.Name()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
_ = tmp.Close()
|
||||||
if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) {
|
if removeErr := os.Remove(tmpPath); removeErr != nil && !os.IsNotExist(removeErr) {
|
||||||
app.Logger.Error("Failed to remove temp file", "err", removeErr)
|
app.Logger.Error("Failed to remove temp file", "err", removeErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
ephemeralKey := make([]byte, crypto.KeySize)
|
||||||
|
if _, err := rand.Read(ephemeralKey); err != nil {
|
||||||
|
app.Logger.Error("Failed to generate ephemeral key", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
hasher := sha256.New()
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(io.MultiWriter(hasher, pw), file)
|
||||||
|
_ = pw.CloseWithError(err)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := tmp.Close(); closeErr != nil {
|
if closeErr := pr.Close(); closeErr != nil {
|
||||||
app.Logger.Error("Failed to close temp file", "err", closeErr)
|
app.Logger.Error("Failed to close pipe reader", "err", closeErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if _, err := io.Copy(tmp, file); err != nil {
|
streamer, err := crypto.NewGCMStreamer(ephemeralKey)
|
||||||
app.Logger.Error("Failed to write temp file", "err", err)
|
if err != nil {
|
||||||
|
app.Logger.Error("Failed to create streamer", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := streamer.EncryptStream(tmp, pr); err != nil {
|
||||||
|
app.Logger.Error("Failed to encrypt stream", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := <-errChan; err != nil {
|
||||||
|
app.Logger.Error("Failed to read/hash upload", "err", err)
|
||||||
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
app.FinalizeFile(writer, request, tmp, header.Filename)
|
convergentKey := hasher.Sum(nil)[:crypto.KeySize]
|
||||||
|
|
||||||
|
if _, err := tmp.Seek(0, 0); err != nil {
|
||||||
|
app.Logger.Error("Seek failed", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, _ := tmp.Stat()
|
||||||
|
decryptor := crypto.NewDecryptor(tmp, streamer.AEAD, info.Size())
|
||||||
|
|
||||||
|
app.finalizeUpload(writer, request, decryptor, convergentKey, header.Filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
|
func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
|
||||||
request.Body = http.MaxBytesReader(writer, request.Body, MaxRequestOverhead)
|
request.Body = http.MaxBytesReader(writer, request.Body, MaxRequestOverhead)
|
||||||
|
|
||||||
uid := request.FormValue("upload_id")
|
uid := request.FormValue("upload_id")
|
||||||
|
|
||||||
idx, err := strconv.Atoi(request.FormValue("index"))
|
idx, err := strconv.Atoi(request.FormValue("index"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
maxChunks := int((app.Conf.MaxMB*MegaByte)/UploadChunkSize) + ChunkSafetyMargin
|
maxChunks := int((app.Conf.MaxMB*MegaByte)/MinChunkSize) + ChunkSafetyMargin
|
||||||
|
|
||||||
if !reUploadID.MatchString(uid) || idx > maxChunks || idx < 0 {
|
if !reUploadID.MatchString(uid) || idx > maxChunks || idx < 0 {
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
@@ -86,17 +124,14 @@ func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
file, _, err := request.FormFile("chunk")
|
file, _, err := request.FormFile("chunk")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "http: request body too large" {
|
if err.Error() == "http: request body too large" {
|
||||||
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := file.Close(); closeErr != nil {
|
if closeErr := file.Close(); closeErr != nil {
|
||||||
app.Logger.Error("Failed to close chunk file", "err", closeErr)
|
app.Logger.Error("Failed to close chunk file", "err", closeErr)
|
||||||
@@ -111,75 +146,60 @@ func (app *App) HandleChunk(writer http.ResponseWriter, request *http.Request) {
|
|||||||
|
|
||||||
func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) {
|
func (app *App) HandleFinish(writer http.ResponseWriter, request *http.Request) {
|
||||||
uid := request.FormValue("upload_id")
|
uid := request.FormValue("upload_id")
|
||||||
|
|
||||||
total, err := strconv.Atoi(request.FormValue("total"))
|
total, err := strconv.Atoi(request.FormValue("total"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
maxChunks := int((app.Conf.MaxMB*MegaByte)/UploadChunkSize) + ChunkSafetyMargin
|
maxChunks := int((app.Conf.MaxMB*MegaByte)/MinChunkSize) + ChunkSafetyMargin
|
||||||
|
|
||||||
if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 {
|
if !reUploadID.MatchString(uid) || total > maxChunks || total <= 0 {
|
||||||
app.SendError(writer, request, http.StatusBadRequest)
|
app.SendError(writer, request, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mergedPath, err := app.mergeChunks(uid, total)
|
decryptors, closeAll, err := app.getChunkDecryptors(uid, total)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.Logger.Error("Merge failed", "err", err)
|
app.Logger.Error("Failed to open chunks", "err", err)
|
||||||
|
|
||||||
if errors.Is(err, io.ErrShortWrite) {
|
|
||||||
app.SendError(writer, request, http.StatusRequestEntityTooLarge)
|
|
||||||
} else {
|
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if removeErr := os.Remove(mergedPath); removeErr != nil && !os.IsNotExist(removeErr) {
|
|
||||||
app.Logger.Error("Failed to remove merged file", "err", removeErr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
mergedRead, err := os.Open(mergedPath)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
app.Logger.Error("Failed to open merged file", "err", err)
|
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := mergedRead.Close(); closeErr != nil {
|
closeAll()
|
||||||
app.Logger.Error("Failed to close merged reader", "err", closeErr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
app.FinalizeFile(writer, request, mergedRead, request.FormValue("filename"))
|
|
||||||
|
|
||||||
if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil {
|
if err := os.RemoveAll(filepath.Join(app.Conf.StorageDir, TempDirName, uid)); err != nil {
|
||||||
app.Logger.Error("Failed to remove chunk dir", "err", err)
|
app.Logger.Error("Failed to remove chunk dir", "err", err)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
readers := make([]io.Reader, len(decryptors))
|
||||||
|
for i, d := range decryptors {
|
||||||
|
readers[i] = d
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher := sha256.New()
|
||||||
|
if _, err := io.Copy(hasher, io.MultiReader(readers...)); err != nil {
|
||||||
|
app.Logger.Error("Failed to hash chunks", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
convergentKey := hasher.Sum(nil)[:crypto.KeySize]
|
||||||
|
|
||||||
|
for _, d := range decryptors {
|
||||||
|
if _, err := d.Seek(0, io.SeekStart); err != nil {
|
||||||
|
app.Logger.Error("Failed to reset chunk decryptor", "err", err)
|
||||||
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
multiSrc := io.MultiReader(readers...)
|
||||||
|
|
||||||
|
app.finalizeUpload(writer, request, multiSrc, convergentKey, request.FormValue("filename"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request, src *os.File, filename string) {
|
func (app *App) finalizeUpload(writer http.ResponseWriter, request *http.Request, src io.Reader, key []byte, filename string) {
|
||||||
if _, err := src.Seek(0, 0); err != nil {
|
|
||||||
app.Logger.Error("Seek failed", "err", err)
|
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := crypto.DeriveKey(src)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
app.Logger.Error("Key derivation failed", "err", err)
|
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ext := filepath.Ext(filename)
|
ext := filepath.Ext(filename)
|
||||||
id := crypto.GetID(key, ext)
|
id := crypto.GetID(key, ext)
|
||||||
finalPath := filepath.Join(app.Conf.StorageDir, id)
|
finalPath := filepath.Join(app.Conf.StorageDir, id)
|
||||||
@@ -192,12 +212,6 @@ func (app *App) FinalizeFile(writer http.ResponseWriter, request *http.Request,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := src.Seek(0, 0); err != nil {
|
|
||||||
app.Logger.Error("Seek failed", "err", err)
|
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := app.encryptAndSave(src, key, finalPath); err != nil {
|
if err := app.encryptAndSave(src, key, finalPath); err != nil {
|
||||||
app.Logger.Error("Encryption failed", "err", err)
|
app.Logger.Error("Encryption failed", "err", err)
|
||||||
app.SendError(writer, request, http.StatusInternalServerError)
|
app.SendError(writer, request, http.StatusInternalServerError)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ type Decryptor struct {
|
|||||||
aead cipher.AEAD
|
aead cipher.AEAD
|
||||||
size int64
|
size int64
|
||||||
offset int64
|
offset int64
|
||||||
|
phyOffset int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDecryptor(readSeeker io.ReadSeeker, aead cipher.AEAD, encryptedSize int64) *Decryptor {
|
func NewDecryptor(readSeeker io.ReadSeeker, aead cipher.AEAD, encryptedSize int64) *Decryptor {
|
||||||
@@ -35,6 +36,7 @@ func NewDecryptor(readSeeker io.ReadSeeker, aead cipher.AEAD, encryptedSize int6
|
|||||||
aead: aead,
|
aead: aead,
|
||||||
size: plainSize,
|
size: plainSize,
|
||||||
offset: 0,
|
offset: 0,
|
||||||
|
phyOffset: -1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,14 +51,22 @@ func (d *Decryptor) Read(buf []byte) (int, error) {
|
|||||||
overhead := int64(d.aead.Overhead())
|
overhead := int64(d.aead.Overhead())
|
||||||
actualChunkSize := int64(GCMChunkSize) + overhead
|
actualChunkSize := int64(GCMChunkSize) + overhead
|
||||||
|
|
||||||
_, err := d.readSeeker.Seek(chunkIdx*actualChunkSize, io.SeekStart)
|
targetOffset := chunkIdx * actualChunkSize
|
||||||
if err != nil {
|
|
||||||
|
if d.phyOffset != targetOffset {
|
||||||
|
if _, err := d.readSeeker.Seek(targetOffset, io.SeekStart); err != nil {
|
||||||
return 0, fmt.Errorf("failed to seek: %w", err)
|
return 0, fmt.Errorf("failed to seek: %w", err)
|
||||||
}
|
}
|
||||||
|
d.phyOffset = targetOffset
|
||||||
|
}
|
||||||
|
|
||||||
encrypted := make([]byte, actualChunkSize)
|
encrypted := make([]byte, actualChunkSize)
|
||||||
|
|
||||||
bytesRead, err := io.ReadFull(d.readSeeker, encrypted)
|
bytesRead, err := io.ReadFull(d.readSeeker, encrypted)
|
||||||
|
if bytesRead > 0 {
|
||||||
|
d.phyOffset += int64(bytesRead)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
|
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
return 0, fmt.Errorf("failed to read encrypted data: %w", err)
|
return 0, fmt.Errorf("failed to read encrypted data: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -50,7 +50,7 @@ async function handleUpload(file) {
|
|||||||
$("busy-state").classList.remove("hidden");
|
$("busy-state").classList.remove("hidden");
|
||||||
$("p-bar-container").classList.add("visible");
|
$("p-bar-container").classList.add("visible");
|
||||||
|
|
||||||
const uploadID = Math.random().toString(36).substring(2, 15);
|
const uploadID = Array.from(window.crypto.getRandomValues(new Uint8Array(16)), (b) => b.toString(16).padStart(2, "0")).join("");
|
||||||
const chunkSize = 1024 * 1024 * 8;
|
const chunkSize = 1024 * 1024 * 8;
|
||||||
const total = Math.ceil(file.size / chunkSize);
|
const total = Math.ceil(file.size / chunkSize);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user