zeek_websocket/
client.rs

1//! Client implementation
2use std::{
3    collections::{HashMap, HashSet},
4    sync::Arc,
5};
6
7use futures_util::{SinkExt, StreamExt};
8use tokio::{
9    sync::{
10        RwLock,
11        mpsc::{Receiver, Sender, channel},
12    },
13    task::JoinHandle,
14};
15use tokio_tungstenite::{
16    connect_async,
17    tungstenite::{self, http::Uri},
18};
19use tungstenite::ClientRequestBuilder;
20use typed_builder::TypedBuilder;
21use zeek_websocket_types::Event;
22
23use crate::{
24    Binding,
25    protocol::{self, ProtocolError},
26};
27
28/// Builder to construct a [`Client`].
29#[derive(TypedBuilder)]
30#[builder(
31    doc,
32    build_method(into = Result<Client, Error>),
33    mutators(
34        /// Subscribe to topic to receive events.
35        pub fn subscribe<S: Into<String>>(&mut self, topic: S) {
36            self.subscriptions.insert(topic.into());
37        }
38
39))]
40pub struct ClientConfig {
41    #[builder(via_mutators)]
42    subscriptions: HashSet<String>,
43
44    /// Zeek WebSocket endpoint to connect to.
45    endpoint: Uri,
46
47    /// String use by the client to identify itself against Zeek.
48    #[builder(setter(into))]
49    app_name: String,
50
51    /// How many events to buffer before exerting backpressure.
52    #[builder(default = 1024)]
53    buffer_capacity: usize,
54}
55
56impl From<ClientConfig> for Result<Client, Error> {
57    fn from(
58        ClientConfig {
59            subscriptions,
60            endpoint,
61            buffer_capacity,
62            app_name,
63        }: ClientConfig,
64    ) -> Self {
65        let (events_sender, events) = channel(buffer_capacity);
66
67        let bindings: Result<_, _> = subscriptions
68            .into_iter()
69            .map(|topic| {
70                let client = TopicHandler::new(
71                    app_name.clone(),
72                    Some(topic.clone()),
73                    endpoint.clone(),
74                    events_sender.clone(),
75                );
76                Ok::<(Option<String>, TopicHandler), Error>((Some(topic), client))
77            })
78            .collect();
79        let bindings = Arc::new(RwLock::<HashMap<_, _>>::new(bindings?));
80
81        Ok(Client {
82            app_name,
83            endpoint,
84            bindings,
85            events,
86            events_sender,
87        })
88    }
89}
90
91/// # Tokio-based for the Zeek WebSocket API
92///
93/// [`Client`] implements an async client for the Zeek WebSocket API. It is intended to be run
94/// inside a `tokio` runtime. The general workflow is to build a client with the [`ClientConfig`]
95/// builder interface, and then either publish or receive events.
96///
97/// ## Example
98///
99/// ```no_run
100/// use anyhow::Result;
101/// use zeek_websocket::client::ClientConfig;
102/// use zeek_websocket::Event;
103///
104/// #[tokio::main]
105/// async fn main() -> Result<()> {
106///     let mut client = ClientConfig::builder()
107///         .app_name("my_client_application")
108///         .subscribe("/info")
109///         .endpoint("ws://127.0.0.1:8080/v1/messages/json".try_into()?)
110///         .build()?;
111///
112///     client
113///         .publish_event("/ping", Event::new("ping", vec!["abc"]))
114///         .await;
115///
116///     loop {
117///         // Client automatically receives events on topics it sent to.
118///         if let Some((_topic, event)) = client.receive_event().await? {
119///             eprintln!("{event:?}");
120///             break;
121///         }
122///     }
123///
124///     Ok(())
125/// }
126/// ```
127pub struct Client {
128    app_name: String,
129
130    endpoint: Uri,
131    bindings: Arc<RwLock<HashMap<Option<String>, TopicHandler>>>,
132
133    events: Receiver<Result<(String, Event), ProtocolError>>,
134    events_sender: Sender<Result<(String, Event), ProtocolError>>,
135}
136
137impl Client {
138    /// Publish an [`Event`] to `topic`.
139    pub async fn publish_event<S: Into<String>>(&mut self, topic: S, event: Event) {
140        let topic = topic.into();
141
142        let mut bindings = self.bindings.write().await;
143
144        let client = if let Some(client) = bindings.get(&Some(topic.clone())) {
145            // If we are subscribed to a topic use its handler for publishing the event.
146            client
147        } else {
148            // Else use a null handler which does not receive events, but can still publish.
149            bindings.entry(None).or_insert_with(|| {
150                TopicHandler::new(
151                    self.app_name.clone(),
152                    None,
153                    self.endpoint.clone(),
154                    self.events_sender.clone(),
155                )
156            })
157        };
158
159        let _ = client.publish_sink.send((topic.clone(), event)).await;
160    }
161
162    /// Receive the next [`Event`] or [`Error`].
163    ///
164    /// If an event was received it will be returned as `Ok(Some((topic, event)))`.
165    ///
166    /// # Errors
167    ///
168    /// Might return a [`ProtocolError`] from the underlying binding, e.g., Zeek, or an
169    /// transport-related error.
170    pub async fn receive_event(&mut self) -> Result<Option<(String, Event)>, Error> {
171        Ok(self.events.recv().await.transpose()?)
172    }
173
174    /// Subscribe to a topic.
175    ///
176    /// This is a noop if the client is already subscribed.
177    pub async fn subscribe<S: Into<String>>(&mut self, topic: S) {
178        let topic = topic.into();
179
180        let mut bindings = self.bindings.write().await;
181        bindings.entry(Some(topic.clone())).or_insert_with(|| {
182            TopicHandler::new(
183                self.app_name.clone(),
184                Some(topic),
185                self.endpoint.clone(),
186                self.events_sender.clone(),
187            )
188        });
189    }
190
191    /// Unsubscribe from a topic.
192    ///
193    /// This is a noop if the client was not subscribed.
194    pub async fn unsubscribe(&mut self, topic: &str) -> bool {
195        let mut bindings = self.bindings.write().await;
196        bindings.remove(&Some(topic.to_string())).is_some()
197    }
198}
199
200struct TopicHandler {
201    _loop: JoinHandle<Result<(), Error>>,
202    publish_sink: Sender<(String, Event)>,
203}
204
205impl TopicHandler {
206    fn new(
207        app_name: String,
208        topic: Option<String>,
209        endpoint: Uri,
210        events_sender: Sender<Result<(String, Event), ProtocolError>>,
211    ) -> Self {
212        let (publish_sink, mut publish) = channel(1);
213
214        let loop_ = tokio::spawn(async move {
215            let topics = if let Some(topic) = &topic {
216                vec![topic.clone()]
217            } else {
218                vec![]
219            };
220            let mut binding = Binding::new(topics);
221
222            let endpoint =
223                ClientRequestBuilder::new(endpoint).with_header("X-Application-Name", app_name);
224
225            let (mut stream, ..) = connect_async(endpoint).await?;
226
227            loop {
228                tokio::select! {
229                    r = publish.recv() => {
230                        let Some((topic, event)) = r else { return Ok(()); };
231                        binding.publish_event::<String>(topic, event);
232                    }
233                    s = stream.next() => {
234                        if let Ok(message) = match s {
235                            Some(payload) => match payload {
236                                Ok(p) => p.try_into(),
237                                Err(e) => {
238                                    if let Some(topic) = &topic {
239                                        handle_transport_error(e, &mut binding, topic)?;
240                                    }
241                                    continue;
242                                },
243                            },
244
245                            None => continue,
246                        } {
247                            binding.handle_incoming(message)?;
248                        }
249                    }
250                };
251
252                while let Some(bin) = binding.outgoing() {
253                    if let Err(e) = stream.send(tungstenite::Message::binary(bin)).await
254                        && let Some(topic) = &topic
255                    {
256                        handle_transport_error(e, &mut binding, topic)?;
257                    }
258                }
259
260                while let Some(payload) = binding.receive_event().transpose() {
261                    let _ = events_sender.send(payload).await;
262                }
263            }
264        });
265
266        TopicHandler {
267            _loop: loop_,
268            publish_sink,
269        }
270    }
271}
272
273/// Error enum for client-related errors.
274#[derive(thiserror::Error, Debug, PartialEq)]
275pub enum Error {
276    #[error("failure in websocket transport: {0}")]
277    Transport(String),
278    #[error("protocol-related error: {0}")]
279    ProtocolError(#[from] protocol::ProtocolError),
280}
281
282impl From<tungstenite::Error> for Error {
283    fn from(value: tungstenite::Error) -> Self {
284        Self::Transport(value.to_string())
285    }
286}
287
288/// Error handling for transport errors.
289///
290/// Returns a `Ok(())` if the incoming could be handled, or an `Err` if it should be propagated up.
291fn handle_transport_error(
292    e: tungstenite::Error,
293    binding: &mut Binding,
294    topic: &str,
295) -> Result<(), Error> {
296    match e {
297        // Errors we can handle by gracefully resubscribing.
298        tungstenite::Error::AttackAttempt
299        | tungstenite::Error::AlreadyClosed
300        | tungstenite::Error::ConnectionClosed
301        | tungstenite::Error::Io(_) => {
302            *binding = Binding::new(vec![topic]);
303            Ok(())
304        }
305
306        // Errors we bail on and bubble up to the user.
307        tungstenite::Error::Protocol(_)
308        | tungstenite::Error::WriteBufferFull(_)
309        | tungstenite::Error::Capacity(_)
310        | tungstenite::Error::Tls(_)
311        | tungstenite::Error::Url(_)
312        | tungstenite::Error::Http(_)
313        | tungstenite::Error::HttpFormat(_)
314        | tungstenite::Error::Utf8(_) => Err(Error::from(e)),
315    }
316}
317
318#[cfg(test)]
319mod test {
320    use std::time::Duration;
321
322    use zeek_websocket_types::Event;
323
324    use crate::client::ClientConfig;
325
326    #[tokio::test]
327    async fn basic() {
328        let endpoint = "ws://127.0.0.1";
329        let mut client = ClientConfig::builder()
330            .endpoint(endpoint.try_into().unwrap())
331            .app_name("foo")
332            .subscribe("/info")
333            .build()
334            .unwrap();
335
336        client
337            .publish_event("/info", Event::new("info", ["hi!"]))
338            .await;
339
340        client
341            .publish_event("/not-yet-subscribed", Event::new("info", ["hi!"]))
342            .await;
343
344        tokio::select! {
345            _e = client.receive_event() => {}
346            _timeout = tokio::time::sleep(Duration::from_millis(10)) => {}
347        };
348
349        client.subscribe("/info").await;
350        client.subscribe("/foo").await;
351
352        client.unsubscribe("/info").await;
353        client.unsubscribe("/foo").await;
354    }
355}