mirror of
https://github.com/skidoodle/pastebin
synced 2026-04-28 03:07:40 +02:00
resolve dangliing hashes
This commit is contained in:
@@ -2,18 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4=
|
||||
github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo=
|
||||
go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
|
||||
@@ -10,11 +10,6 @@ func notFound(slug string, err error, w http.ResponseWriter, r *http.Request) {
|
||||
respondWithError(slug, err, w, r, http.StatusNotFound)
|
||||
}
|
||||
|
||||
// badRequest handles 400 Bad Request errors.
|
||||
func badRequest(slug string, err error, w http.ResponseWriter, r *http.Request) {
|
||||
respondWithError(slug, err, w, r, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// internal handles 500 Internal Server Error errors.
|
||||
func internal(slug string, err error, w http.ResponseWriter, r *http.Request) {
|
||||
respondWithError(slug, err, w, r, http.StatusInternalServerError)
|
||||
|
||||
+1
-1
@@ -87,7 +87,7 @@ func (h *HttpHandler) HandleSet(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.Set(id, contentHash, content); err != nil {
|
||||
if err := h.store.Set(id, contentHash, content, nil); err != nil {
|
||||
internal("could not save bin", err, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
+159
-4
@@ -1,7 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -31,8 +33,8 @@ func (m *MockStore) GetIDByHash(hash string) (string, bool, error) {
|
||||
return args.String(0), args.Bool(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockStore) Set(id, hash, content string) error {
|
||||
args := m.Called(id, hash, content)
|
||||
func (m *MockStore) Set(id, hash, content string, metadata map[string]interface{}) error {
|
||||
args := m.Called(id, hash, content, metadata)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
@@ -57,7 +59,7 @@ func TestHandleSet(t *testing.T) {
|
||||
content := "new content"
|
||||
ch := hash(content)
|
||||
s.On("GetIDByHash", ch).Return("", false, nil).Once()
|
||||
s.On("Set", mock.Anything, ch, content).Return(nil).Once()
|
||||
s.On("Set", mock.Anything, ch, content, mock.Anything).Return(nil).Once()
|
||||
|
||||
form := url.Values{"content": {content}}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode()))
|
||||
@@ -120,7 +122,7 @@ func TestHandleSet(t *testing.T) {
|
||||
content := "error content"
|
||||
ch := hash(content)
|
||||
s.On("GetIDByHash", ch).Return("", false, nil).Once()
|
||||
s.On("Set", mock.Anything, ch, content).Return(errors.New("db error")).Once()
|
||||
s.On("Set", mock.Anything, ch, content, mock.Anything).Return(errors.New("db error")).Once()
|
||||
|
||||
form := url.Values{"content": {content}}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode()))
|
||||
@@ -130,6 +132,44 @@ func TestHandleSet(t *testing.T) {
|
||||
h.HandleSet(rr, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("Generate ID Error", func(t *testing.T) {
|
||||
originalReader := rand.Reader
|
||||
defer func() { rand.Reader = originalReader }()
|
||||
rand.Reader = errorReader{}
|
||||
|
||||
content := "generate error content"
|
||||
ch := hash(content)
|
||||
s.On("GetIDByHash", ch).Return("", false, nil).Once()
|
||||
|
||||
form := url.Values{"content": {content}}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleSet(rr, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("Malformed Form Data", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader("content=%zz"))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleSet(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
assert.Contains(t, rr.Body.String(), "Invalid form data")
|
||||
})
|
||||
|
||||
t.Run("Too Large Content", func(t *testing.T) {
|
||||
content := strings.Repeat("a", 2048)
|
||||
form := url.Values{"content": {content}}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleSet(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
assert.Contains(t, rr.Body.String(), "Content too large")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleGet(t *testing.T) {
|
||||
@@ -167,6 +207,18 @@ func TestHandleGet(t *testing.T) {
|
||||
h.HandleGet(rr, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("With Extension", func(t *testing.T) {
|
||||
id := "testid"
|
||||
s.On("Get", id).Return(&store.Paste{Content: "hello", CreatedAt: time.Now()}, true, nil).Once()
|
||||
req := httptest.NewRequest("GET", "/"+id+".go", nil)
|
||||
req.SetPathValue("id", id+".go")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleGet(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Contains(t, rr.Body.String(), "hello")
|
||||
})
|
||||
}
|
||||
|
||||
type FailingResponseWriter struct {
|
||||
@@ -184,6 +236,14 @@ func TestHandleHomeError(t *testing.T) {
|
||||
h.HandleHome(rr, req)
|
||||
}
|
||||
|
||||
type mockTemplateStore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockTemplateStore) ExecuteTemplate(w http.ResponseWriter, name string, data interface{}) error {
|
||||
return errors.New("template error")
|
||||
}
|
||||
|
||||
func TestHandleRaw(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
h := NewHandler(s, 1024, "../view/templates/*.html")
|
||||
@@ -202,6 +262,19 @@ func TestHandleRaw(t *testing.T) {
|
||||
assert.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type"))
|
||||
})
|
||||
|
||||
t.Run("With Extension", func(t *testing.T) {
|
||||
id := "testid"
|
||||
content := "raw content"
|
||||
s.On("Get", id).Return(&store.Paste{Content: content}, true, nil).Once()
|
||||
req := httptest.NewRequest("GET", "/raw/"+id+".txt", nil)
|
||||
req.SetPathValue("id", id+".txt")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleRaw(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, content, rr.Body.String())
|
||||
})
|
||||
|
||||
t.Run("Not Found", func(t *testing.T) {
|
||||
s.On("Get", "missing").Return(nil, false, nil).Once()
|
||||
req := httptest.NewRequest("GET", "/raw/missing", nil)
|
||||
@@ -211,6 +284,16 @@ func TestHandleRaw(t *testing.T) {
|
||||
h.HandleRaw(rr, req)
|
||||
assert.Equal(t, http.StatusNotFound, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("Store Error", func(t *testing.T) {
|
||||
s.On("Get", "error").Return(nil, false, errors.New("db error")).Once()
|
||||
req := httptest.NewRequest("GET", "/raw/error", nil)
|
||||
req.SetPathValue("id", "error")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleRaw(rr, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleGetTemplateError(t *testing.T) {
|
||||
@@ -224,3 +307,75 @@ func TestHandleGetTemplateError(t *testing.T) {
|
||||
|
||||
h.HandleGet(rr, req)
|
||||
}
|
||||
|
||||
func TestHandleSetTemplateError(t *testing.T) {
|
||||
h := NewHandler(nil, 1024, "../view/templates/*.html")
|
||||
|
||||
t.Run("Empty Content", func(t *testing.T) {
|
||||
form := url.Values{"content": {""}}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := &FailingResponseWriter{*httptest.NewRecorder()}
|
||||
h.HandleSet(rr, req)
|
||||
})
|
||||
|
||||
t.Run("Parse Error", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader("content=%zz"))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := &FailingResponseWriter{*httptest.NewRecorder()}
|
||||
h.HandleSet(rr, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTemplateErrors(t *testing.T) {
|
||||
tmpl := template.New("empty")
|
||||
h := &HttpHandler{
|
||||
templates: tmpl,
|
||||
maxSize: 1024,
|
||||
}
|
||||
|
||||
t.Run("HandleHome Error", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleHome(rr, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("HandleGet Error", func(t *testing.T) {
|
||||
s := new(MockStore)
|
||||
h.store = s
|
||||
id := "testid"
|
||||
s.On("Get", id).Return(&store.Paste{Content: "hello", CreatedAt: time.Now()}, true, nil).Once()
|
||||
req := httptest.NewRequest("GET", "/"+id, nil)
|
||||
req.SetPathValue("id", id)
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleGet(rr, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("HandleSet Empty Content Error", func(t *testing.T) {
|
||||
form := url.Values{"content": {""}}
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode()))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleSet(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
|
||||
t.Run("HandleSet Parse Error", func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/", strings.NewReader("content=%zz"))
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleSet(rr, req)
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
})
|
||||
}
|
||||
func TestSafeHTML(t *testing.T) {
|
||||
h := NewHandler(nil, 1024, "../view/templates/*.html")
|
||||
tmpl, err := h.templates.New("test").Parse(`{{ safeHTML "<br>" }}`)
|
||||
assert.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
err = tmpl.Execute(rr, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "<br>", rr.Body.String())
|
||||
}
|
||||
|
||||
+8
-7
@@ -5,22 +5,23 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
func generateId() (string, error) {
|
||||
bytes := make([]byte, 10)
|
||||
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
|
||||
result := make([]byte, 10)
|
||||
max := big.NewInt(int64(len(charset)))
|
||||
for i := 0; i < 10; i++ {
|
||||
n, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i := range bytes {
|
||||
bytes[i] = charset[bytes[i]%byte(len(charset))]
|
||||
result[i] = charset[n.Int64()]
|
||||
}
|
||||
return string(bytes), nil
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
func TimeAgo(t time.Time) string {
|
||||
|
||||
@@ -1,18 +1,36 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type errorReader struct{}
|
||||
|
||||
func (e errorReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("read error")
|
||||
}
|
||||
|
||||
func TestGenerateId(t *testing.T) {
|
||||
id, err := generateId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, len(id))
|
||||
}
|
||||
|
||||
func TestGenerateIdError(t *testing.T) {
|
||||
originalReader := rand.Reader
|
||||
defer func() { rand.Reader = originalReader }()
|
||||
rand.Reader = errorReader{}
|
||||
|
||||
id, err := generateId()
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, id)
|
||||
}
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
c := "test content"
|
||||
h1 := hash(c)
|
||||
|
||||
@@ -36,23 +36,40 @@ func parseFlags() *config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func securityHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline'")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg := parseFlags()
|
||||
|
||||
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))
|
||||
|
||||
storage, err := store.NewBoltStore(cfg.dbPath)
|
||||
storage, err := store.NewBoltStore(cfg.dbPath, nil)
|
||||
if err != nil {
|
||||
slog.Error("failed to initialize store", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
storage.Cleanup(cfg.ttl)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -82,7 +99,7 @@ func main() {
|
||||
|
||||
server := &http.Server{
|
||||
Addr: cfg.addr,
|
||||
Handler: mux,
|
||||
Handler: securityHeadersMiddleware(mux),
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -97,10 +114,12 @@ func main() {
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
cancel()
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
slog.Error("server shutdown failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSecurityHeadersMiddleware(t *testing.T) {
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := securityHeadersMiddleware(nextHandler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, "max-age=63072000; includeSubDomains", rr.Header().Get("Strict-Transport-Security"))
|
||||
assert.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options"))
|
||||
assert.Equal(t, "DENY", rr.Header().Get("X-Frame-Options"))
|
||||
assert.Contains(t, rr.Header().Get("Content-Security-Policy"), "default-src 'self'")
|
||||
}
|
||||
+37
-6
@@ -16,12 +16,21 @@ type BoltStore struct {
|
||||
db *bbolt.DB
|
||||
}
|
||||
|
||||
func NewBoltStore(path string) (*BoltStore, error) {
|
||||
db, err := bbolt.Open(path, 0600, nil)
|
||||
func NewBoltStore(path string, opts *bbolt.Options) (*BoltStore, error) {
|
||||
db, err := bbolt.Open(path, 0600, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bucketsExist := false
|
||||
db.View(func(tx *bbolt.Tx) error {
|
||||
if tx.Bucket(pastesBucket) != nil && tx.Bucket(hashesBucket) != nil {
|
||||
bucketsExist = true
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if !bucketsExist {
|
||||
err = db.Update(func(tx *bbolt.Tx) error {
|
||||
if _, err := tx.CreateBucketIfNotExists(pastesBucket); err != nil {
|
||||
return err
|
||||
@@ -30,8 +39,10 @@ func NewBoltStore(path string) (*BoltStore, error) {
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &BoltStore{db: db}, nil
|
||||
}
|
||||
@@ -70,7 +81,7 @@ func (s *BoltStore) GetIDByHash(hash string) (string, bool, error) {
|
||||
return id, exists, err
|
||||
}
|
||||
|
||||
func (s *BoltStore) Set(id, hash, content string) error {
|
||||
func (s *BoltStore) Set(id, hash, content string, metadata map[string]interface{}) error {
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
pb := tx.Bucket(pastesBucket)
|
||||
hb := tx.Bucket(hashesBucket)
|
||||
@@ -78,6 +89,8 @@ func (s *BoltStore) Set(id, hash, content string) error {
|
||||
paste := Paste{
|
||||
Content: content,
|
||||
CreatedAt: time.Now(),
|
||||
Hash: hash,
|
||||
Metadata: metadata,
|
||||
}
|
||||
encoded, err := json.Marshal(paste)
|
||||
if err != nil {
|
||||
@@ -93,12 +106,23 @@ func (s *BoltStore) Set(id, hash, content string) error {
|
||||
|
||||
func (s *BoltStore) Del(id string) error {
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(pastesBucket).Delete([]byte(id))
|
||||
pb := tx.Bucket(pastesBucket)
|
||||
hb := tx.Bucket(hashesBucket)
|
||||
|
||||
val := pb.Get([]byte(id))
|
||||
if val != nil {
|
||||
var p Paste
|
||||
if err := json.Unmarshal(val, &p); err == nil && p.Hash != "" {
|
||||
hb.Delete([]byte(p.Hash))
|
||||
}
|
||||
}
|
||||
return pb.Delete([]byte(id))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BoltStore) Cleanup(maxAge time.Duration) {
|
||||
var keysToDelete [][]byte
|
||||
var hashesToDelete [][]byte
|
||||
s.db.View(func(tx *bbolt.Tx) error {
|
||||
b := tx.Bucket(pastesBucket)
|
||||
c := b.Cursor()
|
||||
@@ -107,6 +131,9 @@ func (s *BoltStore) Cleanup(maxAge time.Duration) {
|
||||
if err := json.Unmarshal(v, &p); err == nil {
|
||||
if time.Since(p.CreatedAt) > maxAge {
|
||||
keysToDelete = append(keysToDelete, k)
|
||||
if p.Hash != "" {
|
||||
hashesToDelete = append(hashesToDelete, []byte(p.Hash))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -115,9 +142,13 @@ func (s *BoltStore) Cleanup(maxAge time.Duration) {
|
||||
|
||||
if len(keysToDelete) > 0 {
|
||||
s.db.Update(func(tx *bbolt.Tx) error {
|
||||
b := tx.Bucket(pastesBucket)
|
||||
pb := tx.Bucket(pastesBucket)
|
||||
hb := tx.Bucket(hashesBucket)
|
||||
for _, k := range keysToDelete {
|
||||
b.Delete(k)
|
||||
pb.Delete(k)
|
||||
}
|
||||
for _, h := range hashesToDelete {
|
||||
hb.Delete(h)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
+87
-7
@@ -8,22 +8,40 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/bbolt"
|
||||
bbolterrors "go.etcd.io/bbolt/errors"
|
||||
)
|
||||
|
||||
func TestBoltStore(t *testing.T) {
|
||||
dbPath := "test.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
s, err := NewBoltStore(dbPath)
|
||||
s, err := NewBoltStore(dbPath, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
t.Run("Open Existing Store", func(t *testing.T) {
|
||||
dbPath := "existing_store.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
s1, err := NewBoltStore(dbPath, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s1.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
s2, err := NewBoltStore(dbPath, nil)
|
||||
require.NoError(t, err)
|
||||
defer s2.Close()
|
||||
|
||||
assert.NotNil(t, s2)
|
||||
})
|
||||
|
||||
t.Run("Set and Get", func(t *testing.T) {
|
||||
id := "id1"
|
||||
hash := "hash1"
|
||||
content := "content1"
|
||||
|
||||
err := s.Set(id, hash, content)
|
||||
err := s.Set(id, hash, content, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
p, exists, err := s.Get(id)
|
||||
@@ -37,7 +55,7 @@ func TestBoltStore(t *testing.T) {
|
||||
hash := "hash2"
|
||||
content := "content2"
|
||||
|
||||
s.Set(id, hash, content)
|
||||
s.Set(id, hash, content, nil)
|
||||
storedID, exists, err := s.GetIDByHash(hash)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
@@ -46,19 +64,25 @@ func TestBoltStore(t *testing.T) {
|
||||
|
||||
t.Run("Del", func(t *testing.T) {
|
||||
id := "id3"
|
||||
s.Set(id, "h3", "c3")
|
||||
s.Set(id, "h3", "c3", nil)
|
||||
err := s.Del(id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, exists, _ := s.Get(id)
|
||||
assert.False(t, exists)
|
||||
|
||||
_, hashExists, _ := s.GetIDByHash("h3")
|
||||
assert.False(t, hashExists)
|
||||
})
|
||||
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
s.Set("old", "oldhash", "oldcontent")
|
||||
s.Set("old", "oldhash", "oldcontent", nil)
|
||||
s.Cleanup(-time.Hour)
|
||||
_, exists, _ := s.Get("old")
|
||||
assert.False(t, exists)
|
||||
|
||||
_, hashExists, _ := s.GetIDByHash("oldhash")
|
||||
assert.False(t, hashExists)
|
||||
})
|
||||
|
||||
t.Run("Cleanup Bad Data", func(t *testing.T) {
|
||||
@@ -68,15 +92,71 @@ func TestBoltStore(t *testing.T) {
|
||||
s.Cleanup(-time.Hour)
|
||||
// Should skip without panic
|
||||
})
|
||||
|
||||
t.Run("Set Marshal Error", func(t *testing.T) {
|
||||
dbPath := "marshal_error.db"
|
||||
db, _ := bbolt.Open(dbPath, 0666, nil)
|
||||
store := &BoltStore{db: db}
|
||||
defer db.Close()
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
err := store.Set("id", "h", "c", map[string]interface{}{
|
||||
"foo": func() {},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Set Put Error", func(t *testing.T) {
|
||||
dbPath := "put_error.db"
|
||||
s, err := NewBoltStore(dbPath, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
err = s.Set("", "hash", "content", nil)
|
||||
assert.ErrorIs(t, err, bbolterrors.ErrKeyRequired)
|
||||
})
|
||||
|
||||
t.Run("Del Error", func(t *testing.T) {
|
||||
dbPath := "error_del.db"
|
||||
db, _ := bbolt.Open(dbPath, 0666, nil)
|
||||
store := &BoltStore{db: db}
|
||||
db.Close()
|
||||
err := store.Del("id")
|
||||
assert.Error(t, err)
|
||||
os.Remove(dbPath)
|
||||
})
|
||||
t.Run("Get Error", func(t *testing.T) {
|
||||
dbPath := "error_get.db"
|
||||
db, _ := bbolt.Open(dbPath, 0666, nil)
|
||||
store := &BoltStore{db: db}
|
||||
db.Close()
|
||||
_, _, err := store.Get("id")
|
||||
assert.Error(t, err)
|
||||
os.Remove(dbPath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewBoltStoreError(t *testing.T) {
|
||||
// Use a directory name as file path to trigger error
|
||||
err := os.Mkdir("testdir", 0755)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll("testdir")
|
||||
|
||||
s, err := NewBoltStore("testdir")
|
||||
s, err := NewBoltStore("testdir", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
|
||||
t.Run("Bucket Creation Error", func(t *testing.T) {
|
||||
dbPath := "bucket_error.db"
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
originalPastesBucket := pastesBucket
|
||||
pastesBucket = []byte("")
|
||||
defer func() { pastesBucket = originalPastesBucket }()
|
||||
|
||||
s, err := NewBoltStore(dbPath, nil)
|
||||
|
||||
assert.ErrorIs(t, err, bbolterrors.ErrBucketNameRequired)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
}
|
||||
|
||||
+6
-1
@@ -32,12 +32,14 @@ func (s *MemoryStore) GetIDByHash(hash string) (string, bool, error) {
|
||||
return id, ok, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStore) Set(id, hash, content string) error {
|
||||
func (s *MemoryStore) Set(id, hash, content string, metadata map[string]interface{}) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.pastes[id] = &Paste{
|
||||
Content: content,
|
||||
CreatedAt: time.Now(),
|
||||
Hash: hash,
|
||||
Metadata: metadata,
|
||||
}
|
||||
s.hashes[hash] = id
|
||||
return nil
|
||||
@@ -46,6 +48,9 @@ func (s *MemoryStore) Set(id, hash, content string) error {
|
||||
func (s *MemoryStore) Del(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if p, ok := s.pastes[id]; ok && p.Hash != "" {
|
||||
delete(s.hashes, p.Hash)
|
||||
}
|
||||
delete(s.pastes, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestMemoryStore(t *testing.T) {
|
||||
hash := "hash1"
|
||||
content := "content1"
|
||||
|
||||
err := s.Set(id, hash, content)
|
||||
err := s.Set(id, hash, content, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
p, exists, err := s.Get(id)
|
||||
@@ -29,7 +29,7 @@ func TestMemoryStore(t *testing.T) {
|
||||
hash := "hash2"
|
||||
content := "content2"
|
||||
|
||||
s.Set(id, hash, content)
|
||||
s.Set(id, hash, content, nil)
|
||||
storedID, exists, err := s.GetIDByHash(hash)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
@@ -50,11 +50,14 @@ func TestMemoryStore(t *testing.T) {
|
||||
|
||||
t.Run("Del", func(t *testing.T) {
|
||||
id := "id3"
|
||||
s.Set(id, "h3", "c3")
|
||||
s.Set(id, "h3", "c3", nil)
|
||||
err := s.Del(id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, exists, _ := s.Get(id)
|
||||
assert.False(t, exists)
|
||||
|
||||
_, hashExists, _ := s.GetIDByHash("h3")
|
||||
assert.False(t, hashExists)
|
||||
})
|
||||
}
|
||||
|
||||
+3
-1
@@ -5,11 +5,13 @@ import "time"
|
||||
type Paste struct {
|
||||
Content string `json:"content"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Hash string `json:"hash,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
Get(id string) (*Paste, bool, error)
|
||||
GetIDByHash(hash string) (string, bool, error)
|
||||
Set(id, hash, content string) error
|
||||
Set(id, hash, content string, metadata map[string]interface{}) error
|
||||
Del(id string) error
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
--margin: 1rem;
|
||||
--line-height: 1.5;
|
||||
--digits: 1;
|
||||
--font-mono: "Fira Code", "JetBrains Mono", "Cascadia Code", "Source Code Pro", "Menlo", "Monaco", "Consolas", monospace;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user