implement storage + improve code

This commit is contained in:
csehviktor
2025-07-06 02:28:29 +02:00
parent d2a82e973b
commit bf9d1e4da6
14 changed files with 357 additions and 62 deletions

View File

@@ -5,10 +5,13 @@ edition = "2024"
[dependencies] [dependencies]
tokio = { version = "1.45.1", features = ["full"] } tokio = { version = "1.45.1", features = ["full"] }
serde = { version = "1.0.219", features = ["derive"] }
common = { path = "../common" } common = { path = "../common" }
anyhow = "1.0.98" anyhow = "1.0.98"
rumqttd = "0.19.0" rumqttd = "0.19.0"
toml = "0.8.23" toml = "0.8.23"
warp = "0.3.7" warp = "0.3.7"
rusqlite = "0.36.0"
chrono = "0.4.41"
async-trait = "0.1.88"
futures-util = "0.3.31" futures-util = "0.3.31"
serde = { version = "1.0.219", features = ["derive"] }

View File

@@ -1,3 +1,7 @@
[storage]
sqlite = true # if set to false, broker will use memory (not optimal for production)
db_path = "_db/"
[mqtt] [mqtt]
id = 0 id = 0

View File

@@ -1,22 +1,24 @@
use rumqttd::{Broker, Config}; use rumqttd::{Broker, Config};
use std::sync::Arc; use std::sync::Arc;
use crate::bridge::ClientManager; use crate::{bridge::ClientManager, storage::StorageRepositoryImpl};
use super::subscriber::MqttSubscriber; use super::subscriber::MqttSubscriber;
pub struct MqttBroker { pub struct MqttBroker {
broker: &'static mut Broker, broker: &'static mut Broker,
clients: Arc<ClientManager>, clients: Arc<ClientManager>,
storage: Arc<StorageRepositoryImpl>,
} }
impl MqttBroker { impl MqttBroker {
pub async fn new(cfg: Config) -> Self { pub async fn new(cfg: Config, storage: Arc<StorageRepositoryImpl>) -> Self {
let clients = Arc::new(ClientManager::new()); let clients = Arc::new(ClientManager::new());
let broker: &'static mut Broker = Box::leak(Box::new(Broker::new(cfg))); let broker: &'static mut Broker = Box::leak(Box::new(Broker::new(cfg)));
Self { Self {
broker, broker,
clients, clients,
storage,
} }
} }
@@ -25,7 +27,7 @@ impl MqttBroker {
} }
pub async fn run(self) -> anyhow::Result<()> { pub async fn run(self) -> anyhow::Result<()> {
let mut subscriber = MqttSubscriber::new(&self.broker, self.clients); let mut subscriber = MqttSubscriber::new(&self.broker, self.clients, self.storage.inner());
println!("starting mqtt broker on specified port"); println!("starting mqtt broker on specified port");

View File

@@ -2,21 +2,23 @@ use rumqttd::{local::LinkRx, Broker, Notification};
use common::{StatusMessage, MQTT_TOPIC}; use common::{StatusMessage, MQTT_TOPIC};
use std::sync::Arc; use std::sync::Arc;
use crate::bridge::ClientManager; use crate::{bridge::ClientManager, storage::StorageRepository};
pub struct MqttSubscriber { pub struct MqttSubscriber {
link_rx: LinkRx, link_rx: LinkRx,
clients: Arc<ClientManager>, clients: Arc<ClientManager>,
storage: Arc<dyn StorageRepository>,
} }
impl MqttSubscriber { impl MqttSubscriber {
pub fn new(broker: &Broker, clients: Arc<ClientManager>) -> Self { pub fn new(broker: &Broker, clients: Arc<ClientManager>, storage: Arc<dyn StorageRepository>) -> Self {
let (mut link_tx, link_rx) = broker.link("internal-subscriber").unwrap(); let (mut link_tx, link_rx) = broker.link("internal-subscriber").unwrap();
link_tx.subscribe(MQTT_TOPIC).unwrap(); link_tx.subscribe(MQTT_TOPIC).unwrap();
Self { Self {
link_rx, link_rx,
clients, clients,
storage,
} }
} }
@@ -27,6 +29,10 @@ impl MqttSubscriber {
if let Ok(payload_str) = payload.to_string() { if let Ok(payload_str) = payload.to_string() {
self.clients.broadcast(payload_str).await; self.clients.broadcast(payload_str).await;
if let Err(e) = self.storage.record_message(&payload.agent).await {
eprintln!("failed to record message for {}: {}", &payload.agent, e);
}
} }
} }
} }

View File

@@ -1,5 +1,5 @@
use rumqttd::Config;
use serde::Deserialize; use serde::Deserialize;
use rumqttd::Config;
const CONFIG_PATH: &str = if cfg!(debug_assertions) { const CONFIG_PATH: &str = if cfg!(debug_assertions) {
"server/config.toml" "server/config.toml"
@@ -7,8 +7,15 @@ const CONFIG_PATH: &str = if cfg!(debug_assertions) {
"config.toml" "config.toml"
}; };
#[derive(Debug, Deserialize)]
pub struct StorageConfig {
pub sqlite: bool,
pub db_path: String,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct Configuration { pub struct Configuration {
pub storage: StorageConfig,
pub mqtt: Config, pub mqtt: Config,
} }

View File

@@ -1,22 +1,31 @@
use storage::{StorageRepositoryImpl, StorageStrategy};
use broker::manager::MqttBroker; use broker::manager::MqttBroker;
use config::load_config; use config::load_config;
use websocket::server::Websocket; use server::Server;
use std::sync::Arc;
pub mod broker; pub mod broker;
pub mod bridge; pub mod bridge;
pub mod config; pub mod config;
pub mod websocket; pub mod storage;
pub mod server;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let cfg = load_config()?; let cfg = load_config()?;
let broker = MqttBroker::new(cfg.mqtt).await; let storage = Arc::new(if cfg.storage.sqlite {
let ws = Websocket::new(broker.clients()); StorageRepositoryImpl::new(StorageStrategy::SQLite(format!("{}agents.db", cfg.storage.db_path)))
} else {
StorageRepositoryImpl::new(StorageStrategy::InMemory)
});
let broker = MqttBroker::new(cfg.mqtt, Arc::clone(&storage)).await;
let server = Server::new(broker.clients(), Arc::clone(&storage));
tokio::select! { tokio::select! {
res = broker.run() => res?, res = broker.run() => res?,
res = ws.run() => res, res = server.serve() => res,
} }
Ok(()) Ok(())

27
server/src/server/http.rs Normal file
View File

@@ -0,0 +1,27 @@
use warp::{reply::json, Filter, Reply, Rejection};
use std::sync::Arc;
use crate::storage::StorageRepository;
pub struct HttpRoutes {
storage: Arc<dyn StorageRepository>,
}
impl HttpRoutes {
pub fn new(storage: Arc<dyn StorageRepository>) -> Self {
Self { storage }
}
pub fn routes(self: Arc<Self>) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path!("agents")
.and(warp::get())
.and(warp::any().map(move || self.storage.clone()))
.and_then(Self::get_agents)
}
async fn get_agents(storage: Arc<dyn StorageRepository>) -> Result<impl Reply, Rejection> {
let agents = storage.get_agents().await.unwrap();
Ok(json(&agents))
}
}

37
server/src/server/mod.rs Normal file
View File

@@ -0,0 +1,37 @@
use http::HttpRoutes;
use warp::Filter;
use websocket::WebsocketRoutes;
use std::sync::Arc;
use crate::{bridge::ClientManager, storage::StorageRepositoryImpl};
pub mod http;
pub mod websocket;
pub struct Server {
clients: Arc<ClientManager>,
storage: Arc<StorageRepositoryImpl>,
}
impl Server {
pub fn new(clients: Arc<ClientManager>, storage: Arc<StorageRepositoryImpl>) -> Self {
Self {
clients,
storage,
}
}
pub async fn serve(&self) {
let http_routes = Arc::new(HttpRoutes::new(self.storage.inner()));
let ws_routes = Arc::new(WebsocketRoutes::new(self.clients.clone()));
let cors = warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET"]);
let routes = http_routes.routes().with(cors).or(ws_routes.routes());
println!("starting websocket server on :{}", 3000);
warp::serve(routes).run(([0, 0, 0, 0], 3000)).await;
}
}

View File

@@ -0,0 +1,42 @@
use warp::{filters::ws::{Message, WebSocket}, Filter, Reply, Rejection};
use futures_util::{StreamExt, SinkExt};
use tokio::sync::mpsc::channel;
use std::sync::Arc;
use crate::bridge::ClientManager;
pub struct WebsocketRoutes {
clients: Arc<ClientManager>,
}
impl WebsocketRoutes {
pub fn new(clients: Arc<ClientManager>) -> Self {
Self { clients }
}
pub fn routes(self: Arc<Self>) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
warp::path("ws")
.and(warp::get())
.and(warp::ws())
.and(warp::any().map(move || self.clients.clone()))
.map(|websocket: warp::ws::Ws, clients: Arc<ClientManager>| {
websocket.on_upgrade(move |websocket| Self::handle_ws_connection(websocket, clients))
})
}
async fn handle_ws_connection(websocket: WebSocket, clients: Arc<ClientManager>) {
let (mut ws_tx, _) = websocket.split();
let (tx, mut rx) = channel::<String>(100);
clients.add_client(tx).await;
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if ws_tx.send(Message::text(msg)).await.is_err() {
break;
}
}
});
}
}

