Compare commits

18 Commits

Author SHA1 Message Date
Micheal Smith
c3b168f86f Added configuration tests. 2025-11-29 16:31:15 -06:00
Micheal Smith
b46a03c13e Added documentation. 2025-11-23 13:40:50 -06:00
Micheal Smith
17b087e618 Implemented external processes as potential plugins. 2025-11-20 04:30:18 -06:00
Micheal Smith
30e2d9a448 Renamed commands/Command to plugin/Plugin. 2025-11-14 07:19:28 -06:00
Micheal Smith
70de039610 Broadcast, and FIFO are currently functional. 2025-11-14 05:06:57 -06:00
8ec4f2860c Implementing fifo channel, and command structure. 2025-11-13 21:07:49 -06:00
Micheal Smith
21d9c3f002 Adding response FIFO. 2025-11-12 06:42:30 -06:00
Micheal Smith
f880795b44 Added an integration test for events. 2025-11-11 00:39:58 -06:00
a158ee385f Moved most main.rs functionality into lib.rs 2025-11-10 22:58:44 -06:00
Micheal Smith
2da7cc4450 Breaking out some portions for integration testing. 2025-11-10 22:20:09 -06:00
Micheal Smith
3af95235e6 Added some tests at least for the broadcast buffering. 2025-11-10 05:26:59 -06:00
Micheal Smith
5d390ee9f3 Adding some IPC. 2025-11-09 08:26:39 -06:00
Micheal Smith
7f7981d6cd Added a hard coded 500 character limit. 2025-11-06 17:28:46 -06:00
Micheal Smith
ae190cc421 Fixed instruct argument issue. 2025-11-03 22:31:13 -06:00
Micheal Smith
ae44cc947b Fixed a misspelling 2025-11-03 22:13:11 -06:00
Micheal Smith
a3ebca0bb2 Added some release optimizations. 2025-10-31 13:55:57 -05:00
Micheal Smith
8fa79932d6 Removed unnecessary use. 2025-10-31 05:36:55 -05:00
Micheal Smith
9719d9203c Fixed multiline handling, and did a cargo fmt. 2025-10-31 05:34:34 -05:00
16 changed files with 2973 additions and 373 deletions

605
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,17 +4,58 @@ version = "0.1.0"
edition = "2024"
[dependencies]
# TODO: make this a dev and/or debug dependency later.
better-panic = "0.3.0"
clap = { version = "4.5", features = [ "derive" ] }
bytes = "1"
color-eyre = "0.6.3"
config = { version = "0.15", features = [ "toml" ] }
directories = "6.0"
dotenvy_macro = "0.15"
futures = "0.3"
human-panic = "2.0"
genai = "0.4.3"
irc = "1.1"
tokio = { version = "1", features = [ "full" ] }
serde_json = "1.0"
tracing = "0.1"
tracing-subscriber = "0.3"
[dependencies.nix]
version = "0.30.1"
features = ["fs", "resource"]
[dependencies.clap]
version = "4.5"
features = ["derive"]
[dependencies.config]
version = "0.15"
features = ["toml"]
[dependencies.serde]
version = "1.0"
features = ["derive"]
[dependencies.tokio]
version = "1"
features = [
"io-util",
"macros",
"net",
"process",
"rt-multi-thread",
"sync",
"time",
]
[dev-dependencies]
rstest = "0.24"
serial_test = "3.2"
tempfile = "3.13"
[dev-dependencies.cargo-husky]
version = "1"
features = ["run-cargo-check", "run-cargo-clippy"]
[profile.release]
strip = true
opt-level = "z"
lto = true
codegen-units = 1
panic = "abort"

View File

@@ -2,6 +2,8 @@ edition = "2024"
style_edition = "2024"
comment_width = 100
format_code_in_doc_comments = true
format_macro_bodies = true
format_macro_matchers = true
imports_granularity = "Crate"
imports_layout = "Vertical"
imports_layout = "HorizontalVertical"
wrap_comments = true

View File

@@ -1,40 +1,40 @@
use color_eyre::{
Result,
eyre::{
OptionExt,
WrapErr,
},
};
// Lots of namespace confusion potential
use crate::{
commands,
qna::LLMHandle,
};
//! Handles interaction with IRC.
//!
//! Each instance of [`Chat`] handles a single connection to an IRC
//! server.
use std::sync::Arc;
use color_eyre::{Result, eyre::WrapErr};
use config::Config as MainConfig;
use futures::StreamExt;
use irc::client::prelude::{
Client as IRCClient,
Command,
Config as IRCConfig,
};
use tracing::{
Level,
event,
instrument,
};
use irc::client::prelude::{Client, Command, Config as IRCConfig, Message};
use tokio::sync::mpsc;
use tracing::{Level, event, instrument};
use crate::{Event, EventManager, LLMHandle, plugin};
/// Chat struct that is used to interact with IRC chat.
#[derive(Debug)]
pub struct Chat {
client: IRCClient,
/// The actual IRC [`irc::client`](client).
client: Client,
/// Event manager for handling plugin interaction.
event_manager: Arc<EventManager>,
/// Handle for whichever LLM is being used.
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
}
// Need: owners, channels, username, nick, server, password
#[instrument]
pub async fn new(
impl Chat {
// Need: owners, channels, username, nick, server, password rather than reading
// the config values directly.
/// Creates a new [`Chat`].
#[instrument]
pub async fn new(
settings: &MainConfig,
handle: &LLMHandle,
) -> Result<Chat> {
manager: Arc<EventManager>,
) -> Result<Chat> {
// Going to just assign and let the irc library handle errors for now, and
// add my own checking if necessary.
let port: u16 = settings.get("port")?;
@@ -55,37 +55,77 @@ pub async fn new(
event!(Level::INFO, "IRC connection starting...");
Ok(Chat {
client: IRCClient::from_config(config).await?,
client: Client::from_config(config).await?,
llm_handle: handle.clone(),
event_manager: manager,
})
}
impl Chat {
pub async fn run(&mut self) -> Result<()> {
let client = &mut self.client;
client.identify()?;
let outgoing = client
.outgoing()
.ok_or_eyre("Couldn't get outgoing irc sink.")?;
let mut stream = client.stream()?;
tokio::spawn(async move {
if let Err(e) = outgoing.await {
event!(Level::ERROR, "Failed to drive output: {}", e);
}
});
while let Some(message) = stream.next().await.transpose()? {
if let Command::PRIVMSG(channel, message) = message.command
&& message.starts_with("!gem")
{
let msg = self.llm_handle.send_request(message).await?;
event!(Level::INFO, "Message received.");
client
.send_privmsg(channel, msg)
.wrap_err("Couldn't send response to channel.")?;
/// Drives the event loop for the chat.
pub async fn run(&mut self, mut command_in: mpsc::Receiver<plugin::PluginMsg>) -> Result<()> {
self.client.identify()?;
let mut stream = self.client.stream()?;
loop {
tokio::select! {
message = stream.next() => {
match message {
Some(Ok(msg)) => {
self.handle_chat_message(&msg).await?;
}
Some(Err(e)) => return Err(e.into()),
None => break, // disconnected
}
}
command = command_in.recv() => {
event!(Level::INFO, "Received command {:#?}", command);
match command {
Some(plugin::PluginMsg::SendMessage {channel, message} ) => {
// Now to pass on the message.
event!(Level::INFO, "Trying to send to channel.");
self.client.send_privmsg(&channel, &message).wrap_err("Couldn't send to channel")?;
event!(Level::INFO, "Message sent successfully.");
}
None => {
event!(Level::ERROR,
"Command channel unexpectedly closed - \
FIFO reader may have crashed");
break;
}
}
}
}
}
Ok(())
}
async fn handle_chat_message(&mut self, message: &Message) -> Result<()> {
// Broadcast anything here. If it should not be broadcasted then
// TryFrom should fail.
if let Ok(event) = Event::try_from(message) {
self.event_manager.broadcast(&event).await?;
}
// Only handle PRIVMSG for now.
if let Command::PRIVMSG(channel, msg) = &message.command {
// Just preserve the original behavior for now.
if msg.starts_with("!gem") {
let mut llm_response = self.llm_handle.send_request(msg).await?;
event!(Level::INFO, "Asked: {message}");
event!(Level::INFO, "Response: {llm_response}");
// Keep responses to one line for now.
llm_response.retain(|c| c != '\n' && c != '\r');
// TODO: Make this configurable.
llm_response.truncate(500);
event!(Level::INFO, "Sending {llm_response} to channel {channel}");
self.client.send_privmsg(channel, llm_response)?;
}
}

193
src/command.rs Normal file
View File

@@ -0,0 +1,193 @@
//! Commands that are associated with external processes (commands).
//!
//! Process based plugins are just an assortment of executable files in
//! a provided directory. They are given arguments, and the response from
//! them is expected on stdout.
use std::{
path::{Path, PathBuf},
time::Duration,
};
use bytes::Bytes;
use color_eyre::{Result, eyre::eyre};
use tokio::{fs::try_exists, process::Command, time::timeout};
use tracing::{Level, event};
/// Handle containing information about the directory containing commands.
#[derive(Debug)]
pub struct CommandDir {
command_path: PathBuf,
}
impl CommandDir {
/// Register a path containing commands.
pub fn new(command_path: impl AsRef<Path>) -> Self {
event!(
Level::INFO,
"CommandDir initialized with path: {:?}",
command_path.as_ref()
);
CommandDir {
command_path: command_path.as_ref().to_path_buf(),
}
}
/// Look for a command. If it exists Ok(path) is returned.
async fn find_command(&self, name: impl AsRef<Path>) -> Result<String> {
let path = self.command_path.join(name.as_ref());
event!(
Level::INFO,
"Looking for {} command.",
name.as_ref().display()
);
match try_exists(&path).await {
Ok(true) => Ok(path.to_string_lossy().to_string()),
Ok(false) => Err(eyre!(format!("{} Not found.", path.to_string_lossy()))),
Err(e) => Err(e.into()),
}
}
/// Run the given [`command_name`]. It should exist in the directory provided as
/// the command_path.
pub async fn run_command(
&self,
command_name: impl AsRef<str>,
input: impl AsRef<str>,
) -> Result<Bytes> {
let path = self.find_command(Path::new(command_name.as_ref())).await?;
// Well it exists let's cross our fingers...
let output = Command::new(path).arg(input.as_ref()).output().await?;
if output.status.success() {
// So far so good
Ok(Bytes::from(output.stdout))
} else {
// Whoops
Err(eyre!(format!(
"Error running {}: {}",
command_name.as_ref(),
output.status
)))
}
}
/// [`run_command`] but with a timeout.
pub async fn run_command_with_timeout(
&self,
command_name: impl AsRef<str>,
input: impl AsRef<str>,
time_out: Duration,
) -> Result<Bytes> {
timeout(time_out, self.run_command(command_name, input)).await?
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
fs::{self, Permissions},
os::unix::fs::PermissionsExt,
};
use tempfile::TempDir;
fn create_test_script(dir: &Path, name: &str, script: &str) -> PathBuf {
let path = dir.join(name);
fs::write(&path, script).unwrap();
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
path
}
#[test]
fn test_command_dir_new() {
let dir = CommandDir::new("/some/path");
assert_eq!(dir.command_path, PathBuf::from("/some/path"));
}
#[tokio::test]
async fn test_find_command_exists() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "test_cmd", "#!/bin/bash\necho hello");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.find_command("test_cmd").await;
assert!(result.is_ok());
assert!(result.unwrap().contains("test_cmd"));
}
#[tokio::test]
async fn test_find_command_not_found() {
let temp = TempDir::new().unwrap();
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.find_command("nonexistent").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not found"));
}
#[tokio::test]
async fn test_run_command_success() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "echo_cmd", "#!/bin/bash\necho \"$1\"");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.run_command("echo_cmd", "hello world").await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.as_ref(), b"hello world\n");
}
#[tokio::test]
async fn test_run_command_failure() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "fail_cmd", "#!/bin/bash\nexit 1");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.run_command("fail_cmd", "input").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Error running"));
}
#[tokio::test]
async fn test_run_command_not_found() {
let temp = TempDir::new().unwrap();
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.run_command("missing", "input").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_run_command_with_timeout_success() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "fast_cmd", "#!/bin/bash\necho \"$1\"");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir
.run_command_with_timeout("fast_cmd", "quick", Duration::from_secs(5))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_run_command_with_timeout_expires() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "slow_cmd", "#!/bin/bash\nsleep 10\necho done");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir
.run_command_with_timeout("slow_cmd", "input", Duration::from_millis(100))
.await;
assert!(result.is_err());
}
}

