zeek_websocket/
client.rs

1//! Client implementation
2use std::{
3    collections::{HashMap, HashSet},
4    sync::Arc,
5};
6
7use futures_util::{SinkExt, StreamExt};
8use tokio::{
9    sync::{
10        RwLock,
11        mpsc::{Receiver, Sender, channel},
12    },
13    task::JoinHandle,
14};
15use tokio_tungstenite::{
16    connect_async,
17    tungstenite::{self, http::Uri},
18};
19use tungstenite::ClientRequestBuilder;
20use typed_builder::TypedBuilder;
21use zeek_websocket_types::Event;
22
23use crate::{
24    Binding,
25    protocol::{self, ProtocolError},
26};
27
28/// Builder to construct a [`Client`].
29#[derive(TypedBuilder)]
30#[builder(
31    doc,
32    build_method(into = Result<Client, Error>),
33    mutators(
34        /// Subscribe to topic to receive events.
35        pub fn subscribe<S: Into<String>>(&mut self, topic: S) {
36            self.subscriptions.insert(topic.into());
37        }
38
39))]
40pub struct ClientConfig {
41    #[builder(via_mutators)]
42    subscriptions: HashSet<String>,
43
44    /// Zeek WebSocket endpoint to connect to.
45    endpoint: Uri,
46
47    /// String use by the client to identify itself against Zeek.
48    #[builder(setter(into))]
49    app_name: String,
50
51    /// How many events to buffer before exerting backpressure.
52    #[builder(default = 1024)]
53    buffer_capacity: usize,
54}
55
56impl From<ClientConfig> for Result<Client, Error> {
57    fn from(
58        ClientConfig {
59            subscriptions,
60            endpoint,
61            buffer_capacity,
62            app_name,
63        }: ClientConfig,
64    ) -> Self {
65        let (events_sender, events) = channel(buffer_capacity);
66
67        let bindings: Result<_, _> = subscriptions
68            .into_iter()
69            .map(|topic| {
70                let client = TopicHandler::new(
71                    app_name.clone(),
72                    topic.clone(),
73                    endpoint.clone(),
74                    events_sender.clone(),
75                );
76                Ok::<(std::string::String, TopicHandler), Error>((topic, client))
77            })
78            .collect();
79        let bindings = Arc::new(RwLock::<HashMap<_, _>>::new(bindings?));
80
81        Ok(Client {
82            app_name,
83            endpoint,
84            bindings,
85            events,
86            events_sender,
87        })
88    }
89}
90
91/// # Tokio-based for the Zeek WebSocket API
92///
93/// [`Client`] implements an async client for the Zeek WebSocket API. It is intended to be run
94/// inside a [`tokio`] runtime. The general workflow is to build a client with the [`ClientConfig`]
95/// builder interface, and then either publish or receive events.
96///
97/// ## Example
98///
99/// ```no_run
100/// use anyhow::Result;
101/// use zeek_websocket::client::ClientConfig;
102/// pub use zeek_websocket::Event;
103///
104/// #[tokio::main]
105/// async fn main() -> Result<()> {
106///     let mut client = ClientConfig::builder()
107///         .app_name("my_client_application")
108///         .subscribe("/info")
109///         .endpoint("ws://127.0.0.1:8080/v1/messages/json".try_into()?)
110///         .build()?;
111///
112///     client
113///         .publish_event("/ping", Event::new("ping", vec!["abc"]))
114///         .await;
115///
116///     loop {
117///         // Client automatically receives events on topics it sent to.
118///         if let Some((_topic, event)) = client.receive_event().await? {
119///             eprintln!("{event:?}");
120///             break;
121///         }
122///     }
123///
124///     Ok(())
125/// }
126/// ```
127pub struct Client {
128    app_name: String,
129
130    endpoint: Uri,
131    bindings: Arc<RwLock<HashMap<String, TopicHandler>>>,
132
133    events: Receiver<Result<(String, Event), ProtocolError>>,
134    events_sender: Sender<Result<(String, Event), ProtocolError>>,
135}
136
137impl Client {
138    /// Publish an [`Event`] to `topic`. The client will be automatically subscribed to the topic
139    /// if is not already.
140    pub async fn publish_event<S: Into<String>>(&mut self, topic: S, event: Event) {
141        let topic = topic.into();
142
143        let mut bindings = self.bindings.write().await;
144        let client = bindings.entry(topic.clone()).or_insert_with(|| {
145            TopicHandler::new(
146                self.app_name.clone(),
147                topic.clone(),
148                self.endpoint.clone(),
149                self.events_sender.clone(),
150            )
151        });
152
153        let _ = client.publish_sink.send((topic, event)).await;
154    }
155
156    /// Receive the next [`Event`] or [`Error`].
157    ///
158    /// If an event was received it will be returned as `Ok(Some((topic, event)))`.
159    ///
160    /// # Errors
161    ///
162    /// Might return a [`ProtocolError`] from the underlying binding, e.g., Zeek, or an
163    /// transport-related error.
164    pub async fn receive_event(&mut self) -> Result<Option<(String, Event)>, Error> {
165        Ok(self.events.recv().await.transpose()?)
166    }
167
168    /// Subscribe to a topic.
169    ///
170    /// This is a noop if the client is already subscribed.
171    pub async fn subscribe<S: Into<String>>(&mut self, topic: S) {
172        let topic = topic.into();
173
174        let mut bindings = self.bindings.write().await;
175        bindings.entry(topic.clone()).or_insert_with(|| {
176            TopicHandler::new(
177                self.app_name.clone(),
178                topic,
179                self.endpoint.clone(),
180                self.events_sender.clone(),
181            )
182        });
183    }
184
185    /// Unsubscribe from a topic.
186    ///
187    /// This is a noop if the client was not subscribed.
188    pub async fn unsubscribe(&mut self, topic: &str) -> bool {
189        let mut bindings = self.bindings.write().await;
190        bindings.remove(topic).is_some()
191    }
192}
193
194struct TopicHandler {
195    _loop: JoinHandle<Result<(), Error>>,
196    publish_sink: Sender<(String, Event)>,
197}
198
199impl TopicHandler {
200    fn new(
201        app_name: String,
202        topic: String,
203        endpoint: Uri,
204        events_sender: Sender<Result<(String, Event), ProtocolError>>,
205    ) -> Self {
206        let (publish_sink, mut publish) = channel(1);
207
208        let loop_ = tokio::spawn(async move {
209            let mut binding = Binding::new(vec![topic.clone()]);
210
211            let endpoint =
212                ClientRequestBuilder::new(endpoint).with_header("X-Application-Name", app_name);
213
214            let (mut stream, ..) = connect_async(endpoint).await?;
215
216            loop {
217                tokio::select! {
218                    r = publish.recv() => {
219                        let Some((topic, event)) = r else { return Ok(()); };
220                        binding.publish_event::<String>(topic, event)?;
221                    }
222                    s = stream.next() => {
223                        if let Ok(message) = match s {
224                            Some(payload) => match payload {
225                                Ok(p) => p.try_into(),
226                                Err(e) => {
227                                    handle_transport_error(e, &mut binding, &topic)?;
228                                    continue;
229                                },
230                            },
231
232                            None => continue,
233                        } {
234                            binding.handle_incoming(message)?;
235                        }
236                    }
237                };
238
239                while let Some(bin) = binding.outgoing() {
240                    if let Err(e) = stream.send(tungstenite::Message::binary(bin)).await {
241                        handle_transport_error(e, &mut binding, &topic)?;
242                    }
243                }
244
245                while let Some(payload) = binding.receive_event().transpose() {
246                    let _ = events_sender.send(payload).await;
247                }
248            }
249        });
250
251        TopicHandler {
252            _loop: loop_,
253            publish_sink,
254        }
255    }
256}
257
258/// Error enum for client-related errors.
259#[derive(thiserror::Error, Debug, PartialEq)]
260pub enum Error {
261    #[error("failure in websocket transport: {0}")]
262    Transport(String),
263    #[error("protocol-related error: {0}")]
264    ProtocolError(#[from] protocol::ProtocolError),
265}
266
267impl From<tungstenite::Error> for Error {
268    fn from(value: tungstenite::Error) -> Self {
269        Self::Transport(value.to_string())
270    }
271}
272
273/// Error handling for transport errors.
274///
275/// Returns a `Ok(())` if the incoming could be handled, or an `Err` if it should be propagated up.
276fn handle_transport_error(
277    e: tungstenite::Error,
278    binding: &mut Binding,
279    topic: &str,
280) -> Result<(), Error> {
281    match e {
282        // Errors we can handle by gracefully resubscribing.
283        tungstenite::Error::AttackAttempt
284        | tungstenite::Error::AlreadyClosed
285        | tungstenite::Error::ConnectionClosed
286        | tungstenite::Error::Io(_) => {
287            *binding = Binding::new(vec![topic]);
288            Ok(())
289        }
290
291        // Errors we bail on and bubble up to the user.
292        tungstenite::Error::Protocol(_)
293        | tungstenite::Error::WriteBufferFull(_)
294        | tungstenite::Error::Capacity(_)
295        | tungstenite::Error::Tls(_)
296        | tungstenite::Error::Url(_)
297        | tungstenite::Error::Http(_)
298        | tungstenite::Error::HttpFormat(_)
299        | tungstenite::Error::Utf8(_) => Err(Error::from(e)),
300    }
301}
302
303#[cfg(test)]
304mod test {
305    use std::time::Duration;
306
307    use zeek_websocket_types::Event;
308
309    use crate::client::ClientConfig;
310
311    #[tokio::test]
312    async fn basic() {
313        let endpoint = "ws://127.0.0.1";
314        let mut client = ClientConfig::builder()
315            .endpoint(endpoint.try_into().unwrap())
316            .app_name("foo")
317            .subscribe("/info")
318            .build()
319            .unwrap();
320
321        client
322            .publish_event("/info", Event::new("info", vec!["hi!"]))
323            .await;
324
325        client
326            .publish_event("/not-yet-subscribed", Event::new("info", vec!["hi!"]))
327            .await;
328
329        tokio::select! {
330            _e = client.receive_event() => {}
331            _timeout = tokio::time::sleep(Duration::from_millis(10)) => {}
332        };
333
334        client.subscribe("/info").await;
335        client.subscribe("/foo").await;
336
337        client.unsubscribe("/info").await;
338        client.unsubscribe("/foo").await;
339    }
340}