Compare commits

4 Commits

Author SHA1 Message Date
Micheal Smith
c3b168f86f Added configuration tests. 2025-11-29 16:31:15 -06:00
Micheal Smith
b46a03c13e Added documentation. 2025-11-23 13:40:50 -06:00
Micheal Smith
17b087e618 Implemented external processes as potential plugins. 2025-11-20 04:30:18 -06:00
Micheal Smith
30e2d9a448 Renamed commands/Command to plugin/Plugin. 2025-11-14 07:19:28 -06:00
16 changed files with 1303 additions and 145 deletions

76
Cargo.lock generated
View File

@@ -180,9 +180,9 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.10.1" version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
[[package]] [[package]]
name = "cargo-husky" name = "cargo-husky"
@@ -226,9 +226,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.51" version = "4.5.53"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c26d721170e0295f191a69bd9a1f93efcdb0aff38684b61ab5750468972e5f5" checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@@ -236,9 +236,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.51" version = "4.5.53"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75835f0c7bf681bfd05abe44e965760fea999a5286c6eb2d59883634fd02011a" checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@@ -824,9 +824,9 @@ dependencies = [
[[package]] [[package]]
name = "genai" name = "genai"
version = "0.4.3" version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48317c8c4a7011ffb748502f9c45408a351103ad225f26825d84f2ff0ac49b25" checksum = "814c33e79506556ecba6b5f8e39a2fe423262fd3903856377ad2ae6a857c6032"
dependencies = [ dependencies = [
"bytes", "bytes",
"derive_more 2.0.1", "derive_more 2.0.1",
@@ -2130,6 +2130,7 @@ name = "robotnik"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"better-panic", "better-panic",
"bytes",
"cargo-husky", "cargo-husky",
"clap", "clap",
"color-eyre", "color-eyre",
@@ -2143,6 +2144,7 @@ dependencies = [
"rstest", "rstest",
"serde", "serde",
"serde_json", "serde_json",
"serial_test",
"tempfile", "tempfile",
"tokio", "tokio",
"tracing", "tracing",
@@ -2284,6 +2286,15 @@ version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "scc"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc"
dependencies = [
"sdd",
]
[[package]] [[package]]
name = "schannel" name = "schannel"
version = "0.1.28" version = "0.1.28"
@@ -2323,6 +2334,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdd"
version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca"
[[package]] [[package]]
name = "security-framework" name = "security-framework"
version = "2.11.1" version = "2.11.1"
@@ -2439,9 +2456,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_with" name = "serde_with"
version = "3.15.1" version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa66c845eee442168b2c8134fec70ac50dc20e760769c8ba0ad1319ca1959b04" checksum = "10574371d41b0d9b2cff89418eda27da52bcaff2cc8741db26382a77c29131f1"
dependencies = [ dependencies = [
"base64", "base64",
"chrono", "chrono",
@@ -2458,9 +2475,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_with_macros" name = "serde_with_macros"
version = "3.15.1" version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91a903660542fced4e99881aa481bdbaec1634568ee02e0b8bd57c64cb38955" checksum = "08a72d8216842fdd57820dc78d840bef99248e35fb2554ff923319e60f2d686b"
dependencies = [ dependencies = [
"darling", "darling",
"proc-macro2", "proc-macro2",
@@ -2468,6 +2485,31 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "serial_test"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b258109f244e1d6891bf1053a55d63a5cd4f8f4c30cf9a1280989f80e7a1fa9"
dependencies = [
"futures",
"log",
"once_cell",
"parking_lot",
"scc",
"serial_test_derive",
]
[[package]]
name = "serial_test_derive"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.9" version = "0.10.9"
@@ -2494,6 +2536,15 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.11" version = "0.4.11"
@@ -2702,6 +2753,7 @@ dependencies = [
"libc", "libc",
"mio", "mio",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
"windows-sys 0.61.2", "windows-sys 0.61.2",

View File

@@ -5,6 +5,7 @@ edition = "2024"
[dependencies] [dependencies]
better-panic = "0.3.0" better-panic = "0.3.0"
bytes = "1"
color-eyre = "0.6.3" color-eyre = "0.6.3"
directories = "6.0" directories = "6.0"
futures = "0.3" futures = "0.3"
@@ -15,36 +16,42 @@ serde_json = "1.0"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
[dependencies.nix] [dependencies.nix]
version = "0.30.1" version = "0.30.1"
features = [ "fs" ] features = ["fs", "resource"]
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"
features = [ "derive" ] features = ["derive"]
[dependencies.config] [dependencies.config]
version = "0.15" version = "0.15"
features = [ "toml" ] features = ["toml"]
[dependencies.serde] [dependencies.serde]
version = "1.0" version = "1.0"
features = [ "derive" ] features = ["derive"]
[dependencies.tokio] [dependencies.tokio]
version = "1" version = "1"
features = [ "io-util", "macros", "net", "rt-multi-thread", "sync" ] features = [
"io-util",
"macros",
"net",
"process",
"rt-multi-thread",
"sync",
"time",
]
[dev-dependencies] [dev-dependencies]
rstest = "0.24" rstest = "0.24"
serial_test = "3.2"
tempfile = "3.13" tempfile = "3.13"
[dev-dependencies.cargo-husky] [dev-dependencies.cargo-husky]
version = "1" version = "1"
features = [ features = ["run-cargo-check", "run-cargo-clippy"]
"run-cargo-check",
"run-cargo-clippy",
]
[profile.release] [profile.release]
strip = true strip = true

View File

@@ -2,6 +2,8 @@ edition = "2024"
style_edition = "2024" style_edition = "2024"
comment_width = 100 comment_width = 100
format_code_in_doc_comments = true format_code_in_doc_comments = true
format_macro_bodies = true
format_macro_matchers = true
imports_granularity = "Crate" imports_granularity = "Crate"
imports_layout = "HorizontalVertical" imports_layout = "HorizontalVertical"
wrap_comments = true wrap_comments = true

View File

@@ -1,3 +1,8 @@
//! Handles interaction with IRC.
//!
//! Each instance of [`Chat`] handles a single connection to an IRC
//! server.
use std::sync::Arc; use std::sync::Arc;
use color_eyre::{Result, eyre::WrapErr}; use color_eyre::{Result, eyre::WrapErr};
@@ -7,50 +12,57 @@ use irc::client::prelude::{Client, Command, Config as IRCConfig, Message};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{Level, event, instrument}; use tracing::{Level, event, instrument};
use crate::{Event, EventManager, LLMHandle, commands}; use crate::{Event, EventManager, LLMHandle, plugin};
/// Chat struct that is used to interact with IRC chat.
#[derive(Debug)] #[derive(Debug)]
pub struct Chat { pub struct Chat {
/// The actual IRC [`irc::client`](client).
client: Client, client: Client,
/// Event manager for handling plugin interaction.
event_manager: Arc<EventManager>, event_manager: Arc<EventManager>,
/// Handle for whichever LLM is being used.
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc. llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
} }
// Need: owners, channels, username, nick, server, password
#[instrument]
pub async fn new(
settings: &MainConfig,
handle: &LLMHandle,
manager: Arc<EventManager>,
) -> Result<Chat> {
// Going to just assign and let the irc library handle errors for now, and
// add my own checking if necessary.
let port: u16 = settings.get("port")?;
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
event!(Level::INFO, "Channels = {:?}", channels);
let config = IRCConfig {
server: settings.get_string("server").ok(),
nickname: settings.get_string("nickname").ok(),
port: Some(port),
username: settings.get_string("username").ok(),
use_tls: settings.get_bool("use_tls").ok(),
channels,
..IRCConfig::default()
};
event!(Level::INFO, "IRC connection starting...");
Ok(Chat {
client: Client::from_config(config).await?,
llm_handle: handle.clone(),
event_manager: manager,
})
}
impl Chat { impl Chat {
pub async fn run(&mut self, mut command_in: mpsc::Receiver<commands::Command>) -> Result<()> { // Need: owners, channels, username, nick, server, password rather than reading
// the config values directly.
/// Creates a new [`Chat`].
#[instrument]
pub async fn new(
settings: &MainConfig,
handle: &LLMHandle,
manager: Arc<EventManager>,
) -> Result<Chat> {
// Going to just assign and let the irc library handle errors for now, and
// add my own checking if necessary.
let port: u16 = settings.get("port")?;
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
event!(Level::INFO, "Channels = {:?}", channels);
let config = IRCConfig {
server: settings.get_string("server").ok(),
nickname: settings.get_string("nickname").ok(),
port: Some(port),
username: settings.get_string("username").ok(),
use_tls: settings.get_bool("use_tls").ok(),
channels,
..IRCConfig::default()
};
event!(Level::INFO, "IRC connection starting...");
Ok(Chat {
client: Client::from_config(config).await?,
llm_handle: handle.clone(),
event_manager: manager,
})
}
/// Drives the event loop for the chat.
pub async fn run(&mut self, mut command_in: mpsc::Receiver<plugin::PluginMsg>) -> Result<()> {
self.client.identify()?; self.client.identify()?;
let mut stream = self.client.stream()?; let mut stream = self.client.stream()?;
@@ -69,7 +81,7 @@ impl Chat {
command = command_in.recv() => { command = command_in.recv() => {
event!(Level::INFO, "Received command {:#?}", command); event!(Level::INFO, "Received command {:#?}", command);
match command { match command {
Some(commands::Command::SendMessage {channel, message} ) => { Some(plugin::PluginMsg::SendMessage {channel, message} ) => {
// Now to pass on the message. // Now to pass on the message.
event!(Level::INFO, "Trying to send to channel."); event!(Level::INFO, "Trying to send to channel.");
self.client.send_privmsg(&channel, &message).wrap_err("Couldn't send to channel")?; self.client.send_privmsg(&channel, &message).wrap_err("Couldn't send to channel")?;

193
src/command.rs Normal file
View File

@@ -0,0 +1,193 @@
//! Commands that are associated with external processes (commands).
//!
//! Process based plugins are just an assortment of executable files in
//! a provided directory. They are given arguments, and the response from
//! them is expected on stdout.
use std::{
path::{Path, PathBuf},
time::Duration,
};
use bytes::Bytes;
use color_eyre::{Result, eyre::eyre};
use tokio::{fs::try_exists, process::Command, time::timeout};
use tracing::{Level, event};
/// Handle containing information about the directory containing commands.
#[derive(Debug)]
pub struct CommandDir {
command_path: PathBuf,
}
impl CommandDir {
/// Register a path containing commands.
pub fn new(command_path: impl AsRef<Path>) -> Self {
event!(
Level::INFO,
"CommandDir initialized with path: {:?}",
command_path.as_ref()
);
CommandDir {
command_path: command_path.as_ref().to_path_buf(),
}
}
/// Look for a command. If it exists Ok(path) is returned.
async fn find_command(&self, name: impl AsRef<Path>) -> Result<String> {
let path = self.command_path.join(name.as_ref());
event!(
Level::INFO,
"Looking for {} command.",
name.as_ref().display()
);
match try_exists(&path).await {
Ok(true) => Ok(path.to_string_lossy().to_string()),
Ok(false) => Err(eyre!(format!("{} Not found.", path.to_string_lossy()))),
Err(e) => Err(e.into()),
}
}
/// Run the given [`command_name`]. It should exist in the directory provided as
/// the command_path.
pub async fn run_command(
&self,
command_name: impl AsRef<str>,
input: impl AsRef<str>,
) -> Result<Bytes> {
let path = self.find_command(Path::new(command_name.as_ref())).await?;
// Well it exists let's cross our fingers...
let output = Command::new(path).arg(input.as_ref()).output().await?;
if output.status.success() {
// So far so good
Ok(Bytes::from(output.stdout))
} else {
// Whoops
Err(eyre!(format!(
"Error running {}: {}",
command_name.as_ref(),
output.status
)))
}
}
/// [`run_command`] but with a timeout.
pub async fn run_command_with_timeout(
&self,
command_name: impl AsRef<str>,
input: impl AsRef<str>,
time_out: Duration,
) -> Result<Bytes> {
timeout(time_out, self.run_command(command_name, input)).await?
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
fs::{self, Permissions},
os::unix::fs::PermissionsExt,
};
use tempfile::TempDir;
fn create_test_script(dir: &Path, name: &str, script: &str) -> PathBuf {
let path = dir.join(name);
fs::write(&path, script).unwrap();
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
path
}
#[test]
fn test_command_dir_new() {
let dir = CommandDir::new("/some/path");
assert_eq!(dir.command_path, PathBuf::from("/some/path"));
}
#[tokio::test]
async fn test_find_command_exists() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "test_cmd", "#!/bin/bash\necho hello");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.find_command("test_cmd").await;
assert!(result.is_ok());
assert!(result.unwrap().contains("test_cmd"));
}
#[tokio::test]
async fn test_find_command_not_found() {
let temp = TempDir::new().unwrap();
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.find_command("nonexistent").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not found"));
}
#[tokio::test]
async fn test_run_command_success() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "echo_cmd", "#!/bin/bash\necho \"$1\"");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.run_command("echo_cmd", "hello world").await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.as_ref(), b"hello world\n");
}
#[tokio::test]
async fn test_run_command_failure() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "fail_cmd", "#!/bin/bash\nexit 1");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.run_command("fail_cmd", "input").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Error running"));
}
#[tokio::test]
async fn test_run_command_not_found() {
let temp = TempDir::new().unwrap();
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir.run_command("missing", "input").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_run_command_with_timeout_success() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "fast_cmd", "#!/bin/bash\necho \"$1\"");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir
.run_command_with_timeout("fast_cmd", "quick", Duration::from_secs(5))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_run_command_with_timeout_expires() {
let temp = TempDir::new().unwrap();
create_test_script(temp.path(), "slow_cmd", "#!/bin/bash\nsleep 10\necho done");
let cmd_dir = CommandDir::new(temp.path());
let result = cmd_dir
.run_command_with_timeout("slow_cmd", "input", Duration::from_millis(100))
.await;
assert!(result.is_err());
}
}

View File

@@ -1,18 +0,0 @@
use std::fmt::Display;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub enum Command {
SendMessage { channel: String, message: String },
}
impl Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SendMessage { channel, message } => {
write!(f, "[{channel}]: {message}")
}
}
}
}

View File

@@ -1,13 +1,19 @@
//! Internal representations of incoming events.
use irc::proto::{Command, Message}; use irc::proto::{Command, Message};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Represents an event. Probably from IRC.
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct Event { pub struct Event {
/// Who is the message from?
from: String, from: String,
/// What is the message?
message: String, message: String,
} }
impl Event { impl Event {
/// Creates a new message.
pub fn new(from: impl Into<String>, msg: impl Into<String>) -> Self { pub fn new(from: impl Into<String>, msg: impl Into<String>) -> Self {
Self { Self {
from: from.into(), from: from.into(),

View File

@@ -1,3 +1,5 @@
//! Handler for events to and from IPC, and process plugins.
use std::{collections::VecDeque, path::Path, sync::Arc}; use std::{collections::VecDeque, path::Path, sync::Arc};
use color_eyre::Result; use color_eyre::Result;
@@ -9,12 +11,14 @@ use tokio::{
}; };
use tracing::{error, info}; use tracing::{error, info};
use crate::{commands::Command, event::Event}; use crate::{event::Event, plugin::PluginMsg};
// Hard coding for now. Maybe make this a parameter to new. // Hard coding for now. Maybe make this a parameter to new.
const EVENT_BUF_MAX: usize = 1000; const EVENT_BUF_MAX: usize = 1000;
// Manager for communication with plugins. /// Manager for communication with plugins.
///
/// Keeps events in a ring buffer to track a certain amount of history.
#[derive(Debug)] #[derive(Debug)]
pub struct EventManager { pub struct EventManager {
announce: broadcast::Sender<String>, // Everything broadcasts here. announce: broadcast::Sender<String>, // Everything broadcasts here.
@@ -22,6 +26,7 @@ pub struct EventManager {
} }
impl EventManager { impl EventManager {
/// Create a new [`EventManager``].
pub fn new() -> Result<Self> { pub fn new() -> Result<Self> {
let (announce, _) = broadcast::channel(100); let (announce, _) = broadcast::channel(100);
@@ -31,6 +36,7 @@ impl EventManager {
}) })
} }
/// Broadcast an event to every subscribed listener.
pub async fn broadcast(&self, event: &Event) -> Result<()> { pub async fn broadcast(&self, event: &Event) -> Result<()> {
let msg = serde_json::to_string(event)? + "\n"; let msg = serde_json::to_string(event)? + "\n";
@@ -49,7 +55,10 @@ impl EventManager {
} }
// NB: This assumes it has exclusive control of the FIFO. // NB: This assumes it has exclusive control of the FIFO.
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<Command>) -> Result<()> /// Opens a fifo at [`path`]. This is where some plugins can send response events
/// to. The messages MUST be formatted in JSON and match one of the possible
/// [`PluginMsg`](plugin messages).
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<PluginMsg>) -> Result<()>
where where
P: AsRef<Path> + NixPath + ?Sized, P: AsRef<Path> + NixPath + ?Sized,
{ {
@@ -65,7 +74,7 @@ impl EventManager {
while reader.read_line(&mut line).await? > 0 { while reader.read_line(&mut line).await? > 0 {
// Now handle the command. // Now handle the command.
let cmd: Command = serde_json::from_str(&line)?; let cmd: PluginMsg = serde_json::from_str(&line)?;
info!("Command received: {:?}.", cmd); info!("Command received: {:?}.", cmd);
command_tx.send(cmd).await?; command_tx.send(cmd).await?;
line.clear(); line.clear();
@@ -73,6 +82,8 @@ impl EventManager {
} }
} }
/// Start a UNIX socket that will provide broadcast messages to any client that opens
/// the socket for listening.
pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) { pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) {
let listener = UnixListener::bind(broadcast_path).unwrap(); let listener = UnixListener::bind(broadcast_path).unwrap();
@@ -93,6 +104,7 @@ impl EventManager {
} }
} }
/// Send any events queued up to the [`stream`].
async fn send_events(&self, stream: UnixStream) -> Result<()> { async fn send_events(&self, stream: UnixStream) -> Result<()> {
let mut writer = stream; let mut writer = stream;
@@ -316,7 +328,7 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Write a command to the FIFO // Write a command to the FIFO
let cmd = Command::SendMessage { let cmd = PluginMsg::SendMessage {
channel: "#test".to_string(), channel: "#test".to_string(),
message: "hello".to_string(), message: "hello".to_string(),
}; };
@@ -338,7 +350,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match received { match received {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#test"); assert_eq!(channel, "#test");
assert_eq!(message, "hello"); assert_eq!(message, "hello");
} }
@@ -362,15 +374,15 @@ mod tests {
// Write multiple commands // Write multiple commands
let commands = vec![ let commands = vec![
Command::SendMessage { PluginMsg::SendMessage {
channel: "#chan1".to_string(), channel: "#chan1".to_string(),
message: "first".to_string(), message: "first".to_string(),
}, },
Command::SendMessage { PluginMsg::SendMessage {
channel: "#chan2".to_string(), channel: "#chan2".to_string(),
message: "second".to_string(), message: "second".to_string(),
}, },
Command::SendMessage { PluginMsg::SendMessage {
channel: "#chan3".to_string(), channel: "#chan3".to_string(),
message: "third".to_string(), message: "third".to_string(),
}, },
@@ -395,7 +407,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match first { match first {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#chan1"); assert_eq!(channel, "#chan1");
assert_eq!(message, "first"); assert_eq!(message, "first");
} }
@@ -407,7 +419,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match second { match second {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#chan2"); assert_eq!(channel, "#chan2");
assert_eq!(message, "second"); assert_eq!(message, "second");
} }
@@ -419,7 +431,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match third { match third {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#chan3"); assert_eq!(channel, "#chan3");
assert_eq!(message, "third"); assert_eq!(message, "third");
} }
@@ -449,7 +461,7 @@ mod tests {
let tx = pipe::OpenOptions::new().open_sender(&path).unwrap(); let tx = pipe::OpenOptions::new().open_sender(&path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx); let mut tx = tokio::io::BufWriter::new(tx);
let cmd = Command::SendMessage { let cmd = PluginMsg::SendMessage {
channel: "#first".to_string(), channel: "#first".to_string(),
message: "batch1".to_string(), message: "batch1".to_string(),
}; };
@@ -465,7 +477,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match first { match first {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#first"); assert_eq!(channel, "#first");
assert_eq!(message, "batch1"); assert_eq!(message, "batch1");
} }
@@ -482,7 +494,7 @@ mod tests {
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap(); let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx); let mut tx = tokio::io::BufWriter::new(tx);
let cmd = Command::SendMessage { let cmd = PluginMsg::SendMessage {
channel: "#second".to_string(), channel: "#second".to_string(),
message: "batch2".to_string(), message: "batch2".to_string(),
}; };
@@ -497,7 +509,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match second { match second {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#second"); assert_eq!(channel, "#second");
assert_eq!(message, "batch2"); assert_eq!(message, "batch2");
} }
@@ -524,7 +536,7 @@ mod tests {
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap(); let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
let mut tx = tokio::io::BufWriter::new(tx); let mut tx = tokio::io::BufWriter::new(tx);
let cmd1 = Command::SendMessage { let cmd1 = PluginMsg::SendMessage {
channel: "#test".to_string(), channel: "#test".to_string(),
message: "first".to_string(), message: "first".to_string(),
}; };
@@ -537,7 +549,7 @@ mod tests {
// Write whitespace line // Write whitespace line
tx.write_all(b" \n").await.unwrap(); tx.write_all(b" \n").await.unwrap();
let cmd2 = Command::SendMessage { let cmd2 = PluginMsg::SendMessage {
channel: "#test".to_string(), channel: "#test".to_string(),
message: "second".to_string(), message: "second".to_string(),
}; };
@@ -553,7 +565,7 @@ mod tests {
.expect("channel closed"); .expect("channel closed");
match first { match first {
Command::SendMessage { channel, message } => { PluginMsg::SendMessage { channel, message } => {
assert_eq!(channel, "#test"); assert_eq!(channel, "#test");
assert_eq!(message, "first"); assert_eq!(message, "first");
} }

View File

@@ -1,26 +0,0 @@
// Provides an IPC socket to communicate with other processes.
use std::path::Path;
use color_eyre::Result;
use tokio::net::UnixListener;
pub struct IPC {
listener: UnixListener,
}
impl IPC {
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let listener = UnixListener::bind(path)?;
Ok(Self { listener })
}
pub async fn run(&self) -> Result<()> {
loop {
match self.listener.accept().await {
Ok((_stream, _addr)) => {}
Err(e) => return Err(e.into()),
}
}
}
}

View File

@@ -1,4 +1,5 @@
// Robotnik libraries #![warn(missing_docs)]
#![doc = include_str!("../README.md")]
use std::{os::unix::fs, sync::Arc}; use std::{os::unix::fs, sync::Arc};
@@ -9,13 +10,14 @@ use tracing::{Level, info};
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::FmtSubscriber;
pub mod chat; pub mod chat;
pub mod commands; pub mod command;
pub mod event; pub mod event;
pub mod event_manager; pub mod event_manager;
pub mod ipc; pub mod plugin;
pub mod qna; pub mod qna;
pub mod setup; pub mod setup;
pub use chat::Chat;
pub use event::Event; pub use event::Event;
pub use event_manager::EventManager; pub use event_manager::EventManager;
pub use qna::LLMHandle; pub use qna::LLMHandle;
@@ -25,7 +27,9 @@ const DEFAULT_INSTRUCT: &str =
be sent in a single IRC response according to the specification. Keep answers to be sent in a single IRC response according to the specification. Keep answers to
500 characters or less."; 500 characters or less.";
// NB: Everything should fail if logging doesn't start properly. /// Initialize all logging facilities.
///
/// This should cause a panic if there's a failure.
async fn init_logging() { async fn init_logging() {
better_panic::install(); better_panic::install();
setup_panic!(); setup_panic!();
@@ -37,6 +41,10 @@ async fn init_logging() {
tracing::subscriber::set_global_default(subscriber).unwrap(); tracing::subscriber::set_global_default(subscriber).unwrap();
} }
/// Sets up and runs the main event loop.
///
/// Should return an error if it's recoverable, but could panic if something
/// is particularly bad.
pub async fn run() -> Result<()> { pub async fn run() -> Result<()> {
init_logging().await; init_logging().await;
info!("Starting up."); info!("Starting up.");
@@ -69,7 +77,7 @@ pub async fn run() -> Result<()> {
let ev_manager = Arc::new(EventManager::new()?); let ev_manager = Arc::new(EventManager::new()?);
let ev_manager_clone = Arc::clone(&ev_manager); let ev_manager_clone = Arc::clone(&ev_manager);
let mut c = chat::new(&config, &handle, Arc::clone(&ev_manager)).await?; let mut c = Chat::new(&config, &handle, Arc::clone(&ev_manager)).await?;
let (from_plugins, to_chat) = mpsc::channel(100); let (from_plugins, to_chat) = mpsc::channel(100);

37
src/plugin.rs Normal file
View File

@@ -0,0 +1,37 @@
//! Plugin command definitions.
// Dear future me: If you forget the JSON translations in the future you'll
// thank me for the comment overkill.
use std::fmt::Display;
use serde::{Deserialize, Serialize};
/// Message types accepted from plugins.
#[derive(Debug, Deserialize, Serialize)]
pub enum PluginMsg {
/// Plugin message indicating the bot should send a [`message`] to [`channel`].
/// {
/// "SendMessage": {
/// "channel": "channel_name",
/// "message": "your message here"
/// }
///
/// }
SendMessage {
/// The IRC channel to send the [`message`] to.
channel: String,
/// The [`message`] to send.
message: String,
},
}
impl Display for PluginMsg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SendMessage { channel, message } => {
write!(f, "[{channel}]: {message}")
}
}
}
}

View File

@@ -1,3 +1,5 @@
//! Handles communication with a genai compatible LLM.
use color_eyre::Result; use color_eyre::Result;
use futures::StreamExt; use futures::StreamExt;
use genai::{ use genai::{
@@ -8,8 +10,11 @@ use genai::{
}; };
use tracing::info; use tracing::info;
// NB: Docs are quick and dirty as this might move into a plugin.
// Represents an LLM completion source. // Represents an LLM completion source.
// FIXME: Clone is probably temporary. // FIXME: Clone is probably temporary.
/// Struct containing information about the LLM.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct LLMHandle { pub struct LLMHandle {
chat_request: ChatRequest, chat_request: ChatRequest,
@@ -18,6 +23,7 @@ pub struct LLMHandle {
} }
impl LLMHandle { impl LLMHandle {
/// Create a new handle.
pub fn new( pub fn new(
api_key: String, api_key: String,
_base_url: impl AsRef<str>, _base_url: impl AsRef<str>,
@@ -44,6 +50,7 @@ impl LLMHandle {
}) })
} }
/// Send a chat message to the LLM with the response being returned as a [`String`].
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> { pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
let mut req = self.chat_request.clone(); let mut req = self.chat_request.clone();
let client = self.client.clone(); let client = self.client.clone();

View File

@@ -1,3 +1,7 @@
//! Handles configuration for the bot.
//!
//! Both command line, and configuration file options are handled here.
use clap::Parser; use clap::Parser;
use color_eyre::{Result, eyre::WrapErr}; use color_eyre::{Result, eyre::WrapErr};
use config::Config; use config::Config;
@@ -6,6 +10,7 @@ use std::path::PathBuf;
use tracing::{info, instrument}; use tracing::{info, instrument};
// TODO: use [clap(long, short, help_heading = Some(section))] // TODO: use [clap(long, short, help_heading = Some(section))]
/// Struct of potential arguments.
#[derive(Clone, Debug, Parser)] #[derive(Clone, Debug, Parser)]
#[command(about, version)] #[command(about, version)]
pub struct Args { pub struct Args {
@@ -30,6 +35,7 @@ pub struct Args {
pub instruct: Option<String>, pub instruct: Option<String>,
#[arg(long)] #[arg(long)]
/// Name of the model to use. E.g. 'deepseek-chat'
pub model: Option<String>, pub model: Option<String>,
#[arg(long)] #[arg(long)]
@@ -60,21 +66,36 @@ pub struct Args {
/// IRC Username /// IRC Username
pub username: Option<String>, pub username: Option<String>,
#[arg(long)] #[arg(long = "no-tls")]
/// Whether or not to use TLS when connecting to the IRC server. /// Whether or not to use TLS when connecting to the IRC server.
pub use_tls: Option<bool>, pub use_tls: Option<bool>,
} }
/// Handle for interacting with the bot configuration.
pub struct Setup { pub struct Setup {
/// Handle for the configuration file options.
pub config: Config, pub config: Config,
} }
#[instrument] #[instrument]
/// Initialize a new [`Setup`] instance.
///
/// This reads the settings file which becomes the bot's default configuration.
/// These settings shall be overridden by any command line options.
pub async fn init() -> Result<Setup> { pub async fn init() -> Result<Setup> {
// Get arguments. These overrule configuration file, and environment // Get arguments. These overrule configuration file, and environment
// variables if applicable. // variables if applicable.
let args = Args::parse(); let args = Args::parse();
let settings = make_config(args)?;
Ok(Setup { config: settings })
}
/// Create a configuration object from arguments.
///
/// This is exposed for testing purposes.
pub fn make_config(args: Args) -> Result<Config> {
// Use default config location unless specified. // Use default config location unless specified.
let config_location: PathBuf = if let Some(ref path) = args.config_file { let config_location: PathBuf = if let Some(ref path) = args.config_file {
path.to_owned() path.to_owned()
@@ -88,7 +109,7 @@ pub async fn init() -> Result<Setup> {
info!("Starting."); info!("Starting.");
let settings = Config::builder() Config::builder()
.add_source(config::File::with_name(&config_location.to_string_lossy()).required(false)) .add_source(config::File::with_name(&config_location.to_string_lossy()).required(false))
.add_source(config::Environment::with_prefix("BOT")) .add_source(config::Environment::with_prefix("BOT"))
// Doing all of these overrides provides a unified access point for options, // Doing all of these overrides provides a unified access point for options,
@@ -98,15 +119,14 @@ pub async fn init() -> Result<Setup> {
.set_override_option("chroot-dir", args.chroot_dir.clone())? .set_override_option("chroot-dir", args.chroot_dir.clone())?
.set_override_option("command-path", args.command_dir.clone())? .set_override_option("command-path", args.command_dir.clone())?
.set_override_option("model", args.model.clone())? .set_override_option("model", args.model.clone())?
.set_override_option("nick-password", args.nick_password.clone())?
.set_override_option("instruct", args.instruct.clone())? .set_override_option("instruct", args.instruct.clone())?
.set_override_option("channels", args.channels.clone())? .set_override_option("channels", args.channels.clone())?
.set_override_option("server", args.server.clone())? .set_override_option("server", args.server.clone())?
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap. .set_override_option("port", args.port.clone())?
.set_override_option("nickname", args.nickname.clone())? .set_override_option("nickname", args.nickname.clone())?
.set_override_option("username", args.username.clone())? .set_override_option("username", args.username.clone())?
.set_override_option("use_tls", args.use_tls)? .set_override_option("use-tls", args.use_tls)?
.build() .build()
.wrap_err("Couldn't read configuration settings.")?; .wrap_err("Couldn't read configuration settings.")
Ok(Setup { config: settings })
} }

290
tests/command_test.rs Normal file
View File

@@ -0,0 +1,290 @@
use std::{
fs::{self, Permissions},
os::unix::fs::PermissionsExt,
path::Path,
time::Duration,
};
use robotnik::command::CommandDir;
use tempfile::TempDir;
/// Helper to create executable test scripts
fn create_command(dir: &Path, name: &str, script: &str) {
let path = dir.join(name);
fs::write(&path, script).unwrap();
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
}
/// Parse a bot message like "!weather 07008" into (command_name, argument)
fn parse_bot_message(message: &str) -> Option<(&str, &str)> {
if !message.starts_with('!') {
return None;
}
let without_prefix = &message[1..];
let mut parts = without_prefix.splitn(2, ' ');
let command = parts.next()?;
let arg = parts.next().unwrap_or("");
Some((command, arg))
}
#[tokio::test]
async fn test_bot_message_finds_and_runs_command() {
let temp = TempDir::new().unwrap();
// Create a weather command that echoes the zip code
create_command(
temp.path(),
"weather",
r#"#!/bin/bash
echo "Weather for $1: Sunny, 72°F"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!weather 10096";
// Parse the message
let (command_name, arg) = parse_bot_message(message).unwrap();
assert_eq!(command_name, "weather");
assert_eq!(arg, "10096");
// Find and run the command
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("Weather for 10096"));
assert!(output.contains("Sunny"));
}
#[tokio::test]
async fn test_bot_message_command_not_found() {
let temp = TempDir::new().unwrap();
let cmd_dir = CommandDir::new(temp.path());
let message = "!nonexistent arg";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Not found"));
}
#[tokio::test]
async fn test_bot_message_with_multiple_arguments() {
let temp = TempDir::new().unwrap();
// Create a command that handles multiple words as a single argument
create_command(
temp.path(),
"echo",
r#"#!/bin/bash
echo "You said: $1"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!echo hello world how are you";
let (command_name, arg) = parse_bot_message(message).unwrap();
assert_eq!(command_name, "echo");
assert_eq!(arg, "hello world how are you");
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("hello world how are you"));
}
#[tokio::test]
async fn test_bot_message_without_argument() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"help",
r#"#!/bin/bash
echo "Available commands: weather, echo, help"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!help";
let (command_name, arg) = parse_bot_message(message).unwrap();
assert_eq!(command_name, "help");
assert_eq!(arg, "");
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("Available commands"));
}
#[tokio::test]
async fn test_bot_message_command_returns_error_exit_code() {
let temp = TempDir::new().unwrap();
// Create a command that fails for invalid input
create_command(
temp.path(),
"validate",
r#"#!/bin/bash
if [ -z "$1" ]; then
echo "Error: Input required" >&2
exit 1
fi
echo "Valid: $1"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!validate";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Error running"));
}
#[tokio::test]
async fn test_bot_message_with_timeout() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"quick",
r#"#!/bin/bash
echo "Result: $1"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!quick test";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir
.run_command_with_timeout(command_name, arg, Duration::from_secs(5))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_bot_message_command_times_out() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"slow",
r#"#!/bin/bash
sleep 10
echo "Done"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!slow arg";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir
.run_command_with_timeout(command_name, arg, Duration::from_millis(100))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_multiple_commands_in_directory() {
let temp = TempDir::new().unwrap();
create_command(
temp.path(),
"weather",
r#"#!/bin/bash
echo "Weather: Sunny"
"#,
);
create_command(
temp.path(),
"time",
r#"#!/bin/bash
echo "Time: 12:00"
"#,
);
create_command(
temp.path(),
"joke",
r#"#!/bin/bash
echo "Why did the robot go on vacation? To recharge!"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
// Test each command
let messages = ["!weather", "!time", "!joke"];
let expected = ["Sunny", "12:00", "recharge"];
for (message, expect) in messages.iter().zip(expected.iter()) {
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let bytes = result.unwrap();
let output = String::from_utf8_lossy(&bytes);
assert!(
output.contains(expect),
"Expected '{}' in '{}'",
expect,
output
);
}
}
#[tokio::test]
async fn test_non_bot_message_ignored() {
// Messages not starting with ! should be ignored
let messages = ["hello world", "weather 10096", "?help", "/command", ""];
for message in messages {
assert!(
parse_bot_message(message).is_none(),
"Should ignore: {}",
message
);
}
}
#[tokio::test]
async fn test_command_output_is_bytes() {
let temp = TempDir::new().unwrap();
// Create a command that outputs binary-safe content
create_command(
temp.path(),
"binary",
r#"#!/bin/bash
printf "Hello\x00World"
"#,
);
let cmd_dir = CommandDir::new(temp.path());
let message = "!binary test";
let (command_name, arg) = parse_bot_message(message).unwrap();
let result = cmd_dir.run_command(command_name, arg).await;
assert!(result.is_ok());
let output = result.unwrap();
// Should preserve the null byte
assert_eq!(&output[..], b"Hello\x00World");
}

View File

@@ -486,7 +486,7 @@ async fn test_json_deserialization_of_received_events() {
reader.read_line(&mut line).await.unwrap(); reader.read_line(&mut line).await.unwrap();
// Should be valid JSON // Should be valid JSON
let parsed: serde_json::Value = serde_json::from_str(&line.trim()).unwrap(); let parsed: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(parsed["message"], test_message); assert_eq!(parsed["message"], test_message);

556
tests/setup_test.rs Normal file
View File

@@ -0,0 +1,556 @@
use robotnik::setup::{Args, make_config};
use serial_test::serial;
use std::{fs, path::PathBuf};
use tempfile::TempDir;
/// Helper to create a temporary config file
fn create_config_file(dir: &TempDir, content: &str) -> PathBuf {
let config_path = dir.path().join("config.toml");
fs::write(&config_path, content).unwrap();
config_path
}
/// Helper to parse config using environment and config file
async fn parse_config_from_file(config_path: &PathBuf) -> config::Config {
config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.build()
.unwrap()
}
#[tokio::test]
#[serial]
async fn test_setup_make_config_overrides() {
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"file-key\"
model = \"file-model\"
port = 6667
";
let config_path = create_config_file(&temp, config_content);
// Construct Args with overrides
let args = Args {
api_key: Some("cli-key".to_string()),
base_url: None, /* Should fail if required and not in file/env? No, base-url is optional
* in args */
chroot_dir: None,
command_dir: None,
instruct: None,
model: None, // Should fallback to file
channels: None,
config_file: Some(config_path),
server: None, // Should use default or file? Args has default "irc.libera.chat"
port: Some("9999".to_string()),
nickname: None,
nick_password: None,
username: None,
use_tls: None,
};
let config = make_config(args).expect("Failed to make config");
// Check overrides
assert_eq!(config.get_string("api-key").unwrap(), "cli-key");
assert_eq!(config.get_string("port").unwrap(), "9999");
assert_eq!(config.get_int("port").unwrap(), 9999);
// Check fallback to file
assert_eq!(config.get_string("model").unwrap(), "file-model");
}
#[tokio::test]
async fn test_config_file_loads_all_settings() {
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"test-api-key-123\"
base-url = \"https://api.test.com\"
chroot-dir = \"/test/chroot\"
command-path = \"/test/commands\"
model = \"test-model\"
instruct = \"Test instructions\"
server = \"test.irc.server\"
port = 6667
channels = [\"#test1\", \"#test2\"]
username = \"testuser\"
nickname = \"testnick\"
use-tls = false
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Verify all settings are loaded correctly
assert_eq!(config.get_string("api-key").unwrap(), "test-api-key-123");
assert_eq!(
config.get_string("base-url").unwrap(),
"https://api.test.com"
);
assert_eq!(config.get_string("chroot-dir").unwrap(), "/test/chroot");
assert_eq!(config.get_string("command-path").unwrap(), "/test/commands");
assert_eq!(config.get_string("model").unwrap(), "test-model");
assert_eq!(config.get_string("instruct").unwrap(), "Test instructions");
assert_eq!(config.get_string("server").unwrap(), "test.irc.server");
assert_eq!(config.get_int("port").unwrap(), 6667);
let channels: Vec<String> = config.get("channels").unwrap();
assert_eq!(channels, vec!["#test1", "#test2"]);
assert_eq!(config.get_string("username").unwrap(), "testuser");
assert_eq!(config.get_string("nickname").unwrap(), "testnick");
assert_eq!(config.get_bool("use-tls").unwrap(), false);
}
#[tokio::test]
async fn test_config_file_partial_settings() {
let temp = TempDir::new().unwrap();
// Only provide required settings
let config_content = "\
api-key = \"minimal-key\"
base-url = \"https://minimal.api.com\"
model = \"minimal-model\"
server = \"minimal.server\"
port = 6697
channels = [\"#minimal\"]
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Verify required settings are loaded
assert_eq!(config.get_string("api-key").unwrap(), "minimal-key");
assert_eq!(
config.get_string("base-url").unwrap(),
"https://minimal.api.com"
);
assert_eq!(config.get_string("model").unwrap(), "minimal-model");
// Verify optional settings are not present
assert!(config.get_string("chroot-dir").is_err());
assert!(config.get_string("instruct").is_err());
assert!(config.get_string("username").is_err());
}
#[tokio::test]
#[serial]
async fn test_config_with_environment_variables() {
// NOTE: This test documents a limitation in setup.rs
// setup.rs uses Environment::with_prefix("BOT") without a separator
// This means BOT_API_KEY maps to "api_key", NOT "api-key"
// Since config.toml uses kebab-case, environment variables won't override properly
// This is a known issue in the current implementation
let temp = TempDir::new().unwrap();
let config_content = "\
api_key = \"file-api-key\"
base_url = \"https://file.api.com\"
model = \"file-model\"
";
let config_path = create_config_file(&temp, config_content);
// Set environment variables (with BOT_ prefix as setup.rs uses)
unsafe {
std::env::set_var("BOT_API_KEY", "env-api-key");
std::env::set_var("BOT_MODEL", "env-model");
}
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.add_source(config::Environment::with_prefix("BOT"))
.build()
.unwrap();
// Environment variables should override file settings (when using underscore keys)
assert_eq!(config.get_string("api_key").unwrap(), "env-api-key");
assert_eq!(config.get_string("model").unwrap(), "env-model");
// File setting should be used when no env var
assert_eq!(
config.get_string("base_url").unwrap(),
"https://file.api.com"
);
// Cleanup
unsafe {
std::env::remove_var("BOT_API_KEY");
std::env::remove_var("BOT_MODEL");
}
}
#[tokio::test]
async fn test_command_line_overrides_config_file() {
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"file-api-key\"
base-url = \"https://file.api.com\"
model = \"file-model\"
server = \"file.server\"
port = 6667
channels = [\"#file\"]
nickname = \"filenick\"
username = \"fileuser\"
";
let config_path = create_config_file(&temp, config_content);
// Simulate command-line arguments overriding config file
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.set_override_option("api-key", Some("cli-api-key".to_string()))
.unwrap()
.set_override_option("model", Some("cli-model".to_string()))
.unwrap()
.set_override_option("server", Some("cli.server".to_string()))
.unwrap()
.set_override_option("nickname", Some("clinick".to_string()))
.unwrap()
.build()
.unwrap();
// Command-line values should override file settings
assert_eq!(config.get_string("api-key").unwrap(), "cli-api-key");
assert_eq!(config.get_string("model").unwrap(), "cli-model");
assert_eq!(config.get_string("server").unwrap(), "cli.server");
assert_eq!(config.get_string("nickname").unwrap(), "clinick");
// Non-overridden values should come from file
assert_eq!(
config.get_string("base-url").unwrap(),
"https://file.api.com"
);
assert_eq!(config.get_string("username").unwrap(), "fileuser");
assert_eq!(config.get_int("port").unwrap(), 6667);
}
#[tokio::test]
#[serial]
async fn test_command_line_overrides_environment_and_file() {
let temp = TempDir::new().unwrap();
let config_content = "\
api_key = \"file-api-key\"
model = \"file-model\"
base_url = \"https://file.api.com\"
";
let config_path = create_config_file(&temp, config_content);
// Set environment variable
unsafe {
std::env::set_var("BOT_API_KEY", "env-api-key");
}
// Build config with all three sources
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.add_source(config::Environment::with_prefix("BOT"))
.set_override_option("api_key", Some("cli-api-key".to_string()))
.unwrap()
.build()
.unwrap();
// Command-line should win over both environment and file
assert_eq!(config.get_string("api_key").unwrap(), "cli-api-key");
// Cleanup
unsafe {
std::env::remove_var("BOT_API_KEY");
}
}
#[tokio::test]
#[serial]
async fn test_precedence_order() {
// Test: CLI > Environment > Config File > Defaults
// Using underscore keys to match how setup.rs actually works
let temp = TempDir::new().unwrap();
let config_content = "\
api_key = \"file-key\"
base_url = \"https://file-url.com\"
model = \"file-model\"
server = \"file-server\"
";
let config_path = create_config_file(&temp, config_content);
// Set environment variables
unsafe {
std::env::set_var("BOT_BASE_URL", "https://env-url.com");
std::env::set_var("BOT_MODEL", "env-model");
}
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.add_source(config::Environment::with_prefix("BOT"))
.set_override_option("model", Some("cli-model".to_string()))
.unwrap()
.build()
.unwrap();
// CLI overrides everything
assert_eq!(config.get_string("model").unwrap(), "cli-model");
// Environment overrides file
assert_eq!(
config.get_string("base_url").unwrap(),
"https://env-url.com"
);
// File is used when no env or CLI
assert_eq!(config.get_string("api_key").unwrap(), "file-key");
assert_eq!(config.get_string("server").unwrap(), "file-server");
// Cleanup
unsafe {
std::env::remove_var("BOT_BASE_URL");
std::env::remove_var("BOT_MODEL");
}
}
#[tokio::test]
async fn test_boolean_use_tls_setting() {
let temp = TempDir::new().unwrap();
// Test with use-tls = true (kebab-case as in config.toml)
let config_content_true = r#"
use-tls = true
"#;
let config_path = create_config_file(&temp, config_content_true);
let config = parse_config_from_file(&config_path).await;
assert_eq!(config.get_bool("use-tls").unwrap(), true);
// Test with use-tls = false
let config_content_false = r#"
use-tls = false
"#;
let config_path = create_config_file(&temp, config_content_false);
let config = parse_config_from_file(&config_path).await;
assert_eq!(config.get_bool("use-tls").unwrap(), false);
}
#[tokio::test]
async fn test_use_tls_naming_inconsistency() {
// This test documents a bug: setup.rs uses "use_tls" (underscore)
// but config.toml uses "use-tls" (kebab-case)
let temp = TempDir::new().unwrap();
let config_content = r#"
use-tls = true
"#;
let config_path = create_config_file(&temp, config_content);
// Build config the way setup.rs does it
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
// setup.rs line 119 uses "use_tls" (underscore) instead of "use-tls" (kebab)
.set_override_option("use_tls", Some(false))
.unwrap()
.build()
.unwrap();
// This should read from the override (false), not the file (true)
// But due to the naming mismatch, it might not work as expected
// The config file uses "use-tls" but the override uses "use_tls"
// With kebab-case (matches config.toml)
assert_eq!(config.get_bool("use-tls").unwrap(), true);
// With underscore (matches setup.rs override)
assert_eq!(config.get_bool("use_tls").unwrap(), false);
}
#[tokio::test]
async fn test_channels_as_array() {
let temp = TempDir::new().unwrap();
let config_content = "\
channels = [\"#chan1\", \"#chan2\", \"#chan3\"]
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
let channels: Vec<String> = config.get("channels").unwrap();
assert_eq!(channels.len(), 3);
assert_eq!(channels[0], "#chan1");
assert_eq!(channels[1], "#chan2");
assert_eq!(channels[2], "#chan3");
}
#[tokio::test]
async fn test_channels_override_from_cli() {
let temp = TempDir::new().unwrap();
let config_content = "\
channels = [\"#file1\", \"#file2\"]
";
let config_path = create_config_file(&temp, config_content);
let cli_channels = vec![
"#cli1".to_string(),
"#cli2".to_string(),
"#cli3".to_string(),
];
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.set_override_option("channels", Some(cli_channels.clone()))
.unwrap()
.build()
.unwrap();
let channels: Vec<String> = config.get("channels").unwrap();
assert_eq!(channels, cli_channels);
assert_eq!(channels.len(), 3);
}
#[tokio::test]
async fn test_port_as_integer() {
let temp = TempDir::new().unwrap();
let config_content = r#"
port = 6697
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Port should be readable as both integer and string
assert_eq!(config.get_int("port").unwrap(), 6697);
assert_eq!(config.get_string("port").unwrap(), "6697");
}
#[tokio::test]
async fn test_port_override_from_cli_as_string() {
// setup.rs passes port as Option<String> from clap
let temp = TempDir::new().unwrap();
let config_content = r#"
port = 6667
"#;
let config_path = create_config_file(&temp, config_content);
let config = config::Config::builder()
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
.set_override_option("port", Some("9999".to_string()))
.unwrap()
.build()
.unwrap();
// CLI override should work
assert_eq!(config.get_string("port").unwrap(), "9999");
assert_eq!(config.get_int("port").unwrap(), 9999);
}
#[tokio::test]
async fn test_missing_required_fields_fails() {
let temp = TempDir::new().unwrap();
// Create config without required api-key
let config_content = r#"
model = "test-model"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// Should fail when trying to get required field
assert!(config.get_string("api-key").is_err());
}
#[tokio::test]
async fn test_optional_instruct_field() {
let temp = TempDir::new().unwrap();
let config_content = r#"
instruct = "Custom bot instructions"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
assert_eq!(
config.get_string("instruct").unwrap(),
"Custom bot instructions"
);
}
#[tokio::test]
async fn test_command_path_field() {
// command-path is in config.toml but not used anywhere in the code
let temp = TempDir::new().unwrap();
let config_content = r#"
command-path = "/custom/commands"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
assert_eq!(
config.get_string("command-path").unwrap(),
"/custom/commands"
);
}
#[tokio::test]
async fn test_chroot_dir_field() {
let temp = TempDir::new().unwrap();
let config_content = r#"
chroot-dir = "/var/lib/bot/root"
"#;
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
assert_eq!(
config.get_string("chroot-dir").unwrap(),
"/var/lib/bot/root"
);
}
#[tokio::test]
async fn test_empty_config_file() {
let temp = TempDir::new().unwrap();
let config_content = "";
let config_path = create_config_file(&temp, config_content);
// Should build successfully but have no values
let config = parse_config_from_file(&config_path).await;
assert!(config.get_string("api-key").is_err());
assert!(config.get_string("model").is_err());
}
#[tokio::test]
async fn test_all_cli_override_keys_match_config_format() {
// This test documents which override keys in setup.rs match the config.toml format
let temp = TempDir::new().unwrap();
let config_content = "\
api-key = \"test\"
base-url = \"https://test.com\"
chroot-dir = \"/test\"
command-path = \"/cmds\"
model = \"test-model\"
instruct = \"test\"
channels = [\"#test\"]
server = \"test.server\"
port = 6697
nickname = \"test\"
username = \"test\"
use-tls = true
";
let config_path = create_config_file(&temp, config_content);
let config = parse_config_from_file(&config_path).await;
// All these should work with kebab-case (as in config.toml)
assert!(config.get_string("api-key").is_ok());
assert!(config.get_string("base-url").is_ok());
assert!(config.get_string("chroot-dir").is_ok());
assert!(config.get_string("command-path").is_ok());
assert!(config.get_string("model").is_ok());
assert!(config.get_string("instruct").is_ok());
let channels_result: Result<Vec<String>, _> = config.get("channels");
assert!(channels_result.is_ok());
assert!(config.get_string("server").is_ok());
assert!(config.get_int("port").is_ok());
assert!(config.get_string("nickname").is_ok());
assert!(config.get_string("username").is_ok());
assert!(config.get_bool("use-tls").is_ok());
}