diff --git a/server/Cargo.toml b/server/Cargo.toml index 3a45fe1..a3a78ff 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,10 +5,13 @@ edition = "2024" [dependencies] tokio = { version = "1.45.1", features = ["full"] } +serde = { version = "1.0.219", features = ["derive"] } common = { path = "../common" } anyhow = "1.0.98" rumqttd = "0.19.0" toml = "0.8.23" warp = "0.3.7" +rusqlite = "0.36.0" +chrono = "0.4.41" +async-trait = "0.1.88" futures-util = "0.3.31" -serde = { version = "1.0.219", features = ["derive"] } diff --git a/server/config.toml b/server/config.toml index cfb8a53..48308a6 100644 --- a/server/config.toml +++ b/server/config.toml @@ -1,3 +1,7 @@ +[storage] +sqlite = true # if set to false, broker will use memory (not optimal for production) +db_path = "_db/" + [mqtt] id = 0 diff --git a/server/src/broker/manager.rs b/server/src/broker/manager.rs index 34162ff..0ba9d64 100644 --- a/server/src/broker/manager.rs +++ b/server/src/broker/manager.rs @@ -1,22 +1,24 @@ use rumqttd::{Broker, Config}; use std::sync::Arc; -use crate::bridge::ClientManager; +use crate::{bridge::ClientManager, storage::StorageRepositoryImpl}; use super::subscriber::MqttSubscriber; pub struct MqttBroker { broker: &'static mut Broker, clients: Arc, + storage: Arc, } impl MqttBroker { - pub async fn new(cfg: Config) -> Self { + pub async fn new(cfg: Config, storage: Arc) -> Self { let clients = Arc::new(ClientManager::new()); let broker: &'static mut Broker = Box::leak(Box::new(Broker::new(cfg))); Self { broker, clients, + storage, } } @@ -25,7 +27,7 @@ impl MqttBroker { } pub async fn run(self) -> anyhow::Result<()> { - let mut subscriber = MqttSubscriber::new(&self.broker, self.clients); + let mut subscriber = MqttSubscriber::new(&self.broker, self.clients, self.storage.inner()); println!("starting mqtt broker on specified port"); diff --git a/server/src/broker/subscriber.rs b/server/src/broker/subscriber.rs index 7044de5..8b1d48f 100644 --- a/server/src/broker/subscriber.rs +++ b/server/src/broker/subscriber.rs @@ -2,21 +2,23 @@ use rumqttd::{local::LinkRx, Broker, Notification}; use common::{StatusMessage, MQTT_TOPIC}; use std::sync::Arc; -use crate::bridge::ClientManager; +use crate::{bridge::ClientManager, storage::StorageRepository}; pub struct MqttSubscriber { link_rx: LinkRx, clients: Arc, + storage: Arc, } impl MqttSubscriber { - pub fn new(broker: &Broker, clients: Arc) -> Self { + pub fn new(broker: &Broker, clients: Arc, storage: Arc) -> Self { let (mut link_tx, link_rx) = broker.link("internal-subscriber").unwrap(); link_tx.subscribe(MQTT_TOPIC).unwrap(); Self { link_rx, clients, + storage, } } @@ -27,6 +29,10 @@ impl MqttSubscriber { if let Ok(payload_str) = payload.to_string() { self.clients.broadcast(payload_str).await; + + if let Err(e) = self.storage.record_message(&payload.agent).await { + eprintln!("failed to record message for {}: {}", &payload.agent, e); + } } } } diff --git a/server/src/config.rs b/server/src/config.rs index 34d9244..41709e5 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,5 +1,5 @@ -use rumqttd::Config; use serde::Deserialize; +use rumqttd::Config; const CONFIG_PATH: &str = if cfg!(debug_assertions) { "server/config.toml" @@ -7,8 +7,15 @@ const CONFIG_PATH: &str = if cfg!(debug_assertions) { "config.toml" }; +#[derive(Debug, Deserialize)] +pub struct StorageConfig { + pub sqlite: bool, + pub db_path: String, +} + #[derive(Debug, Deserialize)] pub struct Configuration { + pub storage: StorageConfig, pub mqtt: Config, } diff --git a/server/src/main.rs b/server/src/main.rs index 2530b69..413a0ed 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,22 +1,31 @@ +use storage::{StorageRepositoryImpl, StorageStrategy}; use broker::manager::MqttBroker; use config::load_config; -use websocket::server::Websocket; +use server::Server; +use std::sync::Arc; pub mod broker; pub mod bridge; pub mod config; -pub mod websocket; +pub mod storage; +pub mod server; #[tokio::main] async fn main() -> anyhow::Result<()> { let cfg = load_config()?; - let broker = MqttBroker::new(cfg.mqtt).await; - let ws = Websocket::new(broker.clients()); + let storage = Arc::new(if cfg.storage.sqlite { + StorageRepositoryImpl::new(StorageStrategy::SQLite(format!("{}agents.db", cfg.storage.db_path))) + } else { + StorageRepositoryImpl::new(StorageStrategy::InMemory) + }); + + let broker = MqttBroker::new(cfg.mqtt, Arc::clone(&storage)).await; + let server = Server::new(broker.clients(), Arc::clone(&storage)); tokio::select! { res = broker.run() => res?, - res = ws.run() => res, + res = server.serve() => res, } Ok(()) diff --git a/server/src/server/http.rs b/server/src/server/http.rs new file mode 100644 index 0000000..a1f83f6 --- /dev/null +++ b/server/src/server/http.rs @@ -0,0 +1,27 @@ +use warp::{reply::json, Filter, Reply, Rejection}; +use std::sync::Arc; + +use crate::storage::StorageRepository; + +pub struct HttpRoutes { + storage: Arc, +} + +impl HttpRoutes { + pub fn new(storage: Arc) -> Self { + Self { storage } + } + + pub fn routes(self: Arc) -> impl Filter + Clone { + warp::path!("agents") + .and(warp::get()) + .and(warp::any().map(move || self.storage.clone())) + .and_then(Self::get_agents) + } + + async fn get_agents(storage: Arc) -> Result { + let agents = storage.get_agents().await.unwrap(); + + Ok(json(&agents)) + } +} diff --git a/server/src/server/mod.rs b/server/src/server/mod.rs new file mode 100644 index 0000000..c004364 --- /dev/null +++ b/server/src/server/mod.rs @@ -0,0 +1,37 @@ +use http::HttpRoutes; +use warp::Filter; +use websocket::WebsocketRoutes; +use std::sync::Arc; + +use crate::{bridge::ClientManager, storage::StorageRepositoryImpl}; + +pub mod http; +pub mod websocket; + +pub struct Server { + clients: Arc, + storage: Arc, +} + +impl Server { + pub fn new(clients: Arc, storage: Arc) -> Self { + Self { + clients, + storage, + } + } + + pub async fn serve(&self) { + let http_routes = Arc::new(HttpRoutes::new(self.storage.inner())); + let ws_routes = Arc::new(WebsocketRoutes::new(self.clients.clone())); + + let cors = warp::cors() + .allow_any_origin() + .allow_methods(vec!["GET"]); + + let routes = http_routes.routes().with(cors).or(ws_routes.routes()); + + println!("starting websocket server on :{}", 3000); + warp::serve(routes).run(([0, 0, 0, 0], 3000)).await; + } +} diff --git a/server/src/server/websocket.rs b/server/src/server/websocket.rs new file mode 100644 index 0000000..c6e329d --- /dev/null +++ b/server/src/server/websocket.rs @@ -0,0 +1,42 @@ +use warp::{filters::ws::{Message, WebSocket}, Filter, Reply, Rejection}; +use futures_util::{StreamExt, SinkExt}; +use tokio::sync::mpsc::channel; +use std::sync::Arc; + +use crate::bridge::ClientManager; + +pub struct WebsocketRoutes { + clients: Arc, +} + +impl WebsocketRoutes { + pub fn new(clients: Arc) -> Self { + Self { clients } + } + + pub fn routes(self: Arc) -> impl Filter + Clone { + warp::path("ws") + .and(warp::get()) + .and(warp::ws()) + .and(warp::any().map(move || self.clients.clone())) + .map(|websocket: warp::ws::Ws, clients: Arc| { + websocket.on_upgrade(move |websocket| Self::handle_ws_connection(websocket, clients)) + }) + } + + async fn handle_ws_connection(websocket: WebSocket, clients: Arc) { + let (mut ws_tx, _) = websocket.split(); + let (tx, mut rx) = channel::(100); + + clients.add_client(tx).await; + + tokio::spawn(async move { + while let Some(msg) = rx.recv().await { + if ws_tx.send(Message::text(msg)).await.is_err() { + break; + } + } + }); + } + +} diff --git a/server/src/storage/memory.rs b/server/src/storage/memory.rs new file mode 100644 index 0000000..97bb138 --- /dev/null +++ b/server/src/storage/memory.rs @@ -0,0 +1,59 @@ +use std::collections::HashMap; +use chrono::Utc; +use tokio::sync::Mutex; +use async_trait::async_trait; + +use super::{StorageRepository, UptimeMessage, UptimeModel}; + +pub struct InMemoryRepository { + agents: Mutex> +} + +impl InMemoryRepository { + pub fn new() -> Self { + Self { + agents: Default::default(), + } + } +} + +#[async_trait] +impl StorageRepository for InMemoryRepository { + async fn record_message(&self, agent: &str) -> anyhow::Result<()> { + let mut agents = self.agents.lock().await; + let now = Utc::now(); + + agents.entry(agent.to_string()) + .and_modify(|a| { + a.last_seen = now; + a.message_count += 1; + }) + .or_insert_with(|| UptimeModel { + id: agent.to_string(), + first_seen: now, + last_seen: now, + message_count: 1, + }); + + Ok(()) + } + + /* + async fn get_uptime(&self, agent: &str) -> anyhow::Result> { + let agents = self.agents.lock().await; + + match agents.get(agent) { + Some(data) => { + Ok(Some(data.clone().into())) + } + None => Ok(None), + } + } + */ + + async fn get_agents(&self) -> anyhow::Result> { + let agents = self.agents.lock().await; + + Ok(agents.values().cloned().map(Into::into).collect()) + } +} diff --git a/server/src/storage/mod.rs b/server/src/storage/mod.rs new file mode 100644 index 0000000..1023c3e --- /dev/null +++ b/server/src/storage/mod.rs @@ -0,0 +1,70 @@ +use serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; +use sqlite::SQLiteRepository; +use std::sync::Arc; +use memory::InMemoryRepository; +use async_trait::async_trait; +use common::MQTT_SEND_INTERVAL; + +pub mod memory; +pub mod sqlite; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UptimeModel { + pub id: String, + pub first_seen: DateTime, + pub last_seen: DateTime, + pub message_count: u64, +} + +impl Into for UptimeModel { + fn into(self) -> UptimeMessage { + let duration = Utc::now().signed_duration_since(self.first_seen); + let expected_messages = duration.num_seconds() as f64 / MQTT_SEND_INTERVAL as f64; + + let uptime_pct = (self.message_count as f64 / expected_messages * 100.0).min(100.0); + + UptimeMessage { + agent: self.id, + uptime: uptime_pct, + last_seen: self.last_seen, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UptimeMessage { + pub agent: String, + pub uptime: f64, + pub last_seen: DateTime, +} + +#[async_trait] +pub trait StorageRepository: Send + Sync { + async fn record_message(&self, agent: &str) -> anyhow::Result<()>; + async fn get_agents(&self) -> anyhow::Result>; +} + +pub enum StorageStrategy { + InMemory, + SQLite(String), +} + +pub struct StorageRepositoryImpl { + inner: Arc +} + +impl StorageRepositoryImpl { + pub fn new(strategy: StorageStrategy) -> Self { + let inner: Arc = match strategy { + StorageStrategy::InMemory => Arc::new(InMemoryRepository::new()), + StorageStrategy::SQLite(path) => Arc::new(SQLiteRepository::new(path)), + }; + + Self { inner } + } + + pub fn inner(&self) -> Arc { + self.inner.clone() + } +} diff --git a/server/src/storage/sqlite.rs b/server/src/storage/sqlite.rs new file mode 100644 index 0000000..c3ff57c --- /dev/null +++ b/server/src/storage/sqlite.rs @@ -0,0 +1,79 @@ +use async_trait::async_trait; +use rusqlite::Connection; +use tokio::sync::Mutex; +use std::path::Path; +use chrono::{DateTime, Utc}; + +use super::{StorageRepository, UptimeMessage, UptimeModel}; + +pub struct SQLiteRepository { + conn: Mutex +} + +impl SQLiteRepository { + pub fn new>(path: P) -> Self { + let conn = Connection::open(path).unwrap(); + + conn.pragma_update(None, "journal_mode", "WAL").unwrap(); + conn.pragma_update(None, "foreign_keys", "ON").unwrap(); + conn.pragma_update(None, "synchronous", "NORMAL").unwrap(); + + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS agents ( + id TEXT PRIMARY KEY, + first_seen TEXT NOT NULL, + last_seen TEXT NOT NULL, + message_count INTEGER NOT NULL DEFAULT 0 + ) STRICT; + + CREATE INDEX IF NOT EXISTS idx_agents_id ON agents(id); + CREATE INDEX IF NOT EXISTS idx_agents_times ON agents(first_seen, last_seen); + "# + ).unwrap(); + + Self { + conn: Mutex::new(conn) + } + } +} + +#[async_trait] +impl StorageRepository for SQLiteRepository { + async fn record_message(&self, agent: &str) -> anyhow::Result<()> { + let conn = self.conn.lock().await; + let now = Utc::now().to_rfc3339(); + + conn.execute(r#" + INSERT INTO agents (id, first_seen, last_seen, message_count) + VALUES (?1, ?2, ?2, 1) + ON CONFLICT (id) DO UPDATE SET + last_seen = excluded.last_seen, + message_count = message_count + 1; + "#, [agent, &now] + )?; + + Ok(()) + } + + async fn get_agents(&self) -> anyhow::Result> { + let conn = self.conn.lock().await; + let mut stmt = conn.prepare("SELECT id, first_seen, last_seen, message_count FROM agents")?; + + let result = stmt.query_map([], |row| { + let first_seen: DateTime = row.get::<_, String>(1)?.parse().unwrap(); + let last_seen: DateTime = row.get::<_, String>(2)?.parse().unwrap(); + + Ok(UptimeModel { + id: row.get(0)?, + first_seen, + last_seen, + message_count: row.get(3)?, + }) + })?; + + let models: Result, _> = result.collect(); + + Ok(models?.into_iter().map(Into::into).collect()) + } +} diff --git a/server/src/websocket/mod.rs b/server/src/websocket/mod.rs deleted file mode 100644 index af1b101..0000000 --- a/server/src/websocket/mod.rs +++ /dev/null @@ -1,23 +0,0 @@ -use warp::filters::ws::{Message, WebSocket}; -use futures_util::{SinkExt, StreamExt}; -use tokio::sync::mpsc::channel; -use std::sync::Arc; - -use crate::bridge::ClientManager; - -pub mod server; - -pub async fn handle_ws_connection(websocket: WebSocket, clients: Arc) { - let (mut ws_tx, _) = websocket.split(); - let (tx, mut rx) = channel::(100); - - clients.add_client(tx).await; - - tokio::spawn(async move { - while let Some(msg) = rx.recv().await { - if ws_tx.send(Message::text(msg)).await.is_err() { - break; - } - } - }); -} diff --git a/server/src/websocket/server.rs b/server/src/websocket/server.rs deleted file mode 100644 index e3cc05b..0000000 --- a/server/src/websocket/server.rs +++ /dev/null @@ -1,27 +0,0 @@ -use warp::Filter; -use std::sync::Arc; - -use crate::bridge::ClientManager; -use super::handle_ws_connection; - -pub struct Websocket { - clients: Arc, -} - -impl Websocket { - pub fn new(clients: Arc) -> Self { - Self { clients } - } - - pub async fn run(self) { - let route = warp::path("ws") - .and(warp::ws()) - .and(warp::any().map(move || self.clients.clone())) - .map(|ws: warp::ws::Ws, clients| { - ws.on_upgrade(move |websocket| handle_ws_connection(websocket, clients)) - }); - - println!("starting websocket server on :3000"); - warp::serve(route).run(([0, 0, 0, 0], 3000)).await; - } -}