View File

@@ -0,0 +1,59 @@
use std::collections::HashMap;
use chrono::Utc;
use tokio::sync::Mutex;
use async_trait::async_trait;
use super::{StorageRepository, UptimeMessage, UptimeModel};
pub struct InMemoryRepository {
agents: Mutex<HashMap<String, UptimeModel>>
}
impl InMemoryRepository {
pub fn new() -> Self {
Self {
agents: Default::default(),
}
}
}
#[async_trait]
impl StorageRepository for InMemoryRepository {
async fn record_message(&self, agent: &str) -> anyhow::Result<()> {
let mut agents = self.agents.lock().await;
let now = Utc::now();
agents.entry(agent.to_string())
.and_modify(|a| {
a.last_seen = now;
a.message_count += 1;
})
.or_insert_with(|| UptimeModel {
id: agent.to_string(),
first_seen: now,
last_seen: now,
message_count: 1,
});
Ok(())
}
/*
async fn get_uptime(&self, agent: &str) -> anyhow::Result<Option<UptimeMessage>> {
let agents = self.agents.lock().await;
match agents.get(agent) {
Some(data) => {
Ok(Some(data.clone().into()))
}
None => Ok(None),
}
}
*/
async fn get_agents(&self) -> anyhow::Result<Vec<UptimeMessage>> {
let agents = self.agents.lock().await;
Ok(agents.values().cloned().map(Into::into).collect())
}
}

