diff --git a/Cargo.lock b/Cargo.lock index f75d381..4280434 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1907,6 +1907,8 @@ dependencies = [ "genai", "human-panic", "irc", + "serde", + "serde_json", "tokio", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 5364f49..ed0f02d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,9 @@ futures = "0.3" human-panic = "2.0" genai = "0.4.3" irc = "1.1" -tokio = { version = "1", features = [ "macros", "rt-multi-thread" ] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", features = [ "io-util", "macros", "net", "rt-multi-thread", "sync" ] } tracing = "0.1" tracing-subscriber = "0.3" diff --git a/rustfmt.toml b/rustfmt.toml index 768987b..a3f43d4 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -3,5 +3,5 @@ style_edition = "2024" comment_width = 100 format_code_in_doc_comments = true imports_granularity = "Crate" -imports_layout = "Vertical" +imports_layout = "HorizontalVertical" wrap_comments = true diff --git a/src/chat.rs b/src/chat.rs index eb9d13b..5bd7239 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,24 +1,10 @@ -use color_eyre::{ - Result, - eyre::{ - OptionExt, - WrapErr, - }, -}; +use color_eyre::{Result, eyre::WrapErr}; // Lots of namespace confusion potential use crate::qna::LLMHandle; use config::Config as MainConfig; use futures::StreamExt; -use irc::client::prelude::{ - Client as IRCClient, - Command, - Config as IRCConfig, -}; -use tracing::{ - Level, - event, - instrument, -}; +use irc::client::prelude::{Client as IRCClient, Command, Config as IRCConfig}; +use tracing::{Level, event, instrument}; #[derive(Debug)] pub struct Chat { @@ -73,7 +59,9 @@ impl Chat { // Make it all one line. msg.retain(|c| c != '\n' && c != '\r'); msg.truncate(500); - client.send_privmsg(&channel, msg).wrap_err("Could not send to {channel}")?; + client + .send_privmsg(&channel, msg) + .wrap_err("Could not send to {channel}")?; } } diff --git a/src/commands.rs b/src/commands.rs index 1c51c7b..79eb091 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -1,8 +1,5 @@ use color_eyre::Result; -use std::path::{ - Path, - PathBuf, -}; +use std::path::{Path, PathBuf}; #[derive(Clone, Debug)] pub struct Root { @@ -16,7 +13,7 @@ impl Root { } } - pub fn run_command(cmd_string: impl AsRef) -> Result<()> { + pub fn run_command(_cmd_string: impl AsRef) -> Result<()> { todo!(); } } diff --git a/src/event.rs b/src/event.rs new file mode 100644 index 0000000..bc7356d --- /dev/null +++ b/src/event.rs @@ -0,0 +1,14 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize)] +pub struct Event { + message: String, +} + +impl Event { + pub fn new(msg: impl Into) -> Self { + Self { + message: msg.into(), + } + } +} diff --git a/src/event_manager.rs b/src/event_manager.rs new file mode 100644 index 0000000..8838eae --- /dev/null +++ b/src/event_manager.rs @@ -0,0 +1,89 @@ +use std::{collections::VecDeque, path::Path, sync::Arc}; + +use color_eyre::Result; +use tokio::{ + io::AsyncWriteExt, + net::{UnixListener, UnixStream}, + sync::{RwLock, broadcast}, +}; +use tracing::error; + +use crate::event::Event; + +// Hard coding for now. Maybe make this a parameter to new. +const EVENT_BUF_MAX: usize = 1000; + +// Manager for communication with plugins. +pub struct EventManager { + announce: broadcast::Sender, // Everything broadcasts here. + events: Arc>>, // Ring buffer. +} + +impl EventManager { + pub fn new() -> Result { + let (announce, _) = broadcast::channel(100); + + Ok(Self { + announce, + events: Arc::new(RwLock::new(VecDeque::::new())), + }) + } + + pub async fn broadcast(&self, event: &Event) -> Result<()> { + let msg = serde_json::to_string(event)? + "\n"; + + let mut events = self.events.write().await; + + if events.len() >= EVENT_BUF_MAX { + events.pop_front(); + } + + events.push_back(msg.clone()); + drop(events); + + let _ = self.announce.send(msg); + + Ok(()) + } + + pub async fn start_listening(self: Arc, path: impl AsRef) { + let listener = UnixListener::bind(path).unwrap(); + + loop { + match listener.accept().await { + Ok((stream, _)) => { + // Spawn a new stream for the plugin. The loop + // runs recursively from there. + let broadcaster = Arc::clone(&self); + tokio::spawn(async move { + // send events. + let _ = broadcaster.send_events(stream).await; + }); + } + Err(e) => error!("Accept error: {e}"), + } + } + } + + async fn send_events(&self, stream: UnixStream) -> Result<()> { + let mut writer = stream; + + // Take care of history. + let events = self.events.read().await; + for event in events.iter() { + writer.write_all(event.as_bytes()).await?; + } + drop(events); + + // Now just broadcast the new events. + let mut rx = self.announce.subscribe(); + while let Ok(event) = rx.recv().await { + if writer.write_all(event.as_bytes()).await.is_err() { + // *click* + break; + } + } + + Ok(()) + } +} diff --git a/src/ipc.rs b/src/ipc.rs new file mode 100644 index 0000000..2a66903 --- /dev/null +++ b/src/ipc.rs @@ -0,0 +1,26 @@ +// Provides an IPC socket to communicate with other processes. + +use std::path::Path; + +use color_eyre::Result; +use tokio::net::UnixListener; + +pub struct IPC { + listener: UnixListener, +} + +impl IPC { + pub fn new(path: impl AsRef) -> Result { + let listener = UnixListener::bind(path)?; + Ok(Self { listener }) + } + + pub async fn run(&self) -> Result<()> { + loop { + match self.listener.accept().await { + Ok((_stream, _addr)) => {} + Err(e) => return Err(e.into()), + } + } + } +} diff --git a/src/main.rs b/src/main.rs index 8adec4d..dd8e281 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,14 @@ -use color_eyre::{ - Result, - eyre::WrapErr, -}; +use color_eyre::{Result, eyre::WrapErr}; use human_panic::setup_panic; -use std::os::unix::fs; -use tracing::{ - Level, - info, -}; +use std::{os::unix::fs, sync::Arc}; +use tracing::{Level, info}; use tracing_subscriber::FmtSubscriber; +use crate::event_manager::EventManager; + mod chat; -mod commands; +mod event; +mod event_manager; mod qna; mod setup; @@ -47,18 +44,17 @@ async fn main() -> Result<()> { } // Setup root path for commands. - let cmd_root = if let Ok(command_path) = config.get_string("command-path") { - Some(commands::Root::new(command_path)) - } else { - None - }; + // let cmd_root = if let Ok(command_path) = config.get_string("command-path") { + // Some(commands::Root::new(command_path)) + // } else { + // None + // }; let handle = qna::LLMHandle::new( config.get_string("api-key").wrap_err("API missing.")?, config .get_string("base-url") .wrap_err("base-url missing.")?, - cmd_root, config .get_string("model") .wrap_err("model string missing.")?, @@ -67,9 +63,23 @@ async fn main() -> Result<()> { .unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()), ) .wrap_err("Couldn't initialize LLM handle.")?; + + let ev_manager = Arc::new(EventManager::new()?); + let ev_manager_clone = Arc::clone(&ev_manager); + ev_manager_clone + .broadcast(&event::Event::new("Starting...")) + .await?; + let mut c = chat::new(&config, &handle).await?; - c.run().await.unwrap(); + tokio::select! { + _ = ev_manager_clone.start_listening("/tmp/robo.sock") => { + // Event listener ended + } + result = c.run() => { + result.unwrap(); + } + } Ok(()) } diff --git a/src/qna.rs b/src/qna.rs index 696af22..a5658eb 100644 --- a/src/qna.rs +++ b/src/qna.rs @@ -1,19 +1,10 @@ -use crate::commands; use color_eyre::Result; use futures::StreamExt; use genai::{ Client, ModelIden, - chat::{ - ChatMessage, - ChatRequest, - ChatStreamEvent, - StreamChunk, - }, - resolver::{ - AuthData, - AuthResolver, - }, + chat::{ChatMessage, ChatRequest, ChatStreamEvent, StreamChunk}, + resolver::{AuthData, AuthResolver}, }; use tracing::info; @@ -23,7 +14,6 @@ use tracing::info; pub struct LLMHandle { chat_request: ChatRequest, client: Client, - cmd_root: Option, model: String, } @@ -31,7 +21,6 @@ impl LLMHandle { pub fn new( api_key: String, _base_url: impl AsRef, - cmd_root: Option, model: impl Into, system_role: String, ) -> Result { @@ -51,7 +40,6 @@ impl LLMHandle { Ok(LLMHandle { client, chat_request, - cmd_root, model: model.into(), }) } diff --git a/src/setup.rs b/src/setup.rs index 3496d3c..63a4374 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -1,15 +1,9 @@ use clap::Parser; -use color_eyre::{ - Result, - eyre::WrapErr, -}; +use color_eyre::{Result, eyre::WrapErr}; use config::Config; use directories::ProjectDirs; use std::path::PathBuf; -use tracing::{ - info, - instrument, -}; +use tracing::{info, instrument}; // TODO: use [clap(long, short, help_heading = Some(section))] #[derive(Clone, Debug, Parser)]