mirror of
https://github.com/skidoodle/safebin.git
synced 2026-04-28 11:17:42 +02:00
@@ -0,0 +1,37 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetEnv(t *testing.T) {
|
||||||
|
key := "SAFEBIN_TEST_KEY"
|
||||||
|
val := "somevalue"
|
||||||
|
|
||||||
|
if got := getEnv(key, "default"); got != "default" {
|
||||||
|
t.Errorf("Expected default, got %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv(key, val)
|
||||||
|
if got := getEnv(key, "default"); got != val {
|
||||||
|
t.Errorf("Expected %s, got %s", val, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEnvInt(t *testing.T) {
|
||||||
|
key := "SAFEBIN_TEST_INT"
|
||||||
|
|
||||||
|
if got := getEnvInt(key, 8080); got != 8080 {
|
||||||
|
t.Errorf("Expected default 8080, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv(key, "9090")
|
||||||
|
if got := getEnvInt(key, 8080); got != 9090 {
|
||||||
|
t.Errorf("Expected 9090, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv(key, "notanumber")
|
||||||
|
if got := getEnvInt(key, 8080); got != 8080 {
|
||||||
|
t.Errorf("Expected fallback on invalid input, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCalculateRetention(t *testing.T) {
|
||||||
|
maxMB := int64(100)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fileSize int64
|
||||||
|
wantMin time.Duration
|
||||||
|
wantMax time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Tiny file (Max retention)",
|
||||||
|
fileSize: 1024,
|
||||||
|
wantMin: MaxRetention - time.Hour,
|
||||||
|
wantMax: MaxRetention,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max size file (Min retention)",
|
||||||
|
fileSize: 100 * MegaByte,
|
||||||
|
wantMin: MinRetention,
|
||||||
|
wantMax: MinRetention + time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Half size file (Somewhere in between)",
|
||||||
|
fileSize: 50 * MegaByte,
|
||||||
|
wantMin: 24 * time.Hour,
|
||||||
|
wantMax: MaxRetention,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Oversized file (Min retention)",
|
||||||
|
fileSize: 200 * MegaByte,
|
||||||
|
wantMin: MinRetention,
|
||||||
|
wantMax: MinRetention + time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := CalculateRetention(tc.fileSize, maxMB)
|
||||||
|
if got < tc.wantMin || got > tc.wantMax {
|
||||||
|
t.Errorf("Retention for size %d: got %v, want between %v and %v",
|
||||||
|
tc.fileSize, got, tc.wantMin, tc.wantMax)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupTestApp(t *testing.T) (*App, string) {
|
||||||
|
storageDir := t.TempDir()
|
||||||
|
os.MkdirAll(filepath.Join(storageDir, "tmp"), 0700)
|
||||||
|
|
||||||
|
tmplDir := filepath.Join(storageDir, "templates")
|
||||||
|
os.MkdirAll(tmplDir, 0700)
|
||||||
|
os.WriteFile(filepath.Join(tmplDir, "base.html"), []byte(`{{define "base"}}{{template "content" .}}{{end}}`), 0600)
|
||||||
|
os.WriteFile(filepath.Join(tmplDir, "index.html"), []byte(`{{define "content"}}OK{{end}}`), 0600)
|
||||||
|
|
||||||
|
tmpl := template.Must(template.New("base").Parse(`{{define "base"}}OK{{end}}`))
|
||||||
|
|
||||||
|
app := &App{
|
||||||
|
Conf: Config{
|
||||||
|
StorageDir: storageDir,
|
||||||
|
MaxMB: 10,
|
||||||
|
},
|
||||||
|
Logger: discardLogger(),
|
||||||
|
Tmpl: tmpl,
|
||||||
|
}
|
||||||
|
|
||||||
|
return app, storageDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func discardLogger() *slog.Logger {
|
||||||
|
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_StandardUploadAndDownload(t *testing.T) {
|
||||||
|
app, _ := setupTestApp(t)
|
||||||
|
server := httptest.NewServer(app.Routes())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
part, _ := writer.CreateFormFile("file", "test.txt")
|
||||||
|
content := []byte("Hello Safebin")
|
||||||
|
part.Write(content)
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", server.URL+"/", body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Upload request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Upload failed status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
respStr := string(respBytes)
|
||||||
|
parts := strings.Split(strings.TrimSpace(respStr), "/")
|
||||||
|
slugWithExt := parts[len(parts)-1]
|
||||||
|
|
||||||
|
downloadURL := fmt.Sprintf("%s/%s", server.URL, slugWithExt)
|
||||||
|
resp, err = http.Get(downloadURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Download request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Download failed status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
downloadedContent, _ := io.ReadAll(resp.Body)
|
||||||
|
if !bytes.Equal(content, downloadedContent) {
|
||||||
|
t.Errorf("Content mismatch. Want %s, got %s", content, downloadedContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_ChunkedUpload(t *testing.T) {
|
||||||
|
app, _ := setupTestApp(t)
|
||||||
|
server := httptest.NewServer(app.Routes())
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
uploadID := "testchunkid123"
|
||||||
|
content := []byte("Chunk1Content-Chunk2Content")
|
||||||
|
chunk1 := content[:13]
|
||||||
|
chunk2 := content[13:]
|
||||||
|
|
||||||
|
uploadChunk(t, server.URL, uploadID, 0, chunk1)
|
||||||
|
uploadChunk(t, server.URL, uploadID, 1, chunk2)
|
||||||
|
|
||||||
|
finishURL := fmt.Sprintf("%s/upload/finish", server.URL)
|
||||||
|
form := map[string]string{
|
||||||
|
"upload_id": uploadID,
|
||||||
|
"total": "2",
|
||||||
|
"filename": "chunked.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := postForm(t, finishURL, form)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Finish failed: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
respStr := string(respBytes)
|
||||||
|
parts := strings.Split(strings.TrimSpace(respStr), "/")
|
||||||
|
slugWithExt := parts[len(parts)-1]
|
||||||
|
|
||||||
|
downloadURL := fmt.Sprintf("%s/%s", server.URL, slugWithExt)
|
||||||
|
dlResp, _ := http.Get(downloadURL)
|
||||||
|
dlBytes, _ := io.ReadAll(dlResp.Body)
|
||||||
|
dlResp.Body.Close()
|
||||||
|
|
||||||
|
if !bytes.Equal(content, dlBytes) {
|
||||||
|
t.Errorf("Chunked reassembly failed. Want %s, got %s", content, dlBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) {
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
writer.WriteField("upload_id", uid)
|
||||||
|
writer.WriteField("index", fmt.Sprintf("%d", idx))
|
||||||
|
part, _ := writer.CreateFormFile("chunk", "blob")
|
||||||
|
part.Write(data)
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", baseURL+"/upload/chunk", body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil || resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Chunk %d upload failed: %v", idx, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func postForm(t *testing.T, url string, fields map[string]string) *http.Response {
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
for k, v := range fields {
|
||||||
|
writer.WriteField(k, v)
|
||||||
|
}
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", url, body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Post form failed: %v", err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCleanup_AbandonedMerge(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpStorage := filepath.Join(tmpDir, "tmp")
|
||||||
|
os.MkdirAll(tmpStorage, 0700)
|
||||||
|
|
||||||
|
app := &App{
|
||||||
|
Conf: Config{StorageDir: tmpDir},
|
||||||
|
Logger: discardLogger(),
|
||||||
|
}
|
||||||
|
|
||||||
|
abandonedFile := filepath.Join(tmpStorage, "m_abandoned_upload_id")
|
||||||
|
if err := os.WriteFile(abandonedFile, []byte("partial data"), 0600); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldTime := time.Now().Add(-TempExpiry - time.Hour)
|
||||||
|
if err := os.Chtimes(abandonedFile, oldTime, oldTime); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.CleanTemp(tmpStorage)
|
||||||
|
|
||||||
|
if _, err := os.Stat(abandonedFile); !os.IsNotExist(err) {
|
||||||
|
t.Error("Cleanup failed to remove abandoned merge file from crashed session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanup_AbandonedChunks(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpStorage := filepath.Join(tmpDir, "tmp")
|
||||||
|
os.MkdirAll(tmpStorage, 0700)
|
||||||
|
|
||||||
|
app := &App{
|
||||||
|
Conf: Config{StorageDir: tmpDir},
|
||||||
|
Logger: discardLogger(),
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkDir := filepath.Join(tmpStorage, "some_upload_id")
|
||||||
|
os.MkdirAll(chunkDir, 0700)
|
||||||
|
os.WriteFile(filepath.Join(chunkDir, "0"), []byte("chunk data"), 0600)
|
||||||
|
|
||||||
|
oldTime := time.Now().Add(-TempExpiry - time.Hour)
|
||||||
|
os.Chtimes(chunkDir, oldTime, oldTime)
|
||||||
|
|
||||||
|
app.CleanTemp(tmpStorage)
|
||||||
|
|
||||||
|
if _, err := os.Stat(chunkDir); !os.IsNotExist(err) {
|
||||||
|
t.Error("Cleanup failed to remove abandoned chunk directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanup_ExpiredStorage(t *testing.T) {
|
||||||
|
storageDir := t.TempDir()
|
||||||
|
app := &App{
|
||||||
|
Conf: Config{
|
||||||
|
StorageDir: storageDir,
|
||||||
|
MaxMB: 100,
|
||||||
|
},
|
||||||
|
Logger: discardLogger(),
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := "large_file_id"
|
||||||
|
path := filepath.Join(storageDir, filename)
|
||||||
|
f, _ := os.Create(path)
|
||||||
|
f.Truncate(100 * MegaByte) // Max size
|
||||||
|
f.Close()
|
||||||
|
|
||||||
|
oldTime := time.Now().Add(-MinRetention - time.Hour)
|
||||||
|
os.Chtimes(path, oldTime, oldTime)
|
||||||
|
|
||||||
|
app.CleanStorage(storageDir)
|
||||||
|
|
||||||
|
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||||
|
t.Error("Cleanup failed to remove expired large file")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
package crypto_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/skidoodle/safebin/internal/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeriveKey(t *testing.T) {
|
||||||
|
data := []byte("some random file content")
|
||||||
|
reader := bytes.NewReader(data)
|
||||||
|
|
||||||
|
key1, err := crypto.DeriveKey(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKey failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(key1) != 16 {
|
||||||
|
t.Errorf("Expected key length 16, got %d", len(key1))
|
||||||
|
}
|
||||||
|
|
||||||
|
reader.Seek(0, 0)
|
||||||
|
key2, err := crypto.DeriveKey(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKey failed second time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(key1, key2) {
|
||||||
|
t.Error("DeriveKey is not deterministic")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetID(t *testing.T) {
|
||||||
|
key := make([]byte, 16)
|
||||||
|
ext := ".txt"
|
||||||
|
id1 := crypto.GetID(key, ext)
|
||||||
|
id2 := crypto.GetID(key, ext)
|
||||||
|
|
||||||
|
if id1 != id2 {
|
||||||
|
t.Error("GetID is not deterministic")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(id1) == 0 {
|
||||||
|
t.Error("GetID returned empty string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecryptStream(t *testing.T) {
|
||||||
|
payloadSize := (64 * 1024) * 3
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
rand.Read(payload)
|
||||||
|
|
||||||
|
key := make([]byte, 16)
|
||||||
|
rand.Read(key)
|
||||||
|
|
||||||
|
var encryptedBuf bytes.Buffer
|
||||||
|
streamer, err := crypto.NewGCMStreamer(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create streamer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := streamer.EncryptStream(&encryptedBuf, bytes.NewReader(payload)); err != nil {
|
||||||
|
t.Fatalf("EncryptStream failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedReader := bytes.NewReader(encryptedBuf.Bytes())
|
||||||
|
decryptor := crypto.NewDecryptor(encryptedReader, streamer.AEAD, int64(encryptedBuf.Len()))
|
||||||
|
|
||||||
|
decrypted := make([]byte, payloadSize)
|
||||||
|
n, err := io.ReadFull(decryptor, decrypted)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFull failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != payloadSize {
|
||||||
|
t.Errorf("Expected %d bytes, got %d", payloadSize, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(payload, decrypted) {
|
||||||
|
t.Error("Decrypted content does not match original payload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptorSeeking(t *testing.T) {
|
||||||
|
chunkSize := 64 * 1024
|
||||||
|
payload := make([]byte, chunkSize*4)
|
||||||
|
for i := range len(payload) {
|
||||||
|
payload[i] = byte(i % 255)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := make([]byte, 16)
|
||||||
|
rand.Read(key)
|
||||||
|
|
||||||
|
var encryptedBuf bytes.Buffer
|
||||||
|
streamer, _ := crypto.NewGCMStreamer(key)
|
||||||
|
streamer.EncryptStream(&encryptedBuf, bytes.NewReader(payload))
|
||||||
|
|
||||||
|
r := bytes.NewReader(encryptedBuf.Bytes())
|
||||||
|
d := crypto.NewDecryptor(r, streamer.AEAD, int64(encryptedBuf.Len()))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
offset int64
|
||||||
|
whence int
|
||||||
|
read int
|
||||||
|
}{
|
||||||
|
{"Start of file", 0, io.SeekStart, 100},
|
||||||
|
{"Middle of chunk 1", 1000, io.SeekStart, 100},
|
||||||
|
{"Start of chunk 2", int64(chunkSize), io.SeekStart, 100},
|
||||||
|
{"Middle of chunk 2", int64(chunkSize) + 50, io.SeekStart, 100},
|
||||||
|
{"Near end", int64(len(payload)) - 10, io.SeekStart, 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
pos, err := d.Seek(tc.offset, tc.whence)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Seek failed: %v", err)
|
||||||
|
}
|
||||||
|
if pos != tc.offset {
|
||||||
|
t.Errorf("Expected pos %d, got %d", tc.offset, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, tc.read)
|
||||||
|
n, err := io.ReadFull(d, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Read failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != tc.read {
|
||||||
|
t.Errorf("Expected %d bytes, got %d", tc.read, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := payload[tc.offset : tc.offset+int64(tc.read)]
|
||||||
|
if !bytes.Equal(buf, expected) {
|
||||||
|
t.Errorf("Data mismatch at offset %d", tc.offset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user