Compare commits
21 Commits
db292c2fd1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3b168f86f
|
||
|
|
b46a03c13e
|
||
|
|
17b087e618
|
||
|
|
30e2d9a448
|
||
|
|
70de039610
|
||
|
8ec4f2860c
|
|||
|
|
21d9c3f002 | ||
|
|
f880795b44
|
||
| a158ee385f | |||
|
|
2da7cc4450
|
||
|
|
3af95235e6
|
||
|
|
5d390ee9f3
|
||
|
|
7f7981d6cd
|
||
|
|
ae190cc421
|
||
|
|
ae44cc947b
|
||
|
|
a3ebca0bb2
|
||
|
|
8fa79932d6
|
||
|
|
9719d9203c
|
||
|
|
138df60661
|
||
|
|
5f30fdbf77
|
||
|
|
b86e46fe00 |
1442
Cargo.lock
generated
1442
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
53
Cargo.toml
53
Cargo.toml
@@ -4,17 +4,58 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# TODO: make this a dev and/or debug dependency later.
|
|
||||||
better-panic = "0.3.0"
|
better-panic = "0.3.0"
|
||||||
clap = { version = "4.5", features = [ "derive" ] }
|
bytes = "1"
|
||||||
color-eyre = "0.6.3"
|
color-eyre = "0.6.3"
|
||||||
config = { version = "0.15", features = [ "toml" ] }
|
|
||||||
directories = "6.0"
|
directories = "6.0"
|
||||||
dotenvy_macro = "0.15"
|
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
human-panic = "2.0"
|
human-panic = "2.0"
|
||||||
genai = "0.4.0-alpha.9"
|
genai = "0.4.3"
|
||||||
irc = "1.1"
|
irc = "1.1"
|
||||||
tokio = { version = "1", features = [ "full" ] }
|
serde_json = "1.0"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
|
|
||||||
|
[dependencies.nix]
|
||||||
|
version = "0.30.1"
|
||||||
|
features = ["fs", "resource"]
|
||||||
|
|
||||||
|
[dependencies.clap]
|
||||||
|
version = "4.5"
|
||||||
|
features = ["derive"]
|
||||||
|
|
||||||
|
[dependencies.config]
|
||||||
|
version = "0.15"
|
||||||
|
features = ["toml"]
|
||||||
|
|
||||||
|
[dependencies.serde]
|
||||||
|
version = "1.0"
|
||||||
|
features = ["derive"]
|
||||||
|
|
||||||
|
[dependencies.tokio]
|
||||||
|
version = "1"
|
||||||
|
features = [
|
||||||
|
"io-util",
|
||||||
|
"macros",
|
||||||
|
"net",
|
||||||
|
"process",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"sync",
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rstest = "0.24"
|
||||||
|
serial_test = "3.2"
|
||||||
|
tempfile = "3.13"
|
||||||
|
|
||||||
|
[dev-dependencies.cargo-husky]
|
||||||
|
version = "1"
|
||||||
|
features = ["run-cargo-check", "run-cargo-clippy"]
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
strip = true
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
codegen-units = 1
|
||||||
|
panic = "abort"
|
||||||
|
|||||||
11
LICENSE
Normal file
11
LICENSE
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
Copyright 2025 Micheal Smith
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Redistributions in any form must retain this license verbatim. No additional licensing terms, including but not limited to the GNU General Public License, may be imposed on the original or modified work.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
22
README.md
Normal file
22
README.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Robotnik - A Basic LLM Capable IRC Bot.
|
||||||
|
|
||||||
|
This is an IRC bot that. The name is based on a fictional video game villain.
|
||||||
|
Currently it supports any LLM that uses the OpenAI style of interface. They
|
||||||
|
can be selected via command line options, environment variables, or via a configuration
|
||||||
|
file. There is a [configuration file](config.toml) that *should* contain all available options
|
||||||
|
currently.
|
||||||
|
|
||||||
|
## Some supported but ~~possibly~~ *mostly* untested LLMs:
|
||||||
|
|
||||||
|
| Name | Model | Base URL | Tested |
|
||||||
|
|------------|-------------------|-------------------------------------------|---------|
|
||||||
|
| OpenAI | gpt-5 | https://api.openai.com/v1 | no |
|
||||||
|
| Deepseek | deepseek-chat | https://api.deepseek.com/v1 | yes |
|
||||||
|
| Anthropic | claude-sonnet-4-0 | https://api.anthropic.com/v1 | no |
|
||||||
|
| Gemini | gemini-2.5-turbo | https://generativelanguage.googleapis.com | no |
|
||||||
|
| OpenRouter | some-model | https://api.openrouter.ai/v1 | no |
|
||||||
|
|
||||||
|
## Further reading...
|
||||||
|
|
||||||
|
There should be a man page that might be useful. Otherwise the -h/--help
|
||||||
|
switch should hopefully suffice.
|
||||||
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
api-key = "<YOUR-KEY>"
|
api-key = "<YOUR-KEY>"
|
||||||
base-url = "api.openai.com"
|
base-url = "api.openai.com"
|
||||||
|
chroot-dir = "/home/bot/root"
|
||||||
|
|
||||||
|
# If using chroot (recommended) then this will be relative.
|
||||||
|
command-path = "/cmds"
|
||||||
|
|
||||||
# If you don't already know the model name you can generally find a listing
|
# If you don't already know the model name you can generally find a listing
|
||||||
# on the models API pages.
|
# on the models API pages.
|
||||||
|
|||||||
42
robotnik.1
Normal file
42
robotnik.1
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
.Dd $Mdocdate$
|
||||||
|
.Dt robotnik 1
|
||||||
|
.Os
|
||||||
|
.Sh NAME
|
||||||
|
.Nm robotnik
|
||||||
|
.Nd A simple bot that among other things uses the OpenAI API.
|
||||||
|
.\" .Sh LIBRARY
|
||||||
|
.\" For sections 2, 3, and 9 only.
|
||||||
|
.\" Not used in OpenBSD.
|
||||||
|
.Sh SYNOPSIS
|
||||||
|
.Nm progname
|
||||||
|
.Op Fl options
|
||||||
|
.Ar
|
||||||
|
.Sh DESCRIPTION
|
||||||
|
The
|
||||||
|
.Nm
|
||||||
|
utility processes files ...
|
||||||
|
.\" .Sh CONTEXT
|
||||||
|
.\" For section 9 functions only.
|
||||||
|
.\" .Sh IMPLEMENTATION NOTES
|
||||||
|
.\" Not used in OpenBSD.
|
||||||
|
.\" .Sh RETURN VALUES
|
||||||
|
.\" For sections 2, 3, and 9 function return values only.
|
||||||
|
.\" .Sh ENVIRONMENT
|
||||||
|
.\" For sections 1, 6, 7, and 8 only.
|
||||||
|
.\" .Sh FILES
|
||||||
|
.\" .Sh EXIT STATUS
|
||||||
|
.\" For sections 1, 6, and 8 only.
|
||||||
|
.\" .Sh EXAMPLES
|
||||||
|
.\" .Sh DIAGNOSTICS
|
||||||
|
.\" For sections 1, 4, 6, 7, 8, and 9 printf/stderr messages only.
|
||||||
|
.\" .Sh ERRORS
|
||||||
|
.\" For sections 2, 3, 4, and 9 errno settings only.
|
||||||
|
.\" .Sh SEE ALSO
|
||||||
|
.\" .Xr foobar 1
|
||||||
|
.\" .Sh STANDARDS
|
||||||
|
.\" .Sh HISTORY
|
||||||
|
.\" .Sh AUTHORS
|
||||||
|
.\" .Sh CAVEATS
|
||||||
|
.\" .Sh BUGS
|
||||||
|
.\" .Sh SECURITY CONSIDERATIONS
|
||||||
|
.\" Not used in OpenBSD.
|
||||||
@@ -2,6 +2,8 @@ edition = "2024"
|
|||||||
style_edition = "2024"
|
style_edition = "2024"
|
||||||
comment_width = 100
|
comment_width = 100
|
||||||
format_code_in_doc_comments = true
|
format_code_in_doc_comments = true
|
||||||
|
format_macro_bodies = true
|
||||||
|
format_macro_matchers = true
|
||||||
imports_granularity = "Crate"
|
imports_granularity = "Crate"
|
||||||
imports_layout = "Vertical"
|
imports_layout = "HorizontalVertical"
|
||||||
wrap_comments = true
|
wrap_comments = true
|
||||||
|
|||||||
183
src/chat.rs
183
src/chat.rs
@@ -1,89 +1,134 @@
|
|||||||
use color_eyre::{
|
//! Handles interaction with IRC.
|
||||||
Result,
|
//!
|
||||||
eyre::{
|
//! Each instance of [`Chat`] handles a single connection to an IRC
|
||||||
OptionExt,
|
//! server.
|
||||||
WrapErr,
|
|
||||||
},
|
use std::sync::Arc;
|
||||||
};
|
|
||||||
// Lots of namespace confusion potential
|
use color_eyre::{Result, eyre::WrapErr};
|
||||||
use crate::qna::LLMHandle;
|
|
||||||
use config::Config as MainConfig;
|
use config::Config as MainConfig;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use irc::client::prelude::{
|
use irc::client::prelude::{Client, Command, Config as IRCConfig, Message};
|
||||||
Client as IRCClient,
|
use tokio::sync::mpsc;
|
||||||
Command,
|
use tracing::{Level, event, instrument};
|
||||||
Config as IRCConfig,
|
|
||||||
};
|
|
||||||
use tracing::{
|
|
||||||
Level,
|
|
||||||
event,
|
|
||||||
instrument,
|
|
||||||
};
|
|
||||||
|
|
||||||
|
use crate::{Event, EventManager, LLMHandle, plugin};
|
||||||
|
|
||||||
|
/// Chat struct that is used to interact with IRC chat.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Chat {
|
pub struct Chat {
|
||||||
client: IRCClient,
|
/// The actual IRC [`irc::client`](client).
|
||||||
|
client: Client,
|
||||||
|
/// Event manager for handling plugin interaction.
|
||||||
|
event_manager: Arc<EventManager>,
|
||||||
|
/// Handle for whichever LLM is being used.
|
||||||
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
|
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need: owners, channels, username, nick, server, password
|
|
||||||
#[instrument]
|
|
||||||
pub async fn new(settings: &MainConfig, handle: &LLMHandle) -> Result<Chat> {
|
|
||||||
// Going to just assign and let the irc library handle errors for now, and
|
|
||||||
// add my own checking if necessary.
|
|
||||||
let port: u16 = settings.get("port")?;
|
|
||||||
let channels: Vec<String> = settings.get("channels")
|
|
||||||
.wrap_err("No channels provided.")?;
|
|
||||||
|
|
||||||
event!(Level::INFO, "Channels = {:?}", channels);
|
|
||||||
|
|
||||||
let config = IRCConfig {
|
|
||||||
server: settings.get_string("server").ok(),
|
|
||||||
nickname: settings.get_string("nickname").ok(),
|
|
||||||
port: Some(port),
|
|
||||||
username: settings.get_string("username").ok(),
|
|
||||||
use_tls: settings.get_bool("use_tls").ok(),
|
|
||||||
channels,
|
|
||||||
..IRCConfig::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
event!(Level::INFO, "IRC connection starting...");
|
|
||||||
|
|
||||||
Ok(Chat {
|
|
||||||
client: IRCClient::from_config(config).await?,
|
|
||||||
llm_handle: handle.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Chat {
|
impl Chat {
|
||||||
pub async fn run(&mut self) -> Result<()> {
|
// Need: owners, channels, username, nick, server, password rather than reading
|
||||||
let client = &mut self.client;
|
// the config values directly.
|
||||||
|
/// Creates a new [`Chat`].
|
||||||
|
#[instrument]
|
||||||
|
pub async fn new(
|
||||||
|
settings: &MainConfig,
|
||||||
|
handle: &LLMHandle,
|
||||||
|
manager: Arc<EventManager>,
|
||||||
|
) -> Result<Chat> {
|
||||||
|
// Going to just assign and let the irc library handle errors for now, and
|
||||||
|
// add my own checking if necessary.
|
||||||
|
let port: u16 = settings.get("port")?;
|
||||||
|
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
|
||||||
|
|
||||||
client.identify()?;
|
event!(Level::INFO, "Channels = {:?}", channels);
|
||||||
|
|
||||||
let outgoing = client
|
let config = IRCConfig {
|
||||||
.outgoing()
|
server: settings.get_string("server").ok(),
|
||||||
.ok_or_eyre("Couldn't get outgoing irc sink.")?;
|
nickname: settings.get_string("nickname").ok(),
|
||||||
let mut stream = client.stream()?;
|
port: Some(port),
|
||||||
|
username: settings.get_string("username").ok(),
|
||||||
|
use_tls: settings.get_bool("use_tls").ok(),
|
||||||
|
channels,
|
||||||
|
..IRCConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
tokio::spawn(async move {
|
event!(Level::INFO, "IRC connection starting...");
|
||||||
if let Err(e) = outgoing.await {
|
|
||||||
event!(Level::ERROR, "Failed to drive output: {}", e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
while let Some(message) = stream.next().await.transpose()? {
|
Ok(Chat {
|
||||||
if let Command::PRIVMSG(channel, message) = message.command {
|
client: Client::from_config(config).await?,
|
||||||
if message.starts_with("!gem") {
|
llm_handle: handle.clone(),
|
||||||
let msg = self.llm_handle.send_request(message).await?;
|
event_manager: manager,
|
||||||
event!(Level::INFO, "Message received.");
|
})
|
||||||
client
|
}
|
||||||
.send_privmsg(channel, msg)
|
|
||||||
.wrap_err("Couldn't send response to channel.")?;
|
/// Drives the event loop for the chat.
|
||||||
|
pub async fn run(&mut self, mut command_in: mpsc::Receiver<plugin::PluginMsg>) -> Result<()> {
|
||||||
|
self.client.identify()?;
|
||||||
|
|
||||||
|
let mut stream = self.client.stream()?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
message = stream.next() => {
|
||||||
|
match message {
|
||||||
|
Some(Ok(msg)) => {
|
||||||
|
self.handle_chat_message(&msg).await?;
|
||||||
|
}
|
||||||
|
Some(Err(e)) => return Err(e.into()),
|
||||||
|
None => break, // disconnected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
command = command_in.recv() => {
|
||||||
|
event!(Level::INFO, "Received command {:#?}", command);
|
||||||
|
match command {
|
||||||
|
Some(plugin::PluginMsg::SendMessage {channel, message} ) => {
|
||||||
|
// Now to pass on the message.
|
||||||
|
event!(Level::INFO, "Trying to send to channel.");
|
||||||
|
self.client.send_privmsg(&channel, &message).wrap_err("Couldn't send to channel")?;
|
||||||
|
event!(Level::INFO, "Message sent successfully.");
|
||||||
|
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
event!(Level::ERROR,
|
||||||
|
"Command channel unexpectedly closed - \
|
||||||
|
FIFO reader may have crashed");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_chat_message(&mut self, message: &Message) -> Result<()> {
|
||||||
|
// Broadcast anything here. If it should not be broadcasted then
|
||||||
|
// TryFrom should fail.
|
||||||
|
if let Ok(event) = Event::try_from(message) {
|
||||||
|
self.event_manager.broadcast(&event).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle PRIVMSG for now.
|
||||||
|
if let Command::PRIVMSG(channel, msg) = &message.command {
|
||||||
|
// Just preserve the original behavior for now.
|
||||||
|
if msg.starts_with("!gem") {
|
||||||
|
let mut llm_response = self.llm_handle.send_request(msg).await?;
|
||||||
|
|
||||||
|
event!(Level::INFO, "Asked: {message}");
|
||||||
|
event!(Level::INFO, "Response: {llm_response}");
|
||||||
|
|
||||||
|
// Keep responses to one line for now.
|
||||||
|
llm_response.retain(|c| c != '\n' && c != '\r');
|
||||||
|
|
||||||
|
// TODO: Make this configurable.
|
||||||
|
llm_response.truncate(500);
|
||||||
|
|
||||||
|
event!(Level::INFO, "Sending {llm_response} to channel {channel}");
|
||||||
|
self.client.send_privmsg(channel, llm_response)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
193
src/command.rs
193
src/command.rs
@@ -0,0 +1,193 @@
|
|||||||
|
//! Commands that are associated with external processes (commands).
|
||||||
|
//!
|
||||||
|
//! Process based plugins are just an assortment of executable files in
|
||||||
|
//! a provided directory. They are given arguments, and the response from
|
||||||
|
//! them is expected on stdout.
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use color_eyre::{Result, eyre::eyre};
|
||||||
|
use tokio::{fs::try_exists, process::Command, time::timeout};
|
||||||
|
use tracing::{Level, event};
|
||||||
|
|
||||||
|
/// Handle containing information about the directory containing commands.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CommandDir {
|
||||||
|
command_path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CommandDir {
|
||||||
|
/// Register a path containing commands.
|
||||||
|
pub fn new(command_path: impl AsRef<Path>) -> Self {
|
||||||
|
event!(
|
||||||
|
Level::INFO,
|
||||||
|
"CommandDir initialized with path: {:?}",
|
||||||
|
command_path.as_ref()
|
||||||
|
);
|
||||||
|
CommandDir {
|
||||||
|
command_path: command_path.as_ref().to_path_buf(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look for a command. If it exists Ok(path) is returned.
|
||||||
|
async fn find_command(&self, name: impl AsRef<Path>) -> Result<String> {
|
||||||
|
let path = self.command_path.join(name.as_ref());
|
||||||
|
|
||||||
|
event!(
|
||||||
|
Level::INFO,
|
||||||
|
"Looking for {} command.",
|
||||||
|
name.as_ref().display()
|
||||||
|
);
|
||||||
|
|
||||||
|
match try_exists(&path).await {
|
||||||
|
Ok(true) => Ok(path.to_string_lossy().to_string()),
|
||||||
|
Ok(false) => Err(eyre!(format!("{} Not found.", path.to_string_lossy()))),
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the given [`command_name`]. It should exist in the directory provided as
|
||||||
|
/// the command_path.
|
||||||
|
pub async fn run_command(
|
||||||
|
&self,
|
||||||
|
command_name: impl AsRef<str>,
|
||||||
|
input: impl AsRef<str>,
|
||||||
|
) -> Result<Bytes> {
|
||||||
|
let path = self.find_command(Path::new(command_name.as_ref())).await?;
|
||||||
|
// Well it exists let's cross our fingers...
|
||||||
|
let output = Command::new(path).arg(input.as_ref()).output().await?;
|
||||||
|
|
||||||
|
if output.status.success() {
|
||||||
|
// So far so good
|
||||||
|
Ok(Bytes::from(output.stdout))
|
||||||
|
} else {
|
||||||
|
// Whoops
|
||||||
|
Err(eyre!(format!(
|
||||||
|
"Error running {}: {}",
|
||||||
|
command_name.as_ref(),
|
||||||
|
output.status
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [`run_command`] but with a timeout.
|
||||||
|
pub async fn run_command_with_timeout(
|
||||||
|
&self,
|
||||||
|
command_name: impl AsRef<str>,
|
||||||
|
input: impl AsRef<str>,
|
||||||
|
time_out: Duration,
|
||||||
|
) -> Result<Bytes> {
|
||||||
|
timeout(time_out, self.run_command(command_name, input)).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::{
|
||||||
|
fs::{self, Permissions},
|
||||||
|
os::unix::fs::PermissionsExt,
|
||||||
|
};
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn create_test_script(dir: &Path, name: &str, script: &str) -> PathBuf {
|
||||||
|
let path = dir.join(name);
|
||||||
|
fs::write(&path, script).unwrap();
|
||||||
|
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
|
||||||
|
path
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_command_dir_new() {
|
||||||
|
let dir = CommandDir::new("/some/path");
|
||||||
|
assert_eq!(dir.command_path, PathBuf::from("/some/path"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_find_command_exists() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "test_cmd", "#!/bin/bash\necho hello");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir.find_command("test_cmd").await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert!(result.unwrap().contains("test_cmd"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_find_command_not_found() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
let result = cmd_dir.find_command("nonexistent").await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_success() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "echo_cmd", "#!/bin/bash\necho \"$1\"");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir.run_command("echo_cmd", "hello world").await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let output = result.unwrap();
|
||||||
|
assert_eq!(output.as_ref(), b"hello world\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_failure() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "fail_cmd", "#!/bin/bash\nexit 1");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir.run_command("fail_cmd", "input").await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Error running"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_not_found() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command("missing", "input").await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_with_timeout_success() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "fast_cmd", "#!/bin/bash\necho \"$1\"");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout("fast_cmd", "quick", Duration::from_secs(5))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_with_timeout_expires() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "slow_cmd", "#!/bin/bash\nsleep 10\necho done");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout("slow_cmd", "input", Duration::from_millis(100))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
38
src/event.rs
Normal file
38
src/event.rs
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
//! Internal representations of incoming events.
|
||||||
|
|
||||||
|
use irc::proto::{Command, Message};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Represents an event. Probably from IRC.
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
pub struct Event {
|
||||||
|
/// Who is the message from?
|
||||||
|
from: String,
|
||||||
|
/// What is the message?
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Event {
|
||||||
|
/// Creates a new message.
|
||||||
|
pub fn new(from: impl Into<String>, msg: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
from: from.into(),
|
||||||
|
message: msg.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&Message> for Event {
|
||||||
|
type Error = &'static str;
|
||||||
|
|
||||||
|
fn try_from(value: &Message) -> Result<Self, Self::Error> {
|
||||||
|
let from = value.response_target().unwrap_or("unknown").to_string();
|
||||||
|
match &value.command {
|
||||||
|
Command::PRIVMSG(_channel, message) => Ok(Event {
|
||||||
|
from,
|
||||||
|
message: message.clone(),
|
||||||
|
}),
|
||||||
|
_ => Err("Not a PRIVMSG"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
587
src/event_manager.rs
Normal file
587
src/event_manager.rs
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
//! Handler for events to and from IPC, and process plugins.
|
||||||
|
|
||||||
|
use std::{collections::VecDeque, path::Path, sync::Arc};
|
||||||
|
|
||||||
|
use color_eyre::Result;
|
||||||
|
use nix::{NixPath, sys::stat, unistd::mkfifo};
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||||
|
net::{UnixListener, UnixStream, unix::pipe},
|
||||||
|
sync::{RwLock, broadcast, mpsc},
|
||||||
|
};
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
use crate::{event::Event, plugin::PluginMsg};
|
||||||
|
|
||||||
|
// Hard coding for now. Maybe make this a parameter to new.
|
||||||
|
const EVENT_BUF_MAX: usize = 1000;
|
||||||
|
|
||||||
|
/// Manager for communication with plugins.
|
||||||
|
///
|
||||||
|
/// Keeps events in a ring buffer to track a certain amount of history.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EventManager {
|
||||||
|
announce: broadcast::Sender<String>, // Everything broadcasts here.
|
||||||
|
events: Arc<RwLock<VecDeque<String>>>, // Ring buffer.
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventManager {
|
||||||
|
/// Create a new [`EventManager``].
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let (announce, _) = broadcast::channel(100);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
announce,
|
||||||
|
events: Arc::new(RwLock::new(VecDeque::<String>::new())),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Broadcast an event to every subscribed listener.
|
||||||
|
pub async fn broadcast(&self, event: &Event) -> Result<()> {
|
||||||
|
let msg = serde_json::to_string(event)? + "\n";
|
||||||
|
|
||||||
|
let mut events = self.events.write().await;
|
||||||
|
|
||||||
|
if events.len() >= EVENT_BUF_MAX {
|
||||||
|
events.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back(msg.clone());
|
||||||
|
drop(events);
|
||||||
|
|
||||||
|
let _ = self.announce.send(msg);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NB: This assumes it has exclusive control of the FIFO.
|
||||||
|
/// Opens a fifo at [`path`]. This is where some plugins can send response events
|
||||||
|
/// to. The messages MUST be formatted in JSON and match one of the possible
|
||||||
|
/// [`PluginMsg`](plugin messages).
|
||||||
|
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<PluginMsg>) -> Result<()>
|
||||||
|
where
|
||||||
|
P: AsRef<Path> + NixPath + ?Sized,
|
||||||
|
{
|
||||||
|
// Overwrite, or create the FIFO.
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
mkfifo(path, stat::Mode::S_IRWXU)?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let rx = pipe::OpenOptions::new().open_receiver(path)?;
|
||||||
|
|
||||||
|
let mut reader = BufReader::new(rx);
|
||||||
|
let mut line = String::new();
|
||||||
|
|
||||||
|
while reader.read_line(&mut line).await? > 0 {
|
||||||
|
// Now handle the command.
|
||||||
|
let cmd: PluginMsg = serde_json::from_str(&line)?;
|
||||||
|
info!("Command received: {:?}.", cmd);
|
||||||
|
command_tx.send(cmd).await?;
|
||||||
|
line.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start a UNIX socket that will provide broadcast messages to any client that opens
|
||||||
|
/// the socket for listening.
|
||||||
|
pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) {
|
||||||
|
let listener = UnixListener::bind(broadcast_path).unwrap();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match listener.accept().await {
|
||||||
|
Ok((stream, _addr)) => {
|
||||||
|
info!("New broadcast subscriber");
|
||||||
|
// Spawn a new stream for the plugin. The loop
|
||||||
|
// runs recursively from there.
|
||||||
|
let broadcaster = Arc::clone(&self);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// send events.
|
||||||
|
let _ = broadcaster.send_events(stream).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => error!("Accept error: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send any events queued up to the [`stream`].
|
||||||
|
async fn send_events(&self, stream: UnixStream) -> Result<()> {
|
||||||
|
let mut writer = stream;
|
||||||
|
|
||||||
|
// Take care of history.
|
||||||
|
let events = self.events.read().await;
|
||||||
|
for event in events.iter() {
|
||||||
|
writer.write_all(event.as_bytes()).await?;
|
||||||
|
}
|
||||||
|
drop(events);
|
||||||
|
|
||||||
|
// Now just broadcast the new events.
|
||||||
|
let mut rx = self.announce.subscribe();
|
||||||
|
while let Ok(event) = rx.recv().await {
|
||||||
|
if writer.write_all(event.as_bytes()).await.is_err() {
|
||||||
|
// *click*
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rstest::rstest;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_new_event_manager_has_empty_buffer() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_broadcast_adds_event_to_buffer() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let event = Event::new("test_user", "test message");
|
||||||
|
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
assert!(events[0].contains("test message"));
|
||||||
|
assert!(events[0].ends_with('\n'));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_broadcast_serializes_event_as_json() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let event = Event::new("test_user", "hello world");
|
||||||
|
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
let stored = &events[0];
|
||||||
|
|
||||||
|
// Should be valid JSON
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(stored.trim()).unwrap();
|
||||||
|
assert_eq!(parsed["message"], "hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case(1)]
|
||||||
|
#[case(10)]
|
||||||
|
#[case(100)]
|
||||||
|
#[case(999)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_holds_events_below_max(#[case] count: usize) {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
|
||||||
|
for i in 0..count {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), count);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_at_exactly_max_capacity() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
|
||||||
|
// Fill to exactly EVENT_BUF_MAX (1000)
|
||||||
|
for i in 0..EVENT_BUF_MAX {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||||
|
assert!(events[0].contains("event 0"));
|
||||||
|
assert!(events[EVENT_BUF_MAX - 1].contains("event 999"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case(1)]
|
||||||
|
#[case(10)]
|
||||||
|
#[case(100)]
|
||||||
|
#[case(500)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_overflow_evicts_oldest_fifo(#[case] overflow: usize) {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let total = EVENT_BUF_MAX + overflow;
|
||||||
|
|
||||||
|
// Broadcast more events than buffer can hold
|
||||||
|
for i in 0..total {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
|
||||||
|
// Buffer should still be at max capacity
|
||||||
|
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||||
|
|
||||||
|
// Oldest events (0 through overflow-1) should be evicted
|
||||||
|
// Buffer should contain events [overflow..total)
|
||||||
|
let first_event = &events[0];
|
||||||
|
let last_event = &events[EVENT_BUF_MAX - 1];
|
||||||
|
|
||||||
|
assert!(first_event.contains(&format!("event {}", overflow)));
|
||||||
|
assert!(last_event.contains(&format!("event {}", total - 1)));
|
||||||
|
|
||||||
|
// Verify the evicted events are NOT in the buffer
|
||||||
|
let buffer_string = events.iter().cloned().collect::<String>();
|
||||||
|
assert!(!buffer_string.contains(r#""message":"event 0""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_broadcasts_maintain_order() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let messages = vec!["first", "second", "third", "fourth", "fifth"];
|
||||||
|
|
||||||
|
for msg in &messages {
|
||||||
|
let event = Event::new("test_user", *msg);
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), messages.len());
|
||||||
|
|
||||||
|
for (i, expected) in messages.iter().enumerate() {
|
||||||
|
assert!(events[i].contains(expected));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_wraparound_maintains_newest_events() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
|
||||||
|
// Fill buffer completely
|
||||||
|
for i in 0..EVENT_BUF_MAX {
|
||||||
|
let event = Event::new("test_user", format!("old {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 5 more events
|
||||||
|
for i in 0..5 {
|
||||||
|
let event = Event::new("test_user", format!("new {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||||
|
|
||||||
|
// First 5 old events should be gone
|
||||||
|
let buffer_string = events.iter().cloned().collect::<String>();
|
||||||
|
assert!(!buffer_string.contains(r#""message":"old 0""#));
|
||||||
|
assert!(!buffer_string.contains(r#""message":"old 4""#));
|
||||||
|
|
||||||
|
// But old 5 should still be there (now at the front)
|
||||||
|
assert!(events[0].contains("old 5"));
|
||||||
|
|
||||||
|
// New events should be at the end
|
||||||
|
assert!(events[EVENT_BUF_MAX - 5].contains("new 0"));
|
||||||
|
assert!(events[EVENT_BUF_MAX - 1].contains("new 4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_concurrent_broadcasts_all_stored() {
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
// Spawn 10 concurrent tasks, each broadcasting 10 events
|
||||||
|
for task_id in 0..10 {
|
||||||
|
let manager_clone = Arc::clone(&manager);
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
for i in 0..10 {
|
||||||
|
let event = Event::new("test_user", format!("task {} event {}", task_id, i));
|
||||||
|
manager_clone.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all tasks to complete
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_receives_and_forwards_single_command() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = EventManager::start_fifo(&path, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Write a command to the FIFO
|
||||||
|
let cmd = PluginMsg::SendMessage {
|
||||||
|
channel: "#test".to_string(),
|
||||||
|
message: "hello".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
|
||||||
|
// Open FIFO for writing and write the command
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should receive the command within a reasonable time
|
||||||
|
let received = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout waiting for command")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match received {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#test");
|
||||||
|
assert_eq!(message, "hello");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_handles_multiple_commands() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = EventManager::start_fifo(&path, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Write multiple commands
|
||||||
|
let commands = vec![
|
||||||
|
PluginMsg::SendMessage {
|
||||||
|
channel: "#chan1".to_string(),
|
||||||
|
message: "first".to_string(),
|
||||||
|
},
|
||||||
|
PluginMsg::SendMessage {
|
||||||
|
channel: "#chan2".to_string(),
|
||||||
|
message: "second".to_string(),
|
||||||
|
},
|
||||||
|
PluginMsg::SendMessage {
|
||||||
|
channel: "#chan3".to_string(),
|
||||||
|
message: "third".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
for cmd in commands {
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
}
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Receive all three commands in order
|
||||||
|
let first = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on first")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match first {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#chan1");
|
||||||
|
assert_eq!(message, "first");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let second = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on second")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match second {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#chan2");
|
||||||
|
assert_eq!(message, "second");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let third = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on third")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match third {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#chan3");
|
||||||
|
assert_eq!(message, "third");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_reopens_after_writer_closes() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = EventManager::start_fifo(&path, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// First writer sends a command and closes
|
||||||
|
{
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
let cmd = PluginMsg::SendMessage {
|
||||||
|
channel: "#first".to_string(),
|
||||||
|
message: "batch1".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
// Writer drops here, closing the FIFO
|
||||||
|
});
|
||||||
|
|
||||||
|
let first = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on first batch")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match first {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#first");
|
||||||
|
assert_eq!(message, "batch1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give the FIFO time to reopen
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
|
// Second writer opens and sends a command
|
||||||
|
{
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
let cmd = PluginMsg::SendMessage {
|
||||||
|
channel: "#second".to_string(),
|
||||||
|
message: "batch2".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let second = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on second batch - FIFO may not have reopened")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match second {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#second");
|
||||||
|
assert_eq!(message, "batch2");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_handles_empty_lines() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
let handle = tokio::spawn(async move { EventManager::start_fifo(&path, tx).await });
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Write command, empty line, whitespace, another command
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
let cmd1 = PluginMsg::SendMessage {
|
||||||
|
channel: "#test".to_string(),
|
||||||
|
message: "first".to_string(),
|
||||||
|
};
|
||||||
|
let json1 = serde_json::to_string(&cmd1).unwrap() + "\n";
|
||||||
|
tx.write_all(json1.as_bytes()).await.unwrap();
|
||||||
|
|
||||||
|
// Write empty line
|
||||||
|
tx.write_all(b"\n").await.unwrap();
|
||||||
|
|
||||||
|
// Write whitespace line
|
||||||
|
tx.write_all(b" \n").await.unwrap();
|
||||||
|
|
||||||
|
let cmd2 = PluginMsg::SendMessage {
|
||||||
|
channel: "#test".to_string(),
|
||||||
|
message: "second".to_string(),
|
||||||
|
};
|
||||||
|
let json2 = serde_json::to_string(&cmd2).unwrap() + "\n";
|
||||||
|
tx.write_all(json2.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should receive first command
|
||||||
|
let first = tokio::time::timeout(tokio::time::Duration::from_millis(500), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on first")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match first {
|
||||||
|
PluginMsg::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#test");
|
||||||
|
assert_eq!(message, "first");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The empty/whitespace lines should cause JSON parse errors
|
||||||
|
// which will cause start_fifo to error and exit
|
||||||
|
// So we expect the handle to complete (with an error)
|
||||||
|
let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
|
||||||
|
.await
|
||||||
|
.expect("FIFO task should exit due to parse error");
|
||||||
|
|
||||||
|
// The task should have errored
|
||||||
|
assert!(
|
||||||
|
result.unwrap().is_err(),
|
||||||
|
"Expected parse error from empty line"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
100
src/lib.rs
Normal file
100
src/lib.rs
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
#![warn(missing_docs)]
|
||||||
|
#![doc = include_str!("../README.md")]
|
||||||
|
|
||||||
|
use std::{os::unix::fs, sync::Arc};
|
||||||
|
|
||||||
|
use color_eyre::{Result, eyre::WrapErr};
|
||||||
|
use human_panic::setup_panic;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tracing::{Level, info};
|
||||||
|
use tracing_subscriber::FmtSubscriber;
|
||||||
|
|
||||||
|
pub mod chat;
|
||||||
|
pub mod command;
|
||||||
|
pub mod event;
|
||||||
|
pub mod event_manager;
|
||||||
|
pub mod plugin;
|
||||||
|
pub mod qna;
|
||||||
|
pub mod setup;
|
||||||
|
|
||||||
|
pub use chat::Chat;
|
||||||
|
pub use event::Event;
|
||||||
|
pub use event_manager::EventManager;
|
||||||
|
pub use qna::LLMHandle;
|
||||||
|
|
||||||
|
const DEFAULT_INSTRUCT: &str =
|
||||||
|
"You are a shady, yet helpful IRC bot. You try to give responses that can
|
||||||
|
be sent in a single IRC response according to the specification. Keep answers to
|
||||||
|
500 characters or less.";
|
||||||
|
|
||||||
|
/// Initialize all logging facilities.
|
||||||
|
///
|
||||||
|
/// This should cause a panic if there's a failure.
|
||||||
|
async fn init_logging() {
|
||||||
|
better_panic::install();
|
||||||
|
setup_panic!();
|
||||||
|
|
||||||
|
let subscriber = FmtSubscriber::builder()
|
||||||
|
.with_max_level(Level::TRACE)
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets up and runs the main event loop.
|
||||||
|
///
|
||||||
|
/// Should return an error if it's recoverable, but could panic if something
|
||||||
|
/// is particularly bad.
|
||||||
|
pub async fn run() -> Result<()> {
|
||||||
|
init_logging().await;
|
||||||
|
info!("Starting up.");
|
||||||
|
|
||||||
|
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
|
||||||
|
let config = settings.config;
|
||||||
|
|
||||||
|
// NOTE: Doing chroot this way might be impractical.
|
||||||
|
if let Ok(chroot_path) = config.get_string("chroot-dir") {
|
||||||
|
info!("Attempting to chroot to {}", chroot_path);
|
||||||
|
fs::chroot(&chroot_path)
|
||||||
|
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path))?;
|
||||||
|
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let handle = qna::LLMHandle::new(
|
||||||
|
config.get_string("api-key").wrap_err("API missing.")?,
|
||||||
|
config
|
||||||
|
.get_string("base-url")
|
||||||
|
.wrap_err("base-url missing.")?,
|
||||||
|
config
|
||||||
|
.get_string("model")
|
||||||
|
.wrap_err("model string missing.")?,
|
||||||
|
config
|
||||||
|
.get_string("instruct")
|
||||||
|
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
|
||||||
|
)
|
||||||
|
.wrap_err("Couldn't initialize LLM handle.")?;
|
||||||
|
|
||||||
|
let ev_manager = Arc::new(EventManager::new()?);
|
||||||
|
let ev_manager_clone = Arc::clone(&ev_manager);
|
||||||
|
|
||||||
|
let mut c = Chat::new(&config, &handle, Arc::clone(&ev_manager)).await?;
|
||||||
|
|
||||||
|
let (from_plugins, to_chat) = mpsc::channel(100);
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ev_manager_clone.start_listening("/tmp/robo.sock") => {
|
||||||
|
// Event listener ended
|
||||||
|
}
|
||||||
|
result = c.run(to_chat) => {
|
||||||
|
if let Err(e) = result {
|
||||||
|
tracing::error!("Chat run error: {:?}", e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fifo = EventManager::start_fifo("/tmp/robo_in.sock", from_plugins) => {
|
||||||
|
fifo.wrap_err("FIFO reader failed.")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
53
src/main.rs
53
src/main.rs
@@ -1,55 +1,6 @@
|
|||||||
use color_eyre::{
|
use color_eyre::Result;
|
||||||
Result,
|
|
||||||
eyre::WrapErr,
|
|
||||||
};
|
|
||||||
use human_panic::setup_panic;
|
|
||||||
use tracing::{
|
|
||||||
Level,
|
|
||||||
info,
|
|
||||||
};
|
|
||||||
use tracing_subscriber::FmtSubscriber;
|
|
||||||
|
|
||||||
mod chat;
|
|
||||||
mod command;
|
|
||||||
mod qna;
|
|
||||||
mod setup;
|
|
||||||
|
|
||||||
const DEFAULT_INSTRUCT: &'static str =
|
|
||||||
"You are a shady, yet helpful IRC bot. You try to give responses that can
|
|
||||||
be sent in a single IRC response according to the specification.";
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
// Some error sprucing.
|
robotnik::run().await
|
||||||
better_panic::install();
|
|
||||||
setup_panic!();
|
|
||||||
|
|
||||||
let subscriber = FmtSubscriber::builder()
|
|
||||||
.with_max_level(Level::TRACE)
|
|
||||||
.finish();
|
|
||||||
|
|
||||||
tracing::subscriber::set_global_default(subscriber)
|
|
||||||
.wrap_err("Failed to setup trace logging.")?;
|
|
||||||
|
|
||||||
info!("Starting");
|
|
||||||
|
|
||||||
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
|
|
||||||
let config = settings.config;
|
|
||||||
|
|
||||||
let handle = qna::new(
|
|
||||||
config.get_string("api-key").wrap_err("API missing.")?,
|
|
||||||
config
|
|
||||||
.get_string("base-url")
|
|
||||||
.wrap_err("base-url missing.")?,
|
|
||||||
config
|
|
||||||
.get_string("model")
|
|
||||||
.wrap_err("model string missing.")?,
|
|
||||||
config.get_string("instruct").unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
|
|
||||||
)
|
|
||||||
.wrap_err("Couldn't initialize LLM handle.")?;
|
|
||||||
let mut c = chat::new(&config, &handle).await?;
|
|
||||||
|
|
||||||
c.run().await.unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|||||||
37
src/plugin.rs
Normal file
37
src/plugin.rs
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
//! Plugin command definitions.
|
||||||
|
|
||||||
|
// Dear future me: If you forget the JSON translations in the future you'll
|
||||||
|
// thank me for the comment overkill.
|
||||||
|
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Message types accepted from plugins.
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub enum PluginMsg {
|
||||||
|
/// Plugin message indicating the bot should send a [`message`] to [`channel`].
|
||||||
|
/// {
|
||||||
|
/// "SendMessage": {
|
||||||
|
/// "channel": "channel_name",
|
||||||
|
/// "message": "your message here"
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// }
|
||||||
|
SendMessage {
|
||||||
|
/// The IRC channel to send the [`message`] to.
|
||||||
|
channel: String,
|
||||||
|
/// The [`message`] to send.
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for PluginMsg {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::SendMessage { channel, message } => {
|
||||||
|
write!(f, "[{channel}]: {message}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
71
src/qna.rs
71
src/qna.rs
@@ -1,23 +1,20 @@
|
|||||||
|
//! Handles communication with a genai compatible LLM.
|
||||||
|
|
||||||
use color_eyre::Result;
|
use color_eyre::Result;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use genai::{
|
use genai::{
|
||||||
Client,
|
Client,
|
||||||
ModelIden,
|
ModelIden,
|
||||||
chat::{
|
chat::{ChatMessage, ChatRequest, ChatStreamEvent, StreamChunk},
|
||||||
ChatMessage,
|
resolver::{AuthData, AuthResolver},
|
||||||
ChatRequest,
|
|
||||||
ChatStreamEvent,
|
|
||||||
StreamChunk,
|
|
||||||
},
|
|
||||||
resolver::{
|
|
||||||
AuthData,
|
|
||||||
AuthResolver,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
// NB: Docs are quick and dirty as this might move into a plugin.
|
||||||
|
|
||||||
// Represents an LLM completion source.
|
// Represents an LLM completion source.
|
||||||
// FIXME: Clone is probably temporary.
|
// FIXME: Clone is probably temporary.
|
||||||
|
/// Struct containing information about the LLM.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct LLMHandle {
|
pub struct LLMHandle {
|
||||||
chat_request: ChatRequest,
|
chat_request: ChatRequest,
|
||||||
@@ -25,33 +22,35 @@ pub struct LLMHandle {
|
|||||||
model: String,
|
model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
|
||||||
api_key: String,
|
|
||||||
_base_url: impl AsRef<str>,
|
|
||||||
model: impl Into<String>,
|
|
||||||
system_role: String,
|
|
||||||
) -> Result<LLMHandle> {
|
|
||||||
let auth_resolver = AuthResolver::from_resolver_fn(
|
|
||||||
|_model_iden: ModelIden| -> Result<Option<AuthData>, genai::resolver::Error> {
|
|
||||||
// let ModelIden { adapter_kind, model_name } = model_iden;
|
|
||||||
|
|
||||||
Ok(Some(AuthData::from_single(api_key)))
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let client = Client::builder().with_auth_resolver(auth_resolver).build();
|
|
||||||
let chat_request = ChatRequest::default().with_system(system_role);
|
|
||||||
|
|
||||||
info!("New LLMHandle created.");
|
|
||||||
|
|
||||||
Ok(LLMHandle {
|
|
||||||
client,
|
|
||||||
chat_request,
|
|
||||||
model: model.into(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LLMHandle {
|
impl LLMHandle {
|
||||||
|
/// Create a new handle.
|
||||||
|
pub fn new(
|
||||||
|
api_key: String,
|
||||||
|
_base_url: impl AsRef<str>,
|
||||||
|
model: impl Into<String>,
|
||||||
|
system_role: String,
|
||||||
|
) -> Result<LLMHandle> {
|
||||||
|
let auth_resolver = AuthResolver::from_resolver_fn(
|
||||||
|
|_model_iden: ModelIden| -> Result<Option<AuthData>, genai::resolver::Error> {
|
||||||
|
// let ModelIden { adapter_kind, model_name } = model_iden;
|
||||||
|
|
||||||
|
Ok(Some(AuthData::from_single(api_key)))
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = Client::builder().with_auth_resolver(auth_resolver).build();
|
||||||
|
let chat_request = ChatRequest::default().with_system(system_role);
|
||||||
|
|
||||||
|
info!("New LLMHandle created.");
|
||||||
|
|
||||||
|
Ok(LLMHandle {
|
||||||
|
client,
|
||||||
|
chat_request,
|
||||||
|
model: model.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a chat message to the LLM with the response being returned as a [`String`].
|
||||||
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
||||||
let mut req = self.chat_request.clone();
|
let mut req = self.chat_request.clone();
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
|||||||
86
src/setup.rs
86
src/setup.rs
@@ -1,78 +1,101 @@
|
|||||||
|
//! Handles configuration for the bot.
|
||||||
|
//!
|
||||||
|
//! Both command line, and configuration file options are handled here.
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use color_eyre::{
|
use color_eyre::{Result, eyre::WrapErr};
|
||||||
Result,
|
|
||||||
eyre::WrapErr,
|
|
||||||
};
|
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use directories::ProjectDirs;
|
use directories::ProjectDirs;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tracing::{
|
use tracing::{info, instrument};
|
||||||
info,
|
|
||||||
instrument,
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: use [clap(long, short, help_heading = Some(section))]
|
// TODO: use [clap(long, short, help_heading = Some(section))]
|
||||||
|
/// Struct of potential arguments.
|
||||||
#[derive(Clone, Debug, Parser)]
|
#[derive(Clone, Debug, Parser)]
|
||||||
#[command(about, version)]
|
#[command(about, version)]
|
||||||
pub(crate) struct Args {
|
pub struct Args {
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
/// API Key for the LLM in use.
|
/// API Key for the LLM in use.
|
||||||
pub(crate) api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
|
|
||||||
#[arg(short, long, default_value = "https://api.openai.com")]
|
#[arg(short, long, default_value = "https://api.openai.com")]
|
||||||
/// Base URL for the LLM API to use.
|
/// Base URL for the LLM API to use.
|
||||||
pub(crate) base_url: Option<String>,
|
pub base_url: Option<String>,
|
||||||
|
|
||||||
|
/// Directory to use for chroot (recommended).
|
||||||
|
#[arg(long)]
|
||||||
|
pub chroot_dir: Option<String>,
|
||||||
|
|
||||||
|
/// Root directory for file based command structure.
|
||||||
|
#[arg(long)]
|
||||||
|
pub command_dir: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// Instructions to the model on how to behave.
|
/// Instructions to the model on how to behave.
|
||||||
pub(crate) intruct: Option<String>,
|
pub instruct: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub(crate) model: Option<String>,
|
/// Name of the model to use. E.g. 'deepseek-chat'
|
||||||
|
pub model: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// List of IRC channels to join.
|
/// List of IRC channels to join.
|
||||||
pub(crate) channels: Option<Vec<String>>,
|
pub channels: Option<Vec<String>>,
|
||||||
|
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
/// Custom configuration file location if need be.
|
/// Custom configuration file location if need be.
|
||||||
pub(crate) config_file: Option<PathBuf>,
|
pub config_file: Option<PathBuf>,
|
||||||
|
|
||||||
#[arg(short, long, default_value = "irc.libera.chat")]
|
#[arg(short, long, default_value = "irc.libera.chat")]
|
||||||
/// IRC server.
|
/// IRC server.
|
||||||
pub(crate) server: Option<String>,
|
pub server: Option<String>,
|
||||||
|
|
||||||
#[arg(short, long, default_value = "6697")]
|
#[arg(short, long, default_value = "6697")]
|
||||||
/// Port of the IRC server.
|
/// Port of the IRC server.
|
||||||
pub(crate) port: Option<String>,
|
pub port: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// IRC Nickname.
|
/// IRC Nickname.
|
||||||
pub(crate) nickname: Option<String>,
|
pub nickname: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// IRC Nick Password
|
/// IRC Nick Password
|
||||||
pub(crate) nick_password: Option<String>,
|
pub nick_password: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// IRC Username
|
/// IRC Username
|
||||||
pub(crate) username: Option<String>,
|
pub username: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long = "no-tls")]
|
||||||
/// Whether or not to use TLS when connecting to the IRC server.
|
/// Whether or not to use TLS when connecting to the IRC server.
|
||||||
pub(crate) use_tls: Option<bool>,
|
pub use_tls: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Setup {
|
/// Handle for interacting with the bot configuration.
|
||||||
pub(crate) config: Config,
|
pub struct Setup {
|
||||||
|
/// Handle for the configuration file options.
|
||||||
|
pub config: Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
#[instrument]
|
||||||
|
/// Initialize a new [`Setup`] instance.
|
||||||
|
///
|
||||||
|
/// This reads the settings file which becomes the bot's default configuration.
|
||||||
|
/// These settings shall be overridden by any command line options.
|
||||||
pub async fn init() -> Result<Setup> {
|
pub async fn init() -> Result<Setup> {
|
||||||
// Get arguments. These overrule configuration file, and environment
|
// Get arguments. These overrule configuration file, and environment
|
||||||
// variables if applicable.
|
// variables if applicable.
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let settings = make_config(args)?;
|
||||||
|
|
||||||
|
Ok(Setup { config: settings })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a configuration object from arguments.
|
||||||
|
///
|
||||||
|
/// This is exposed for testing purposes.
|
||||||
|
pub fn make_config(args: Args) -> Result<Config> {
|
||||||
// Use default config location unless specified.
|
// Use default config location unless specified.
|
||||||
let config_location: PathBuf = if let Some(ref path) = args.config_file {
|
let config_location: PathBuf = if let Some(ref path) = args.config_file {
|
||||||
path.to_owned()
|
path.to_owned()
|
||||||
@@ -86,23 +109,24 @@ pub async fn init() -> Result<Setup> {
|
|||||||
|
|
||||||
info!("Starting.");
|
info!("Starting.");
|
||||||
|
|
||||||
let settings = Config::builder()
|
Config::builder()
|
||||||
.add_source(config::File::with_name(&config_location.to_string_lossy()).required(false))
|
.add_source(config::File::with_name(&config_location.to_string_lossy()).required(false))
|
||||||
.add_source(config::Environment::with_prefix("BOT"))
|
.add_source(config::Environment::with_prefix("BOT"))
|
||||||
// Doing all of these overrides provides a unified access point for options,
|
// Doing all of these overrides provides a unified access point for options,
|
||||||
// but a derive macro could do this a bit better if this becomes too large.
|
// but a derive macro could do this a bit better if this becomes too large.
|
||||||
.set_override_option("api-key", args.api_key.clone())?
|
.set_override_option("api-key", args.api_key.clone())?
|
||||||
.set_override_option("base-url", args.base_url.clone())?
|
.set_override_option("base-url", args.base_url.clone())?
|
||||||
|
.set_override_option("chroot-dir", args.chroot_dir.clone())?
|
||||||
|
.set_override_option("command-path", args.command_dir.clone())?
|
||||||
.set_override_option("model", args.model.clone())?
|
.set_override_option("model", args.model.clone())?
|
||||||
.set_override_option("instruct", args.model.clone())?
|
.set_override_option("nick-password", args.nick_password.clone())?
|
||||||
|
.set_override_option("instruct", args.instruct.clone())?
|
||||||
.set_override_option("channels", args.channels.clone())?
|
.set_override_option("channels", args.channels.clone())?
|
||||||
.set_override_option("server", args.server.clone())?
|
.set_override_option("server", args.server.clone())?
|
||||||
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap.
|
.set_override_option("port", args.port.clone())?
|
||||||
.set_override_option("nickname", args.nickname.clone())?
|
.set_override_option("nickname", args.nickname.clone())?
|
||||||
.set_override_option("username", args.username.clone())?
|
.set_override_option("username", args.username.clone())?
|
||||||
.set_override_option("use_tls", args.use_tls)?
|
.set_override_option("use-tls", args.use_tls)?
|
||||||
.build()
|
.build()
|
||||||
.wrap_err("Couldn't read configuration settings.")?;
|
.wrap_err("Couldn't read configuration settings.")
|
||||||
|
|
||||||
Ok(Setup { config: settings })
|
|
||||||
}
|
}
|
||||||
|
|||||||
290
tests/command_test.rs
Normal file
290
tests/command_test.rs
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
use std::{
|
||||||
|
fs::{self, Permissions},
|
||||||
|
os::unix::fs::PermissionsExt,
|
||||||
|
path::Path,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use robotnik::command::CommandDir;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
/// Helper to create executable test scripts
|
||||||
|
fn create_command(dir: &Path, name: &str, script: &str) {
|
||||||
|
let path = dir.join(name);
|
||||||
|
fs::write(&path, script).unwrap();
|
||||||
|
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a bot message like "!weather 07008" into (command_name, argument)
|
||||||
|
fn parse_bot_message(message: &str) -> Option<(&str, &str)> {
|
||||||
|
if !message.starts_with('!') {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let without_prefix = &message[1..];
|
||||||
|
let mut parts = without_prefix.splitn(2, ' ');
|
||||||
|
let command = parts.next()?;
|
||||||
|
let arg = parts.next().unwrap_or("");
|
||||||
|
Some((command, arg))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_finds_and_runs_command() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a weather command that echoes the zip code
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"weather",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Weather for $1: Sunny, 72°F"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!weather 10096";
|
||||||
|
|
||||||
|
// Parse the message
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
assert_eq!(command_name, "weather");
|
||||||
|
assert_eq!(arg, "10096");
|
||||||
|
|
||||||
|
// Find and run the command
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(output.contains("Weather for 10096"));
|
||||||
|
assert!(output.contains("Sunny"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_command_not_found() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
let message = "!nonexistent arg";
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_with_multiple_arguments() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a command that handles multiple words as a single argument
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"echo",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "You said: $1"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!echo hello world how are you";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
assert_eq!(command_name, "echo");
|
||||||
|
assert_eq!(arg, "hello world how are you");
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(output.contains("hello world how are you"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_without_argument() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"help",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Available commands: weather, echo, help"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!help";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
assert_eq!(command_name, "help");
|
||||||
|
assert_eq!(arg, "");
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(output.contains("Available commands"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_command_returns_error_exit_code() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a command that fails for invalid input
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"validate",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "Error: Input required" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Valid: $1"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!validate";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Error running"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_with_timeout() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"quick",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Result: $1"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!quick test";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout(command_name, arg, Duration::from_secs(5))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_command_times_out() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"slow",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
sleep 10
|
||||||
|
echo "Done"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!slow arg";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout(command_name, arg, Duration::from_millis(100))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_commands_in_directory() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"weather",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Weather: Sunny"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"time",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Time: 12:00"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"joke",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Why did the robot go on vacation? To recharge!"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
// Test each command
|
||||||
|
let messages = ["!weather", "!time", "!joke"];
|
||||||
|
let expected = ["Sunny", "12:00", "recharge"];
|
||||||
|
|
||||||
|
for (message, expect) in messages.iter().zip(expected.iter()) {
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(
|
||||||
|
output.contains(expect),
|
||||||
|
"Expected '{}' in '{}'",
|
||||||
|
expect,
|
||||||
|
output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_non_bot_message_ignored() {
|
||||||
|
// Messages not starting with ! should be ignored
|
||||||
|
let messages = ["hello world", "weather 10096", "?help", "/command", ""];
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
assert!(
|
||||||
|
parse_bot_message(message).is_none(),
|
||||||
|
"Should ignore: {}",
|
||||||
|
message
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_command_output_is_bytes() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a command that outputs binary-safe content
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"binary",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
printf "Hello\x00World"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!binary test";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let output = result.unwrap();
|
||||||
|
// Should preserve the null byte
|
||||||
|
assert_eq!(&output[..], b"Hello\x00World");
|
||||||
|
}
|
||||||
495
tests/event_test.rs
Normal file
495
tests/event_test.rs
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
use robotnik::{event::Event, event_manager::EventManager};
|
||||||
|
use rstest::rstest;
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncBufReadExt, BufReader},
|
||||||
|
net::UnixStream,
|
||||||
|
time::timeout,
|
||||||
|
};
|
||||||
|
|
||||||
|
const TEST_SOCKET_BASE: &str = "/tmp/robotnik_test";
|
||||||
|
|
||||||
|
/// Helper to create unique socket paths for parallel tests
|
||||||
|
fn test_socket_path(name: &str) -> String {
|
||||||
|
format!("{}_{}_{}", TEST_SOCKET_BASE, name, std::process::id())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to read one JSON event from a stream
|
||||||
|
async fn read_event(
|
||||||
|
reader: &mut BufReader<UnixStream>,
|
||||||
|
) -> Result<Event, Box<dyn std::error::Error>> {
|
||||||
|
let mut line = String::new();
|
||||||
|
reader.read_line(&mut line).await?;
|
||||||
|
let event: Event = serde_json::from_str(&line)?;
|
||||||
|
Ok(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to read all available events with a timeout
|
||||||
|
async fn read_events_with_timeout(
|
||||||
|
reader: &mut BufReader<UnixStream>,
|
||||||
|
max_count: usize,
|
||||||
|
timeout_ms: u64,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let mut events = Vec::new();
|
||||||
|
for _ in 0..max_count {
|
||||||
|
let mut line = String::new();
|
||||||
|
match timeout(
|
||||||
|
Duration::from_millis(timeout_ms),
|
||||||
|
reader.read_line(&mut line),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(0)) => break, // EOF
|
||||||
|
Ok(Ok(_)) => events.push(line),
|
||||||
|
Ok(Err(_)) => break, // Read error
|
||||||
|
Err(_) => break, // Timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
events
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_connects_and_receives_event() {
|
||||||
|
let socket_path = test_socket_path("basic_connect");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give the listener time to start
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Broadcast an event
|
||||||
|
let event = Event::new("test_user", "test message");
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
// Connect as a client
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Read the event
|
||||||
|
let mut line = String::new();
|
||||||
|
reader.read_line(&mut line).await.unwrap();
|
||||||
|
|
||||||
|
assert!(line.contains("test message"));
|
||||||
|
assert!(line.ends_with('\n'));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_receives_event_history() {
|
||||||
|
let socket_path = test_socket_path("event_history");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Broadcast events BEFORE starting the listener
|
||||||
|
for i in 0..5 {
|
||||||
|
let event = Event::new("test_user", format!("historical event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect as a client
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Should receive all 5 historical events
|
||||||
|
let events = read_events_with_timeout(&mut reader, 5, 100).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 5);
|
||||||
|
assert!(events[0].contains("historical event 0"));
|
||||||
|
assert!(events[4].contains("historical event 4"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_clients_receive_same_events() {
|
||||||
|
let socket_path = test_socket_path("multiple_clients");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect 3 clients
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
let mut reader3 = BufReader::new(stream3);
|
||||||
|
|
||||||
|
// Broadcast a new event
|
||||||
|
let event = Event::new("test_user", "broadcast to all");
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
// All clients should receive the event
|
||||||
|
let mut line1 = String::new();
|
||||||
|
let mut line2 = String::new();
|
||||||
|
let mut line3 = String::new();
|
||||||
|
|
||||||
|
timeout(Duration::from_millis(100), reader1.read_line(&mut line1))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
timeout(Duration::from_millis(100), reader2.read_line(&mut line2))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
timeout(Duration::from_millis(100), reader3.read_line(&mut line3))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(line1.contains("broadcast to all"));
|
||||||
|
assert!(line2.contains("broadcast to all"));
|
||||||
|
assert!(line3.contains("broadcast to all"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_late_joiner_receives_full_history() {
|
||||||
|
let socket_path = test_socket_path("late_joiner");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// First client connects
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
|
||||||
|
// Broadcast several events
|
||||||
|
for i in 0..10 {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume events from first client
|
||||||
|
let _ = read_events_with_timeout(&mut reader1, 10, 100).await;
|
||||||
|
|
||||||
|
// Late joiner connects
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
|
||||||
|
// Late joiner should receive all 10 events from history
|
||||||
|
let events = read_events_with_timeout(&mut reader2, 10, 100).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 10);
|
||||||
|
assert!(events[0].contains("event 0"));
|
||||||
|
assert!(events[9].contains("event 9"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_receives_events_in_order() {
|
||||||
|
let socket_path = test_socket_path("event_order");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect client
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Broadcast events rapidly
|
||||||
|
let count = 50;
|
||||||
|
for i in 0..count {
|
||||||
|
let event = Event::new("test_user", format!("sequence {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all events
|
||||||
|
let events = read_events_with_timeout(&mut reader, count, 500).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), count);
|
||||||
|
|
||||||
|
// Verify order
|
||||||
|
for (i, event) in events.iter().enumerate() {
|
||||||
|
assert!(
|
||||||
|
event.contains(&format!("sequence {}", i)),
|
||||||
|
"Event {} out of order: {}",
|
||||||
|
i,
|
||||||
|
event
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_concurrent_broadcasts_during_client_connections() {
|
||||||
|
let socket_path = test_socket_path("concurrent_ops");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect client 1 BEFORE any broadcasts
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
|
||||||
|
// Spawn a task that continuously broadcasts
|
||||||
|
let broadcast_manager = Arc::clone(&manager);
|
||||||
|
let broadcast_handle = tokio::spawn(async move {
|
||||||
|
for i in 0..100 {
|
||||||
|
let event = Event::new("test_user", format!("concurrent event {}", i));
|
||||||
|
broadcast_manager.broadcast(&event).await.unwrap();
|
||||||
|
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// While broadcasting, connect more clients at different times
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||||
|
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader3 = BufReader::new(stream3);
|
||||||
|
|
||||||
|
// Wait for broadcasts to complete
|
||||||
|
broadcast_handle.await.unwrap();
|
||||||
|
|
||||||
|
// All clients should have received events
|
||||||
|
let events1 = read_events_with_timeout(&mut reader1, 100, 200).await;
|
||||||
|
let events2 = read_events_with_timeout(&mut reader2, 100, 200).await;
|
||||||
|
let events3 = read_events_with_timeout(&mut reader3, 100, 200).await;
|
||||||
|
|
||||||
|
// Client 1 connected first (before any broadcasts), should get all 100
|
||||||
|
assert_eq!(events1.len(), 100);
|
||||||
|
|
||||||
|
// Client 2 connected after ~20 events were broadcast
|
||||||
|
// Gets ~20 from history + ~80 live = 100
|
||||||
|
assert_eq!(events2.len(), 100);
|
||||||
|
|
||||||
|
// Client 3 connected after ~50 events were broadcast
|
||||||
|
// Gets ~50 from history + ~50 live = 100
|
||||||
|
assert_eq!(events3.len(), 100);
|
||||||
|
|
||||||
|
// Verify they all received events in order
|
||||||
|
assert!(events1[0].contains("concurrent event 0"));
|
||||||
|
assert!(events2[0].contains("concurrent event 0"));
|
||||||
|
assert!(events3[0].contains("concurrent event 0"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_overflow_affects_new_clients() {
|
||||||
|
let socket_path = test_socket_path("buffer_overflow");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Broadcast more than buffer max (1000)
|
||||||
|
for i in 0..1100 {
|
||||||
|
let event = Event::new("test_user", format!("overflow event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// New client connects
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Should receive exactly 1000 events (buffer max)
|
||||||
|
let events = read_events_with_timeout(&mut reader, 1100, 500).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 1000);
|
||||||
|
|
||||||
|
// First event should be 100 (oldest 100 were evicted)
|
||||||
|
assert!(events[0].contains("overflow event 100"));
|
||||||
|
|
||||||
|
// Last event should be 1099
|
||||||
|
assert!(events[999].contains("overflow event 1099"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case(10, 1)]
|
||||||
|
#[case(50, 5)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_count_scaling(#[case] num_clients: usize, #[case] events_per_client: usize) {
|
||||||
|
let socket_path = test_socket_path(&format!("scaling_{}_{}", num_clients, events_per_client));
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect many clients
|
||||||
|
let mut readers = Vec::new();
|
||||||
|
for _ in 0..num_clients {
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
readers.push(BufReader::new(stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast events
|
||||||
|
for i in 0..events_per_client {
|
||||||
|
let event = Event::new("test_user", format!("scale event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all clients received all events
|
||||||
|
for reader in &mut readers {
|
||||||
|
let events = read_events_with_timeout(reader, events_per_client, 200).await;
|
||||||
|
assert_eq!(events.len(), events_per_client);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_disconnect_doesnt_affect_others() {
|
||||||
|
let socket_path = test_socket_path("disconnect");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect 3 clients
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
let mut reader3 = BufReader::new(stream3);
|
||||||
|
|
||||||
|
// Broadcast initial event
|
||||||
|
manager
|
||||||
|
.broadcast(&Event::new("test_user", "before disconnect"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// All receive it
|
||||||
|
let _ = read_events_with_timeout(&mut reader1, 1, 100).await;
|
||||||
|
let _ = read_events_with_timeout(&mut reader2, 1, 100).await;
|
||||||
|
let _ = read_events_with_timeout(&mut reader3, 1, 100).await;
|
||||||
|
|
||||||
|
// Drop client 2 (simulates disconnect)
|
||||||
|
drop(reader2);
|
||||||
|
|
||||||
|
// Broadcast another event
|
||||||
|
manager
|
||||||
|
.broadcast(&Event::new("test_user", "after disconnect"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Clients 1 and 3 should still receive it
|
||||||
|
let events1 = read_events_with_timeout(&mut reader1, 1, 100).await;
|
||||||
|
let events3 = read_events_with_timeout(&mut reader3, 1, 100).await;
|
||||||
|
|
||||||
|
assert_eq!(events1.len(), 1);
|
||||||
|
assert_eq!(events3.len(), 1);
|
||||||
|
assert!(events1[0].contains("after disconnect"));
|
||||||
|
assert!(events3[0].contains("after disconnect"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_json_deserialization_of_received_events() {
|
||||||
|
let socket_path = test_socket_path("json_deser");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Broadcast an event with special characters
|
||||||
|
let test_message = "special chars: @#$% newline\\n tab\\t quotes \"test\"";
|
||||||
|
manager
|
||||||
|
.broadcast(&Event::new("test_user", test_message))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Connect and deserialize
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
let mut line = String::new();
|
||||||
|
reader.read_line(&mut line).await.unwrap();
|
||||||
|
|
||||||
|
// Should be valid JSON
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed["message"], test_message);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
556
tests/setup_test.rs
Normal file
556
tests/setup_test.rs
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
use robotnik::setup::{Args, make_config};
|
||||||
|
use serial_test::serial;
|
||||||
|
use std::{fs, path::PathBuf};
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
/// Helper to create a temporary config file
|
||||||
|
fn create_config_file(dir: &TempDir, content: &str) -> PathBuf {
|
||||||
|
let config_path = dir.path().join("config.toml");
|
||||||
|
fs::write(&config_path, content).unwrap();
|
||||||
|
config_path
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to parse config using environment and config file
|
||||||
|
async fn parse_config_from_file(config_path: &PathBuf) -> config::Config {
|
||||||
|
config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_setup_make_config_overrides() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api-key = \"file-key\"
|
||||||
|
model = \"file-model\"
|
||||||
|
port = 6667
|
||||||
|
";
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Construct Args with overrides
|
||||||
|
let args = Args {
|
||||||
|
api_key: Some("cli-key".to_string()),
|
||||||
|
base_url: None, /* Should fail if required and not in file/env? No, base-url is optional
|
||||||
|
* in args */
|
||||||
|
chroot_dir: None,
|
||||||
|
command_dir: None,
|
||||||
|
instruct: None,
|
||||||
|
model: None, // Should fallback to file
|
||||||
|
channels: None,
|
||||||
|
config_file: Some(config_path),
|
||||||
|
server: None, // Should use default or file? Args has default "irc.libera.chat"
|
||||||
|
port: Some("9999".to_string()),
|
||||||
|
nickname: None,
|
||||||
|
nick_password: None,
|
||||||
|
username: None,
|
||||||
|
use_tls: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = make_config(args).expect("Failed to make config");
|
||||||
|
|
||||||
|
// Check overrides
|
||||||
|
assert_eq!(config.get_string("api-key").unwrap(), "cli-key");
|
||||||
|
assert_eq!(config.get_string("port").unwrap(), "9999");
|
||||||
|
assert_eq!(config.get_int("port").unwrap(), 9999);
|
||||||
|
|
||||||
|
// Check fallback to file
|
||||||
|
assert_eq!(config.get_string("model").unwrap(), "file-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_config_file_loads_all_settings() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api-key = \"test-api-key-123\"
|
||||||
|
base-url = \"https://api.test.com\"
|
||||||
|
chroot-dir = \"/test/chroot\"
|
||||||
|
command-path = \"/test/commands\"
|
||||||
|
model = \"test-model\"
|
||||||
|
instruct = \"Test instructions\"
|
||||||
|
server = \"test.irc.server\"
|
||||||
|
port = 6667
|
||||||
|
channels = [\"#test1\", \"#test2\"]
|
||||||
|
username = \"testuser\"
|
||||||
|
nickname = \"testnick\"
|
||||||
|
use-tls = false
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
// Verify all settings are loaded correctly
|
||||||
|
assert_eq!(config.get_string("api-key").unwrap(), "test-api-key-123");
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("base-url").unwrap(),
|
||||||
|
"https://api.test.com"
|
||||||
|
);
|
||||||
|
assert_eq!(config.get_string("chroot-dir").unwrap(), "/test/chroot");
|
||||||
|
assert_eq!(config.get_string("command-path").unwrap(), "/test/commands");
|
||||||
|
assert_eq!(config.get_string("model").unwrap(), "test-model");
|
||||||
|
assert_eq!(config.get_string("instruct").unwrap(), "Test instructions");
|
||||||
|
assert_eq!(config.get_string("server").unwrap(), "test.irc.server");
|
||||||
|
assert_eq!(config.get_int("port").unwrap(), 6667);
|
||||||
|
|
||||||
|
let channels: Vec<String> = config.get("channels").unwrap();
|
||||||
|
assert_eq!(channels, vec!["#test1", "#test2"]);
|
||||||
|
|
||||||
|
assert_eq!(config.get_string("username").unwrap(), "testuser");
|
||||||
|
assert_eq!(config.get_string("nickname").unwrap(), "testnick");
|
||||||
|
assert_eq!(config.get_bool("use-tls").unwrap(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_config_file_partial_settings() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
// Only provide required settings
|
||||||
|
let config_content = "\
|
||||||
|
api-key = \"minimal-key\"
|
||||||
|
base-url = \"https://minimal.api.com\"
|
||||||
|
model = \"minimal-model\"
|
||||||
|
server = \"minimal.server\"
|
||||||
|
port = 6697
|
||||||
|
channels = [\"#minimal\"]
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
// Verify required settings are loaded
|
||||||
|
assert_eq!(config.get_string("api-key").unwrap(), "minimal-key");
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("base-url").unwrap(),
|
||||||
|
"https://minimal.api.com"
|
||||||
|
);
|
||||||
|
assert_eq!(config.get_string("model").unwrap(), "minimal-model");
|
||||||
|
|
||||||
|
// Verify optional settings are not present
|
||||||
|
assert!(config.get_string("chroot-dir").is_err());
|
||||||
|
assert!(config.get_string("instruct").is_err());
|
||||||
|
assert!(config.get_string("username").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_config_with_environment_variables() {
|
||||||
|
// NOTE: This test documents a limitation in setup.rs
|
||||||
|
// setup.rs uses Environment::with_prefix("BOT") without a separator
|
||||||
|
// This means BOT_API_KEY maps to "api_key", NOT "api-key"
|
||||||
|
// Since config.toml uses kebab-case, environment variables won't override properly
|
||||||
|
// This is a known issue in the current implementation
|
||||||
|
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api_key = \"file-api-key\"
|
||||||
|
base_url = \"https://file.api.com\"
|
||||||
|
model = \"file-model\"
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Set environment variables (with BOT_ prefix as setup.rs uses)
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("BOT_API_KEY", "env-api-key");
|
||||||
|
std::env::set_var("BOT_MODEL", "env-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.add_source(config::Environment::with_prefix("BOT"))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Environment variables should override file settings (when using underscore keys)
|
||||||
|
assert_eq!(config.get_string("api_key").unwrap(), "env-api-key");
|
||||||
|
assert_eq!(config.get_string("model").unwrap(), "env-model");
|
||||||
|
// File setting should be used when no env var
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("base_url").unwrap(),
|
||||||
|
"https://file.api.com"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("BOT_API_KEY");
|
||||||
|
std::env::remove_var("BOT_MODEL");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_command_line_overrides_config_file() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api-key = \"file-api-key\"
|
||||||
|
base-url = \"https://file.api.com\"
|
||||||
|
model = \"file-model\"
|
||||||
|
server = \"file.server\"
|
||||||
|
port = 6667
|
||||||
|
channels = [\"#file\"]
|
||||||
|
nickname = \"filenick\"
|
||||||
|
username = \"fileuser\"
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Simulate command-line arguments overriding config file
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.set_override_option("api-key", Some("cli-api-key".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.set_override_option("model", Some("cli-model".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.set_override_option("server", Some("cli.server".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.set_override_option("nickname", Some("clinick".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Command-line values should override file settings
|
||||||
|
assert_eq!(config.get_string("api-key").unwrap(), "cli-api-key");
|
||||||
|
assert_eq!(config.get_string("model").unwrap(), "cli-model");
|
||||||
|
assert_eq!(config.get_string("server").unwrap(), "cli.server");
|
||||||
|
assert_eq!(config.get_string("nickname").unwrap(), "clinick");
|
||||||
|
|
||||||
|
// Non-overridden values should come from file
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("base-url").unwrap(),
|
||||||
|
"https://file.api.com"
|
||||||
|
);
|
||||||
|
assert_eq!(config.get_string("username").unwrap(), "fileuser");
|
||||||
|
assert_eq!(config.get_int("port").unwrap(), 6667);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_command_line_overrides_environment_and_file() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api_key = \"file-api-key\"
|
||||||
|
model = \"file-model\"
|
||||||
|
base_url = \"https://file.api.com\"
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Set environment variable
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("BOT_API_KEY", "env-api-key");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build config with all three sources
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.add_source(config::Environment::with_prefix("BOT"))
|
||||||
|
.set_override_option("api_key", Some("cli-api-key".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Command-line should win over both environment and file
|
||||||
|
assert_eq!(config.get_string("api_key").unwrap(), "cli-api-key");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("BOT_API_KEY");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_precedence_order() {
|
||||||
|
// Test: CLI > Environment > Config File > Defaults
|
||||||
|
// Using underscore keys to match how setup.rs actually works
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api_key = \"file-key\"
|
||||||
|
base_url = \"https://file-url.com\"
|
||||||
|
model = \"file-model\"
|
||||||
|
server = \"file-server\"
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Set environment variables
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("BOT_BASE_URL", "https://env-url.com");
|
||||||
|
std::env::set_var("BOT_MODEL", "env-model");
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.add_source(config::Environment::with_prefix("BOT"))
|
||||||
|
.set_override_option("model", Some("cli-model".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// CLI overrides everything
|
||||||
|
assert_eq!(config.get_string("model").unwrap(), "cli-model");
|
||||||
|
|
||||||
|
// Environment overrides file
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("base_url").unwrap(),
|
||||||
|
"https://env-url.com"
|
||||||
|
);
|
||||||
|
|
||||||
|
// File is used when no env or CLI
|
||||||
|
assert_eq!(config.get_string("api_key").unwrap(), "file-key");
|
||||||
|
assert_eq!(config.get_string("server").unwrap(), "file-server");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("BOT_BASE_URL");
|
||||||
|
std::env::remove_var("BOT_MODEL");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_boolean_use_tls_setting() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Test with use-tls = true (kebab-case as in config.toml)
|
||||||
|
let config_content_true = r#"
|
||||||
|
use-tls = true
|
||||||
|
"#;
|
||||||
|
let config_path = create_config_file(&temp, config_content_true);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
assert_eq!(config.get_bool("use-tls").unwrap(), true);
|
||||||
|
|
||||||
|
// Test with use-tls = false
|
||||||
|
let config_content_false = r#"
|
||||||
|
use-tls = false
|
||||||
|
"#;
|
||||||
|
let config_path = create_config_file(&temp, config_content_false);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
assert_eq!(config.get_bool("use-tls").unwrap(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_use_tls_naming_inconsistency() {
|
||||||
|
// This test documents a bug: setup.rs uses "use_tls" (underscore)
|
||||||
|
// but config.toml uses "use-tls" (kebab-case)
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = r#"
|
||||||
|
use-tls = true
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Build config the way setup.rs does it
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
// setup.rs line 119 uses "use_tls" (underscore) instead of "use-tls" (kebab)
|
||||||
|
.set_override_option("use_tls", Some(false))
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// This should read from the override (false), not the file (true)
|
||||||
|
// But due to the naming mismatch, it might not work as expected
|
||||||
|
// The config file uses "use-tls" but the override uses "use_tls"
|
||||||
|
|
||||||
|
// With kebab-case (matches config.toml)
|
||||||
|
assert_eq!(config.get_bool("use-tls").unwrap(), true);
|
||||||
|
|
||||||
|
// With underscore (matches setup.rs override)
|
||||||
|
assert_eq!(config.get_bool("use_tls").unwrap(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_channels_as_array() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
channels = [\"#chan1\", \"#chan2\", \"#chan3\"]
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
let channels: Vec<String> = config.get("channels").unwrap();
|
||||||
|
assert_eq!(channels.len(), 3);
|
||||||
|
assert_eq!(channels[0], "#chan1");
|
||||||
|
assert_eq!(channels[1], "#chan2");
|
||||||
|
assert_eq!(channels[2], "#chan3");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_channels_override_from_cli() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
channels = [\"#file1\", \"#file2\"]
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
let cli_channels = vec![
|
||||||
|
"#cli1".to_string(),
|
||||||
|
"#cli2".to_string(),
|
||||||
|
"#cli3".to_string(),
|
||||||
|
];
|
||||||
|
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.set_override_option("channels", Some(cli_channels.clone()))
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let channels: Vec<String> = config.get("channels").unwrap();
|
||||||
|
assert_eq!(channels, cli_channels);
|
||||||
|
assert_eq!(channels.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_port_as_integer() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = r#"
|
||||||
|
port = 6697
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
// Port should be readable as both integer and string
|
||||||
|
assert_eq!(config.get_int("port").unwrap(), 6697);
|
||||||
|
assert_eq!(config.get_string("port").unwrap(), "6697");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_port_override_from_cli_as_string() {
|
||||||
|
// setup.rs passes port as Option<String> from clap
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = r#"
|
||||||
|
port = 6667
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
let config = config::Config::builder()
|
||||||
|
.add_source(config::File::with_name(&config_path.to_string_lossy()).required(true))
|
||||||
|
.set_override_option("port", Some("9999".to_string()))
|
||||||
|
.unwrap()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// CLI override should work
|
||||||
|
assert_eq!(config.get_string("port").unwrap(), "9999");
|
||||||
|
assert_eq!(config.get_int("port").unwrap(), 9999);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_missing_required_fields_fails() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
// Create config without required api-key
|
||||||
|
let config_content = r#"
|
||||||
|
model = "test-model"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
// Should fail when trying to get required field
|
||||||
|
assert!(config.get_string("api-key").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_optional_instruct_field() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = r#"
|
||||||
|
instruct = "Custom bot instructions"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("instruct").unwrap(),
|
||||||
|
"Custom bot instructions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_command_path_field() {
|
||||||
|
// command-path is in config.toml but not used anywhere in the code
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = r#"
|
||||||
|
command-path = "/custom/commands"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("command-path").unwrap(),
|
||||||
|
"/custom/commands"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chroot_dir_field() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = r#"
|
||||||
|
chroot-dir = "/var/lib/bot/root"
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
config.get_string("chroot-dir").unwrap(),
|
||||||
|
"/var/lib/bot/root"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_empty_config_file() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
|
||||||
|
// Should build successfully but have no values
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
assert!(config.get_string("api-key").is_err());
|
||||||
|
assert!(config.get_string("model").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_all_cli_override_keys_match_config_format() {
|
||||||
|
// This test documents which override keys in setup.rs match the config.toml format
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let config_content = "\
|
||||||
|
api-key = \"test\"
|
||||||
|
base-url = \"https://test.com\"
|
||||||
|
chroot-dir = \"/test\"
|
||||||
|
command-path = \"/cmds\"
|
||||||
|
model = \"test-model\"
|
||||||
|
instruct = \"test\"
|
||||||
|
channels = [\"#test\"]
|
||||||
|
server = \"test.server\"
|
||||||
|
port = 6697
|
||||||
|
nickname = \"test\"
|
||||||
|
username = \"test\"
|
||||||
|
use-tls = true
|
||||||
|
";
|
||||||
|
|
||||||
|
let config_path = create_config_file(&temp, config_content);
|
||||||
|
let config = parse_config_from_file(&config_path).await;
|
||||||
|
|
||||||
|
// All these should work with kebab-case (as in config.toml)
|
||||||
|
assert!(config.get_string("api-key").is_ok());
|
||||||
|
assert!(config.get_string("base-url").is_ok());
|
||||||
|
assert!(config.get_string("chroot-dir").is_ok());
|
||||||
|
assert!(config.get_string("command-path").is_ok());
|
||||||
|
assert!(config.get_string("model").is_ok());
|
||||||
|
assert!(config.get_string("instruct").is_ok());
|
||||||
|
let channels_result: Result<Vec<String>, _> = config.get("channels");
|
||||||
|
assert!(channels_result.is_ok());
|
||||||
|
assert!(config.get_string("server").is_ok());
|
||||||
|
assert!(config.get_int("port").is_ok());
|
||||||
|
assert!(config.get_string("nickname").is_ok());
|
||||||
|
assert!(config.get_string("username").is_ok());
|
||||||
|
assert!(config.get_bool("use-tls").is_ok());
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user