This commit is contained in:
2025-10-03 03:45:29 +02:00
parent 48b88892b5
commit 5427393547
6 changed files with 183 additions and 133 deletions

View File

@@ -0,0 +1,82 @@
package websocket
import (
"log/slog"
"sync"
"time"
"golang.org/x/net/websocket"
)
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
)
// Client is a middleman between the websocket connection and the hub.
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
closeOnce sync.Once
}
// close is a thread-safe method to clean up the client's resources.
// It ensures that the unregister and connection close operations happen exactly once.
func (c *Client) close() {
c.closeOnce.Do(func() {
slog.Debug("closing client connection", "remoteAddr", c.conn.RemoteAddr())
c.hub.unregister <- c
if err := c.conn.Close(); err != nil {
// This error is expected if the other end has already hung up.
slog.Debug("error while closing client connection", "error", err, "remoteAddr", c.conn.RemoteAddr())
}
})
}
// readPump is responsible for detecting a dead connection via read deadlines.
func (c *Client) readPump() {
defer c.close()
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
slog.Warn("failed to set initial read deadline", "error", err, "remoteAddr", c.conn.RemoteAddr())
return
}
var msg string
for {
if err := websocket.Message.Receive(c.conn, &msg); err != nil {
slog.Debug("client read error, triggering disconnect", "error", err, "remoteAddr", c.conn.RemoteAddr())
break
}
if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
slog.Warn("failed to reset read deadline", "error", err, "remoteAddr", c.conn.RemoteAddr())
break
}
}
}
// writePump pumps messages from the hub to the websocket connection.
func (c *Client) writePump() {
defer c.close()
for {
select {
case message, ok := <-c.send:
if !ok {
slog.Debug("hub closed channel, closing connection", "remoteAddr", c.conn.RemoteAddr())
return
}
if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
slog.Warn("failed to set write deadline", "error", err, "remoteAddr", c.conn.RemoteAddr())
return
}
if err := websocket.Message.Send(c.conn, string(message)); err != nil {
slog.Warn("client write error", "error", err, "remoteAddr", c.conn.RemoteAddr())
return
}
}
}
}

View File