View File

@@ -1,21 +0,0 @@
use color_eyre::Result;
use std::{
path::{Path, PathBuf},
};
#[derive(Clone, Debug)]
pub struct Root {
path: PathBuf,
}
impl Root {
pub fn new(path: impl AsRef<Path>) -> Self {
Root {
path: path.as_ref().to_owned(),
}
}
pub fn run_command(cmd_string: impl AsRef<str>) -> Result<()> {
todo!();
}
}

38
src/event.rs Normal file
View File

@@ -0,0 +1,38 @@
//! Internal representations of incoming events.
use irc::proto::{Command, Message};
use serde::{Deserialize, Serialize};
/// Represents an event. Probably from IRC.
#[derive(Deserialize, Serialize)]
pub struct Event {
/// Who is the message from?
from: String,
/// What is the message?
message: String,
}
impl Event {
/// Creates a new message.
pub fn new(from: impl Into<String>, msg: impl Into<String>) -> Self {
Self {
from: from.into(),
message: msg.into(),
}
}
}
impl TryFrom<&Message> for Event {
type Error = &'static str;
fn try_from(value: &Message) -> Result<Self, Self::Error> {
let from = value.response_target().unwrap_or("unknown").to_string();
match &value.command {
Command::PRIVMSG(_channel, message) => Ok(Event {
from,
message: message.clone(),
}),
_ => Err("Not a PRIVMSG"),
}
}
}

587
src/event_manager.rs Normal file
View File

