1use std::{
3 collections::{HashMap, HashSet},
4 sync::Arc,
5};
6
7use futures_util::{SinkExt, StreamExt};
8use tokio::{
9 sync::{
10 RwLock,
11 mpsc::{Receiver, Sender, channel},
12 },
13 task::JoinHandle,
14};
15use tokio_tungstenite::{
16 connect_async,
17 tungstenite::{self, http::Uri},
18};
19use tungstenite::ClientRequestBuilder;
20use typed_builder::TypedBuilder;
21use zeek_websocket_types::Event;
22
23use crate::{
24 Binding,
25 protocol::{self, ProtocolError},
26};
27
28#[derive(TypedBuilder)]
30#[builder(
31 doc,
32 build_method(into = Result<Client, Error>),
33 mutators(
34 pub fn subscribe<S: Into<String>>(&mut self, topic: S) {
36 self.subscriptions.insert(topic.into());
37 }
38
39))]
40pub struct ClientConfig {
41 #[builder(via_mutators)]
42 subscriptions: HashSet<String>,
43
44 endpoint: Uri,
46
47 #[builder(setter(into))]
49 app_name: String,
50
51 #[builder(default = 1024)]
53 buffer_capacity: usize,
54}
55
56impl From<ClientConfig> for Result<Client, Error> {
57 fn from(
58 ClientConfig {
59 subscriptions,
60 endpoint,
61 buffer_capacity,
62 app_name,
63 }: ClientConfig,
64 ) -> Self {
65 let (events_sender, events) = channel(buffer_capacity);
66
67 let bindings: Result<_, _> = subscriptions
68 .into_iter()
69 .map(|topic| {
70 let client = TopicHandler::new(
71 app_name.clone(),
72 topic.clone(),
73 endpoint.clone(),
74 events_sender.clone(),
75 );
76 Ok::<(std::string::String, TopicHandler), Error>((topic, client))
77 })
78 .collect();
79 let bindings = Arc::new(RwLock::<HashMap<_, _>>::new(bindings?));
80
81 Ok(Client {
82 app_name,
83 endpoint,
84 bindings,
85 events,
86 events_sender,
87 })
88 }
89}
90
91pub struct Client {
128 app_name: String,
129
130 endpoint: Uri,
131 bindings: Arc<RwLock<HashMap<String, TopicHandler>>>,
132
133 events: Receiver<Result<(String, Event), ProtocolError>>,
134 events_sender: Sender<Result<(String, Event), ProtocolError>>,
135}
136
137impl Client {
138 pub async fn publish_event<S: Into<String>>(&mut self, topic: S, event: Event) {
141 let topic = topic.into();
142
143 let mut bindings = self.bindings.write().await;
144 let client = bindings.entry(topic.clone()).or_insert_with(|| {
145 TopicHandler::new(
146 self.app_name.clone(),
147 topic.clone(),
148 self.endpoint.clone(),
149 self.events_sender.clone(),
150 )
151 });
152
153 let _ = client.publish_sink.send((topic, event)).await;
154 }
155
156 pub async fn receive_event(&mut self) -> Result<Option<(String, Event)>, Error> {
165 Ok(self.events.recv().await.transpose()?)
166 }
167
168 pub async fn subscribe<S: Into<String>>(&mut self, topic: S) {
172 let topic = topic.into();
173
174 let mut bindings = self.bindings.write().await;
175 bindings.entry(topic.clone()).or_insert_with(|| {
176 TopicHandler::new(
177 self.app_name.clone(),
178 topic,
179 self.endpoint.clone(),
180 self.events_sender.clone(),
181 )
182 });
183 }
184
185 pub async fn unsubscribe(&mut self, topic: &str) -> bool {
189 let mut bindings = self.bindings.write().await;
190 bindings.remove(topic).is_some()
191 }
192}
193
194struct TopicHandler {
195 _loop: JoinHandle<Result<(), Error>>,
196 publish_sink: Sender<(String, Event)>,
197}
198
199impl TopicHandler {
200 fn new(
201 app_name: String,
202 topic: String,
203 endpoint: Uri,
204 events_sender: Sender<Result<(String, Event), ProtocolError>>,
205 ) -> Self {
206 let (publish_sink, mut publish) = channel(1);
207
208 let loop_ = tokio::spawn(async move {
209 let mut binding = Binding::new(vec![topic.clone()]);
210
211 let endpoint =
212 ClientRequestBuilder::new(endpoint).with_header("X-Application-Name", app_name);
213
214 let (mut stream, ..) = connect_async(endpoint).await?;
215
216 loop {
217 tokio::select! {
218 r = publish.recv() => {
219 let Some((topic, event)) = r else { return Ok(()); };
220 binding.publish_event::<String>(topic, event)?;
221 }
222 s = stream.next() => {
223 if let Ok(message) = match s {
224 Some(payload) => match payload {
225 Ok(p) => p.try_into(),
226 Err(e) => {
227 handle_transport_error(e, &mut binding, &topic)?;
228 continue;
229 },
230 },
231
232 None => continue,
233 } {
234 binding.handle_incoming(message)?;
235 }
236 }
237 };
238
239 while let Some(bin) = binding.outgoing() {
240 if let Err(e) = stream.send(tungstenite::Message::binary(bin)).await {
241 handle_transport_error(e, &mut binding, &topic)?;
242 }
243 }
244
245 while let Some(payload) = binding.receive_event().transpose() {
246 let _ = events_sender.send(payload).await;
247 }
248 }
249 });
250
251 TopicHandler {
252 _loop: loop_,
253 publish_sink,
254 }
255 }
256}
257
258#[derive(thiserror::Error, Debug, PartialEq)]
260pub enum Error {
261 #[error("failure in websocket transport: {0}")]
262 Transport(String),
263 #[error("protocol-related error: {0}")]
264 ProtocolError(#[from] protocol::ProtocolError),
265}
266
267impl From<tungstenite::Error> for Error {
268 fn from(value: tungstenite::Error) -> Self {
269 Self::Transport(value.to_string())
270 }
271}
272
273fn handle_transport_error(
277 e: tungstenite::Error,
278 binding: &mut Binding,
279 topic: &str,
280) -> Result<(), Error> {
281 match e {
282 tungstenite::Error::AttackAttempt
284 | tungstenite::Error::AlreadyClosed
285 | tungstenite::Error::ConnectionClosed
286 | tungstenite::Error::Io(_) => {
287 *binding = Binding::new(vec![topic]);
288 Ok(())
289 }
290
291 tungstenite::Error::Protocol(_)
293 | tungstenite::Error::WriteBufferFull(_)
294 | tungstenite::Error::Capacity(_)
295 | tungstenite::Error::Tls(_)
296 | tungstenite::Error::Url(_)
297 | tungstenite::Error::Http(_)
298 | tungstenite::Error::HttpFormat(_)
299 | tungstenite::Error::Utf8(_) => Err(Error::from(e)),
300 }
301}
302
303#[cfg(test)]
304mod test {
305 use std::time::Duration;
306
307 use zeek_websocket_types::Event;
308
309 use crate::client::ClientConfig;
310
311 #[tokio::test]
312 async fn basic() {
313 let endpoint = "ws://127.0.0.1";
314 let mut client = ClientConfig::builder()
315 .endpoint(endpoint.try_into().unwrap())
316 .app_name("foo")
317 .subscribe("/info")
318 .build()
319 .unwrap();
320
321 client
322 .publish_event("/info", Event::new("info", vec!["hi!"]))
323 .await;
324
325 client
326 .publish_event("/not-yet-subscribed", Event::new("info", vec!["hi!"]))
327 .await;
328
329 tokio::select! {
330 _e = client.receive_event() => {}
331 _timeout = tokio::time::sleep(Duration::from_millis(10)) => {}
332 };
333
334 client.subscribe("/info").await;
335 client.subscribe("/foo").await;
336
337 client.unsubscribe("/info").await;
338 client.unsubscribe("/foo").await;
339 }
340}