@@ -7,34 +7,23 @@ import (
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
// newWebsocketHandler creates a new WebSocket handler closure. // newWebsocketHandler creates the handler that upgrades connections.
func (s *Server) newWebsocketHandler() websocket.Handler { func (s *Server) newWebsocketHandler() websocket.Handler {
return func(ws *websocket.Conn) { return func(ws *websocket.Conn) {
defer func() {
s.hub.unregister <- ws
if err := ws.Close(); err != nil {
slog.Warn("error while closing websocket connection", "error", err, "remoteAddr", ws.RemoteAddr())
}
}()
origin := ws.Config().Origin.String() origin := ws.Config().Origin.String()
if !s.originChecker(origin) { if !s.originChecker(origin) {
slog.Warn("origin not allowed, rejecting connection", "origin", origin) slog.Warn("origin not allowed, rejecting connection", "origin", origin)
return return
} }
s.hub.register <- ws slog.Debug("client connected, upgrading connection", "remoteAddr", ws.RemoteAddr())
// Send the last known state immediately upon connection. client := &Client{hub: s.hub, conn: ws, send: make(chan []byte, 256)}
s.poller.SendLastState(ws)
// Block by reading from the client to detect disconnection. client.hub.register <- client
var msg string
for { go client.writePump()
if err := websocket.Message.Receive(ws, &msg); err != nil { client.readPump()
break // Client has disconnected.
}
}
} }
} }

View File

@@ -3,30 +3,28 @@ package websocket
import ( import (
"context" "context"
"log/slog" "log/slog"
"spotify-ws/internal/spotify"
"sync" "sync"
"golang.org/x/net/websocket"
) )
// Hub manages the set of active clients and broadcasts messages. // Hub maintains the set of active clients and broadcasts messages to them.
type Hub struct { type Hub struct {
clients map[*websocket.Conn]struct{} clients map[*Client]struct{}
mu sync.RWMutex broadcast chan []byte
realtime bool register chan *Client
register chan *websocket.Conn unregister chan *Client
unregister chan *websocket.Conn
broadcast chan *spotify.CurrentlyPlaying // The last known state, protected by a mutex. This is sent to new clients.
lastState []byte
mu sync.RWMutex
} }
// NewHub creates a new Hub. // NewHub creates a new Hub.
func NewHub(realtime bool) *Hub { func NewHub() *Hub {
return &Hub{ return &Hub{
clients: make(map[*websocket.Conn]struct{}), clients: make(map[*Client]struct{}),
realtime: realtime, broadcast: make(chan []byte),
register: make(chan *websocket.Conn), register: make(chan *Client),
unregister: make(chan *websocket.Conn), unregister: make(chan *Client),
broadcast: make(chan *spotify.CurrentlyPlaying),
} }
} }
@@ -38,57 +36,36 @@ func (h *Hub) Run(ctx context.Context) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
h.closeAllConnections() for client := range h.clients {
close(client.send)
}
return return
case client := <-h.register: case client := <-h.register:
h.mu.Lock()
h.clients[client] = struct{}{} h.clients[client] = struct{}{}
h.mu.Unlock() h.mu.RLock()
slog.Debug("client registered", "remoteAddr", client.RemoteAddr()) if h.lastState != nil {
client.send <- h.lastState
}
h.mu.RUnlock()
case client := <-h.unregister: case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok { if _, ok := h.clients[client]; ok {
delete(h.clients, client) delete(h.clients, client)
close(client.send)
} }
case message := <-h.broadcast:
h.mu.Lock()
h.lastState = message
h.mu.Unlock() h.mu.Unlock()
slog.Debug("client unregistered", "remoteAddr", client.RemoteAddr())
case state := <-h.broadcast:
h.broadcastState(state)
}
}
}
// Broadcast sends a state update to all connected clients. for client := range h.clients {
func (h *Hub) Broadcast(state *spotify.CurrentlyPlaying) { select {
h.broadcast <- state case client.send <- message:
} default:
slog.Warn("client send buffer full, disconnecting", "remoteAddr", client.conn.RemoteAddr())
// broadcastState handles the actual message sending. close(client.send)
func (h *Hub) broadcastState(state *spotify.CurrentlyPlaying) { delete(h.clients, client)
h.mu.RLock() }
defer h.mu.RUnlock()
if len(h.clients) == 0 {
return
}
clientPayload := newPlaybackState(state, h.realtime)
for client := range h.clients {
go func(c *websocket.Conn) {
if err := websocket.JSON.Send(c, clientPayload); err != nil {
slog.Warn("failed to broadcast message", "error", err, "remoteAddr", c.RemoteAddr())
} }
}(client)
}
}
// closeAllConnections closes all active client connections during shutdown.
func (h *Hub) closeAllConnections() {
h.mu.Lock()
defer h.mu.Unlock()
for client := range h.clients {
if err := client.Close(); err != nil {
slog.Warn("error closing client connection during shutdown", "error", err, "remoteAddr", client.RemoteAddr())
} }
} }
} }

View File

