diff --git a/internal/app/config_test.go b/internal/app/config_test.go new file mode 100644 index 0000000..0821e22 --- /dev/null +++ b/internal/app/config_test.go @@ -0,0 +1,37 @@ +package app + +import ( + "testing" +) + +func TestGetEnv(t *testing.T) { + key := "SAFEBIN_TEST_KEY" + val := "somevalue" + + if got := getEnv(key, "default"); got != "default" { + t.Errorf("Expected default, got %s", got) + } + + t.Setenv(key, val) + if got := getEnv(key, "default"); got != val { + t.Errorf("Expected %s, got %s", val, got) + } +} + +func TestGetEnvInt(t *testing.T) { + key := "SAFEBIN_TEST_INT" + + if got := getEnvInt(key, 8080); got != 8080 { + t.Errorf("Expected default 8080, got %d", got) + } + + t.Setenv(key, "9090") + if got := getEnvInt(key, 8080); got != 9090 { + t.Errorf("Expected 9090, got %d", got) + } + + t.Setenv(key, "notanumber") + if got := getEnvInt(key, 8080); got != 8080 { + t.Errorf("Expected fallback on invalid input, got %d", got) + } +} diff --git a/internal/app/retention_test.go b/internal/app/retention_test.go new file mode 100644 index 0000000..a9e9906 --- /dev/null +++ b/internal/app/retention_test.go @@ -0,0 +1,52 @@ +package app + +import ( + "testing" + "time" +) + +func TestCalculateRetention(t *testing.T) { + maxMB := int64(100) + + tests := []struct { + name string + fileSize int64 + wantMin time.Duration + wantMax time.Duration + }{ + { + name: "Tiny file (Max retention)", + fileSize: 1024, + wantMin: MaxRetention - time.Hour, + wantMax: MaxRetention, + }, + { + name: "Max size file (Min retention)", + fileSize: 100 * MegaByte, + wantMin: MinRetention, + wantMax: MinRetention + time.Minute, + }, + { + name: "Half size file (Somewhere in between)", + fileSize: 50 * MegaByte, + wantMin: 24 * time.Hour, + wantMax: MaxRetention, + }, + { + name: "Oversized file (Min retention)", + fileSize: 200 * MegaByte, + wantMin: MinRetention, + wantMax: MinRetention + time.Minute, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := CalculateRetention(tc.fileSize, maxMB) + if got < tc.wantMin || got > tc.wantMax { + t.Errorf("Retention for size %d: got %v, want between %v and %v", + tc.fileSize, got, tc.wantMin, tc.wantMax) + } + }) + } +} diff --git a/internal/app/server_test.go b/internal/app/server_test.go new file mode 100644 index 0000000..0d7d19f --- /dev/null +++ b/internal/app/server_test.go @@ -0,0 +1,167 @@ +package app + +import ( + "bytes" + "fmt" + "html/template" + "io" + "log/slog" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func setupTestApp(t *testing.T) (*App, string) { + storageDir := t.TempDir() + os.MkdirAll(filepath.Join(storageDir, "tmp"), 0700) + + tmplDir := filepath.Join(storageDir, "templates") + os.MkdirAll(tmplDir, 0700) + os.WriteFile(filepath.Join(tmplDir, "base.html"), []byte(`{{define "base"}}{{template "content" .}}{{end}}`), 0600) + os.WriteFile(filepath.Join(tmplDir, "index.html"), []byte(`{{define "content"}}OK{{end}}`), 0600) + + tmpl := template.Must(template.New("base").Parse(`{{define "base"}}OK{{end}}`)) + + app := &App{ + Conf: Config{ + StorageDir: storageDir, + MaxMB: 10, + }, + Logger: discardLogger(), + Tmpl: tmpl, + } + + return app, storageDir +} + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestIntegration_StandardUploadAndDownload(t *testing.T) { + app, _ := setupTestApp(t) + server := httptest.NewServer(app.Routes()) + defer server.Close() + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, _ := writer.CreateFormFile("file", "test.txt") + content := []byte("Hello Safebin") + part.Write(content) + writer.Close() + + req, _ := http.NewRequest("POST", server.URL+"/", body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Upload request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Upload failed status: %d", resp.StatusCode) + } + + respBytes, _ := io.ReadAll(resp.Body) + respStr := string(respBytes) + parts := strings.Split(strings.TrimSpace(respStr), "/") + slugWithExt := parts[len(parts)-1] + + downloadURL := fmt.Sprintf("%s/%s", server.URL, slugWithExt) + resp, err = http.Get(downloadURL) + if err != nil { + t.Fatalf("Download request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Download failed status: %d", resp.StatusCode) + } + + downloadedContent, _ := io.ReadAll(resp.Body) + if !bytes.Equal(content, downloadedContent) { + t.Errorf("Content mismatch. Want %s, got %s", content, downloadedContent) + } +} + +func TestIntegration_ChunkedUpload(t *testing.T) { + app, _ := setupTestApp(t) + server := httptest.NewServer(app.Routes()) + defer server.Close() + + uploadID := "testchunkid123" + content := []byte("Chunk1Content-Chunk2Content") + chunk1 := content[:13] + chunk2 := content[13:] + + uploadChunk(t, server.URL, uploadID, 0, chunk1) + uploadChunk(t, server.URL, uploadID, 1, chunk2) + + finishURL := fmt.Sprintf("%s/upload/finish", server.URL) + form := map[string]string{ + "upload_id": uploadID, + "total": "2", + "filename": "chunked.txt", + } + + resp := postForm(t, finishURL, form) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Finish failed: %d", resp.StatusCode) + } + + respBytes, _ := io.ReadAll(resp.Body) + respStr := string(respBytes) + parts := strings.Split(strings.TrimSpace(respStr), "/") + slugWithExt := parts[len(parts)-1] + + downloadURL := fmt.Sprintf("%s/%s", server.URL, slugWithExt) + dlResp, _ := http.Get(downloadURL) + dlBytes, _ := io.ReadAll(dlResp.Body) + dlResp.Body.Close() + + if !bytes.Equal(content, dlBytes) { + t.Errorf("Chunked reassembly failed. Want %s, got %s", content, dlBytes) + } +} + +func uploadChunk(t *testing.T, baseURL, uid string, idx int, data []byte) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + writer.WriteField("upload_id", uid) + writer.WriteField("index", fmt.Sprintf("%d", idx)) + part, _ := writer.CreateFormFile("chunk", "blob") + part.Write(data) + writer.Close() + + req, _ := http.NewRequest("POST", baseURL+"/upload/chunk", body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + resp, err := http.DefaultClient.Do(req) + if err != nil || resp.StatusCode != http.StatusOK { + t.Fatalf("Chunk %d upload failed: %v", idx, err) + } + resp.Body.Close() +} + +func postForm(t *testing.T, url string, fields map[string]string) *http.Response { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + for k, v := range fields { + writer.WriteField(k, v) + } + writer.Close() + + req, _ := http.NewRequest("POST", url, body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Post form failed: %v", err) + } + return resp +} diff --git a/internal/app/storage_test.go b/internal/app/storage_test.go new file mode 100644 index 0000000..ae7d286 --- /dev/null +++ b/internal/app/storage_test.go @@ -0,0 +1,85 @@ +package app + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestCleanup_AbandonedMerge(t *testing.T) { + tmpDir := t.TempDir() + tmpStorage := filepath.Join(tmpDir, "tmp") + os.MkdirAll(tmpStorage, 0700) + + app := &App{ + Conf: Config{StorageDir: tmpDir}, + Logger: discardLogger(), + } + + abandonedFile := filepath.Join(tmpStorage, "m_abandoned_upload_id") + if err := os.WriteFile(abandonedFile, []byte("partial data"), 0600); err != nil { + t.Fatal(err) + } + + oldTime := time.Now().Add(-TempExpiry - time.Hour) + if err := os.Chtimes(abandonedFile, oldTime, oldTime); err != nil { + t.Fatal(err) + } + + app.CleanTemp(tmpStorage) + + if _, err := os.Stat(abandonedFile); !os.IsNotExist(err) { + t.Error("Cleanup failed to remove abandoned merge file from crashed session") + } +} + +func TestCleanup_AbandonedChunks(t *testing.T) { + tmpDir := t.TempDir() + tmpStorage := filepath.Join(tmpDir, "tmp") + os.MkdirAll(tmpStorage, 0700) + + app := &App{ + Conf: Config{StorageDir: tmpDir}, + Logger: discardLogger(), + } + + chunkDir := filepath.Join(tmpStorage, "some_upload_id") + os.MkdirAll(chunkDir, 0700) + os.WriteFile(filepath.Join(chunkDir, "0"), []byte("chunk data"), 0600) + + oldTime := time.Now().Add(-TempExpiry - time.Hour) + os.Chtimes(chunkDir, oldTime, oldTime) + + app.CleanTemp(tmpStorage) + + if _, err := os.Stat(chunkDir); !os.IsNotExist(err) { + t.Error("Cleanup failed to remove abandoned chunk directory") + } +} + +func TestCleanup_ExpiredStorage(t *testing.T) { + storageDir := t.TempDir() + app := &App{ + Conf: Config{ + StorageDir: storageDir, + MaxMB: 100, + }, + Logger: discardLogger(), + } + + filename := "large_file_id" + path := filepath.Join(storageDir, filename) + f, _ := os.Create(path) + f.Truncate(100 * MegaByte) // Max size + f.Close() + + oldTime := time.Now().Add(-MinRetention - time.Hour) + os.Chtimes(path, oldTime, oldTime) + + app.CleanStorage(storageDir) + + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("Cleanup failed to remove expired large file") + } +} diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go new file mode 100644 index 0000000..0e9200b --- /dev/null +++ b/internal/crypto/crypto_test.go @@ -0,0 +1,142 @@ +package crypto_test + +import ( + "bytes" + "crypto/rand" + "io" + "testing" + + "github.com/skidoodle/safebin/internal/crypto" +) + +func TestDeriveKey(t *testing.T) { + data := []byte("some random file content") + reader := bytes.NewReader(data) + + key1, err := crypto.DeriveKey(reader) + if err != nil { + t.Fatalf("DeriveKey failed: %v", err) + } + + if len(key1) != 16 { + t.Errorf("Expected key length 16, got %d", len(key1)) + } + + reader.Seek(0, 0) + key2, err := crypto.DeriveKey(reader) + if err != nil { + t.Fatalf("DeriveKey failed second time: %v", err) + } + + if !bytes.Equal(key1, key2) { + t.Error("DeriveKey is not deterministic") + } +} + +func TestGetID(t *testing.T) { + key := make([]byte, 16) + ext := ".txt" + id1 := crypto.GetID(key, ext) + id2 := crypto.GetID(key, ext) + + if id1 != id2 { + t.Error("GetID is not deterministic") + } + + if len(id1) == 0 { + t.Error("GetID returned empty string") + } +} + +func TestEncryptDecryptStream(t *testing.T) { + payloadSize := (64 * 1024) * 3 + payload := make([]byte, payloadSize) + rand.Read(payload) + + key := make([]byte, 16) + rand.Read(key) + + var encryptedBuf bytes.Buffer + streamer, err := crypto.NewGCMStreamer(key) + if err != nil { + t.Fatalf("Failed to create streamer: %v", err) + } + + if err := streamer.EncryptStream(&encryptedBuf, bytes.NewReader(payload)); err != nil { + t.Fatalf("EncryptStream failed: %v", err) + } + + encryptedReader := bytes.NewReader(encryptedBuf.Bytes()) + decryptor := crypto.NewDecryptor(encryptedReader, streamer.AEAD, int64(encryptedBuf.Len())) + + decrypted := make([]byte, payloadSize) + n, err := io.ReadFull(decryptor, decrypted) + if err != nil { + t.Fatalf("ReadFull failed: %v", err) + } + + if n != payloadSize { + t.Errorf("Expected %d bytes, got %d", payloadSize, n) + } + + if !bytes.Equal(payload, decrypted) { + t.Error("Decrypted content does not match original payload") + } +} + +func TestDecryptorSeeking(t *testing.T) { + chunkSize := 64 * 1024 + payload := make([]byte, chunkSize*4) + for i := range len(payload) { + payload[i] = byte(i % 255) + } + + key := make([]byte, 16) + rand.Read(key) + + var encryptedBuf bytes.Buffer + streamer, _ := crypto.NewGCMStreamer(key) + streamer.EncryptStream(&encryptedBuf, bytes.NewReader(payload)) + + r := bytes.NewReader(encryptedBuf.Bytes()) + d := crypto.NewDecryptor(r, streamer.AEAD, int64(encryptedBuf.Len())) + + tests := []struct { + name string + offset int64 + whence int + read int + }{ + {"Start of file", 0, io.SeekStart, 100}, + {"Middle of chunk 1", 1000, io.SeekStart, 100}, + {"Start of chunk 2", int64(chunkSize), io.SeekStart, 100}, + {"Middle of chunk 2", int64(chunkSize) + 50, io.SeekStart, 100}, + {"Near end", int64(len(payload)) - 10, io.SeekStart, 10}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + pos, err := d.Seek(tc.offset, tc.whence) + if err != nil { + t.Fatalf("Seek failed: %v", err) + } + if pos != tc.offset { + t.Errorf("Expected pos %d, got %d", tc.offset, pos) + } + + buf := make([]byte, tc.read) + n, err := io.ReadFull(d, buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != tc.read { + t.Errorf("Expected %d bytes, got %d", tc.read, n) + } + + expected := payload[tc.offset : tc.offset+int64(tc.read)] + if !bytes.Equal(buf, expected) { + t.Errorf("Data mismatch at offset %d", tc.offset) + } + }) + } +}