Compare commits
18 Commits
f6afc959c8
...
plugin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
585afa5f6f
|
||
|
|
30e2d9a448
|
||
|
|
70de039610
|
||
|
8ec4f2860c
|
|||
|
|
21d9c3f002 | ||
|
|
f880795b44
|
||
| a158ee385f | |||
|
|
2da7cc4450
|
||
|
|
3af95235e6
|
||
|
|
5d390ee9f3
|
||
|
|
7f7981d6cd
|
||
|
|
ae190cc421
|
||
|
|
ae44cc947b
|
||
|
|
a3ebca0bb2
|
||
|
|
8fa79932d6
|
||
|
|
9719d9203c
|
||
|
|
138df60661
|
||
|
|
5f30fdbf77
|
1397
Cargo.lock
generated
1397
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
52
Cargo.toml
52
Cargo.toml
@@ -4,17 +4,57 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# TODO: make this a dev and/or debug dependency later.
|
|
||||||
better-panic = "0.3.0"
|
better-panic = "0.3.0"
|
||||||
clap = { version = "4.5", features = [ "derive" ] }
|
bytes = "1"
|
||||||
color-eyre = "0.6.3"
|
color-eyre = "0.6.3"
|
||||||
config = { version = "0.15", features = [ "toml" ] }
|
|
||||||
directories = "6.0"
|
directories = "6.0"
|
||||||
dotenvy_macro = "0.15"
|
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
human-panic = "2.0"
|
human-panic = "2.0"
|
||||||
genai = "0.4.0-alpha.9"
|
genai = "0.4.3"
|
||||||
irc = "1.1"
|
irc = "1.1"
|
||||||
tokio = { version = "1", features = [ "full" ] }
|
serde_json = "1.0"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
|
|
||||||
|
[dependencies.nix]
|
||||||
|
version = "0.30.1"
|
||||||
|
features = ["fs"]
|
||||||
|
|
||||||
|
[dependencies.clap]
|
||||||
|
version = "4.5"
|
||||||
|
features = ["derive"]
|
||||||
|
|
||||||
|
[dependencies.config]
|
||||||
|
version = "0.15"
|
||||||
|
features = ["toml"]
|
||||||
|
|
||||||
|
[dependencies.serde]
|
||||||
|
version = "1.0"
|
||||||
|
features = ["derive"]
|
||||||
|
|
||||||
|
[dependencies.tokio]
|
||||||
|
version = "1"
|
||||||
|
features = [
|
||||||
|
"io-util",
|
||||||
|
"macros",
|
||||||
|
"net",
|
||||||
|
"process",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"sync",
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rstest = "0.24"
|
||||||
|
tempfile = "3.13"
|
||||||
|
|
||||||
|
[dev-dependencies.cargo-husky]
|
||||||
|
version = "1"
|
||||||
|
features = ["run-cargo-check", "run-cargo-clippy"]
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
strip = true
|
||||||
|
opt-level = "z"
|
||||||
|
lto = true
|
||||||
|
codegen-units = 1
|
||||||
|
panic = "abort"
|
||||||
|
|||||||
@@ -3,12 +3,12 @@
|
|||||||
This is an IRC bot that. The name is based on a fictional video game villain.
|
This is an IRC bot that. The name is based on a fictional video game villain.
|
||||||
Currently it supports any LLM that uses the OpenAI style of interface. They
|
Currently it supports any LLM that uses the OpenAI style of interface. They
|
||||||
can be selected via command line options, environment variables, or via a configuration
|
can be selected via command line options, environment variables, or via a configuration
|
||||||
file. There is a [configureation file](config.toml) that *should* contain all available options
|
file. There is a [configuration file](config.toml) that *should* contain all available options
|
||||||
currently.
|
currently.
|
||||||
|
|
||||||
## Some supported but ~~possibly~~ *mostly* untested LLMs:
|
## Some supported but ~~possibly~~ *mostly* untested LLMs:
|
||||||
|
|
||||||
| Name | Model | Base URL | Teested |
|
| Name | Model | Base URL | Tested |
|
||||||
|------------|-------------------|-------------------------------------------|---------|
|
|------------|-------------------|-------------------------------------------|---------|
|
||||||
| OpenAI | gpt-5 | https://api.openai.com/v1 | no |
|
| OpenAI | gpt-5 | https://api.openai.com/v1 | no |
|
||||||
| Deepseek | deepseek-chat | https://api.deepseek.com/v1 | yes |
|
| Deepseek | deepseek-chat | https://api.deepseek.com/v1 | yes |
|
||||||
|
|||||||
42
robotnik.1
Normal file
42
robotnik.1
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
.Dd $Mdocdate$
|
||||||
|
.Dt robotnik 1
|
||||||
|
.Os
|
||||||
|
.Sh NAME
|
||||||
|
.Nm robotnik
|
||||||
|
.Nd A simple bot that among other things uses the OpenAI API.
|
||||||
|
.\" .Sh LIBRARY
|
||||||
|
.\" For sections 2, 3, and 9 only.
|
||||||
|
.\" Not used in OpenBSD.
|
||||||
|
.Sh SYNOPSIS
|
||||||
|
.Nm progname
|
||||||
|
.Op Fl options
|
||||||
|
.Ar
|
||||||
|
.Sh DESCRIPTION
|
||||||
|
The
|
||||||
|
.Nm
|
||||||
|
utility processes files ...
|
||||||
|
.\" .Sh CONTEXT
|
||||||
|
.\" For section 9 functions only.
|
||||||
|
.\" .Sh IMPLEMENTATION NOTES
|
||||||
|
.\" Not used in OpenBSD.
|
||||||
|
.\" .Sh RETURN VALUES
|
||||||
|
.\" For sections 2, 3, and 9 function return values only.
|
||||||
|
.\" .Sh ENVIRONMENT
|
||||||
|
.\" For sections 1, 6, 7, and 8 only.
|
||||||
|
.\" .Sh FILES
|
||||||
|
.\" .Sh EXIT STATUS
|
||||||
|
.\" For sections 1, 6, and 8 only.
|
||||||
|
.\" .Sh EXAMPLES
|
||||||
|
.\" .Sh DIAGNOSTICS
|
||||||
|
.\" For sections 1, 4, 6, 7, 8, and 9 printf/stderr messages only.
|
||||||
|
.\" .Sh ERRORS
|
||||||
|
.\" For sections 2, 3, 4, and 9 errno settings only.
|
||||||
|
.\" .Sh SEE ALSO
|
||||||
|
.\" .Xr foobar 1
|
||||||
|
.\" .Sh STANDARDS
|
||||||
|
.\" .Sh HISTORY
|
||||||
|
.\" .Sh AUTHORS
|
||||||
|
.\" .Sh CAVEATS
|
||||||
|
.\" .Sh BUGS
|
||||||
|
.\" .Sh SECURITY CONSIDERATIONS
|
||||||
|
.\" Not used in OpenBSD.
|
||||||
@@ -3,5 +3,5 @@ style_edition = "2024"
|
|||||||
comment_width = 100
|
comment_width = 100
|
||||||
format_code_in_doc_comments = true
|
format_code_in_doc_comments = true
|
||||||
imports_granularity = "Crate"
|
imports_granularity = "Crate"
|
||||||
imports_layout = "Vertical"
|
imports_layout = "HorizontalVertical"
|
||||||
wrap_comments = true
|
wrap_comments = true
|
||||||
|
|||||||
123
src/chat.rs
123
src/chat.rs
@@ -1,39 +1,32 @@
|
|||||||
use color_eyre::{
|
use std::sync::Arc;
|
||||||
Result,
|
|
||||||
eyre::{
|
use color_eyre::{Result, eyre::WrapErr};
|
||||||
OptionExt,
|
|
||||||
WrapErr,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
// Lots of namespace confusion potential
|
|
||||||
use crate::qna::LLMHandle;
|
|
||||||
use config::Config as MainConfig;
|
use config::Config as MainConfig;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use irc::client::prelude::{
|
use irc::client::prelude::{Client, Command, Config as IRCConfig, Message};
|
||||||
Client as IRCClient,
|
use tokio::sync::mpsc;
|
||||||
Command,
|
use tracing::{Level, event, instrument};
|
||||||
Config as IRCConfig,
|
|
||||||
};
|
use crate::{Event, EventManager, LLMHandle, plugin};
|
||||||
use tracing::{
|
|
||||||
Level,
|
|
||||||
event,
|
|
||||||
instrument,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Chat {
|
pub struct Chat {
|
||||||
client: IRCClient,
|
client: Client,
|
||||||
|
event_manager: Arc<EventManager>,
|
||||||
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
|
llm_handle: LLMHandle, // FIXME: This needs to be thread safe, and shared, etc.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need: owners, channels, username, nick, server, password
|
// Need: owners, channels, username, nick, server, password
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub async fn new(settings: &MainConfig, handle: &LLMHandle) -> Result<Chat> {
|
pub async fn new(
|
||||||
|
settings: &MainConfig,
|
||||||
|
handle: &LLMHandle,
|
||||||
|
manager: Arc<EventManager>,
|
||||||
|
) -> Result<Chat> {
|
||||||
// Going to just assign and let the irc library handle errors for now, and
|
// Going to just assign and let the irc library handle errors for now, and
|
||||||
// add my own checking if necessary.
|
// add my own checking if necessary.
|
||||||
let port: u16 = settings.get("port")?;
|
let port: u16 = settings.get("port")?;
|
||||||
let channels: Vec<String> = settings.get("channels")
|
let channels: Vec<String> = settings.get("channels").wrap_err("No channels provided.")?;
|
||||||
.wrap_err("No channels provided.")?;
|
|
||||||
|
|
||||||
event!(Level::INFO, "Channels = {:?}", channels);
|
event!(Level::INFO, "Channels = {:?}", channels);
|
||||||
|
|
||||||
@@ -50,40 +43,80 @@ pub async fn new(settings: &MainConfig, handle: &LLMHandle) -> Result<Chat> {
|
|||||||
event!(Level::INFO, "IRC connection starting...");
|
event!(Level::INFO, "IRC connection starting...");
|
||||||
|
|
||||||
Ok(Chat {
|
Ok(Chat {
|
||||||
client: IRCClient::from_config(config).await?,
|
client: Client::from_config(config).await?,
|
||||||
llm_handle: handle.clone(),
|
llm_handle: handle.clone(),
|
||||||
|
event_manager: manager,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Chat {
|
impl Chat {
|
||||||
pub async fn run(&mut self) -> Result<()> {
|
pub async fn run(&mut self, mut command_in: mpsc::Receiver<plugin::Plugin>) -> Result<()> {
|
||||||
let client = &mut self.client;
|
self.client.identify()?;
|
||||||
|
|
||||||
client.identify()?;
|
let mut stream = self.client.stream()?;
|
||||||
|
|
||||||
let outgoing = client
|
loop {
|
||||||
.outgoing()
|
tokio::select! {
|
||||||
.ok_or_eyre("Couldn't get outgoing irc sink.")?;
|
message = stream.next() => {
|
||||||
let mut stream = client.stream()?;
|
match message {
|
||||||
|
Some(Ok(msg)) => {
|
||||||
|
self.handle_chat_message(&msg).await?;
|
||||||
|
}
|
||||||
|
Some(Err(e)) => return Err(e.into()),
|
||||||
|
None => break, // disconnected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
command = command_in.recv() => {
|
||||||
|
event!(Level::INFO, "Received command {:#?}", command);
|
||||||
|
match command {
|
||||||
|
Some(plugin::Plugin::SendMessage {channel, message} ) => {
|
||||||
|
// Now to pass on the message.
|
||||||
|
event!(Level::INFO, "Trying to send to channel.");
|
||||||
|
self.client.send_privmsg(&channel, &message).wrap_err("Couldn't send to channel")?;
|
||||||
|
event!(Level::INFO, "Message sent successfully.");
|
||||||
|
|
||||||
tokio::spawn(async move {
|
}
|
||||||
if let Err(e) = outgoing.await {
|
None => {
|
||||||
event!(Level::ERROR, "Failed to drive output: {}", e);
|
event!(Level::ERROR,
|
||||||
}
|
"Command channel unexpectedly closed - \
|
||||||
});
|
FIFO reader may have crashed");
|
||||||
|
break;
|
||||||
while let Some(message) = stream.next().await.transpose()? {
|
}
|
||||||
if let Command::PRIVMSG(channel, message) = message.command {
|
}
|
||||||
if message.starts_with("!gem") {
|
|
||||||
let msg = self.llm_handle.send_request(message).await?;
|
|
||||||
event!(Level::INFO, "Message received.");
|
|
||||||
client
|
|
||||||
.send_privmsg(channel, msg)
|
|
||||||
.wrap_err("Couldn't send response to channel.")?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_chat_message(&mut self, message: &Message) -> Result<()> {
|
||||||
|
// Broadcast anything here. If it should not be broadcasted then
|
||||||
|
// TryFrom should fail.
|
||||||
|
if let Ok(event) = Event::try_from(message) {
|
||||||
|
self.event_manager.broadcast(&event).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only handle PRIVMSG for now.
|
||||||
|
if let Command::PRIVMSG(channel, msg) = &message.command {
|
||||||
|
// Just preserve the original behavior for now.
|
||||||
|
if msg.starts_with("!gem") {
|
||||||
|
let mut llm_response = self.llm_handle.send_request(msg).await?;
|
||||||
|
|
||||||
|
event!(Level::INFO, "Asked: {message}");
|
||||||
|
event!(Level::INFO, "Response: {llm_response}");
|
||||||
|
|
||||||
|
// Keep responses to one line for now.
|
||||||
|
llm_response.retain(|c| c != '\n' && c != '\r');
|
||||||
|
|
||||||
|
// TODO: Make this configurable.
|
||||||
|
llm_response.truncate(500);
|
||||||
|
|
||||||
|
event!(Level::INFO, "Sending {llm_response} to channel {channel}");
|
||||||
|
self.client.send_privmsg(channel, llm_response)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
183
src/command.rs
Normal file
183
src/command.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
// Commands that are associated with external processes (commands).
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use color_eyre::{Result, eyre::eyre};
|
||||||
|
use tokio::{fs::try_exists, process::Command, time::timeout};
|
||||||
|
use tracing::{Level, event};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CommandDir {
|
||||||
|
command_path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CommandDir {
|
||||||
|
pub fn new(command_path: impl AsRef<Path>) -> Self {
|
||||||
|
event!(
|
||||||
|
Level::INFO,
|
||||||
|
"CommandDir initialized with path: {:?}",
|
||||||
|
command_path.as_ref()
|
||||||
|
);
|
||||||
|
CommandDir {
|
||||||
|
command_path: command_path.as_ref().to_path_buf(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn find_command(&self, name: impl AsRef<Path>) -> Result<String> {
|
||||||
|
let path = self.command_path.join(name.as_ref());
|
||||||
|
|
||||||
|
event!(
|
||||||
|
Level::INFO,
|
||||||
|
"Looking for {} command.",
|
||||||
|
name.as_ref().display()
|
||||||
|
);
|
||||||
|
|
||||||
|
match try_exists(&path).await {
|
||||||
|
Ok(true) => Ok(path.to_string_lossy().to_string()),
|
||||||
|
Ok(false) => Err(eyre!(format!("{} Not found.", path.to_string_lossy()))),
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_command(
|
||||||
|
&self,
|
||||||
|
command_name: impl AsRef<str>,
|
||||||
|
input: impl AsRef<str>,
|
||||||
|
) -> Result<Bytes> {
|
||||||
|
let path = self.find_command(Path::new(command_name.as_ref())).await?;
|
||||||
|
// Well it exists let's cross our fingers...
|
||||||
|
let output = Command::new(path).arg(input.as_ref()).output().await?;
|
||||||
|
|
||||||
|
if output.status.success() {
|
||||||
|
// So far so good
|
||||||
|
Ok(Bytes::from(output.stdout))
|
||||||
|
} else {
|
||||||
|
// Whoops
|
||||||
|
Err(eyre!(format!(
|
||||||
|
"Error running {}: {}",
|
||||||
|
command_name.as_ref(),
|
||||||
|
output.status
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_command_with_timeout(
|
||||||
|
&self,
|
||||||
|
command_name: impl AsRef<str>,
|
||||||
|
input: impl AsRef<str>,
|
||||||
|
time_out: Duration,
|
||||||
|
) -> Result<Bytes> {
|
||||||
|
timeout(time_out, self.run_command(command_name, input)).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::{
|
||||||
|
fs::{self, Permissions},
|
||||||
|
os::unix::fs::PermissionsExt,
|
||||||
|
};
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn create_test_script(dir: &Path, name: &str, script: &str) -> PathBuf {
|
||||||
|
let path = dir.join(name);
|
||||||
|
fs::write(&path, script).unwrap();
|
||||||
|
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
|
||||||
|
path
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_command_dir_new() {
|
||||||
|
let dir = CommandDir::new("/some/path");
|
||||||
|
assert_eq!(dir.command_path, PathBuf::from("/some/path"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_find_command_exists() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "test_cmd", "#!/bin/bash\necho hello");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir.find_command("test_cmd").await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert!(result.unwrap().contains("test_cmd"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_find_command_not_found() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
let result = cmd_dir.find_command("nonexistent").await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_success() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "echo_cmd", "#!/bin/bash\necho \"$1\"");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir.run_command("echo_cmd", "hello world").await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let output = result.unwrap();
|
||||||
|
assert_eq!(output.as_ref(), b"hello world\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_failure() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "fail_cmd", "#!/bin/bash\nexit 1");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir.run_command("fail_cmd", "input").await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Error running"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_not_found() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command("missing", "input").await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_with_timeout_success() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "fast_cmd", "#!/bin/bash\necho \"$1\"");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout("fast_cmd", "quick", Duration::from_secs(5))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_command_with_timeout_expires() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
create_test_script(temp.path(), "slow_cmd", "#!/bin/bash\nsleep 10\necho done");
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout("slow_cmd", "input", Duration::from_millis(100))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
32
src/event.rs
Normal file
32
src/event.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use irc::proto::{Command, Message};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
pub struct Event {
|
||||||
|
from: String,
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Event {
|
||||||
|
pub fn new(from: impl Into<String>, msg: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
from: from.into(),
|
||||||
|
message: msg.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&Message> for Event {
|
||||||
|
type Error = &'static str;
|
||||||
|
|
||||||
|
fn try_from(value: &Message) -> Result<Self, Self::Error> {
|
||||||
|
let from = value.response_target().unwrap_or("unknown").to_string();
|
||||||
|
match &value.command {
|
||||||
|
Command::PRIVMSG(_channel, message) => Ok(Event {
|
||||||
|
from,
|
||||||
|
message: message.clone(),
|
||||||
|
}),
|
||||||
|
_ => Err("Not a PRIVMSG"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
575
src/event_manager.rs
Normal file
575
src/event_manager.rs
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
use std::{collections::VecDeque, path::Path, sync::Arc};
|
||||||
|
|
||||||
|
use color_eyre::Result;
|
||||||
|
use nix::{NixPath, sys::stat, unistd::mkfifo};
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||||
|
net::{UnixListener, UnixStream, unix::pipe},
|
||||||
|
sync::{RwLock, broadcast, mpsc},
|
||||||
|
};
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
use crate::{event::Event, plugin::Plugin};
|
||||||
|
|
||||||
|
// Hard coding for now. Maybe make this a parameter to new.
|
||||||
|
const EVENT_BUF_MAX: usize = 1000;
|
||||||
|
|
||||||
|
// Manager for communication with plugins.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EventManager {
|
||||||
|
announce: broadcast::Sender<String>, // Everything broadcasts here.
|
||||||
|
events: Arc<RwLock<VecDeque<String>>>, // Ring buffer.
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventManager {
|
||||||
|
pub fn new() -> Result<Self> {
|
||||||
|
let (announce, _) = broadcast::channel(100);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
announce,
|
||||||
|
events: Arc::new(RwLock::new(VecDeque::<String>::new())),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn broadcast(&self, event: &Event) -> Result<()> {
|
||||||
|
let msg = serde_json::to_string(event)? + "\n";
|
||||||
|
|
||||||
|
let mut events = self.events.write().await;
|
||||||
|
|
||||||
|
if events.len() >= EVENT_BUF_MAX {
|
||||||
|
events.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back(msg.clone());
|
||||||
|
drop(events);
|
||||||
|
|
||||||
|
let _ = self.announce.send(msg);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NB: This assumes it has exclusive control of the FIFO.
|
||||||
|
pub async fn start_fifo<P>(path: &P, command_tx: mpsc::Sender<Plugin>) -> Result<()>
|
||||||
|
where
|
||||||
|
P: AsRef<Path> + NixPath + ?Sized,
|
||||||
|
{
|
||||||
|
// Overwrite, or create the FIFO.
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
mkfifo(path, stat::Mode::S_IRWXU)?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let rx = pipe::OpenOptions::new().open_receiver(path)?;
|
||||||
|
|
||||||
|
let mut reader = BufReader::new(rx);
|
||||||
|
let mut line = String::new();
|
||||||
|
|
||||||
|
while reader.read_line(&mut line).await? > 0 {
|
||||||
|
// Now handle the command.
|
||||||
|
let cmd: Plugin = serde_json::from_str(&line)?;
|
||||||
|
info!("Command received: {:?}.", cmd);
|
||||||
|
command_tx.send(cmd).await?;
|
||||||
|
line.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_listening(self: Arc<Self>, broadcast_path: impl AsRef<Path>) {
|
||||||
|
let listener = UnixListener::bind(broadcast_path).unwrap();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match listener.accept().await {
|
||||||
|
Ok((stream, _addr)) => {
|
||||||
|
info!("New broadcast subscriber");
|
||||||
|
// Spawn a new stream for the plugin. The loop
|
||||||
|
// runs recursively from there.
|
||||||
|
let broadcaster = Arc::clone(&self);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// send events.
|
||||||
|
let _ = broadcaster.send_events(stream).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => error!("Accept error: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_events(&self, stream: UnixStream) -> Result<()> {
|
||||||
|
let mut writer = stream;
|
||||||
|
|
||||||
|
// Take care of history.
|
||||||
|
let events = self.events.read().await;
|
||||||
|
for event in events.iter() {
|
||||||
|
writer.write_all(event.as_bytes()).await?;
|
||||||
|
}
|
||||||
|
drop(events);
|
||||||
|
|
||||||
|
// Now just broadcast the new events.
|
||||||
|
let mut rx = self.announce.subscribe();
|
||||||
|
while let Ok(event) = rx.recv().await {
|
||||||
|
if writer.write_all(event.as_bytes()).await.is_err() {
|
||||||
|
// *click*
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rstest::rstest;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_new_event_manager_has_empty_buffer() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_broadcast_adds_event_to_buffer() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let event = Event::new("test_user", "test message");
|
||||||
|
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
assert!(events[0].contains("test message"));
|
||||||
|
assert!(events[0].ends_with('\n'));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_broadcast_serializes_event_as_json() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let event = Event::new("test_user", "hello world");
|
||||||
|
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
let stored = &events[0];
|
||||||
|
|
||||||
|
// Should be valid JSON
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(stored.trim()).unwrap();
|
||||||
|
assert_eq!(parsed["message"], "hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case(1)]
|
||||||
|
#[case(10)]
|
||||||
|
#[case(100)]
|
||||||
|
#[case(999)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_holds_events_below_max(#[case] count: usize) {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
|
||||||
|
for i in 0..count {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), count);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_at_exactly_max_capacity() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
|
||||||
|
// Fill to exactly EVENT_BUF_MAX (1000)
|
||||||
|
for i in 0..EVENT_BUF_MAX {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||||
|
assert!(events[0].contains("event 0"));
|
||||||
|
assert!(events[EVENT_BUF_MAX - 1].contains("event 999"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case(1)]
|
||||||
|
#[case(10)]
|
||||||
|
#[case(100)]
|
||||||
|
#[case(500)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_overflow_evicts_oldest_fifo(#[case] overflow: usize) {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let total = EVENT_BUF_MAX + overflow;
|
||||||
|
|
||||||
|
// Broadcast more events than buffer can hold
|
||||||
|
for i in 0..total {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
|
||||||
|
// Buffer should still be at max capacity
|
||||||
|
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||||
|
|
||||||
|
// Oldest events (0 through overflow-1) should be evicted
|
||||||
|
// Buffer should contain events [overflow..total)
|
||||||
|
let first_event = &events[0];
|
||||||
|
let last_event = &events[EVENT_BUF_MAX - 1];
|
||||||
|
|
||||||
|
assert!(first_event.contains(&format!("event {}", overflow)));
|
||||||
|
assert!(last_event.contains(&format!("event {}", total - 1)));
|
||||||
|
|
||||||
|
// Verify the evicted events are NOT in the buffer
|
||||||
|
let buffer_string = events.iter().cloned().collect::<String>();
|
||||||
|
assert!(!buffer_string.contains(r#""message":"event 0""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_broadcasts_maintain_order() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
let messages = vec!["first", "second", "third", "fourth", "fifth"];
|
||||||
|
|
||||||
|
for msg in &messages {
|
||||||
|
let event = Event::new("test_user", *msg);
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), messages.len());
|
||||||
|
|
||||||
|
for (i, expected) in messages.iter().enumerate() {
|
||||||
|
assert!(events[i].contains(expected));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_wraparound_maintains_newest_events() {
|
||||||
|
let manager = EventManager::new().unwrap();
|
||||||
|
|
||||||
|
// Fill buffer completely
|
||||||
|
for i in 0..EVENT_BUF_MAX {
|
||||||
|
let event = Event::new("test_user", format!("old {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 5 more events
|
||||||
|
for i in 0..5 {
|
||||||
|
let event = Event::new("test_user", format!("new {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), EVENT_BUF_MAX);
|
||||||
|
|
||||||
|
// First 5 old events should be gone
|
||||||
|
let buffer_string = events.iter().cloned().collect::<String>();
|
||||||
|
assert!(!buffer_string.contains(r#""message":"old 0""#));
|
||||||
|
assert!(!buffer_string.contains(r#""message":"old 4""#));
|
||||||
|
|
||||||
|
// But old 5 should still be there (now at the front)
|
||||||
|
assert!(events[0].contains("old 5"));
|
||||||
|
|
||||||
|
// New events should be at the end
|
||||||
|
assert!(events[EVENT_BUF_MAX - 5].contains("new 0"));
|
||||||
|
assert!(events[EVENT_BUF_MAX - 1].contains("new 4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_concurrent_broadcasts_all_stored() {
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
// Spawn 10 concurrent tasks, each broadcasting 10 events
|
||||||
|
for task_id in 0..10 {
|
||||||
|
let manager_clone = Arc::clone(&manager);
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
for i in 0..10 {
|
||||||
|
let event = Event::new("test_user", format!("task {} event {}", task_id, i));
|
||||||
|
manager_clone.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all tasks to complete
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = manager.events.read().await;
|
||||||
|
assert_eq!(events.len(), 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_receives_and_forwards_single_command() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = EventManager::start_fifo(&path, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Write a command to the FIFO
|
||||||
|
let cmd = Plugin::SendMessage {
|
||||||
|
channel: "#test".to_string(),
|
||||||
|
message: "hello".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
|
||||||
|
// Open FIFO for writing and write the command
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should receive the command within a reasonable time
|
||||||
|
let received = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout waiting for command")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match received {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#test");
|
||||||
|
assert_eq!(message, "hello");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_handles_multiple_commands() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = EventManager::start_fifo(&path, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Write multiple commands
|
||||||
|
let commands = vec![
|
||||||
|
Plugin::SendMessage {
|
||||||
|
channel: "#chan1".to_string(),
|
||||||
|
message: "first".to_string(),
|
||||||
|
},
|
||||||
|
Plugin::SendMessage {
|
||||||
|
channel: "#chan2".to_string(),
|
||||||
|
message: "second".to_string(),
|
||||||
|
},
|
||||||
|
Plugin::SendMessage {
|
||||||
|
channel: "#chan3".to_string(),
|
||||||
|
message: "third".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
for cmd in commands {
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
}
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Receive all three commands in order
|
||||||
|
let first = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on first")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match first {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#chan1");
|
||||||
|
assert_eq!(message, "first");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let second = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on second")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match second {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#chan2");
|
||||||
|
assert_eq!(message, "second");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let third = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on third")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match third {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#chan3");
|
||||||
|
assert_eq!(message, "third");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_reopens_after_writer_closes() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = EventManager::start_fifo(&path, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// First writer sends a command and closes
|
||||||
|
{
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
let cmd = Plugin::SendMessage {
|
||||||
|
channel: "#first".to_string(),
|
||||||
|
message: "batch1".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
// Writer drops here, closing the FIFO
|
||||||
|
});
|
||||||
|
|
||||||
|
let first = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on first batch")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match first {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#first");
|
||||||
|
assert_eq!(message, "batch1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give the FIFO time to reopen
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
|
// Second writer opens and sends a command
|
||||||
|
{
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
let cmd = Plugin::SendMessage {
|
||||||
|
channel: "#second".to_string(),
|
||||||
|
message: "batch2".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&cmd).unwrap() + "\n";
|
||||||
|
tx.write_all(json.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let second = tokio::time::timeout(tokio::time::Duration::from_secs(1), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on second batch - FIFO may not have reopened")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match second {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#second");
|
||||||
|
assert_eq!(message, "batch2");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fifo_handles_empty_lines() {
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let fifo_path = temp_dir.path().join("test.fifo");
|
||||||
|
let (tx, mut rx) = mpsc::channel(10);
|
||||||
|
|
||||||
|
// Spawn the FIFO reader
|
||||||
|
let path = fifo_path.clone();
|
||||||
|
let handle = tokio::spawn(async move { EventManager::start_fifo(&path, tx).await });
|
||||||
|
|
||||||
|
// Give it time to create the FIFO
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Write command, empty line, whitespace, another command
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let tx = pipe::OpenOptions::new().open_sender(&fifo_path).unwrap();
|
||||||
|
let mut tx = tokio::io::BufWriter::new(tx);
|
||||||
|
|
||||||
|
let cmd1 = Plugin::SendMessage {
|
||||||
|
channel: "#test".to_string(),
|
||||||
|
message: "first".to_string(),
|
||||||
|
};
|
||||||
|
let json1 = serde_json::to_string(&cmd1).unwrap() + "\n";
|
||||||
|
tx.write_all(json1.as_bytes()).await.unwrap();
|
||||||
|
|
||||||
|
// Write empty line
|
||||||
|
tx.write_all(b"\n").await.unwrap();
|
||||||
|
|
||||||
|
// Write whitespace line
|
||||||
|
tx.write_all(b" \n").await.unwrap();
|
||||||
|
|
||||||
|
let cmd2 = Plugin::SendMessage {
|
||||||
|
channel: "#test".to_string(),
|
||||||
|
message: "second".to_string(),
|
||||||
|
};
|
||||||
|
let json2 = serde_json::to_string(&cmd2).unwrap() + "\n";
|
||||||
|
tx.write_all(json2.as_bytes()).await.unwrap();
|
||||||
|
tx.flush().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should receive first command
|
||||||
|
let first = tokio::time::timeout(tokio::time::Duration::from_millis(500), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout on first")
|
||||||
|
.expect("channel closed");
|
||||||
|
|
||||||
|
match first {
|
||||||
|
Plugin::SendMessage { channel, message } => {
|
||||||
|
assert_eq!(channel, "#test");
|
||||||
|
assert_eq!(message, "first");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The empty/whitespace lines should cause JSON parse errors
|
||||||
|
// which will cause start_fifo to error and exit
|
||||||
|
// So we expect the handle to complete (with an error)
|
||||||
|
let result = tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
|
||||||
|
.await
|
||||||
|
.expect("FIFO task should exit due to parse error");
|
||||||
|
|
||||||
|
// The task should have errored
|
||||||
|
assert!(
|
||||||
|
result.unwrap().is_err(),
|
||||||
|
"Expected parse error from empty line"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
92
src/lib.rs
Normal file
92
src/lib.rs
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
// Robotnik libraries
|
||||||
|
|
||||||
|
use std::{os::unix::fs, sync::Arc};
|
||||||
|
|
||||||
|
use color_eyre::{Result, eyre::WrapErr};
|
||||||
|
use human_panic::setup_panic;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tracing::{Level, info};
|
||||||
|
use tracing_subscriber::FmtSubscriber;
|
||||||
|
|
||||||
|
pub mod chat;
|
||||||
|
pub mod command;
|
||||||
|
pub mod event;
|
||||||
|
pub mod event_manager;
|
||||||
|
pub mod plugin;
|
||||||
|
pub mod qna;
|
||||||
|
pub mod setup;
|
||||||
|
|
||||||
|
pub use event::Event;
|
||||||
|
pub use event_manager::EventManager;
|
||||||
|
pub use qna::LLMHandle;
|
||||||
|
|
||||||
|
const DEFAULT_INSTRUCT: &str =
|
||||||
|
"You are a shady, yet helpful IRC bot. You try to give responses that can
|
||||||
|
be sent in a single IRC response according to the specification. Keep answers to
|
||||||
|
500 characters or less.";
|
||||||
|
|
||||||
|
// NB: Everything should fail if logging doesn't start properly.
|
||||||
|
async fn init_logging() {
|
||||||
|
better_panic::install();
|
||||||
|
setup_panic!();
|
||||||
|
|
||||||
|
let subscriber = FmtSubscriber::builder()
|
||||||
|
.with_max_level(Level::TRACE)
|
||||||
|
.finish();
|
||||||
|
|
||||||
|
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run() -> Result<()> {
|
||||||
|
init_logging().await;
|
||||||
|
info!("Starting up.");
|
||||||
|
|
||||||
|
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
|
||||||
|
let config = settings.config;
|
||||||
|
|
||||||
|
// NOTE: Doing chroot this way might be impractical.
|
||||||
|
if let Ok(chroot_path) = config.get_string("chroot-dir") {
|
||||||
|
info!("Attempting to chroot to {}", chroot_path);
|
||||||
|
fs::chroot(&chroot_path)
|
||||||
|
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path))?;
|
||||||
|
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let handle = qna::LLMHandle::new(
|
||||||
|
config.get_string("api-key").wrap_err("API missing.")?,
|
||||||
|
config
|
||||||
|
.get_string("base-url")
|
||||||
|
.wrap_err("base-url missing.")?,
|
||||||
|
config
|
||||||
|
.get_string("model")
|
||||||
|
.wrap_err("model string missing.")?,
|
||||||
|
config
|
||||||
|
.get_string("instruct")
|
||||||
|
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
|
||||||
|
)
|
||||||
|
.wrap_err("Couldn't initialize LLM handle.")?;
|
||||||
|
|
||||||
|
let ev_manager = Arc::new(EventManager::new()?);
|
||||||
|
let ev_manager_clone = Arc::clone(&ev_manager);
|
||||||
|
|
||||||
|
let mut c = chat::new(&config, &handle, Arc::clone(&ev_manager)).await?;
|
||||||
|
|
||||||
|
let (from_plugins, to_chat) = mpsc::channel(100);
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ev_manager_clone.start_listening("/tmp/robo.sock") => {
|
||||||
|
// Event listener ended
|
||||||
|
}
|
||||||
|
result = c.run(to_chat) => {
|
||||||
|
if let Err(e) = result {
|
||||||
|
tracing::error!("Chat run error: {:?}", e);
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fifo = EventManager::start_fifo("/tmp/robo_in.sock", from_plugins) => {
|
||||||
|
fifo.wrap_err("FIFO reader failed.")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
63
src/main.rs
63
src/main.rs
@@ -1,65 +1,6 @@
|
|||||||
use color_eyre::{
|
use color_eyre::Result;
|
||||||
Result,
|
|
||||||
eyre::WrapErr,
|
|
||||||
};
|
|
||||||
use human_panic::setup_panic;
|
|
||||||
use std::os::unix::fs;
|
|
||||||
use tracing::{
|
|
||||||
Level,
|
|
||||||
info,
|
|
||||||
};
|
|
||||||
use tracing_subscriber::FmtSubscriber;
|
|
||||||
|
|
||||||
mod chat;
|
|
||||||
mod commands;
|
|
||||||
mod qna;
|
|
||||||
mod setup;
|
|
||||||
|
|
||||||
const DEFAULT_INSTRUCT: &str =
|
|
||||||
"You are a shady, yet helpful IRC bot. You try to give responses that can
|
|
||||||
be sent in a single IRC response according to the specification.";
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
// Some error sprucing.
|
robotnik::run().await
|
||||||
better_panic::install();
|
|
||||||
setup_panic!();
|
|
||||||
|
|
||||||
let subscriber = FmtSubscriber::builder()
|
|
||||||
.with_max_level(Level::TRACE)
|
|
||||||
.finish();
|
|
||||||
|
|
||||||
tracing::subscriber::set_global_default(subscriber)
|
|
||||||
.wrap_err("Failed to setup trace logging.")?;
|
|
||||||
|
|
||||||
info!("Starting");
|
|
||||||
|
|
||||||
let settings = setup::init().await.wrap_err("Failed to initialize.")?;
|
|
||||||
let config = settings.config;
|
|
||||||
|
|
||||||
// chroot if applicable.
|
|
||||||
if let Ok(chroot_path) = config.get_string("chroot-dir") {
|
|
||||||
fs::chroot(&chroot_path)
|
|
||||||
.wrap_err_with(|| format!("Failed setting chroot '{}'", chroot_path.to_string()))?;
|
|
||||||
std::env::set_current_dir("/").wrap_err("Couldn't change directory after chroot.")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let handle = qna::new(
|
|
||||||
config.get_string("api-key").wrap_err("API missing.")?,
|
|
||||||
config
|
|
||||||
.get_string("base-url")
|
|
||||||
.wrap_err("base-url missing.")?,
|
|
||||||
config
|
|
||||||
.get_string("model")
|
|
||||||
.wrap_err("model string missing.")?,
|
|
||||||
config
|
|
||||||
.get_string("instruct")
|
|
||||||
.unwrap_or_else(|_| DEFAULT_INSTRUCT.to_string()),
|
|
||||||
)
|
|
||||||
.wrap_err("Couldn't initialize LLM handle.")?;
|
|
||||||
let mut c = chat::new(&config, &handle).await?;
|
|
||||||
|
|
||||||
c.run().await.unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|||||||
18
src/plugin.rs
Normal file
18
src/plugin.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub enum Plugin {
|
||||||
|
SendMessage { channel: String, message: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Plugin {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::SendMessage { channel, message } => {
|
||||||
|
write!(f, "[{channel}]: {message}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
64
src/qna.rs
64
src/qna.rs
@@ -3,16 +3,8 @@ use futures::StreamExt;
|
|||||||
use genai::{
|
use genai::{
|
||||||
Client,
|
Client,
|
||||||
ModelIden,
|
ModelIden,
|
||||||
chat::{
|
chat::{ChatMessage, ChatRequest, ChatStreamEvent, StreamChunk},
|
||||||
ChatMessage,
|
resolver::{AuthData, AuthResolver},
|
||||||
ChatRequest,
|
|
||||||
ChatStreamEvent,
|
|
||||||
StreamChunk,
|
|
||||||
},
|
|
||||||
resolver::{
|
|
||||||
AuthData,
|
|
||||||
AuthResolver,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
@@ -25,33 +17,33 @@ pub struct LLMHandle {
|
|||||||
model: String,
|
model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(
|
|
||||||
api_key: String,
|
|
||||||
_base_url: impl AsRef<str>,
|
|
||||||
model: impl Into<String>,
|
|
||||||
system_role: String,
|
|
||||||
) -> Result<LLMHandle> {
|
|
||||||
let auth_resolver = AuthResolver::from_resolver_fn(
|
|
||||||
|_model_iden: ModelIden| -> Result<Option<AuthData>, genai::resolver::Error> {
|
|
||||||
// let ModelIden { adapter_kind, model_name } = model_iden;
|
|
||||||
|
|
||||||
Ok(Some(AuthData::from_single(api_key)))
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let client = Client::builder().with_auth_resolver(auth_resolver).build();
|
|
||||||
let chat_request = ChatRequest::default().with_system(system_role);
|
|
||||||
|
|
||||||
info!("New LLMHandle created.");
|
|
||||||
|
|
||||||
Ok(LLMHandle {
|
|
||||||
client,
|
|
||||||
chat_request,
|
|
||||||
model: model.into(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LLMHandle {
|
impl LLMHandle {
|
||||||
|
pub fn new(
|
||||||
|
api_key: String,
|
||||||
|
_base_url: impl AsRef<str>,
|
||||||
|
model: impl Into<String>,
|
||||||
|
system_role: String,
|
||||||
|
) -> Result<LLMHandle> {
|
||||||
|
let auth_resolver = AuthResolver::from_resolver_fn(
|
||||||
|
|_model_iden: ModelIden| -> Result<Option<AuthData>, genai::resolver::Error> {
|
||||||
|
// let ModelIden { adapter_kind, model_name } = model_iden;
|
||||||
|
|
||||||
|
Ok(Some(AuthData::from_single(api_key)))
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = Client::builder().with_auth_resolver(auth_resolver).build();
|
||||||
|
let chat_request = ChatRequest::default().with_system(system_role);
|
||||||
|
|
||||||
|
info!("New LLMHandle created.");
|
||||||
|
|
||||||
|
Ok(LLMHandle {
|
||||||
|
client,
|
||||||
|
chat_request,
|
||||||
|
model: model.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
pub async fn send_request(&mut self, message: impl Into<String>) -> Result<String> {
|
||||||
let mut req = self.chat_request.clone();
|
let mut req = self.chat_request.clone();
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
|||||||
46
src/setup.rs
46
src/setup.rs
@@ -1,78 +1,72 @@
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use color_eyre::{
|
use color_eyre::{Result, eyre::WrapErr};
|
||||||
Result,
|
|
||||||
eyre::WrapErr,
|
|
||||||
};
|
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use directories::ProjectDirs;
|
use directories::ProjectDirs;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tracing::{
|
use tracing::{info, instrument};
|
||||||
info,
|
|
||||||
instrument,
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: use [clap(long, short, help_heading = Some(section))]
|
// TODO: use [clap(long, short, help_heading = Some(section))]
|
||||||
#[derive(Clone, Debug, Parser)]
|
#[derive(Clone, Debug, Parser)]
|
||||||
#[command(about, version)]
|
#[command(about, version)]
|
||||||
pub(crate) struct Args {
|
pub struct Args {
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
/// API Key for the LLM in use.
|
/// API Key for the LLM in use.
|
||||||
pub(crate) api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
|
|
||||||
#[arg(short, long, default_value = "https://api.openai.com")]
|
#[arg(short, long, default_value = "https://api.openai.com")]
|
||||||
/// Base URL for the LLM API to use.
|
/// Base URL for the LLM API to use.
|
||||||
pub(crate) base_url: Option<String>,
|
pub base_url: Option<String>,
|
||||||
|
|
||||||
/// Directory to use for chroot (recommended).
|
/// Directory to use for chroot (recommended).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub(crate) chroot_dir: Option<String>,
|
pub chroot_dir: Option<String>,
|
||||||
|
|
||||||
/// Root directory for file based command structure.
|
/// Root directory for file based command structure.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub(crate) command_dir: Option<String>,
|
pub command_dir: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// Instructions to the model on how to behave.
|
/// Instructions to the model on how to behave.
|
||||||
pub(crate) intruct: Option<String>,
|
pub instruct: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub(crate) model: Option<String>,
|
pub model: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// List of IRC channels to join.
|
/// List of IRC channels to join.
|
||||||
pub(crate) channels: Option<Vec<String>>,
|
pub channels: Option<Vec<String>>,
|
||||||
|
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
/// Custom configuration file location if need be.
|
/// Custom configuration file location if need be.
|
||||||
pub(crate) config_file: Option<PathBuf>,
|
pub config_file: Option<PathBuf>,
|
||||||
|
|
||||||
#[arg(short, long, default_value = "irc.libera.chat")]
|
#[arg(short, long, default_value = "irc.libera.chat")]
|
||||||
/// IRC server.
|
/// IRC server.
|
||||||
pub(crate) server: Option<String>,
|
pub server: Option<String>,
|
||||||
|
|
||||||
#[arg(short, long, default_value = "6697")]
|
#[arg(short, long, default_value = "6697")]
|
||||||
/// Port of the IRC server.
|
/// Port of the IRC server.
|
||||||
pub(crate) port: Option<String>,
|
pub port: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// IRC Nickname.
|
/// IRC Nickname.
|
||||||
pub(crate) nickname: Option<String>,
|
pub nickname: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// IRC Nick Password
|
/// IRC Nick Password
|
||||||
pub(crate) nick_password: Option<String>,
|
pub nick_password: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// IRC Username
|
/// IRC Username
|
||||||
pub(crate) username: Option<String>,
|
pub username: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
/// Whether or not to use TLS when connecting to the IRC server.
|
/// Whether or not to use TLS when connecting to the IRC server.
|
||||||
pub(crate) use_tls: Option<bool>,
|
pub use_tls: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Setup {
|
pub struct Setup {
|
||||||
pub(crate) config: Config,
|
pub config: Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
#[instrument]
|
||||||
@@ -104,7 +98,7 @@ pub async fn init() -> Result<Setup> {
|
|||||||
.set_override_option("chroot-dir", args.chroot_dir.clone())?
|
.set_override_option("chroot-dir", args.chroot_dir.clone())?
|
||||||
.set_override_option("command-path", args.command_dir.clone())?
|
.set_override_option("command-path", args.command_dir.clone())?
|
||||||
.set_override_option("model", args.model.clone())?
|
.set_override_option("model", args.model.clone())?
|
||||||
.set_override_option("instruct", args.model.clone())?
|
.set_override_option("instruct", args.instruct.clone())?
|
||||||
.set_override_option("channels", args.channels.clone())?
|
.set_override_option("channels", args.channels.clone())?
|
||||||
.set_override_option("server", args.server.clone())?
|
.set_override_option("server", args.server.clone())?
|
||||||
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap.
|
.set_override_option("port", args.port.clone())? // FIXME: Make this a default here not in clap.
|
||||||
|
|||||||
290
tests/command_test.rs
Normal file
290
tests/command_test.rs
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
use std::{
|
||||||
|
fs::{self, Permissions},
|
||||||
|
os::unix::fs::PermissionsExt,
|
||||||
|
path::Path,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use robotnik::command::CommandDir;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
/// Helper to create executable test scripts
|
||||||
|
fn create_command(dir: &Path, name: &str, script: &str) {
|
||||||
|
let path = dir.join(name);
|
||||||
|
fs::write(&path, script).unwrap();
|
||||||
|
fs::set_permissions(&path, Permissions::from_mode(0o755)).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a bot message like "!weather 73135" into (command_name, argument)
|
||||||
|
fn parse_bot_message(message: &str) -> Option<(&str, &str)> {
|
||||||
|
if !message.starts_with('!') {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let without_prefix = &message[1..];
|
||||||
|
let mut parts = without_prefix.splitn(2, ' ');
|
||||||
|
let command = parts.next()?;
|
||||||
|
let arg = parts.next().unwrap_or("");
|
||||||
|
Some((command, arg))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_finds_and_runs_command() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a weather command that echoes the zip code
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"weather",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Weather for $1: Sunny, 72°F"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!weather 73135";
|
||||||
|
|
||||||
|
// Parse the message
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
assert_eq!(command_name, "weather");
|
||||||
|
assert_eq!(arg, "73135");
|
||||||
|
|
||||||
|
// Find and run the command
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(output.contains("Weather for 73135"));
|
||||||
|
assert!(output.contains("Sunny"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_command_not_found() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
let message = "!nonexistent arg";
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Not found"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_with_multiple_arguments() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a command that handles multiple words as a single argument
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"echo",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "You said: $1"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!echo hello world how are you";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
assert_eq!(command_name, "echo");
|
||||||
|
assert_eq!(arg, "hello world how are you");
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(output.contains("hello world how are you"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_without_argument() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"help",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Available commands: weather, echo, help"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!help";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
assert_eq!(command_name, "help");
|
||||||
|
assert_eq!(arg, "");
|
||||||
|
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(output.contains("Available commands"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_command_returns_error_exit_code() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a command that fails for invalid input
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"validate",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
if [ -z "$1" ]; then
|
||||||
|
echo "Error: Input required" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Valid: $1"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!validate";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("Error running"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_with_timeout() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"quick",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Result: $1"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!quick test";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout(command_name, arg, Duration::from_secs(5))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bot_message_command_times_out() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"slow",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
sleep 10
|
||||||
|
echo "Done"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!slow arg";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir
|
||||||
|
.run_command_with_timeout(command_name, arg, Duration::from_millis(100))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_commands_in_directory() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"weather",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Weather: Sunny"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"time",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Time: 12:00"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"joke",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
echo "Why did the robot go on vacation? To recharge!"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
|
||||||
|
// Test each command
|
||||||
|
let messages = ["!weather", "!time", "!joke"];
|
||||||
|
let expected = ["Sunny", "12:00", "recharge"];
|
||||||
|
|
||||||
|
for (message, expect) in messages.iter().zip(expected.iter()) {
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let bytes = result.unwrap();
|
||||||
|
let output = String::from_utf8_lossy(&bytes);
|
||||||
|
assert!(
|
||||||
|
output.contains(expect),
|
||||||
|
"Expected '{}' in '{}'",
|
||||||
|
expect,
|
||||||
|
output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_non_bot_message_ignored() {
|
||||||
|
// Messages not starting with ! should be ignored
|
||||||
|
let messages = ["hello world", "weather 73135", "?help", "/command", ""];
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
assert!(
|
||||||
|
parse_bot_message(message).is_none(),
|
||||||
|
"Should ignore: {}",
|
||||||
|
message
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_command_output_is_bytes() {
|
||||||
|
let temp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Create a command that outputs binary-safe content
|
||||||
|
create_command(
|
||||||
|
temp.path(),
|
||||||
|
"binary",
|
||||||
|
r#"#!/bin/bash
|
||||||
|
printf "Hello\x00World"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cmd_dir = CommandDir::new(temp.path());
|
||||||
|
let message = "!binary test";
|
||||||
|
|
||||||
|
let (command_name, arg) = parse_bot_message(message).unwrap();
|
||||||
|
let result = cmd_dir.run_command(command_name, arg).await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let output = result.unwrap();
|
||||||
|
// Should preserve the null byte
|
||||||
|
assert_eq!(&output[..], b"Hello\x00World");
|
||||||
|
}
|
||||||
495
tests/event_test.rs
Normal file
495
tests/event_test.rs
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
use robotnik::{event::Event, event_manager::EventManager};
|
||||||
|
use rstest::rstest;
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncBufReadExt, BufReader},
|
||||||
|
net::UnixStream,
|
||||||
|
time::timeout,
|
||||||
|
};
|
||||||
|
|
||||||
|
const TEST_SOCKET_BASE: &str = "/tmp/robotnik_test";
|
||||||
|
|
||||||
|
/// Helper to create unique socket paths for parallel tests
|
||||||
|
fn test_socket_path(name: &str) -> String {
|
||||||
|
format!("{}_{}_{}", TEST_SOCKET_BASE, name, std::process::id())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to read one JSON event from a stream
|
||||||
|
async fn read_event(
|
||||||
|
reader: &mut BufReader<UnixStream>,
|
||||||
|
) -> Result<Event, Box<dyn std::error::Error>> {
|
||||||
|
let mut line = String::new();
|
||||||
|
reader.read_line(&mut line).await?;
|
||||||
|
let event: Event = serde_json::from_str(&line)?;
|
||||||
|
Ok(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to read all available events with a timeout
|
||||||
|
async fn read_events_with_timeout(
|
||||||
|
reader: &mut BufReader<UnixStream>,
|
||||||
|
max_count: usize,
|
||||||
|
timeout_ms: u64,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let mut events = Vec::new();
|
||||||
|
for _ in 0..max_count {
|
||||||
|
let mut line = String::new();
|
||||||
|
match timeout(
|
||||||
|
Duration::from_millis(timeout_ms),
|
||||||
|
reader.read_line(&mut line),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Ok(0)) => break, // EOF
|
||||||
|
Ok(Ok(_)) => events.push(line),
|
||||||
|
Ok(Err(_)) => break, // Read error
|
||||||
|
Err(_) => break, // Timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
events
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_connects_and_receives_event() {
|
||||||
|
let socket_path = test_socket_path("basic_connect");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give the listener time to start
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Broadcast an event
|
||||||
|
let event = Event::new("test_user", "test message");
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
// Connect as a client
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Read the event
|
||||||
|
let mut line = String::new();
|
||||||
|
reader.read_line(&mut line).await.unwrap();
|
||||||
|
|
||||||
|
assert!(line.contains("test message"));
|
||||||
|
assert!(line.ends_with('\n'));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_receives_event_history() {
|
||||||
|
let socket_path = test_socket_path("event_history");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Broadcast events BEFORE starting the listener
|
||||||
|
for i in 0..5 {
|
||||||
|
let event = Event::new("test_user", format!("historical event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect as a client
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Should receive all 5 historical events
|
||||||
|
let events = read_events_with_timeout(&mut reader, 5, 100).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 5);
|
||||||
|
assert!(events[0].contains("historical event 0"));
|
||||||
|
assert!(events[4].contains("historical event 4"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_clients_receive_same_events() {
|
||||||
|
let socket_path = test_socket_path("multiple_clients");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect 3 clients
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
let mut reader3 = BufReader::new(stream3);
|
||||||
|
|
||||||
|
// Broadcast a new event
|
||||||
|
let event = Event::new("test_user", "broadcast to all");
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
|
||||||
|
// All clients should receive the event
|
||||||
|
let mut line1 = String::new();
|
||||||
|
let mut line2 = String::new();
|
||||||
|
let mut line3 = String::new();
|
||||||
|
|
||||||
|
timeout(Duration::from_millis(100), reader1.read_line(&mut line1))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
timeout(Duration::from_millis(100), reader2.read_line(&mut line2))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
timeout(Duration::from_millis(100), reader3.read_line(&mut line3))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(line1.contains("broadcast to all"));
|
||||||
|
assert!(line2.contains("broadcast to all"));
|
||||||
|
assert!(line3.contains("broadcast to all"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_late_joiner_receives_full_history() {
|
||||||
|
let socket_path = test_socket_path("late_joiner");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// First client connects
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
|
||||||
|
// Broadcast several events
|
||||||
|
for i in 0..10 {
|
||||||
|
let event = Event::new("test_user", format!("event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume events from first client
|
||||||
|
let _ = read_events_with_timeout(&mut reader1, 10, 100).await;
|
||||||
|
|
||||||
|
// Late joiner connects
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
|
||||||
|
// Late joiner should receive all 10 events from history
|
||||||
|
let events = read_events_with_timeout(&mut reader2, 10, 100).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 10);
|
||||||
|
assert!(events[0].contains("event 0"));
|
||||||
|
assert!(events[9].contains("event 9"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_receives_events_in_order() {
|
||||||
|
let socket_path = test_socket_path("event_order");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect client
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Broadcast events rapidly
|
||||||
|
let count = 50;
|
||||||
|
for i in 0..count {
|
||||||
|
let event = Event::new("test_user", format!("sequence {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all events
|
||||||
|
let events = read_events_with_timeout(&mut reader, count, 500).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), count);
|
||||||
|
|
||||||
|
// Verify order
|
||||||
|
for (i, event) in events.iter().enumerate() {
|
||||||
|
assert!(
|
||||||
|
event.contains(&format!("sequence {}", i)),
|
||||||
|
"Event {} out of order: {}",
|
||||||
|
i,
|
||||||
|
event
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_concurrent_broadcasts_during_client_connections() {
|
||||||
|
let socket_path = test_socket_path("concurrent_ops");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect client 1 BEFORE any broadcasts
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
|
||||||
|
// Spawn a task that continuously broadcasts
|
||||||
|
let broadcast_manager = Arc::clone(&manager);
|
||||||
|
let broadcast_handle = tokio::spawn(async move {
|
||||||
|
for i in 0..100 {
|
||||||
|
let event = Event::new("test_user", format!("concurrent event {}", i));
|
||||||
|
broadcast_manager.broadcast(&event).await.unwrap();
|
||||||
|
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// While broadcasting, connect more clients at different times
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||||
|
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader3 = BufReader::new(stream3);
|
||||||
|
|
||||||
|
// Wait for broadcasts to complete
|
||||||
|
broadcast_handle.await.unwrap();
|
||||||
|
|
||||||
|
// All clients should have received events
|
||||||
|
let events1 = read_events_with_timeout(&mut reader1, 100, 200).await;
|
||||||
|
let events2 = read_events_with_timeout(&mut reader2, 100, 200).await;
|
||||||
|
let events3 = read_events_with_timeout(&mut reader3, 100, 200).await;
|
||||||
|
|
||||||
|
// Client 1 connected first (before any broadcasts), should get all 100
|
||||||
|
assert_eq!(events1.len(), 100);
|
||||||
|
|
||||||
|
// Client 2 connected after ~20 events were broadcast
|
||||||
|
// Gets ~20 from history + ~80 live = 100
|
||||||
|
assert_eq!(events2.len(), 100);
|
||||||
|
|
||||||
|
// Client 3 connected after ~50 events were broadcast
|
||||||
|
// Gets ~50 from history + ~50 live = 100
|
||||||
|
assert_eq!(events3.len(), 100);
|
||||||
|
|
||||||
|
// Verify they all received events in order
|
||||||
|
assert!(events1[0].contains("concurrent event 0"));
|
||||||
|
assert!(events2[0].contains("concurrent event 0"));
|
||||||
|
assert!(events3[0].contains("concurrent event 0"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_buffer_overflow_affects_new_clients() {
|
||||||
|
let socket_path = test_socket_path("buffer_overflow");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Broadcast more than buffer max (1000)
|
||||||
|
for i in 0..1100 {
|
||||||
|
let event = Event::new("test_user", format!("overflow event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// New client connects
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
// Should receive exactly 1000 events (buffer max)
|
||||||
|
let events = read_events_with_timeout(&mut reader, 1100, 500).await;
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 1000);
|
||||||
|
|
||||||
|
// First event should be 100 (oldest 100 were evicted)
|
||||||
|
assert!(events[0].contains("overflow event 100"));
|
||||||
|
|
||||||
|
// Last event should be 1099
|
||||||
|
assert!(events[999].contains("overflow event 1099"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case(10, 1)]
|
||||||
|
#[case(50, 5)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_count_scaling(#[case] num_clients: usize, #[case] events_per_client: usize) {
|
||||||
|
let socket_path = test_socket_path(&format!("scaling_{}_{}", num_clients, events_per_client));
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect many clients
|
||||||
|
let mut readers = Vec::new();
|
||||||
|
for _ in 0..num_clients {
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
readers.push(BufReader::new(stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast events
|
||||||
|
for i in 0..events_per_client {
|
||||||
|
let event = Event::new("test_user", format!("scale event {}", i));
|
||||||
|
manager.broadcast(&event).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all clients received all events
|
||||||
|
for reader in &mut readers {
|
||||||
|
let events = read_events_with_timeout(reader, events_per_client, 200).await;
|
||||||
|
assert_eq!(events.len(), events_per_client);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_disconnect_doesnt_affect_others() {
|
||||||
|
let socket_path = test_socket_path("disconnect");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Connect 3 clients
|
||||||
|
let stream1 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream2 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let stream3 = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
|
||||||
|
let mut reader1 = BufReader::new(stream1);
|
||||||
|
let mut reader2 = BufReader::new(stream2);
|
||||||
|
let mut reader3 = BufReader::new(stream3);
|
||||||
|
|
||||||
|
// Broadcast initial event
|
||||||
|
manager
|
||||||
|
.broadcast(&Event::new("test_user", "before disconnect"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// All receive it
|
||||||
|
let _ = read_events_with_timeout(&mut reader1, 1, 100).await;
|
||||||
|
let _ = read_events_with_timeout(&mut reader2, 1, 100).await;
|
||||||
|
let _ = read_events_with_timeout(&mut reader3, 1, 100).await;
|
||||||
|
|
||||||
|
// Drop client 2 (simulates disconnect)
|
||||||
|
drop(reader2);
|
||||||
|
|
||||||
|
// Broadcast another event
|
||||||
|
manager
|
||||||
|
.broadcast(&Event::new("test_user", "after disconnect"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Clients 1 and 3 should still receive it
|
||||||
|
let events1 = read_events_with_timeout(&mut reader1, 1, 100).await;
|
||||||
|
let events3 = read_events_with_timeout(&mut reader3, 1, 100).await;
|
||||||
|
|
||||||
|
assert_eq!(events1.len(), 1);
|
||||||
|
assert_eq!(events3.len(), 1);
|
||||||
|
assert!(events1[0].contains("after disconnect"));
|
||||||
|
assert!(events3[0].contains("after disconnect"));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_json_deserialization_of_received_events() {
|
||||||
|
let socket_path = test_socket_path("json_deser");
|
||||||
|
let manager = Arc::new(EventManager::new().unwrap());
|
||||||
|
|
||||||
|
// Start the listener
|
||||||
|
let listener_manager = Arc::clone(&manager);
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
listener_manager.start_listening(socket_path_clone).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
// Broadcast an event with special characters
|
||||||
|
let test_message = "special chars: @#$% newline\\n tab\\t quotes \"test\"";
|
||||||
|
manager
|
||||||
|
.broadcast(&Event::new("test_user", test_message))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Connect and deserialize
|
||||||
|
let stream = UnixStream::connect(&socket_path).await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
let mut line = String::new();
|
||||||
|
reader.read_line(&mut line).await.unwrap();
|
||||||
|
|
||||||
|
// Should be valid JSON
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed["message"], test_message);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user