@@ -2,36 +2,39 @@ package websocket
import ( import (
"context" "context"
"encoding/json"
"log/slog" "log/slog"
"sync" "sync"
"time" "time"
"spotify-ws/internal/spotify" "spotify-ws/internal/spotify"
"golang.org/x/net/websocket"
) )
// Poller is responsible for fetching data from the Spotify API periodically. // Poller is responsible for fetching data from the Spotify API periodically.
type Poller struct { type Poller struct {
client *spotify.Client client *spotify.Client
hub *Hub hub *Hub
realtime bool
lastState *spotify.CurrentlyPlaying lastState *spotify.CurrentlyPlaying
mu sync.RWMutex mu sync.RWMutex
} }
// NewPoller creates a new Poller. // NewPoller creates a new Poller.
func NewPoller(client *spotify.Client, hub *Hub) *Poller { func NewPoller(client *spotify.Client, hub *Hub, realtime bool) *Poller {
return &Poller{ return &Poller{
client: client, client: client,
hub: hub, hub: hub,
realtime: realtime,
} }
} }
// Run starts the polling loop. It must be run in a separate goroutine. // Run starts the polling loop.
func (p *Poller) Run(ctx context.Context) { func (p *Poller) Run(ctx context.Context) {
slog.Info("poller started") slog.Info("poller started")
defer slog.Info("poller stopped") defer slog.Info("poller stopped")
p.UpdateState(ctx)
ticker := time.NewTicker(3 * time.Second) ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -61,38 +64,31 @@ func (p *Poller) UpdateState(ctx context.Context) {
p.mu.Unlock() p.mu.Unlock()
if hasChanged { if hasChanged {
if !p.hub.realtime { if !p.realtime {
trackName := "Nothing" trackName := "Nothing"
if current.Item != nil { if current.Item != nil {
trackName = current.Item.Name trackName = current.Item.Name
} }
slog.Info("state changed, broadcasting update", "isPlaying", current.IsPlaying, "track", trackName) slog.Info("state changed, broadcasting update", "isPlaying", current.IsPlaying, "track", trackName)
} }
p.hub.Broadcast(current)
}
}
// SendLastState sends the cached state to a single new client. payload := newPlaybackState(current, p.realtime)
func (p *Poller) SendLastState(ws *websocket.Conn) { message, err := json.Marshal(payload)
p.mu.RLock() if err != nil {
defer p.mu.RUnlock() slog.Error("failed to marshal playback state", "error", err)
return
}
if p.lastState == nil { p.hub.broadcast <- message
return
}
clientPayload := newPlaybackState(p.lastState, p.hub.realtime)
if err := websocket.JSON.Send(ws, clientPayload); err != nil {
slog.Warn("failed to send initial state to client", "error", err, "remoteAddr", ws.RemoteAddr())
} }
} }
// hasStateChanged performs a robust comparison between the new and old states. // hasStateChanged performs a robust comparison between the new and old states.
// This function must be called within a lock.
func (p *Poller) hasStateChanged(current *spotify.CurrentlyPlaying) bool { func (p *Poller) hasStateChanged(current *spotify.CurrentlyPlaying) bool {
if p.lastState == nil { if p.lastState == nil {
return true return true
} }
if p.hub.realtime && current.IsPlaying && current.Item != nil { if p.realtime && current.IsPlaying && current.Item != nil {
return true return true
} }
if p.lastState.IsPlaying != current.IsPlaying { if p.lastState.IsPlaying != current.IsPlaying {

View File

@@ -5,14 +5,14 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
"spotify-ws/internal/spotify" "spotify-ws/internal/spotify"
) )
// Server is the main application orchestrator. It owns all components // Server is the main application orchestrator.
// and manages the application's lifecycle.
type Server struct { type Server struct {
addr string addr string
httpServer *http.Server httpServer *http.Server
@@ -23,13 +23,12 @@ type Server struct {
// NewServer creates a new, fully configured WebSocket server. // NewServer creates a new, fully configured WebSocket server.
func NewServer(addr string, allowedOrigins []string, spotifyClient *spotify.Client, realtime bool) *Server { func NewServer(addr string, allowedOrigins []string, spotifyClient *spotify.Client, realtime bool) *Server {
hub := NewHub(realtime) hub := NewHub()
poller := NewPoller(spotifyClient, hub) poller := NewPoller(spotifyClient, hub, realtime)
// Create a closure for origin checking to keep the Server's dependencies clean.
originChecker := func(origin string) bool { originChecker := func(origin string) bool {
if len(allowedOrigins) == 0 { if len(allowedOrigins) == 0 {
return true // Allow all if not specified. return true
} }
for _, allowedOrigin := range allowedOrigins { for _, allowedOrigin := range allowedOrigins {
if allowedOrigin == origin { if allowedOrigin == origin {
@@ -47,15 +46,32 @@ func NewServer(addr string, allowedOrigins []string, spotifyClient *spotify.Clie
} }
} }
// Run starts the server and its components. It blocks until the context is // Run starts the server and its components.
// canceled and all components have shut down gracefully.
func (s *Server) Run(ctx context.Context) error { func (s *Server) Run(ctx context.Context) error {
// Do an initial state fetch before starting the server.
s.poller.UpdateState(ctx)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/", s.newWebsocketHandler()) wsHandler := s.newWebsocketHandler()
mux.HandleFunc("/health", healthHandler) mux.HandleFunc("/health", healthHandler)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
isWebSocket := r.Header.Get("Upgrade") == "websocket" &&
strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade")
if isWebSocket {
wsHandler.ServeHTTP(w, r)
return
}
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
w.Header().Set("X-Source", "github.com/skidoodle/spotify-ws")
w.WriteHeader(http.StatusUpgradeRequired)
if _, err := w.Write([]byte("426 Upgrade Required (github.com/skidoodle/spotify-ws)")); err != nil {
slog.Warn("failed to write upgrade required response", "error", err)
}
})
s.httpServer = &http.Server{ s.httpServer = &http.Server{
Addr: s.addr, Addr: s.addr,
@@ -63,7 +79,7 @@ func (s *Server) Run(ctx context.Context) error {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) // For the hub and the poller. wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
@@ -75,32 +91,24 @@ func (s *Server) Run(ctx context.Context) error {
s.poller.Run(ctx) s.poller.Run(ctx)
}() }()
// Start the HTTP server.
go func() { go func() {
if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { <-ctx.Done()
slog.Error("http server error", "error", err) slog.Info("shutdown signal received, stopping http server")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(shutdownCtx); err != nil {
slog.Error("http server shutdown error", "error", err)
} }
}() }()
// Wait for shutdown signal. slog.Info("http server listening", "addr", s.addr)
<-ctx.Done() if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.Info("shutdown signal received") return err
}
// The hub and poller will stop automatically via the context.
// We just need to shut down the HTTP server and wait for goroutines to finish.
s.shutdown()
wg.Wait() wg.Wait()
return nil return nil
} }
// shutdown gracefully shuts down the HTTP server.
func (s *Server) shutdown() {
slog.Info("shutting down http server")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(shutdownCtx); err != nil {
slog.Error("http server shutdown error", "error", err)
}
}

View File

@@ -2,8 +2,7 @@ package websocket
import "spotify-ws/internal/spotify" import "spotify-ws/internal/spotify"
// PlaybackState is the client-facing data structure. It conditionally omits // PlaybackState is the client-facing data structure.
// real-time data fields from JSON based on the server's mode.
type PlaybackState struct { type PlaybackState struct {
IsPlaying bool `json:"is_playing"` IsPlaying bool `json:"is_playing"`
ProgressMs int `json:"progress_ms,omitempty"` ProgressMs int `json:"progress_ms,omitempty"`
@@ -12,7 +11,6 @@ type PlaybackState struct {
} }
// newPlaybackState creates a client-facing PlaybackState from the internal Spotify data. // newPlaybackState creates a client-facing PlaybackState from the internal Spotify data.
// It includes progress data only if the server is in real-time mode.
func newPlaybackState(data *spotify.CurrentlyPlaying, realtime bool) PlaybackState { func newPlaybackState(data *spotify.CurrentlyPlaying, realtime bool) PlaybackState {
if data == nil { if data == nil {
return PlaybackState{IsPlaying: false} return PlaybackState{IsPlaying: false}