70
server/src/storage/mod.rs Normal file
View File

@@ -0,0 +1,70 @@
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
use sqlite::SQLiteRepository;
use std::sync::Arc;
use memory::InMemoryRepository;
use async_trait::async_trait;
use common::MQTT_SEND_INTERVAL;
pub mod memory;
pub mod sqlite;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UptimeModel {
pub id: String,
pub first_seen: DateTime<Utc>,
pub last_seen: DateTime<Utc>,
pub message_count: u64,
}
impl Into<UptimeMessage> for UptimeModel {
fn into(self) -> UptimeMessage {
let duration = Utc::now().signed_duration_since(self.first_seen);
let expected_messages = duration.num_seconds() as f64 / MQTT_SEND_INTERVAL as f64;
let uptime_pct = (self.message_count as f64 / expected_messages * 100.0).min(100.0);
UptimeMessage {
agent: self.id,
uptime: uptime_pct,
last_seen: self.last_seen,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UptimeMessage {
pub agent: String,
pub uptime: f64,
pub last_seen: DateTime<Utc>,
}
#[async_trait]
pub trait StorageRepository: Send + Sync {
async fn record_message(&self, agent: &str) -> anyhow::Result<()>;
async fn get_agents(&self) -> anyhow::Result<Vec<UptimeMessage>>;
}
pub enum StorageStrategy {
InMemory,
SQLite(String),
}
pub struct StorageRepositoryImpl {
inner: Arc<dyn StorageRepository>
}
impl StorageRepositoryImpl {
pub fn new(strategy: StorageStrategy) -> Self {
let inner: Arc<dyn StorageRepository> = match strategy {
StorageStrategy::InMemory => Arc::new(InMemoryRepository::new()),
StorageStrategy::SQLite(path) => Arc::new(SQLiteRepository::new(path)),
};
Self { inner }
}
pub fn inner(&self) -> Arc<dyn StorageRepository> {
self.inner.clone()
}
}

View File

@@ -0,0 +1,79 @@
use async_trait::async_trait;
use rusqlite::Connection;
use tokio::sync::Mutex;
use std::path::Path;
use chrono::{DateTime, Utc};
use super::{StorageRepository, UptimeMessage, UptimeModel};
pub struct SQLiteRepository {
conn: Mutex<Connection>
}
impl SQLiteRepository {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
let conn = Connection::open(path).unwrap();
conn.pragma_update(None, "journal_mode", "WAL").unwrap();
conn.pragma_update(None, "foreign_keys", "ON").unwrap();
conn.pragma_update(None, "synchronous", "NORMAL").unwrap();
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS agents (
id TEXT PRIMARY KEY,
first_seen TEXT NOT NULL,
last_seen TEXT NOT NULL,
message_count INTEGER NOT NULL DEFAULT 0
) STRICT;
CREATE INDEX IF NOT EXISTS idx_agents_id ON agents(id);
CREATE INDEX IF NOT EXISTS idx_agents_times ON agents(first_seen, last_seen);
"#
).unwrap();
Self {
conn: Mutex::new(conn)
}
}
}
#[async_trait]
impl StorageRepository for SQLiteRepository {
async fn record_message(&self, agent: &str) -> anyhow::Result<()> {
let conn = self.conn.lock().await;
let now = Utc::now().to_rfc3339();
conn.execute(r#"
INSERT INTO agents (id, first_seen, last_seen, message_count)
VALUES (?1, ?2, ?2, 1)
ON CONFLICT (id) DO UPDATE SET
last_seen = excluded.last_seen,
message_count = message_count + 1;
"#, [agent, &now]
)?;
Ok(())
}
async fn get_agents(&self) -> anyhow::Result<Vec<UptimeMessage>> {
let conn = self.conn.lock().await;
let mut stmt = conn.prepare("SELECT id, first_seen, last_seen, message_count FROM agents")?;
let result = stmt.query_map([], |row| {
let first_seen: DateTime<Utc> = row.get::<_, String>(1)?.parse().unwrap();
let last_seen: DateTime<Utc> = row.get::<_, String>(2)?.parse().unwrap();
Ok(UptimeModel {
id: row.get(0)?,
first_seen,
last_seen,
message_count: row.get(3)?,
})
})?;
let models: Result<Vec<UptimeModel>, _> = result.collect();
Ok(models?.into_iter().map(Into::into).collect())
}
}

View File

@@ -1,23 +0,0 @@
use warp::filters::ws::{Message, WebSocket};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc::channel;
use std::sync::Arc;
use crate::bridge::ClientManager;
pub mod server;
pub async fn handle_ws_connection(websocket: WebSocket, clients: Arc<ClientManager>) {
let (mut ws_tx, _) = websocket.split();
let (tx, mut rx) = channel::<String>(100);
clients.add_client(tx).await;
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if ws_tx.send(Message::text(msg)).await.is_err() {
break;
}
}
});
}

View File

@@ -1,27 +0,0 @@
use warp::Filter;
use std::sync::Arc;
use crate::bridge::ClientManager;
use super::handle_ws_connection;
pub struct Websocket {
clients: Arc<ClientManager>,
}
impl Websocket {
pub fn new(clients: Arc<ClientManager>) -> Self {
Self { clients }
}
pub async fn run(self) {
let route = warp::path("ws")
.and(warp::ws())
.and(warp::any().map(move || self.clients.clone()))
.map(|ws: warp::ws::Ws, clients| {
ws.on_upgrade(move |websocket| handle_ws_connection(websocket, clients))
});
println!("starting websocket server on :3000");
warp::serve(route).run(([0, 0, 0, 0], 3000)).await;
}
}