diff --git a/internal/websocket/client.go b/internal/websocket/client.go new file mode 100644 index 0000000..fddcd5b --- /dev/null +++ b/internal/websocket/client.go @@ -0,0 +1,82 @@ +package websocket + +import ( + "log/slog" + "sync" + "time" + + "golang.org/x/net/websocket" +) + +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second +) + +// Client is a middleman between the websocket connection and the hub. +type Client struct { + hub *Hub + conn *websocket.Conn + send chan []byte + closeOnce sync.Once +} + +// close is a thread-safe method to clean up the client's resources. +// It ensures that the unregister and connection close operations happen exactly once. +func (c *Client) close() { + c.closeOnce.Do(func() { + slog.Debug("closing client connection", "remoteAddr", c.conn.RemoteAddr()) + c.hub.unregister <- c + if err := c.conn.Close(); err != nil { + // This error is expected if the other end has already hung up. + slog.Debug("error while closing client connection", "error", err, "remoteAddr", c.conn.RemoteAddr()) + } + }) +} + +// readPump is responsible for detecting a dead connection via read deadlines. +func (c *Client) readPump() { + defer c.close() + + if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + slog.Warn("failed to set initial read deadline", "error", err, "remoteAddr", c.conn.RemoteAddr()) + return + } + + var msg string + for { + if err := websocket.Message.Receive(c.conn, &msg); err != nil { + slog.Debug("client read error, triggering disconnect", "error", err, "remoteAddr", c.conn.RemoteAddr()) + break + } + if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + slog.Warn("failed to reset read deadline", "error", err, "remoteAddr", c.conn.RemoteAddr()) + break + } + } +} + +// writePump pumps messages from the hub to the websocket connection. +func (c *Client) writePump() { + defer c.close() + + for { + select { + case message, ok := <-c.send: + if !ok { + slog.Debug("hub closed channel, closing connection", "remoteAddr", c.conn.RemoteAddr()) + return + } + + if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + slog.Warn("failed to set write deadline", "error", err, "remoteAddr", c.conn.RemoteAddr()) + return + } + + if err := websocket.Message.Send(c.conn, string(message)); err != nil { + slog.Warn("client write error", "error", err, "remoteAddr", c.conn.RemoteAddr()) + return + } + } + } +} diff --git a/internal/websocket/handler.go b/internal/websocket/handler.go index 4432bba..9ec5cfb 100644 --- a/internal/websocket/handler.go +++ b/internal/websocket/handler.go @@ -7,34 +7,23 @@ import ( "golang.org/x/net/websocket" ) -// newWebsocketHandler creates a new WebSocket handler closure. +// newWebsocketHandler creates the handler that upgrades connections. func (s *Server) newWebsocketHandler() websocket.Handler { return func(ws *websocket.Conn) { - defer func() { - s.hub.unregister <- ws - if err := ws.Close(); err != nil { - slog.Warn("error while closing websocket connection", "error", err, "remoteAddr", ws.RemoteAddr()) - } - }() - origin := ws.Config().Origin.String() if !s.originChecker(origin) { slog.Warn("origin not allowed, rejecting connection", "origin", origin) return } - s.hub.register <- ws + slog.Debug("client connected, upgrading connection", "remoteAddr", ws.RemoteAddr()) - // Send the last known state immediately upon connection. - s.poller.SendLastState(ws) + client := &Client{hub: s.hub, conn: ws, send: make(chan []byte, 256)} - // Block by reading from the client to detect disconnection. - var msg string - for { - if err := websocket.Message.Receive(ws, &msg); err != nil { - break // Client has disconnected. - } - } + client.hub.register <- client + + go client.writePump() + client.readPump() } } diff --git a/internal/websocket/hub.go b/internal/websocket/hub.go index 7900299..7edd81d 100644 --- a/internal/websocket/hub.go +++ b/internal/websocket/hub.go @@ -3,30 +3,28 @@ package websocket import ( "context" "log/slog" - "spotify-ws/internal/spotify" "sync" - - "golang.org/x/net/websocket" ) -// Hub manages the set of active clients and broadcasts messages. +// Hub maintains the set of active clients and broadcasts messages to them. type Hub struct { - clients map[*websocket.Conn]struct{} - mu sync.RWMutex - realtime bool - register chan *websocket.Conn - unregister chan *websocket.Conn - broadcast chan *spotify.CurrentlyPlaying + clients map[*Client]struct{} + broadcast chan []byte + register chan *Client + unregister chan *Client + + // The last known state, protected by a mutex. This is sent to new clients. + lastState []byte + mu sync.RWMutex } // NewHub creates a new Hub. -func NewHub(realtime bool) *Hub { +func NewHub() *Hub { return &Hub{ - clients: make(map[*websocket.Conn]struct{}), - realtime: realtime, - register: make(chan *websocket.Conn), - unregister: make(chan *websocket.Conn), - broadcast: make(chan *spotify.CurrentlyPlaying), + clients: make(map[*Client]struct{}), + broadcast: make(chan []byte), + register: make(chan *Client), + unregister: make(chan *Client), } } @@ -38,57 +36,36 @@ func (h *Hub) Run(ctx context.Context) { for { select { case <-ctx.Done(): - h.closeAllConnections() + for client := range h.clients { + close(client.send) + } return case client := <-h.register: - h.mu.Lock() h.clients[client] = struct{}{} - h.mu.Unlock() - slog.Debug("client registered", "remoteAddr", client.RemoteAddr()) + h.mu.RLock() + if h.lastState != nil { + client.send <- h.lastState + } + h.mu.RUnlock() case client := <-h.unregister: - h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) + close(client.send) } + case message := <-h.broadcast: + h.mu.Lock() + h.lastState = message h.mu.Unlock() - slog.Debug("client unregistered", "remoteAddr", client.RemoteAddr()) - case state := <-h.broadcast: - h.broadcastState(state) - } - } -} -// Broadcast sends a state update to all connected clients. -func (h *Hub) Broadcast(state *spotify.CurrentlyPlaying) { - h.broadcast <- state -} - -// broadcastState handles the actual message sending. -func (h *Hub) broadcastState(state *spotify.CurrentlyPlaying) { - h.mu.RLock() - defer h.mu.RUnlock() - - if len(h.clients) == 0 { - return - } - - clientPayload := newPlaybackState(state, h.realtime) - for client := range h.clients { - go func(c *websocket.Conn) { - if err := websocket.JSON.Send(c, clientPayload); err != nil { - slog.Warn("failed to broadcast message", "error", err, "remoteAddr", c.RemoteAddr()) + for client := range h.clients { + select { + case client.send <- message: + default: + slog.Warn("client send buffer full, disconnecting", "remoteAddr", client.conn.RemoteAddr()) + close(client.send) + delete(h.clients, client) + } } - }(client) - } -} - -// closeAllConnections closes all active client connections during shutdown. -func (h *Hub) closeAllConnections() { - h.mu.Lock() - defer h.mu.Unlock() - for client := range h.clients { - if err := client.Close(); err != nil { - slog.Warn("error closing client connection during shutdown", "error", err, "remoteAddr", client.RemoteAddr()) } } } diff --git a/internal/websocket/poller.go b/internal/websocket/poller.go index 413647b..8420848 100644 --- a/internal/websocket/poller.go +++ b/internal/websocket/poller.go @@ -2,36 +2,39 @@ package websocket import ( "context" + "encoding/json" "log/slog" "sync" "time" "spotify-ws/internal/spotify" - - "golang.org/x/net/websocket" ) // Poller is responsible for fetching data from the Spotify API periodically. type Poller struct { client *spotify.Client hub *Hub + realtime bool lastState *spotify.CurrentlyPlaying mu sync.RWMutex } // NewPoller creates a new Poller. -func NewPoller(client *spotify.Client, hub *Hub) *Poller { +func NewPoller(client *spotify.Client, hub *Hub, realtime bool) *Poller { return &Poller{ - client: client, - hub: hub, + client: client, + hub: hub, + realtime: realtime, } } -// Run starts the polling loop. It must be run in a separate goroutine. +// Run starts the polling loop. func (p *Poller) Run(ctx context.Context) { slog.Info("poller started") defer slog.Info("poller stopped") + p.UpdateState(ctx) + ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() @@ -61,38 +64,31 @@ func (p *Poller) UpdateState(ctx context.Context) { p.mu.Unlock() if hasChanged { - if !p.hub.realtime { + if !p.realtime { trackName := "Nothing" if current.Item != nil { trackName = current.Item.Name } slog.Info("state changed, broadcasting update", "isPlaying", current.IsPlaying, "track", trackName) } - p.hub.Broadcast(current) - } -} -// SendLastState sends the cached state to a single new client. -func (p *Poller) SendLastState(ws *websocket.Conn) { - p.mu.RLock() - defer p.mu.RUnlock() + payload := newPlaybackState(current, p.realtime) + message, err := json.Marshal(payload) + if err != nil { + slog.Error("failed to marshal playback state", "error", err) + return + } - if p.lastState == nil { - return - } - clientPayload := newPlaybackState(p.lastState, p.hub.realtime) - if err := websocket.JSON.Send(ws, clientPayload); err != nil { - slog.Warn("failed to send initial state to client", "error", err, "remoteAddr", ws.RemoteAddr()) + p.hub.broadcast <- message } } // hasStateChanged performs a robust comparison between the new and old states. -// This function must be called within a lock. func (p *Poller) hasStateChanged(current *spotify.CurrentlyPlaying) bool { if p.lastState == nil { return true } - if p.hub.realtime && current.IsPlaying && current.Item != nil { + if p.realtime && current.IsPlaying && current.Item != nil { return true } if p.lastState.IsPlaying != current.IsPlaying { diff --git a/internal/websocket/server.go b/internal/websocket/server.go index cceca2a..9ab37f0 100644 --- a/internal/websocket/server.go +++ b/internal/websocket/server.go @@ -5,14 +5,14 @@ import ( "errors" "log/slog" "net/http" + "strings" "sync" "time" "spotify-ws/internal/spotify" ) -// Server is the main application orchestrator. It owns all components -// and manages the application's lifecycle. +// Server is the main application orchestrator. type Server struct { addr string httpServer *http.Server @@ -23,13 +23,12 @@ type Server struct { // NewServer creates a new, fully configured WebSocket server. func NewServer(addr string, allowedOrigins []string, spotifyClient *spotify.Client, realtime bool) *Server { - hub := NewHub(realtime) - poller := NewPoller(spotifyClient, hub) + hub := NewHub() + poller := NewPoller(spotifyClient, hub, realtime) - // Create a closure for origin checking to keep the Server's dependencies clean. originChecker := func(origin string) bool { if len(allowedOrigins) == 0 { - return true // Allow all if not specified. + return true } for _, allowedOrigin := range allowedOrigins { if allowedOrigin == origin { @@ -47,15 +46,32 @@ func NewServer(addr string, allowedOrigins []string, spotifyClient *spotify.Clie } } -// Run starts the server and its components. It blocks until the context is -// canceled and all components have shut down gracefully. +// Run starts the server and its components. func (s *Server) Run(ctx context.Context) error { - // Do an initial state fetch before starting the server. - s.poller.UpdateState(ctx) - mux := http.NewServeMux() - mux.Handle("/", s.newWebsocketHandler()) + wsHandler := s.newWebsocketHandler() + mux.HandleFunc("/health", healthHandler) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + isWebSocket := r.Header.Get("Upgrade") == "websocket" && + strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") + + if isWebSocket { + wsHandler.ServeHTTP(w, r) + return + } + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + w.Header().Set("X-Source", "github.com/skidoodle/spotify-ws") + w.WriteHeader(http.StatusUpgradeRequired) + if _, err := w.Write([]byte("426 Upgrade Required (github.com/skidoodle/spotify-ws)")); err != nil { + slog.Warn("failed to write upgrade required response", "error", err) + } + }) s.httpServer = &http.Server{ Addr: s.addr, @@ -63,7 +79,7 @@ func (s *Server) Run(ctx context.Context) error { } var wg sync.WaitGroup - wg.Add(2) // For the hub and the poller. + wg.Add(2) go func() { defer wg.Done() @@ -75,32 +91,24 @@ func (s *Server) Run(ctx context.Context) error { s.poller.Run(ctx) }() - // Start the HTTP server. go func() { - if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - slog.Error("http server error", "error", err) + <-ctx.Done() + slog.Info("shutdown signal received, stopping http server") + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := s.httpServer.Shutdown(shutdownCtx); err != nil { + slog.Error("http server shutdown error", "error", err) } }() - // Wait for shutdown signal. - <-ctx.Done() - slog.Info("shutdown signal received") + slog.Info("http server listening", "addr", s.addr) + if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } - // The hub and poller will stop automatically via the context. - // We just need to shut down the HTTP server and wait for goroutines to finish. - s.shutdown() wg.Wait() return nil } - -// shutdown gracefully shuts down the HTTP server. -func (s *Server) shutdown() { - slog.Info("shutting down http server") - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := s.httpServer.Shutdown(shutdownCtx); err != nil { - slog.Error("http server shutdown error", "error", err) - } -} diff --git a/internal/websocket/state.go b/internal/websocket/state.go index db80432..860ab45 100644 --- a/internal/websocket/state.go +++ b/internal/websocket/state.go @@ -2,8 +2,7 @@ package websocket import "spotify-ws/internal/spotify" -// PlaybackState is the client-facing data structure. It conditionally omits -// real-time data fields from JSON based on the server's mode. +// PlaybackState is the client-facing data structure. type PlaybackState struct { IsPlaying bool `json:"is_playing"` ProgressMs int `json:"progress_ms,omitempty"` @@ -12,7 +11,6 @@ type PlaybackState struct { } // newPlaybackState creates a client-facing PlaybackState from the internal Spotify data. -// It includes progress data only if the server is in real-time mode. func newPlaybackState(data *spotify.CurrentlyPlaying, realtime bool) PlaybackState { if data == nil { return PlaybackState{IsPlaying: false}