refactor: split handlers.go and centralize config

Signed-off-by: skidoodle <contact@albert.lol>
This commit is contained in:
2026-01-18 19:25:35 +01:00
parent aca7267301
commit 00e5c95fe3
7 changed files with 447 additions and 458 deletions
+121 -13
View File
@@ -2,22 +2,19 @@ package app
import (
"context"
"fmt"
"io"
"math"
"os"
"path/filepath"
"strconv"
"time"
)
const (
cleanupInterval = 1 * time.Hour
tempExpiry = 4 * time.Hour
minRetention = 24 * time.Hour
maxRetention = 365 * 24 * time.Hour
bytesInMB = 1 << 20
"github.com/skidoodle/safebin/internal/crypto"
)
func (app *App) StartCleanupTask(ctx context.Context) {
ticker := time.NewTicker(cleanupInterval)
ticker := time.NewTicker(CleanupInterval)
for {
select {
@@ -31,6 +28,117 @@ func (app *App) StartCleanupTask(ctx context.Context) {
}
}
func (app *App) saveChunk(uid string, idx int, src io.Reader) error {
dir := filepath.Join(app.Conf.StorageDir, "tmp", uid)
if err := os.MkdirAll(dir, PermUserRWX); err != nil {
return fmt.Errorf("create chunk dir: %w", err)
}
dest, err := os.Create(filepath.Join(dir, strconv.Itoa(idx)))
if err != nil {
return fmt.Errorf("create chunk file: %w", err)
}
defer func() {
if closeErr := dest.Close(); closeErr != nil {
app.Logger.Error("Failed to close chunk dest", "err", closeErr)
}
}()
if _, err := io.Copy(dest, src); err != nil {
return fmt.Errorf("copy chunk: %w", err)
}
return nil
}
func (app *App) mergeChunks(uid string, total int) (string, error) {
tmpPath := filepath.Join(app.Conf.StorageDir, "tmp", "m_"+uid)
merged, err := os.Create(tmpPath)
if err != nil {
return "", fmt.Errorf("create merge file: %w", err)
}
defer func() {
if closeErr := merged.Close(); closeErr != nil {
app.Logger.Error("Failed to close merged file", "err", closeErr)
}
}()
limit := app.Conf.MaxMB * MegaByte
var written int64
for i := range total {
partPath := filepath.Join(app.Conf.StorageDir, "tmp", uid, strconv.Itoa(i))
part, err := os.Open(partPath)
if err != nil {
return "", fmt.Errorf("open chunk %d: %w", i, err)
}
n, err := io.Copy(merged, part)
if closeErr := part.Close(); closeErr != nil {
app.Logger.Error("Failed to close chunk part", "err", closeErr)
}
if err != nil {
return "", fmt.Errorf("append chunk %d: %w", i, err)
}
written += n
if written > limit {
return "", io.ErrShortWrite
}
}
return tmpPath, nil
}
func (app *App) encryptAndSave(src io.Reader, key []byte, finalPath string) error {
out, err := os.Create(finalPath + ".tmp")
if err != nil {
return fmt.Errorf("create final file: %w", err)
}
var closed bool
defer func() {
if !closed {
if closeErr := out.Close(); closeErr != nil {
app.Logger.Error("Failed to close final file", "err", closeErr)
}
}
if removeErr := os.Remove(finalPath + ".tmp"); removeErr != nil && !os.IsNotExist(removeErr) {
app.Logger.Error("Failed to remove temp final file", "err", removeErr)
}
}()
streamer, err := crypto.NewGCMStreamer(key)
if err != nil {
return fmt.Errorf("create streamer: %w", err)
}
if err := streamer.EncryptStream(out, src); err != nil {
return fmt.Errorf("encrypt stream: %w", err)
}
if err := out.Close(); err != nil {
return fmt.Errorf("close final file: %w", err)
}
closed = true
if err := os.Rename(finalPath+".tmp", finalPath); err != nil {
return fmt.Errorf("rename final file: %w", err)
}
return nil
}
func (app *App) CleanStorage(path string) {
entries, err := os.ReadDir(path)
if err != nil {
@@ -67,7 +175,7 @@ func (app *App) CleanTemp(path string) {
continue
}
if time.Since(info.ModTime()) > tempExpiry {
if time.Since(info.ModTime()) > TempExpiry {
if err := os.RemoveAll(filepath.Join(path, entry.Name())); err != nil {
app.Logger.Error("Failed to remove expired temp file", "path", entry.Name(), "err", err)
}
@@ -76,13 +184,13 @@ func (app *App) CleanTemp(path string) {
}
func CalculateRetention(fileSize, maxMB int64) time.Duration {
ratio := math.Max(0, math.Min(1, float64(fileSize)/float64(maxMB*bytesInMB)))
ratio := math.Max(0, math.Min(1, float64(fileSize)/float64(maxMB*MegaByte)))
invRatio := 1.0 - ratio
retention := float64(maxRetention) * (invRatio * invRatio * invRatio)
retention := float64(MaxRetention) * (invRatio * invRatio * invRatio)
if retention < float64(minRetention) {
return minRetention
if retention < float64(MinRetention) {
return MinRetention
}
return time.Duration(retention)