diff --git a/go.sum b/go.sum index 18a64ee..9bf73e1 100644 --- a/go.sum +++ b/go.sum @@ -2,18 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= diff --git a/handler/errors.go b/handler/errors.go index 5c145f2..5d2b137 100644 --- a/handler/errors.go +++ b/handler/errors.go @@ -10,11 +10,6 @@ func notFound(slug string, err error, w http.ResponseWriter, r *http.Request) { respondWithError(slug, err, w, r, http.StatusNotFound) } -// badRequest handles 400 Bad Request errors. -func badRequest(slug string, err error, w http.ResponseWriter, r *http.Request) { - respondWithError(slug, err, w, r, http.StatusBadRequest) -} - // internal handles 500 Internal Server Error errors. func internal(slug string, err error, w http.ResponseWriter, r *http.Request) { respondWithError(slug, err, w, r, http.StatusInternalServerError) diff --git a/handler/http.go b/handler/http.go index e6d6993..42e2cfd 100644 --- a/handler/http.go +++ b/handler/http.go @@ -87,7 +87,7 @@ func (h *HttpHandler) HandleSet(w http.ResponseWriter, r *http.Request) { return } - if err := h.store.Set(id, contentHash, content); err != nil { + if err := h.store.Set(id, contentHash, content, nil); err != nil { internal("could not save bin", err, w, r) return } diff --git a/handler/http_test.go b/handler/http_test.go index 9d85def..8053b9c 100644 --- a/handler/http_test.go +++ b/handler/http_test.go @@ -1,7 +1,9 @@ package handler import ( + "crypto/rand" "errors" + "html/template" "net/http" "net/http/httptest" "net/url" @@ -31,8 +33,8 @@ func (m *MockStore) GetIDByHash(hash string) (string, bool, error) { return args.String(0), args.Bool(1), args.Error(2) } -func (m *MockStore) Set(id, hash, content string) error { - args := m.Called(id, hash, content) +func (m *MockStore) Set(id, hash, content string, metadata map[string]interface{}) error { + args := m.Called(id, hash, content, metadata) return args.Error(0) } @@ -57,7 +59,7 @@ func TestHandleSet(t *testing.T) { content := "new content" ch := hash(content) s.On("GetIDByHash", ch).Return("", false, nil).Once() - s.On("Set", mock.Anything, ch, content).Return(nil).Once() + s.On("Set", mock.Anything, ch, content, mock.Anything).Return(nil).Once() form := url.Values{"content": {content}} req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) @@ -120,7 +122,7 @@ func TestHandleSet(t *testing.T) { content := "error content" ch := hash(content) s.On("GetIDByHash", ch).Return("", false, nil).Once() - s.On("Set", mock.Anything, ch, content).Return(errors.New("db error")).Once() + s.On("Set", mock.Anything, ch, content, mock.Anything).Return(errors.New("db error")).Once() form := url.Values{"content": {content}} req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) @@ -130,6 +132,44 @@ func TestHandleSet(t *testing.T) { h.HandleSet(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) }) + + t.Run("Generate ID Error", func(t *testing.T) { + originalReader := rand.Reader + defer func() { rand.Reader = originalReader }() + rand.Reader = errorReader{} + + content := "generate error content" + ch := hash(content) + s.On("GetIDByHash", ch).Return("", false, nil).Once() + + form := url.Values{"content": {content}} + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + h.HandleSet(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + }) + + t.Run("Malformed Form Data", func(t *testing.T) { + req := httptest.NewRequest("POST", "/", strings.NewReader("content=%zz")) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + h.HandleSet(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid form data") + }) + + t.Run("Too Large Content", func(t *testing.T) { + content := strings.Repeat("a", 2048) + form := url.Values{"content": {content}} + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + h.HandleSet(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Content too large") + }) } func TestHandleGet(t *testing.T) { @@ -167,6 +207,18 @@ func TestHandleGet(t *testing.T) { h.HandleGet(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) }) + + t.Run("With Extension", func(t *testing.T) { + id := "testid" + s.On("Get", id).Return(&store.Paste{Content: "hello", CreatedAt: time.Now()}, true, nil).Once() + req := httptest.NewRequest("GET", "/"+id+".go", nil) + req.SetPathValue("id", id+".go") + rr := httptest.NewRecorder() + + h.HandleGet(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), "hello") + }) } type FailingResponseWriter struct { @@ -184,6 +236,14 @@ func TestHandleHomeError(t *testing.T) { h.HandleHome(rr, req) } +type mockTemplateStore struct { + mock.Mock +} + +func (m *mockTemplateStore) ExecuteTemplate(w http.ResponseWriter, name string, data interface{}) error { + return errors.New("template error") +} + func TestHandleRaw(t *testing.T) { s := new(MockStore) h := NewHandler(s, 1024, "../view/templates/*.html") @@ -202,6 +262,19 @@ func TestHandleRaw(t *testing.T) { assert.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) }) + t.Run("With Extension", func(t *testing.T) { + id := "testid" + content := "raw content" + s.On("Get", id).Return(&store.Paste{Content: content}, true, nil).Once() + req := httptest.NewRequest("GET", "/raw/"+id+".txt", nil) + req.SetPathValue("id", id+".txt") + rr := httptest.NewRecorder() + + h.HandleRaw(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, content, rr.Body.String()) + }) + t.Run("Not Found", func(t *testing.T) { s.On("Get", "missing").Return(nil, false, nil).Once() req := httptest.NewRequest("GET", "/raw/missing", nil) @@ -211,6 +284,16 @@ func TestHandleRaw(t *testing.T) { h.HandleRaw(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) }) + + t.Run("Store Error", func(t *testing.T) { + s.On("Get", "error").Return(nil, false, errors.New("db error")).Once() + req := httptest.NewRequest("GET", "/raw/error", nil) + req.SetPathValue("id", "error") + rr := httptest.NewRecorder() + + h.HandleRaw(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + }) } func TestHandleGetTemplateError(t *testing.T) { @@ -224,3 +307,75 @@ func TestHandleGetTemplateError(t *testing.T) { h.HandleGet(rr, req) } + +func TestHandleSetTemplateError(t *testing.T) { + h := NewHandler(nil, 1024, "../view/templates/*.html") + + t.Run("Empty Content", func(t *testing.T) { + form := url.Values{"content": {""}} + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := &FailingResponseWriter{*httptest.NewRecorder()} + h.HandleSet(rr, req) + }) + + t.Run("Parse Error", func(t *testing.T) { + req := httptest.NewRequest("POST", "/", strings.NewReader("content=%zz")) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := &FailingResponseWriter{*httptest.NewRecorder()} + h.HandleSet(rr, req) + }) +} + +func TestTemplateErrors(t *testing.T) { + tmpl := template.New("empty") + h := &HttpHandler{ + templates: tmpl, + maxSize: 1024, + } + + t.Run("HandleHome Error", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + h.HandleHome(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + }) + + t.Run("HandleGet Error", func(t *testing.T) { + s := new(MockStore) + h.store = s + id := "testid" + s.On("Get", id).Return(&store.Paste{Content: "hello", CreatedAt: time.Now()}, true, nil).Once() + req := httptest.NewRequest("GET", "/"+id, nil) + req.SetPathValue("id", id) + rr := httptest.NewRecorder() + h.HandleGet(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + }) + + t.Run("HandleSet Empty Content Error", func(t *testing.T) { + form := url.Values{"content": {""}} + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + h.HandleSet(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) + + t.Run("HandleSet Parse Error", func(t *testing.T) { + req := httptest.NewRequest("POST", "/", strings.NewReader("content=%zz")) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + h.HandleSet(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + }) +} +func TestSafeHTML(t *testing.T) { + h := NewHandler(nil, 1024, "../view/templates/*.html") + tmpl, err := h.templates.New("test").Parse(`{{ safeHTML "
" }}`) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = tmpl.Execute(rr, nil) + assert.NoError(t, err) + assert.Equal(t, "
", rr.Body.String()) +} diff --git a/handler/util.go b/handler/util.go index 99b80d3..f0442b5 100644 --- a/handler/util.go +++ b/handler/util.go @@ -5,22 +5,23 @@ import ( "crypto/sha256" "encoding/hex" "fmt" - "io" + "math/big" "time" ) const charset = "abcdefghijklmnopqrstuvwxyz0123456789" func generateId() (string, error) { - bytes := make([]byte, 10) - if _, err := io.ReadFull(rand.Reader, bytes); err != nil { - return "", err + result := make([]byte, 10) + max := big.NewInt(int64(len(charset))) + for i := 0; i < 10; i++ { + n, err := rand.Int(rand.Reader, max) + if err != nil { + return "", err + } + result[i] = charset[n.Int64()] } - - for i := range bytes { - bytes[i] = charset[bytes[i]%byte(len(charset))] - } - return string(bytes), nil + return string(result), nil } func TimeAgo(t time.Time) string { diff --git a/handler/util_test.go b/handler/util_test.go index 38010a6..146aa2f 100644 --- a/handler/util_test.go +++ b/handler/util_test.go @@ -1,18 +1,36 @@ package handler import ( + "crypto/rand" + "errors" "testing" "time" "github.com/stretchr/testify/assert" ) +type errorReader struct{} + +func (e errorReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + func TestGenerateId(t *testing.T) { id, err := generateId() assert.NoError(t, err) assert.Equal(t, 10, len(id)) } +func TestGenerateIdError(t *testing.T) { + originalReader := rand.Reader + defer func() { rand.Reader = originalReader }() + rand.Reader = errorReader{} + + id, err := generateId() + assert.Error(t, err) + assert.Empty(t, id) +} + func TestHash(t *testing.T) { c := "test content" h1 := hash(c) diff --git a/main.go b/main.go index 351d63c..ffcc431 100644 --- a/main.go +++ b/main.go @@ -36,23 +36,40 @@ func parseFlags() *config { return cfg } +func securityHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline'") + next.ServeHTTP(w, r) + }) +} + func main() { cfg := parseFlags() slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil))) - storage, err := store.NewBoltStore(cfg.dbPath) + storage, err := store.NewBoltStore(cfg.dbPath, nil) if err != nil { slog.Error("failed to initialize store", "error", err) os.Exit(1) } defer storage.Close() + ctx, cancel := context.WithCancel(context.Background()) + go func() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() - for range ticker.C { - storage.Cleanup(cfg.ttl) + for { + select { + case <-ticker.C: + storage.Cleanup(cfg.ttl) + case <-ctx.Done(): + return + } } }() @@ -82,7 +99,7 @@ func main() { server := &http.Server{ Addr: cfg.addr, - Handler: mux, + Handler: securityHeadersMiddleware(mux), } go func() { @@ -97,10 +114,12 @@ func main() { signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + cancel() - if err := server.Shutdown(ctx); err != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { slog.Error("server shutdown failed", "err", err) } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..66e5f51 --- /dev/null +++ b/main_test.go @@ -0,0 +1,27 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSecurityHeadersMiddleware(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := securityHeadersMiddleware(nextHandler) + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + middleware.ServeHTTP(rr, req) + + assert.Equal(t, "max-age=63072000; includeSubDomains", rr.Header().Get("Strict-Transport-Security")) + assert.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", rr.Header().Get("X-Frame-Options")) + assert.Contains(t, rr.Header().Get("Content-Security-Policy"), "default-src 'self'") +} diff --git a/store/boltdb.go b/store/boltdb.go index e04e6a2..262d63a 100644 --- a/store/boltdb.go +++ b/store/boltdb.go @@ -16,21 +16,32 @@ type BoltStore struct { db *bbolt.DB } -func NewBoltStore(path string) (*BoltStore, error) { - db, err := bbolt.Open(path, 0600, nil) +func NewBoltStore(path string, opts *bbolt.Options) (*BoltStore, error) { + db, err := bbolt.Open(path, 0600, opts) if err != nil { return nil, err } - err = db.Update(func(tx *bbolt.Tx) error { - if _, err := tx.CreateBucketIfNotExists(pastesBucket); err != nil { - return err + bucketsExist := false + db.View(func(tx *bbolt.Tx) error { + if tx.Bucket(pastesBucket) != nil && tx.Bucket(hashesBucket) != nil { + bucketsExist = true } - _, err := tx.CreateBucketIfNotExists(hashesBucket) - return err + return nil }) - if err != nil { - return nil, err + + if !bucketsExist { + err = db.Update(func(tx *bbolt.Tx) error { + if _, err := tx.CreateBucketIfNotExists(pastesBucket); err != nil { + return err + } + _, err := tx.CreateBucketIfNotExists(hashesBucket) + return err + }) + if err != nil { + db.Close() + return nil, err + } } return &BoltStore{db: db}, nil @@ -70,7 +81,7 @@ func (s *BoltStore) GetIDByHash(hash string) (string, bool, error) { return id, exists, err } -func (s *BoltStore) Set(id, hash, content string) error { +func (s *BoltStore) Set(id, hash, content string, metadata map[string]interface{}) error { return s.db.Update(func(tx *bbolt.Tx) error { pb := tx.Bucket(pastesBucket) hb := tx.Bucket(hashesBucket) @@ -78,6 +89,8 @@ func (s *BoltStore) Set(id, hash, content string) error { paste := Paste{ Content: content, CreatedAt: time.Now(), + Hash: hash, + Metadata: metadata, } encoded, err := json.Marshal(paste) if err != nil { @@ -93,12 +106,23 @@ func (s *BoltStore) Set(id, hash, content string) error { func (s *BoltStore) Del(id string) error { return s.db.Update(func(tx *bbolt.Tx) error { - return tx.Bucket(pastesBucket).Delete([]byte(id)) + pb := tx.Bucket(pastesBucket) + hb := tx.Bucket(hashesBucket) + + val := pb.Get([]byte(id)) + if val != nil { + var p Paste + if err := json.Unmarshal(val, &p); err == nil && p.Hash != "" { + hb.Delete([]byte(p.Hash)) + } + } + return pb.Delete([]byte(id)) }) } func (s *BoltStore) Cleanup(maxAge time.Duration) { var keysToDelete [][]byte + var hashesToDelete [][]byte s.db.View(func(tx *bbolt.Tx) error { b := tx.Bucket(pastesBucket) c := b.Cursor() @@ -107,6 +131,9 @@ func (s *BoltStore) Cleanup(maxAge time.Duration) { if err := json.Unmarshal(v, &p); err == nil { if time.Since(p.CreatedAt) > maxAge { keysToDelete = append(keysToDelete, k) + if p.Hash != "" { + hashesToDelete = append(hashesToDelete, []byte(p.Hash)) + } } } } @@ -115,9 +142,13 @@ func (s *BoltStore) Cleanup(maxAge time.Duration) { if len(keysToDelete) > 0 { s.db.Update(func(tx *bbolt.Tx) error { - b := tx.Bucket(pastesBucket) + pb := tx.Bucket(pastesBucket) + hb := tx.Bucket(hashesBucket) for _, k := range keysToDelete { - b.Delete(k) + pb.Delete(k) + } + for _, h := range hashesToDelete { + hb.Delete(h) } return nil }) diff --git a/store/boltdb_test.go b/store/boltdb_test.go index b965fe5..7019b6d 100644 --- a/store/boltdb_test.go +++ b/store/boltdb_test.go @@ -8,22 +8,40 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.etcd.io/bbolt" + bbolterrors "go.etcd.io/bbolt/errors" ) func TestBoltStore(t *testing.T) { dbPath := "test.db" defer os.Remove(dbPath) - s, err := NewBoltStore(dbPath) + s, err := NewBoltStore(dbPath, nil) require.NoError(t, err) defer s.Close() + t.Run("Open Existing Store", func(t *testing.T) { + dbPath := "existing_store.db" + defer os.Remove(dbPath) + + s1, err := NewBoltStore(dbPath, nil) + require.NoError(t, err) + + err = s1.Close() + require.NoError(t, err) + + s2, err := NewBoltStore(dbPath, nil) + require.NoError(t, err) + defer s2.Close() + + assert.NotNil(t, s2) + }) + t.Run("Set and Get", func(t *testing.T) { id := "id1" hash := "hash1" content := "content1" - err := s.Set(id, hash, content) + err := s.Set(id, hash, content, nil) assert.NoError(t, err) p, exists, err := s.Get(id) @@ -37,7 +55,7 @@ func TestBoltStore(t *testing.T) { hash := "hash2" content := "content2" - s.Set(id, hash, content) + s.Set(id, hash, content, nil) storedID, exists, err := s.GetIDByHash(hash) assert.NoError(t, err) assert.True(t, exists) @@ -46,19 +64,25 @@ func TestBoltStore(t *testing.T) { t.Run("Del", func(t *testing.T) { id := "id3" - s.Set(id, "h3", "c3") + s.Set(id, "h3", "c3", nil) err := s.Del(id) assert.NoError(t, err) _, exists, _ := s.Get(id) assert.False(t, exists) + + _, hashExists, _ := s.GetIDByHash("h3") + assert.False(t, hashExists) }) t.Run("Cleanup", func(t *testing.T) { - s.Set("old", "oldhash", "oldcontent") + s.Set("old", "oldhash", "oldcontent", nil) s.Cleanup(-time.Hour) _, exists, _ := s.Get("old") assert.False(t, exists) + + _, hashExists, _ := s.GetIDByHash("oldhash") + assert.False(t, hashExists) }) t.Run("Cleanup Bad Data", func(t *testing.T) { @@ -68,15 +92,71 @@ func TestBoltStore(t *testing.T) { s.Cleanup(-time.Hour) // Should skip without panic }) + + t.Run("Set Marshal Error", func(t *testing.T) { + dbPath := "marshal_error.db" + db, _ := bbolt.Open(dbPath, 0666, nil) + store := &BoltStore{db: db} + defer db.Close() + defer os.Remove(dbPath) + + err := store.Set("id", "h", "c", map[string]interface{}{ + "foo": func() {}, + }) + assert.Error(t, err) + }) + + t.Run("Set Put Error", func(t *testing.T) { + dbPath := "put_error.db" + s, err := NewBoltStore(dbPath, nil) + require.NoError(t, err) + defer s.Close() + defer os.Remove(dbPath) + + err = s.Set("", "hash", "content", nil) + assert.ErrorIs(t, err, bbolterrors.ErrKeyRequired) + }) + + t.Run("Del Error", func(t *testing.T) { + dbPath := "error_del.db" + db, _ := bbolt.Open(dbPath, 0666, nil) + store := &BoltStore{db: db} + db.Close() + err := store.Del("id") + assert.Error(t, err) + os.Remove(dbPath) + }) + t.Run("Get Error", func(t *testing.T) { + dbPath := "error_get.db" + db, _ := bbolt.Open(dbPath, 0666, nil) + store := &BoltStore{db: db} + db.Close() + _, _, err := store.Get("id") + assert.Error(t, err) + os.Remove(dbPath) + }) } func TestNewBoltStoreError(t *testing.T) { - // Use a directory name as file path to trigger error err := os.Mkdir("testdir", 0755) require.NoError(t, err) defer os.RemoveAll("testdir") - s, err := NewBoltStore("testdir") + s, err := NewBoltStore("testdir", nil) assert.Error(t, err) assert.Nil(t, s) + + t.Run("Bucket Creation Error", func(t *testing.T) { + dbPath := "bucket_error.db" + defer os.Remove(dbPath) + + originalPastesBucket := pastesBucket + pastesBucket = []byte("") + defer func() { pastesBucket = originalPastesBucket }() + + s, err := NewBoltStore(dbPath, nil) + + assert.ErrorIs(t, err, bbolterrors.ErrBucketNameRequired) + assert.Nil(t, s) + }) } diff --git a/store/memory.go b/store/memory.go index b49f34b..4f294f2 100644 --- a/store/memory.go +++ b/store/memory.go @@ -32,12 +32,14 @@ func (s *MemoryStore) GetIDByHash(hash string) (string, bool, error) { return id, ok, nil } -func (s *MemoryStore) Set(id, hash, content string) error { +func (s *MemoryStore) Set(id, hash, content string, metadata map[string]interface{}) error { s.mu.Lock() defer s.mu.Unlock() s.pastes[id] = &Paste{ Content: content, CreatedAt: time.Now(), + Hash: hash, + Metadata: metadata, } s.hashes[hash] = id return nil @@ -46,6 +48,9 @@ func (s *MemoryStore) Set(id, hash, content string) error { func (s *MemoryStore) Del(id string) error { s.mu.Lock() defer s.mu.Unlock() + if p, ok := s.pastes[id]; ok && p.Hash != "" { + delete(s.hashes, p.Hash) + } delete(s.pastes, id) return nil } diff --git a/store/memory_test.go b/store/memory_test.go index 15b9a7f..f7d7359 100644 --- a/store/memory_test.go +++ b/store/memory_test.go @@ -14,7 +14,7 @@ func TestMemoryStore(t *testing.T) { hash := "hash1" content := "content1" - err := s.Set(id, hash, content) + err := s.Set(id, hash, content, nil) assert.NoError(t, err) p, exists, err := s.Get(id) @@ -29,7 +29,7 @@ func TestMemoryStore(t *testing.T) { hash := "hash2" content := "content2" - s.Set(id, hash, content) + s.Set(id, hash, content, nil) storedID, exists, err := s.GetIDByHash(hash) assert.NoError(t, err) assert.True(t, exists) @@ -50,11 +50,14 @@ func TestMemoryStore(t *testing.T) { t.Run("Del", func(t *testing.T) { id := "id3" - s.Set(id, "h3", "c3") + s.Set(id, "h3", "c3", nil) err := s.Del(id) assert.NoError(t, err) _, exists, _ := s.Get(id) assert.False(t, exists) + + _, hashExists, _ := s.GetIDByHash("h3") + assert.False(t, hashExists) }) } diff --git a/store/store.go b/store/store.go index 28f1a00..af60ffd 100644 --- a/store/store.go +++ b/store/store.go @@ -3,13 +3,15 @@ package store import "time" type Paste struct { - Content string `json:"content"` - CreatedAt time.Time `json:"createdAt"` + Content string `json:"content"` + CreatedAt time.Time `json:"createdAt"` + Hash string `json:"hash,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` } type Store interface { Get(id string) (*Paste, bool, error) GetIDByHash(hash string) (string, bool, error) - Set(id, hash, content string) error + Set(id, hash, content string, metadata map[string]interface{}) error Del(id string) error } diff --git a/view/static/style.css b/view/static/style.css index fa271d6..6a8825e 100644 --- a/view/static/style.css +++ b/view/static/style.css @@ -12,6 +12,7 @@ --margin: 1rem; --line-height: 1.5; + --digits: 1; --font-mono: "Fira Code", "JetBrains Mono", "Cascadia Code", "Source Code Pro", "Menlo", "Monaco", "Consolas", monospace; }