@@ -0,0 +1,587 @@
//! Handler for events to and from IPC, and process plugins.
use std::{collections::VecDeque, path::Path, sync::Arc};
use color_eyre::Result;
use nix::{NixPath, sys::stat, unistd::mkfifo};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream, unix::pipe},
sync::{RwLock, broadcast, mpsc},
};
use tracing::{error, info};
use crate::{event::Event, plugin::PluginMsg};
// Hard coding for now. Maybe make this a parameter to new.
const EVENT_BUF_MAX: usize = 1000;
/// Manager for communication with plugins.
///
/// Keeps events in a ring buffer to track a certain amount of history.
#[derive(Debug)]
pub struct EventManager {
announce: broadcast::Sender<String>, // Everything broadcasts here.
events: Arc<RwLock<VecDeque<String>>>, // Ring buffer.
}
impl EventManager {
/// Create a new [`EventManager``].
pub fn new() -> Result<Self> {
let (announce, _) = broadcast::channel(100);
Ok(Self {
announce,
events: Arc::new(RwLock::new(VecDeque::<String>::new())),
})
}
/// Broadcast an event to every subscribed listener.
pub async fn broadcast(&self, event: &Event) -> Result<()> {
let msg = serde_json::to_string(event)? + "\n";
let mut events = self.events.write().await;
if events.len() >= EVENT_BUF_MAX {
events.pop_front();
}
events.push_back(msg.clone());
drop(events);
let _ = self.announce.send(msg);
Ok(())
}
// NB: This assumes it has exclusive control of the FIFO.
/// Opens a fifo at [`path`]. This is where some plugins can send response events
/// to. The messages MUST be formatted in JSON and match one of the possible
/// [`PluginMsg`](plugin messages).
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<PluginMsg>) -> Result<()>
where
P: AsRef<Path> + NixPath + ?Sized,
{
// Overwrite, or create the FIFO.
let _ = std::fs::remove_file(path);
mkfifo(path, stat::Mode::S_IRWXU)?;
loop {
let rx = pipe::OpenOptions::new().open_receiver(path)?;
let mut reader = BufReader::new(rx);
let mut line = String::new();
while reader.read_line(&mut line).await? > 0 {
// Now handle the command.
let cmd: PluginMsg = serde_json::from_str(&line)?;
info!("Command received: {:?}.", cmd);
command_tx.send(cmd).await?;
line.clear();
}
}
}
/// Start a UNIX socket that will provide broadcast messages to any client that opens
/// the socket for listening.
pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) {
let listener = UnixListener::bind(broadcast_path).unwrap();
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
info!("New broadcast subscriber");
// Spawn a new stream for the plugin. The loop
// runs recursively from there.
let broadcaster = Arc::clone(&self);
tokio::spawn(async move {
// send events.
let _ = broadcaster.send_events(stream).await;
});
}
Err(e) => error!("Accept error: {e}"),
}
}
}
/// Send any events queued up to the [`stream`].
async fn send_events(&self, stream: UnixStream) -> Result<()> {
let mut writer = stream;
// Take care of history.
let events = self.events.read().await;
for event in events.iter() {
writer.write_all(event.as_bytes()).await?;
}
drop(events);
// Now just broadcast the new events.
let mut rx = self.announce.subscribe();
while let Ok(event) = rx.recv().await {
if writer.write_all(event.as_bytes()).await.is_err() {
// *click*
break;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[tokio::test]
async fn test_new_event_manager_has_empty_buffer() {
let manager = EventManager::new().unwrap();
let events = manager.events.read().await;
assert_eq!(events.len(), 0);
}
#[tokio::test]
async fn test_broadcast_adds_event_to_buffer() {
let manager = EventManager::new().unwrap();
let event = Event::new("test_user", "test message");
manager.broadcast(&event).await.unwrap();
let events = manager.events.read().await;
assert_eq!(events.len(), 1);
assert!(events[0].contains("test message"));
assert!(events[0].ends_with('\n'));
}
#[tokio::test]
async fn test_broadcast_serializes_event_as_json() {
let manager = EventManager::new().unwrap();
let event = Event::new("test_user", "hello world");
manager.broadcast(&event).await.unwrap();
let events = manager.events.read().await;
let stored = &events[0];
// Should be valid JSON
let parsed: serde_json::Value = serde_json::from_str(stored.trim()).unwrap();
assert_eq!(parsed["message"], "hello world");
}
#[rstest]
#[case(1)]
#[case(10)]
#[case(100)]
#[case(999)]
#[tokio::test]
async fn test_buffer_holds_events_below_max(#[case] count: usize) {
let manager = EventManager::new().unwrap();
for i in 0..count {
let event = Event::new("test_user", format!("event {}", i));
manager.broadcast(&event).await.unwrap();
}
let events = manager.events.read().await;
assert_eq!(events.len(), count);
}
#[tokio::test]
async fn test_buffer_at_exactly_max_capacity() {
let manager = EventManager::new().unwrap();
// Fill to exactly EVENT_BUF_MAX (1000)
for i in 0..EVENT_BUF_MAX {
let event = Event::new("test_user", format!("event {}", i));
manager.broadcast(&event).await.unwrap();
}
let events = manager.events.read().await;
assert_eq!(events.len(), EVENT_BUF_MAX);
assert!(events[0].contains("event 0"));
assert!(events[EVENT_BUF_MAX - 1].contains("event 999"));
}
#[rstest]
#[case(1)]
#[case(10)]
#[case(100)]
#[case(500)]
#[tokio::test]
async fn test_buffer_overflow_evicts_oldest_fifo(#[case] overflow: usize) {
let manager = EventManager::new().unwrap();
let total = EVENT_BUF_MAX + overflow;
// Broadcast more events than buffer can hold
for i in 0..total {
let event = Event::new("test_user", format!("event {}", i));
manager.broadcast(&event).await.unwrap();
}
let events = manager.events.read().await;
// Buffer should still be at max capacity
assert_eq!(events.len(), EVENT_BUF_MAX);
// Oldest events (0 through overflow-1) should be evicted
// Buffer should contain events [overflow..total)
let first_event = &events[0];
let last_event = &events[EVENT_BUF_MAX - 1];
assert!(first_event.contains(&format!("event {}", overflow)));
assert!(last_event.contains(&format!("event {}", total - 1)));
// Verify the evicted events are NOT in the buffer
let buffer_string = events.iter().cloned().collect::<String>();
assert!(!buffer_string.contains(r#""message":"event 0""#));
}
#[tokio::test]
async fn test_multiple_broadcasts_maintain_order() {
let manager = EventManager::new().unwrap();
let messages = vec!["first", "second", "third", "fourth", "fifth"];
for msg in &messages {
let event = Event::new("test_user", *msg);
manager.broadcast(&event).await.unwrap();
}
let events = manager.events.read().await;
assert_eq!(events.len(), messages.len());
for (i, expected) in messages.iter().enumerate() {
assert!(events[i].contains(expected));
}
}
#[tokio::test]
async fn test_buffer_wraparound_maintains_newest_events() {
let manager = EventManager::new().unwrap();
// Fill buffer completely
for i in 0..EVENT_BUF_MAX {
let event = Event::new("test_user", format!("old {}", i));
manager.broadcast(&event).await.unwrap();
}
// Add 5 more events
for i in 0..5 {
let event = Event::new("test_user", format!("new {}", i));
manager.broadcast(&event).await.unwrap();
}
let events = manager.events.read().await;
assert_eq!(events.len(), EVENT_BUF_MAX);
// First 5 old events should be gone
let buffer_string = events.iter().cloned().collect::<String>();
assert!(!buffer_string.contains(r#""message":"old 0""#));
assert!(!buffer_string.contains(r#""message":"old 4""#));
// But old 5 should still be there (now at the front)
assert!(events[0].contains("old 5"));
// New events should be at the end
assert!(events[EVENT_BUF_MAX - 5].contains("new 0"));
assert!(events[EVENT_BUF_MAX - 1].contains("new 4"));
}
#[tokio::test]
async fn test_concurrent_broadcasts_all_stored() {
let manager = Arc::new(EventManager::new().unwrap());
let mut handles = vec![];
// Spawn 10 concurrent tasks, each broadcasting 10 events
for task_id in 0..10 {
let manager_clone = Arc::clone(&manager);
let handle = tokio::spawn(async move {
for i in 0..10 {
let event = Event::new("test_user", format!("task {} event {}", task_id, i));
manager_clone.broadcast(&event).await.unwrap();
}
});
handles.push(handle);
}
// Wait for all tasks to complete
for handle in handles {
handle.await.unwrap();
}
let events = manager.events.read().await;
assert_eq!(events.len(), 100);
}
#[tokio::test]
async fn test_fifo_receives_and_forwards_single_command() {
let temp_dir = tempfile::tempdir().unwrap();
let fifo_path = temp_dir.path().join("test.fifo");
let (tx, mut rx) = mpsc::channel(10);
// Spawn the FIFO reader
let path = fifo_path.clone();
tokio::spawn(async move {
let _ = EventManager::start_fifo(&path, tx).await;
});
// Give it time to create the FIFO
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Write a command to the FIFO
let cmd = PluginMsg::SendMessage {
channel: "#test".to_string(),
message: "hello".to_string(),
};
let json = serde_json::to_string(&cmd).unwrap() + "\n";
// Open FIFO for writing and write the command
tokio::spawn(async move {
use tokio::io::AsyncWriteExt;
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx);
tx.write_all(json.as_bytes()).await.unwrap();
tx.flush().await.unwrap();
});
// Should receive the command within a reasonable time
let received = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout waiting for command")
.expect("channel closed");
match received {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#test");
assert_eq!(message, "hello");
}
}
}
#[tokio::test]
async fn test_fifo_handles_multiple_commands() {
let temp_dir = tempfile::tempdir().unwrap();
let fifo_path = temp_dir.path().join("test.fifo");
let (tx, mut rx) = mpsc::channel(10);
// Spawn the FIFO reader
let path = fifo_path.clone();
tokio::spawn(async move {
let _ = EventManager::start_fifo(&path, tx).await;
});
// Give it time to create the FIFO
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Write multiple commands
let commands = vec![
PluginMsg::SendMessage {
channel: "#chan1".to_string(),
message: "first".to_string(),
},
PluginMsg::SendMessage {
channel: "#chan2".to_string(),
message: "second".to_string(),
},
PluginMsg::SendMessage {
channel: "#chan3".to_string(),
message: "third".to_string(),
},
];
tokio::spawn(async move {
use tokio::io::AsyncWriteExt;
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx);
for cmd in commands {
let json = serde_json::to_string(&cmd).unwrap() + "\n";
tx.write_all(json.as_bytes()).await.unwrap();
}
tx.flush().await.unwrap();
});
// Receive all three commands in order
let first = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout on first")
.expect("channel closed");
match first {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#chan1");
assert_eq!(message, "first");
}
}
let second = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout on second")
.expect("channel closed");
match second {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#chan2");
assert_eq!(message, "second");
}
}
let third = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout on third")
.expect("channel closed");
match third {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#chan3");
assert_eq!(message, "third");
}
}
}
#[tokio::test]
async fn test_fifo_reopens_after_writer_closes() {
let temp_dir = tempfile::tempdir().unwrap();
let fifo_path = temp_dir.path().join("test.fifo");
let (tx, mut rx) = mpsc::channel(10);
// Spawn the FIFO reader
let path = fifo_path.clone();
tokio::spawn(async move {
let _ = EventManager::start_fifo(&path, tx).await;
});
// Give it time to create the FIFO
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// First writer sends a command and closes
{
use tokio::io::AsyncWriteExt;
let path = fifo_path.clone();
tokio::spawn(async move {
let tx = pipe::OpenOptions::new().open_sender(&path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx);
let cmd = PluginMsg::SendMessage {
channel: "#first".to_string(),
message: "batch1".to_string(),
};
let json = serde_json::to_string(&cmd).unwrap() + "\n";
tx.write_all(json.as_bytes()).await.unwrap();
tx.flush().await.unwrap();
// Writer drops here, closing the FIFO
});
let first = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout on first batch")
.expect("channel closed");
match first {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#first");
assert_eq!(message, "batch1");
}
}
}
// Give the FIFO time to reopen
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Second writer opens and sends a command
{
use tokio::io::AsyncWriteExt;
tokio::spawn(async move {
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx);
let cmd = PluginMsg::SendMessage {
channel: "#second".to_string(),
message: "batch2".to_string(),
};
let json = serde_json::to_string(&cmd).unwrap() + "\n";
tx.write_all(json.as_bytes()).await.unwrap();
tx.flush().await.unwrap();
});
let second = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
.await
.expect("timeout on second batch - FIFO may not have reopened")
.expect("channel closed");
match second {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#second");
assert_eq!(message, "batch2");
}
}
}
}
#[tokio::test]
async fn test_fifo_handles_empty_lines() {
let temp_dir = tempfile::tempdir().unwrap();
let fifo_path = temp_dir.path().join("test.fifo");
let (tx, mut rx) = mpsc::channel(10);
// Spawn the FIFO reader
let path = fifo_path.clone();
let handle = tokio::spawn(async move { EventManager::start_fifo(&path, tx).await });
// Give it time to create the FIFO
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Write command, empty line, whitespace, another command
tokio::spawn(async move {
use tokio::io::AsyncWriteExt;
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx);
let cmd1 = PluginMsg::SendMessage {
channel: "#test".to_string(),
message: "first".to_string(),
};
let json1 = serde_json::to_string(&cmd1).unwrap() + "\n";
tx.write_all(json1.as_bytes()).await.unwrap();
// Write empty line
tx.write_all(b"\n").await.unwrap();
// Write whitespace line
tx.write_all(b" \n").await.unwrap();
let cmd2 = PluginMsg::SendMessage {
channel: "#test".to_string(),
message: "second".to_string(),
};
let json2 = serde_json::to_string(&cmd2).unwrap() + "\n";
tx.write_all(json2.as_bytes()).await.unwrap();
tx.flush().await.unwrap();
});
// Should receive first command
let first = tokio::time::timeout(tokio::time::Duration::from_millis(500), rx.recv())
.await
.expect("timeout on first")
.expect("channel closed");
match first {
PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#test");
assert_eq!(message, "first");
}
}
// The empty/whitespace lines should cause JSON parse errors
// which will cause start_fifo to error and exit
// So we expect the handle to complete (with an error)
let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.expect("FIFO task should exit due to parse error");
// The task should have errored
assert!(
result.unwrap().is_err(),
"Expected parse error from empty line"
);
}
}

100
src/lib.rs Normal file
View File

@@ -0,0 +1,100 @@
#![warn(missing_docs)]
#![doc = include_str!("../README.md")]
use std::{os::unix::fs, sync::Arc};
use color_eyre::{Result, eyre::WrapErr};
use human_panic::setup_panic;
use tokio::sync::mpsc;
use tracing::{Level, info};
use tracing_subscriber::FmtSubscriber;
pub mod chat;
pub mod command;
pub mod event;
pub mod event_manager;
pub mod plugin;
pub mod qna;
pub mod setup;
pub use chat::Chat;
pub use event::Event;
pub use event_manager::EventManager;
pub use qna::LLMHandle;
const DEFAULT_INSTRUCT: &str =
"You are a shady, yet helpful IRC bot. You try to give responses that can
be sent in a single IRC response according to the specification. Keep answers to
500 characters or less.";
/// Initialize all logging facilities.
///
/// This should cause a panic if there's a failure.
async fn init_logging() {
better_panic::install();
setup_panic!();
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::TRACE)
.finish();
tracing::subscriber::set_global_default(subscriber).unwrap();
}
/// Sets up and runs the main event loop.
///
/// Should return an error if it's recoverable, but could panic if something
/// is particularly bad.
pub async fn run() -> Result<()> {
init_logging().await;
info!("Starting up.");
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
let config = settings.config;
// NOTE: Doing chroot this way might be impractical.
if let Ok(chroot_path) = config.get_string("chroot-dir") {
info!("Attempting to chroot to {}", chroot_path);
fs::chroot(&chroot_path)
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path))?;
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
}
let handle = qna::LLMHandle::new(
config.get_string("api-key").wrap_err("API missing.")?,
config
.get_string("base-url")
.wrap_err("base-url missing.")?,
config
.get_string("model")
.wrap_err("model string missing.")?,
config
.get_string("instruct")
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
)
.wrap_err("Couldn't initialize LLM handle.")?;
let ev_manager = Arc::new(EventManager::new()?);
let ev_manager_clone = Arc::clone(&ev_manager);
let mut c = Chat::new(&config, &handle, Arc::clone(&ev_manager)).await?;
let (from_plugins, to_chat) = mpsc::channel(100);
tokio::select! {
_ = ev_manager_clone.start_listening("/tmp/robo.sock") => {
// Event listener ended
}
result = c.run(to_chat) => {
if let Err(e) = result {
tracing::error!("Chat run error: {:?}", e);
return Err(e);
}
}
fifo = EventManager::start_fifo("/tmp/robo_in.sock", from_plugins) => {
fifo.wrap_err("FIFO reader failed.")?;
}
}
Ok(())
}

View File

@@ -1,75 +1,6 @@
use color_eyre::{
Result,
eyre::WrapErr,
};
use human_panic::setup_panic;
use std::os::unix::fs;
use tracing::{
Level,
info,
};
use tracing_subscriber::FmtSubscriber;
mod chat;
mod commands;
mod qna;
mod setup;
const DEFAULT_INSTRUCT: &str =
"You are a shady, yet helpful IRC bot. You try to give responses that can
be sent in a single IRC response according to the specification.";
use color_eyre::Result;
#[tokio::main]
async fn main() -> Result<()> {
// Some error sprucing.
better_panic::install();
setup_panic!();
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::TRACE)
.finish();
tracing::subscriber::set_global_default(subscriber)
.wrap_err("Failed to setup trace logging.")?;
info!("Starting");
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
let config = settings.config;
// chroot if applicable.
if let Ok(chroot_path) = config.get_string("chroot-dir") {
info!("Attempting to chroot to {}", chroot_path);
fs::chroot(&chroot_path)
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path))?;
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
}
// Setup root path for commands.
let cmd_root = if let Ok(command_path) = config.get_string("command-path") {
Some(commands::Root::new(command_path))
} else {
None
};
let handle = qna::LLMHandle::new(
config.get_string("api-key").wrap_err("API missing.")?,
config
.get_string("base-url")
.wrap_err("base-url missing.")?,
cmd_root,
config
.get_string("model")
.wrap_err("model string missing.")?,
config
.get_string("instruct")
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
)
.wrap_err("Couldn't initialize LLM handle.")?;
let mut c = chat::new(&config, &handle).await?;
c.run().await.unwrap();
Ok(())
robotnik::run().await
}

37
src/plugin.rs Normal file
View File

@@ -0,0 +1,37 @@
//! Plugin command definitions.
// Dear future me: If you forget the JSON translations in the future you'll
// thank me for the comment overkill.
use std::fmt::Display;
use serde::{Deserialize, Serialize};
/// Message types accepted from plugins.
#[derive(Debug, Deserialize, Serialize)]
pub enum PluginMsg {
/// Plugin message indicating the bot should send a [`message`] to [`channel`].
/// {
/// "SendMessage": {
/// "channel": "channel_name",
/// "message": "your message here"
/// }
///
/// }
SendMessage {
/// The IRC channel to send the [`message`] to.
channel: String,
/// The [`message`] to send.
message: String,
},
}
impl Display for PluginMsg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SendMessage { channel, message } => {
write!(f, "[{channel}]: {message}")
}
}
}
}

View File

@@ -1,37 +1,32 @@
use crate::commands;
//! Handles communication with a genai compatible LLM.
use color_eyre::Result;
use futures::StreamExt;
use genai::{
Client,
ModelIden,
chat::{
ChatMessage,
ChatRequest,
ChatStreamEvent,
StreamChunk,
},
resolver::{
AuthData,
AuthResolver,
},
chat::{ChatMessage, ChatRequest, ChatStreamEvent, StreamChunk},
resolver::{AuthData, AuthResolver},
};
use tracing::info;
// NB: Docs are quick and dirty as this might move into a plugin.
// Represents an LLM completion source.
// FIXME: Clone is probably temporary.
/// Struct containing information about the LLM.
#[derive(Clone, Debug)]
pub struct LLMHandle {
chat_request: ChatRequest,
client: Client,
cmd_root: Option<commands::Root>,
model: String,
}
impl LLMHandle {
/// Create a new handle.
pub fn new(
api_key: String,
_base_url: impl AsRef<str>,
cmd_root: Option<commands::Root>,
model: impl Into<String>,
system_role: String,
) -> Result<LLMHandle> {
@@ -51,11 +46,11 @@ impl LLMHandle {
Ok(LLMHandle {
client,
chat_request,
cmd_root,
model: model.into(),
})
}
/// Send a chat message to the LLM with the response being returned as a [`String`].
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
let mut req = self.chat_request.clone();
let client = self.client.clone();
@@ -70,11 +65,6 @@ impl LLMHandle {
while let Some(Ok(stream_event)) = stream.next().await {
if let ChatStreamEvent::Chunk(StreamChunk { content }) = stream_event {
text.push_str(&content);
} else if let ChatStreamEvent::End(end) = stream_event {
let texts = end.captured_texts().unwrap();
for text in texts.into_iter() {
info!("An answer: {}", text);
}
}
}

View File

@@ -1,86 +1,101 @@
//! Handles configuration for the bot.
//!
//! Both command line, and configuration file options are handled here.
use clap::Parser;
use color_eyre::{
Result,
eyre::WrapErr,
};
use color_eyre::{Result, eyre::WrapErr};
use config::Config;
use directories::ProjectDirs;
use std::path::PathBuf;
use tracing::{
info,
instrument,
};
use tracing::{info, instrument};
// TODO: use [clap(long, short, help_heading = Some(section))]
/// Struct of potential arguments.
#[derive(Clone, Debug, Parser)]
#[command(about, version)]
pub(crate) struct Args {
pub struct Args {
#[arg(short, long)]
/// API Key for the LLM in use.
pub(crate) api_key: Option<String>,
pub api_key: Option<String>,
#[arg(short, long, default_value = "https://api.openai.com")]
/// Base URL for the LLM API to use.
pub(crate) base_url: Option<String>,
pub base_url: Option<String>,
/// Directory to use for chroot (recommended).
#[arg(long)]
pub(crate) chroot_dir: Option<String>,
pub chroot_dir: Option<String>,
/// Root directory for file based command structure.
#[arg(long)]
pub(crate) command_dir: Option<String>,
pub command_dir: Option<String>,
#[arg(long)]
/// Instructions to the model on how to behave.
pub(crate) intruct: Option<String>,
pub instruct: Option<String>,
#[arg(long)]
pub(crate) model: Option<String>,
/// Name of the model to use. E.g. 'deepseek-chat'
pub model: Option<String>,
#[arg(long)]
/// List of IRC channels to join.
pub(crate) channels: Option<Vec<String>>,
pub channels: Option<Vec<String>>,
#[arg(short, long)]
/// Custom configuration file location if need be.
pub(crate) config_file: Option<PathBuf>,
pub config_file: Option<PathBuf>,
#[arg(short, long, default_value = "irc.libera.chat")]
/// IRC server.
pub(crate) server: Option<String>,
pub server: Option<String>,
#[arg(short, long, default_value = "6697")]
/// Port of the IRC server.
pub(crate) port: Option<String>,
pub port: Option<String>,
#[arg(long)]
/// IRC Nickname.
pub(crate) nickname: Option<String>,
pub nickname: Option<String>,
#[arg(long)]
/// IRC Nick Password
pub(crate) nick_password: Option<String>,
pub nick_password: Option<String>,
#[arg(long)]
/// IRC Username
pub(crate) username: Option<String>,
pub username: Option<String>,
#[arg(long)]
#[arg(long = "no-tls")]
/// Whether or not to use TLS when connecting to the IRC server.
pub(crate) use_tls: Option<bool>,
pub use_tls: Option<bool>,
}
pub(crate) struct Setup {
pub(crate) config: Config,
/// Handle for interacting with the bot configuration.
pub struct Setup {
/// Handle for the configuration file options.
pub config: Config,
}
#[instrument]
/// Initialize a new [`Setup`] instance.
///
/// This reads the settings file which becomes the bot's default configuration.
/// These settings shall be overridden by any command line options.
pub async fn init() -> Result<Setup> {
// Get arguments. These overrule configuration file, and environment
// variables if applicable.
let args = Args::parse();
let settings = make_config(args)?;
Ok(Setup { config: settings })
}
/// Create a configuration object from arguments.
///
/// This is exposed for testing purposes.
pub fn make_config(args: Args) -> Result<Config> {
// Use default config location unless specified.
let config_location: PathBuf = if let Some(ref path) = args.config_file {
path.to_owned()
@@ -94,7 +109,7 @@ pub async fn init() -> Result<Setup> {
info!("Starting.");
let settings = Config::builder()
Config::builder()
.add_source(config::File::with_name(&config_location.to_string_lossy()).required(false))
.add_source(config::Environment::with_prefix("BOT"))
// Doing all of these overrides provides a unified access point for options,
@@ -104,15 +119,14 @@ pub async fn init() -> Result<Setup> {
.set_override_option("chroot-dir", args.chroot_dir.clone())?
.set_override_option("command-path", args.command_dir.clone())?
.set_override_option("model", args.model.clone())?
.set_override_option("instruct", args.model.clone())?
.set_override_option("nick-password", args.nick_password.clone())?
.set_override_option("instruct", args.instruct.clone())?
.set_override_option("channels", args.channels.clone())?
.set_override_option("server", args.server.clone())?
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap.
.set_override_option("port", args.port.clone())?
.set_override_option("nickname", args.nickname.clone())?
.set_override_option("username", args.username.clone())?
.set_override_option("use_tls", args.use_tls)?
.set_override_option("use-tls", args.use_tls)?
.build()
.wrap_err("Couldn't read configuration settings.")?;
Ok(Setup { config: settings })
.wrap_err("Couldn't read configuration settings.")
}

290
tests/command_test.rs Normal file
View File

@@ -0,0 +1,290 @@
use std::{
fs::{self, Permissions},
os::unix::fs::PermissionsExt,
path::Path,
time::Duration,
};
use robotnik::command::CommandDir;
use tempfile::TempDir;
/// Helper to create executable test scripts
fn create_command(dir: &Path, name: &str, script: &str) {
let path = dir.join(name);
fs::write(&path, script).unwrap();
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
}
/// Parse a bot message like "!weather 07008" into (command_name, argument)
fn parse_bot_message(message: &str) -> Option<(&str, &str)> {
if !message.starts_with('!') {
return None;
}
let without_prefix = &message[1..];
let mut parts = without_prefix.splitn(2, ' ');
let command = parts.next()?;
let arg = parts.next().unwrap_or("");
Some((command, arg))
}
#[tokio::test]
async fn test_bot_message_finds_and_runs_command() {
let temp = TempDir::new().unwrap();
// Create a weather command that echoes the zip code
create_command(
temp.path(),
"weather",
r#"#!/bin/bash
echo "Weather for $1: Sunny, 72°F"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!weather 10096";
// Parse the message
let (command_name, arg) = parse_bot_message(message).unwrap();
assert_eq!(command_name, "weather");
assert_eq!(arg, "10096");
// Find and run the command
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("Weather for 10096"));
assert!(output.contains("Sunny"));
}
#[tokio::test]
async fn test_bot_message_command_not_found() {
let temp = TempDir::new().unwrap();
let cmd_dir = CommandDir::new(temp.path());
let message = "!nonexistent arg";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not found"));
}
#[tokio::test]
async fn test_bot_message_with_multiple_arguments() {
let temp = TempDir::new().unwrap();
// Create a command that handles multiple words as a single argument
create_command(
temp.path(),
"echo",
r#"#!/bin/bash
echo "You said: $1"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!echo hello world how are you";
let (command_name, arg) = parse_bot_message(message).unwrap();
assert_eq!(command_name, "echo");
assert_eq!(arg, "hello world how are you");
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("hello world how are you"));
}
#[tokio::test]
async fn test_bot_message_without_argument() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"help",
r#"#!/bin/bash
echo "Available commands: weather, echo, help"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!help";
let (command_name, arg) = parse_bot_message(message).unwrap();
assert_eq!(command_name, "help");
assert_eq!(arg, "");
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("Available commands"));
}
#[tokio::test]
async fn test_bot_message_command_returns_error_exit_code() {
let temp = TempDir::new().unwrap();
// Create a command that fails for invalid input
create_command(
temp.path(),
"validate",
r#"#!/bin/bash
if [ -z "$1" ]; then
echo "Error: Input required" >&2
exit 1
fi
echo "Valid: $1"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!validate";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Error running"));
}
#[tokio::test]
async fn test_bot_message_with_timeout() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"quick",
r#"#!/bin/bash
echo "Result: $1"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!quick test";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir
.run_command_with_timeout(command_name, arg, Duration::from_secs(5))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_bot_message_command_times_out() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"slow",
r#"#!/bin/bash
sleep 10
echo "Done"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!slow arg";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir
.run_command_with_timeout(command_name, arg, Duration::from_millis(100))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_multiple_commands_in_directory() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"weather",
r#"#!/bin/bash
echo "Weather: Sunny"
"#,
);
create_command(
temp.path(),
"time",
r#"#!/bin/bash
echo "Time: 12:00"
"#,
);
create_command(
temp.path(),
"joke",
r#"#!/bin/bash
echo "Why did the robot go on vacation? To recharge!"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
// Test each command
let messages = ["!weather", "!time", "!joke"];
let expected = ["Sunny", "12:00", "recharge"];
for (message, expect) in messages.iter().zip(expected.iter()) {
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(
output.contains(expect),
"Expected '{}' in '{}'",
expect,
output
);
}
}
#[tokio::test]
async fn test_non_bot_message_ignored() {
// Messages not starting with ! should be ignored
let messages = ["hello world", "weather 10096", "?help", "/command", ""];
for message in messages {
assert!(
parse_bot_message(message).is_none(),
"Should ignore: {}",
message
);
}
}
#[tokio::test]
async fn test_command_output_is_bytes() {
let temp = TempDir::new().unwrap();
// Create a command that outputs binary-safe content
create_command(
temp.path(),
"binary",
r#"#!/bin/bash
printf "Hello\x00World"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!binary test";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let output = result.unwrap();
// Should preserve the null byte
assert_eq!(&output[..], b"Hello\x00World");
}

495
tests/event_test.rs Normal file
View File

@@ -0,0 +1,495 @@
use std::{sync::Arc, time::Duration};
use robotnik::{event::Event, event_manager::EventManager};
use rstest::rstest;
use tokio::{
io::{AsyncBufReadExt, BufReader},
net::UnixStream,
time::timeout,
};
const TEST_SOCKET_BASE: &str = "/tmp/robotnik_test";
/// Helper to create unique socket paths for parallel tests
fn test_socket_path(name: &str) -> String {
format!("{}_{}_{}", TEST_SOCKET_BASE, name, std::process::id())
}
/// Helper to read one JSON event from a stream
async fn read_event(
reader: &mut BufReader<UnixStream>,
) -> Result<Event, Box<dyn std::error::Error>> {
let mut line = String::new();
reader.read_line(&mut line).await?;
let event: Event = serde_json::from_str(&line)?;
Ok(event)
}
/// Helper to read all available events with a timeout
async fn read_events_with_timeout(
reader: &mut BufReader<UnixStream>,
max_count: usize,
timeout_ms: u64,
) -> Vec<String> {
let mut events = Vec::new();
for _ in 0..max_count {
let mut line = String::new();
match timeout(
Duration::from_millis(timeout_ms),
reader.read_line(&mut line),
)
.await
{
Ok(Ok(0)) => break, // EOF
Ok(Ok(_)) => events.push(line),
Ok(Err(_)) => break, // Read error
Err(_) => break, // Timeout
}
}
events
}
#[tokio::test]
async fn test_client_connects_and_receives_event() {
let socket_path = test_socket_path("basic_connect");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
// Give the listener time to start
tokio::time::sleep(Duration::from_millis(50)).await;
// Broadcast an event
let event = Event::new("test_user", "test message");
manager.broadcast(&event).await.unwrap();
// Connect as a client
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut reader = BufReader::new(stream);
// Read the event
let mut line = String::new();
reader.read_line(&mut line).await.unwrap();
assert!(line.contains("test message"));
assert!(line.ends_with('\n'));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_client_receives_event_history() {
let socket_path = test_socket_path("event_history");
let manager = Arc::new(EventManager::new().unwrap());
// Broadcast events BEFORE starting the listener
for i in 0..5 {
let event = Event::new("test_user", format!("historical event {}", i));
manager.broadcast(&event).await.unwrap();
}
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Connect as a client
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut reader = BufReader::new(stream);
// Should receive all 5 historical events
let events = read_events_with_timeout(&mut reader, 5, 100).await;
assert_eq!(events.len(), 5);
assert!(events[0].contains("historical event 0"));
assert!(events[4].contains("historical event 4"));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_multiple_clients_receive_same_events() {
let socket_path = test_socket_path("multiple_clients");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Connect 3 clients
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader1 = BufReader::new(stream1);
let mut reader2 = BufReader::new(stream2);
let mut reader3 = BufReader::new(stream3);
// Broadcast a new event
let event = Event::new("test_user", "broadcast to all");
manager.broadcast(&event).await.unwrap();
// All clients should receive the event
let mut line1 = String::new();
let mut line2 = String::new();
let mut line3 = String::new();
timeout(Duration::from_millis(100), reader1.read_line(&mut line1))
.await
.unwrap()
.unwrap();
timeout(Duration::from_millis(100), reader2.read_line(&mut line2))
.await
.unwrap()
.unwrap();
timeout(Duration::from_millis(100), reader3.read_line(&mut line3))
.await
.unwrap()
.unwrap();
assert!(line1.contains("broadcast to all"));
assert!(line2.contains("broadcast to all"));
assert!(line3.contains("broadcast to all"));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_late_joiner_receives_full_history() {
let socket_path = test_socket_path("late_joiner");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// First client connects
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader1 = BufReader::new(stream1);
// Broadcast several events
for i in 0..10 {
let event = Event::new("test_user", format!("event {}", i));
manager.broadcast(&event).await.unwrap();
}
// Consume events from first client
let _ = read_events_with_timeout(&mut reader1, 10, 100).await;
// Late joiner connects
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader2 = BufReader::new(stream2);
// Late joiner should receive all 10 events from history
let events = read_events_with_timeout(&mut reader2, 10, 100).await;
assert_eq!(events.len(), 10);
assert!(events[0].contains("event 0"));
assert!(events[9].contains("event 9"));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_client_receives_events_in_order() {
let socket_path = test_socket_path("event_order");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Connect client
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut reader = BufReader::new(stream);
// Broadcast events rapidly
let count = 50;
for i in 0..count {
let event = Event::new("test_user", format!("sequence {}", i));
manager.broadcast(&event).await.unwrap();
}
// Read all events
let events = read_events_with_timeout(&mut reader, count, 500).await;
assert_eq!(events.len(), count);
// Verify order
for (i, event) in events.iter().enumerate() {
assert!(
event.contains(&format!("sequence {}", i)),
"Event {} out of order: {}",
i,
event
);
}
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_concurrent_broadcasts_during_client_connections() {
let socket_path = test_socket_path("concurrent_ops");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Connect client 1 BEFORE any broadcasts
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader1 = BufReader::new(stream1);
// Spawn a task that continuously broadcasts
let broadcast_manager = Arc::clone(&manager);
let broadcast_handle = tokio::spawn(async move {
for i in 0..100 {
let event = Event::new("test_user", format!("concurrent event {}", i));
broadcast_manager.broadcast(&event).await.unwrap();
tokio::time::sleep(Duration::from_millis(5)).await;
}
});
// While broadcasting, connect more clients at different times
tokio::time::sleep(Duration::from_millis(100)).await;
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader2 = BufReader::new(stream2);
tokio::time::sleep(Duration::from_millis(150)).await;
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader3 = BufReader::new(stream3);
// Wait for broadcasts to complete
broadcast_handle.await.unwrap();
// All clients should have received events
let events1 = read_events_with_timeout(&mut reader1, 100, 200).await;
let events2 = read_events_with_timeout(&mut reader2, 100, 200).await;
let events3 = read_events_with_timeout(&mut reader3, 100, 200).await;
// Client 1 connected first (before any broadcasts), should get all 100
assert_eq!(events1.len(), 100);
// Client 2 connected after ~20 events were broadcast
// Gets ~20 from history + ~80 live = 100
assert_eq!(events2.len(), 100);
// Client 3 connected after ~50 events were broadcast
// Gets ~50 from history + ~50 live = 100
assert_eq!(events3.len(), 100);
// Verify they all received events in order
assert!(events1[0].contains("concurrent event 0"));
assert!(events2[0].contains("concurrent event 0"));
assert!(events3[0].contains("concurrent event 0"));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_buffer_overflow_affects_new_clients() {
let socket_path = test_socket_path("buffer_overflow");
let manager = Arc::new(EventManager::new().unwrap());
// Broadcast more than buffer max (1000)
for i in 0..1100 {
let event = Event::new("test_user", format!("overflow event {}", i));
manager.broadcast(&event).await.unwrap();
}
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// New client connects
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut reader = BufReader::new(stream);
// Should receive exactly 1000 events (buffer max)
let events = read_events_with_timeout(&mut reader, 1100, 500).await;
assert_eq!(events.len(), 1000);
// First event should be 100 (oldest 100 were evicted)
assert!(events[0].contains("overflow event 100"));
// Last event should be 1099
assert!(events[999].contains("overflow event 1099"));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[rstest]
#[case(10, 1)]
#[case(50, 5)]
#[tokio::test]
async fn test_client_count_scaling(#[case] num_clients: usize, #[case] events_per_client: usize) {
let socket_path = test_socket_path(&format!("scaling_{}_{}", num_clients, events_per_client));
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Connect many clients
let mut readers = Vec::new();
for _ in 0..num_clients {
let stream = UnixStream::connect(&socket_path).await.unwrap();
readers.push(BufReader::new(stream));
}
// Broadcast events
for i in 0..events_per_client {
let event = Event::new("test_user", format!("scale event {}", i));
manager.broadcast(&event).await.unwrap();
}
// Verify all clients received all events
for reader in &mut readers {
let events = read_events_with_timeout(reader, events_per_client, 200).await;
assert_eq!(events.len(), events_per_client);
}
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_client_disconnect_doesnt_affect_others() {
let socket_path = test_socket_path("disconnect");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Connect 3 clients
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
let mut reader1 = BufReader::new(stream1);
let mut reader2 = BufReader::new(stream2);
let mut reader3 = BufReader::new(stream3);
// Broadcast initial event
manager
.broadcast(&Event::new("test_user", "before disconnect"))
.await
.unwrap();
// All receive it
let _ = read_events_with_timeout(&mut reader1, 1, 100).await;
let _ = read_events_with_timeout(&mut reader2, 1, 100).await;
let _ = read_events_with_timeout(&mut reader3, 1, 100).await;
// Drop client 2 (simulates disconnect)
drop(reader2);
// Broadcast another event
manager
.broadcast(&Event::new("test_user", "after disconnect"))
.await
.unwrap();
// Clients 1 and 3 should still receive it
let events1 = read_events_with_timeout(&mut reader1, 1, 100).await;
let events3 = read_events_with_timeout(&mut reader3, 1, 100).await;
assert_eq!(events1.len(), 1);
assert_eq!(events3.len(), 1);
assert!(events1[0].contains("after disconnect"));
assert!(events3[0].contains("after disconnect"));
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}
#[tokio::test]
async fn test_json_deserialization_of_received_events() {
let socket_path = test_socket_path("json_deser");
let manager = Arc::new(EventManager::new().unwrap());
// Start the listener
let listener_manager = Arc::clone(&manager);
let socket_path_clone = socket_path.clone();
tokio::spawn(async move {
listener_manager.start_listening(socket_path_clone).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
// Broadcast an event with special characters
let test_message = "special chars: @#$% newline\\n tab\\t quotes \"test\"";
manager
.broadcast(&Event::new("test_user", test_message))
.await
.unwrap();
// Connect and deserialize
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut reader = BufReader::new(stream);
let mut line = String::new();
reader.read_line(&mut line).await.unwrap();
// Should be valid JSON
let parsed: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(parsed["message"], test_message);
// Cleanup
let _ = std::fs::remove_file(&socket_path);
}

556
tests/setup_test.rs Normal file
View File

@@ -0,0 +1,556 @@
use robotnik::setup::{Args, make_config};
use serial_test::serial;
use std::{fs, path::PathBuf};
use tempfile::TempDir;
/// Helper to create a temporary config file
fn create_config_file(dir: &TempDir, content: &str) -> PathBuf {
let config_path = dir.path().join("config.toml");
fs::write(&config_path, content).unwrap();
config_path
}
/// Helper to parse config using environment and config file
async fn parse_config_from_file(config_path: &PathBuf) -> config::Config {
config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.build()
.unwrap()
}
#[tokio::test]
#[serial]
async fn test_setup_make_config_overrides() {
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"file-key\"
model = \"file-model\"
port = 6667
";
let config_path = create_config_file(&temp, config_content);
// Construct Args with overrides
let args = Args {
api_key: Some("cli-key".to_string()),
base_url: None, /* Should fail if required and not in file/env? No, base-url is optional
* in args */
chroot_dir: None,
command_dir: None,
instruct: None,
model: None, // Should fallback to file
channels: None,
config_file: Some(config_path),
server: None, // Should use default or file? Args has default "irc.libera.chat"
port: Some("9999".to_string()),
nickname: None,
nick_password: None,
username: None,
use_tls: None,
};
let config = make_config(args).expect("Failed to make config");
// Check overrides
assert_eq!(config.get_string("api-key").unwrap(), "cli-key");
assert_eq!(config.get_string("port").unwrap(), "9999");
assert_eq!(config.get_int("port").unwrap(), 9999);
// Check fallback to file
assert_eq!(config.get_string("model").unwrap(), "file-model");
}
#[tokio::test]
async fn test_config_file_loads_all_settings() {
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"test-api-key-123\"
base-url = \"https://api.test.com\"
chroot-dir = \"/test/chroot\"
command-path = \"/test/commands\"
model = \"test-model\"
instruct = \"Test instructions\"
server = \"test.irc.server\"
port = 6667
channels = [\"#test1\", \"#test2\"]
username = \"testuser\"
nickname = \"testnick\"
use-tls = false
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Verify all settings are loaded correctly
assert_eq!(config.get_string("api-key").unwrap(), "test-api-key-123");
assert_eq!(
config.get_string("base-url").unwrap(),
"https://api.test.com"
);
assert_eq!(config.get_string("chroot-dir").unwrap(), "/test/chroot");
assert_eq!(config.get_string("command-path").unwrap(), "/test/commands");
assert_eq!(config.get_string("model").unwrap(), "test-model");
assert_eq!(config.get_string("instruct").unwrap(), "Test instructions");
assert_eq!(config.get_string("server").unwrap(), "test.irc.server");
assert_eq!(config.get_int("port").unwrap(), 6667);
let channels: Vec<String> = config.get("channels").unwrap();
assert_eq!(channels, vec!["#test1", "#test2"]);
assert_eq!(config.get_string("username").unwrap(), "testuser");
assert_eq!(config.get_string("nickname").unwrap(), "testnick");
assert_eq!(config.get_bool("use-tls").unwrap(), false);
}
#[tokio::test]
async fn test_config_file_partial_settings() {
let temp = TempDir::new().unwrap();
// Only provide required settings
let config_content = "\
api-key = \"minimal-key\"
base-url = \"https://minimal.api.com\"
model = \"minimal-model\"
server = \"minimal.server\"
port = 6697
channels = [\"#minimal\"]
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Verify required settings are loaded
assert_eq!(config.get_string("api-key").unwrap(), "minimal-key");
assert_eq!(
config.get_string("base-url").unwrap(),
"https://minimal.api.com"
);
assert_eq!(config.get_string("model").unwrap(), "minimal-model");
// Verify optional settings are not present
assert!(config.get_string("chroot-dir").is_err());
assert!(config.get_string("instruct").is_err());
assert!(config.get_string("username").is_err());
}
#[tokio::test]
#[serial]
async fn test_config_with_environment_variables() {
// NOTE: This test documents a limitation in setup.rs
// setup.rs uses Environment::with_prefix("BOT") without a separator
// This means BOT_API_KEY maps to "api_key", NOT "api-key"
// Since config.toml uses kebab-case, environment variables won't override properly
// This is a known issue in the current implementation
let temp = TempDir::new().unwrap();
let config_content = "\
api_key = \"file-api-key\"
base_url = \"https://file.api.com\"
model = \"file-model\"
";
let config_path = create_config_file(&temp, config_content);
// Set environment variables (with BOT_ prefix as setup.rs uses)
unsafe {
std::env::set_var("BOT_API_KEY", "env-api-key");
std::env::set_var("BOT_MODEL", "env-model");
}
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.add_source(config::Environment::with_prefix("BOT"))
.build()
.unwrap();
// Environment variables should override file settings (when using underscore keys)
assert_eq!(config.get_string("api_key").unwrap(), "env-api-key");
assert_eq!(config.get_string("model").unwrap(), "env-model");
// File setting should be used when no env var
assert_eq!(
config.get_string("base_url").unwrap(),
"https://file.api.com"
);
// Cleanup
unsafe {
std::env::remove_var("BOT_API_KEY");
std::env::remove_var("BOT_MODEL");
}
}
#[tokio::test]
async fn test_command_line_overrides_config_file() {
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"file-api-key\"
base-url = \"https://file.api.com\"
model = \"file-model\"
server = \"file.server\"
port = 6667
channels = [\"#file\"]
nickname = \"filenick\"
username = \"fileuser\"
";
let config_path = create_config_file(&temp, config_content);
// Simulate command-line arguments overriding config file
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.set_override_option("api-key", Some("cli-api-key".to_string()))
.unwrap()
.set_override_option("model", Some("cli-model".to_string()))
.unwrap()
.set_override_option("server", Some("cli.server".to_string()))
.unwrap()
.set_override_option("nickname", Some("clinick".to_string()))
.unwrap()
.build()
.unwrap();
// Command-line values should override file settings
assert_eq!(config.get_string("api-key").unwrap(), "cli-api-key");
assert_eq!(config.get_string("model").unwrap(), "cli-model");
assert_eq!(config.get_string("server").unwrap(), "cli.server");
assert_eq!(config.get_string("nickname").unwrap(), "clinick");
// Non-overridden values should come from file
assert_eq!(
config.get_string("base-url").unwrap(),
"https://file.api.com"
);
assert_eq!(config.get_string("username").unwrap(), "fileuser");
assert_eq!(config.get_int("port").unwrap(), 6667);
}
#[tokio::test]
#[serial]
async fn test_command_line_overrides_environment_and_file() {
let temp = TempDir::new().unwrap();
let config_content = "\
api_key = \"file-api-key\"
model = \"file-model\"
base_url = \"https://file.api.com\"
";
let config_path = create_config_file(&temp, config_content);
// Set environment variable
unsafe {
std::env::set_var("BOT_API_KEY", "env-api-key");
}
// Build config with all three sources
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.add_source(config::Environment::with_prefix("BOT"))
.set_override_option("api_key", Some("cli-api-key".to_string()))
.unwrap()
.build()
.unwrap();
// Command-line should win over both environment and file
assert_eq!(config.get_string("api_key").unwrap(), "cli-api-key");
// Cleanup
unsafe {
std::env::remove_var("BOT_API_KEY");
}
}
#[tokio::test]
#[serial]
async fn test_precedence_order() {
// Test: CLI > Environment > Config File > Defaults
// Using underscore keys to match how setup.rs actually works
let temp = TempDir::new().unwrap();
let config_content = "\
api_key = \"file-key\"
base_url = \"https://file-url.com\"
model = \"file-model\"
server = \"file-server\"
";
let config_path = create_config_file(&temp, config_content);
// Set environment variables
unsafe {
std::env::set_var("BOT_BASE_URL", "https://env-url.com");
std::env::set_var("BOT_MODEL", "env-model");
}
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.add_source(config::Environment::with_prefix("BOT"))
.set_override_option("model", Some("cli-model".to_string()))
.unwrap()
.build()
.unwrap();
// CLI overrides everything
assert_eq!(config.get_string("model").unwrap(), "cli-model");
// Environment overrides file
assert_eq!(
config.get_string("base_url").unwrap(),
"https://env-url.com"
);
// File is used when no env or CLI
assert_eq!(config.get_string("api_key").unwrap(), "file-key");
assert_eq!(config.get_string("server").unwrap(), "file-server");
// Cleanup
unsafe {
std::env::remove_var("BOT_BASE_URL");
std::env::remove_var("BOT_MODEL");
}
}
#[tokio::test]
async fn test_boolean_use_tls_setting() {
let temp = TempDir::new().unwrap();
// Test with use-tls = true (kebab-case as in config.toml)
let config_content_true = r#"
use-tls = true
"#;
let config_path = create_config_file(&temp, config_content_true);
let config = parse_config_from_file(&config_path).await;
assert_eq!(config.get_bool("use-tls").unwrap(), true);
// Test with use-tls = false
let config_content_false = r#"
use-tls = false
"#;
let config_path = create_config_file(&temp, config_content_false);
let config = parse_config_from_file(&config_path).await;
assert_eq!(config.get_bool("use-tls").unwrap(), false);
}
#[tokio::test]
async fn test_use_tls_naming_inconsistency() {
// This test documents a bug: setup.rs uses "use_tls" (underscore)
// but config.toml uses "use-tls" (kebab-case)
let temp = TempDir::new().unwrap();
let config_content = r#"
use-tls = true
"#;
let config_path = create_config_file(&temp, config_content);
// Build config the way setup.rs does it
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
// setup.rs line 119 uses "use_tls" (underscore) instead of "use-tls" (kebab)
.set_override_option("use_tls", Some(false))
.unwrap()
.build()
.unwrap();
// This should read from the override (false), not the file (true)
// But due to the naming mismatch, it might not work as expected
// The config file uses "use-tls" but the override uses "use_tls"
// With kebab-case (matches config.toml)
assert_eq!(config.get_bool("use-tls").unwrap(), true);
// With underscore (matches setup.rs override)
assert_eq!(config.get_bool("use_tls").unwrap(), false);
}
#[tokio::test]
async fn test_channels_as_array() {
let temp = TempDir::new().unwrap();
let config_content = "\
channels = [\"#chan1\", \"#chan2\", \"#chan3\"]
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
let channels: Vec<String> = config.get("channels").unwrap();
assert_eq!(channels.len(), 3);
assert_eq!(channels[0], "#chan1");
assert_eq!(channels[1], "#chan2");
assert_eq!(channels[2], "#chan3");
}
#[tokio::test]
async fn test_channels_override_from_cli() {
let temp = TempDir::new().unwrap();
let config_content = "\
channels = [\"#file1\", \"#file2\"]
";
let config_path = create_config_file(&temp, config_content);
let cli_channels = vec![
"#cli1".to_string(),
"#cli2".to_string(),
"#cli3".to_string(),
];
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.set_override_option("channels", Some(cli_channels.clone()))
.unwrap()
.build()
.unwrap();
let channels: Vec<String> = config.get("channels").unwrap();
assert_eq!(channels, cli_channels);
assert_eq!(channels.len(), 3);
}
#[tokio::test]
async fn test_port_as_integer() {
let temp = TempDir::new().unwrap();
let config_content = r#"
port = 6697
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Port should be readable as both integer and string
assert_eq!(config.get_int("port").unwrap(), 6697);
assert_eq!(config.get_string("port").unwrap(), "6697");
}
#[tokio::test]
async fn test_port_override_from_cli_as_string() {
// setup.rs passes port as Option<String> from clap
let temp = TempDir::new().unwrap();
let config_content = r#"
port = 6667
"#;
let config_path = create_config_file(&temp, config_content);
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.set_override_option("port", Some("9999".to_string()))
.unwrap()
.build()
.unwrap();
// CLI override should work
assert_eq!(config.get_string("port").unwrap(), "9999");
assert_eq!(config.get_int("port").unwrap(), 9999);
}
#[tokio::test]
async fn test_missing_required_fields_fails() {
let temp = TempDir::new().unwrap();
// Create config without required api-key
let config_content = r#"
model = "test-model"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Should fail when trying to get required field
assert!(config.get_string("api-key").is_err());
}
#[tokio::test]
async fn test_optional_instruct_field() {
let temp = TempDir::new().unwrap();
let config_content = r#"
instruct = "Custom bot instructions"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
assert_eq!(
config.get_string("instruct").unwrap(),
"Custom bot instructions"
);
}
#[tokio::test]
async fn test_command_path_field() {
// command-path is in config.toml but not used anywhere in the code
let temp = TempDir::new().unwrap();
let config_content = r#"
command-path = "/custom/commands"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
assert_eq!(
config.get_string("command-path").unwrap(),
"/custom/commands"
);
}
#[tokio::test]
async fn test_chroot_dir_field() {
let temp = TempDir::new().unwrap();
let config_content = r#"
chroot-dir = "/var/lib/bot/root"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
assert_eq!(
config.get_string("chroot-dir").unwrap(),
"/var/lib/bot/root"
);
}
#[tokio::test]
async fn test_empty_config_file() {
let temp = TempDir::new().unwrap();
let config_content = "";
let config_path = create_config_file(&temp, config_content);
// Should build successfully but have no values
let config = parse_config_from_file(&config_path).await;
assert!(config.get_string("api-key").is_err());
assert!(config.get_string("model").is_err());
}
#[tokio::test]
async fn test_all_cli_override_keys_match_config_format() {
// This test documents which override keys in setup.rs match the config.toml format
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"test\"
base-url = \"https://test.com\"
chroot-dir = \"/test\"
command-path = \"/cmds\"
model = \"test-model\"
instruct = \"test\"
channels = [\"#test\"]
server = \"test.server\"
port = 6697
nickname = \"test\"
username = \"test\"
use-tls = true
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// All these should work with kebab-case (as in config.toml)
assert!(config.get_string("api-key").is_ok());
assert!(config.get_string("base-url").is_ok());
assert!(config.get_string("chroot-dir").is_ok());
assert!(config.get_string("command-path").is_ok());
assert!(config.get_string("model").is_ok());
assert!(config.get_string("instruct").is_ok());
let channels_result: Result<Vec<String>, _> = config.get("channels");
assert!(channels_result.is_ok());
assert!(config.get_string("server").is_ok());
assert!(config.get_int("port").is_ok());
assert!(config.get_string("nickname").is_ok());
assert!(config.get_string("username").is_ok());
assert!(config.get_bool("use-tls").is_ok());
}