zeek_websocket/
client.rs

1//! Client implementation
2//!
3//! # Tokio-based clients for the Zeek WebSocket API
4//!
5//! This module provides a trait [`ZeekClient`] and a [`Service`] which can be used to create full
6//! asynchronous clients for the Zeek WebSocket API under [`tokio`]. Users implement `ZeekClient`
7//! to specify the runtime behavior of the client. After implementing `ZeekClient` for a client
8//! type it needs to be wrapped in a `Service`, e.g.,
9//!
10//! ```
11//! # use zeek_websocket::client::*;
12//! struct Client {
13//!     outbox: Option<Outbox>
14//! }
15//!
16//! impl ZeekClient for Client {
17//!     async fn connected(&mut self, _endpoint: String, _version: String) {}
18//!     async fn event(&mut self, _topic: String, _event: zeek_websocket::Event) {}
19//!     async fn error(&mut self, _error: zeek_websocket::protocol::ProtocolError) {}
20//! }
21//!
22//! let service = Service::new(|outbox| Client {
23//!     outbox: Some(outbox)
24//! });
25//! ```
26//!
27//! [`Service::new`] passes along an [`Outbox`] which can be used to publish (topic, [`Event`])
28//! tuples to Zeek with `Outbox::send`. Clients should store the `Outbox` since after it is
29//! dropped the `Service` will close the connection to Zeek; one way to control the lifetime of the
30//! API connection is to store an `Option<Outbox>` in the client so it can explicitly be reset to
31//! `None`.
32//!
33//! The service needs to explicitly be started with [`Service::serve`] which will return a `Future`
34//! which will become ready once the service has terminated, either due to connection shutdown or a
35//! fatal error.
36//!
37//! ## Example
38//!
39//! This example implements a client which publishes an event to Zeek and waits for the response
40//! before exiting. The hypothetical event here is `echo`,
41//!
42//! ```zeek
43//! global echo: event(message: string);
44//! ```
45//!
46//! and the server will publish back the event on the same topic. Since the client is subscribed on
47//! the topic it publishes to it will see the response, and can then reset its internally held
48//! `Outbox` to signal to the `Service` that the connection should be closed.
49//!
50//! ```
51//! # use zeek_websocket::client::*;
52//! # use zeek_websocket::*;
53//! struct Client {
54//!     outbox: Option<Outbox>,
55//! };
56//!
57//! impl ZeekClient for Client {
58//!     async fn connected(&mut self, endpoint: String, version: String) {
59//!         // Once connected send a single echo event. The server will send
60//!         // the event back to us.
61//!         if let Some(outbox) = &self.outbox {
62//!             outbox
63//!                 .send("/topic".to_owned(), Event::new("echo", ["hello!"]))
64//!                 .await
65//!                 .unwrap();
66//!         }
67//!     }
68//!
69//!     async fn event(&mut self, topic: String, event: Event) {
70//!         // If we see the `echo` event from the server drop our `outbox`.
71//!         // This will cause the service to terminate.
72//!         if &event.name == "echo" {
73//!             self.outbox.take();
74//!         }
75//!     }
76//!
77//!     async fn error(&mut self, error: protocol::ProtocolError) {
78//!         todo!()
79//!     }
80//! }
81//!
82//! # let rt = tokio::runtime::Builder::new_multi_thread()
83//! #     .enable_io()
84//! #     .build()
85//! #     .unwrap();
86//! # rt.block_on(async move {
87//! let uri = "ws://localhost:8080/v1/messages/json".try_into().unwrap();
88//! # let uri: tungstenite::http::Uri = uri;
89//! # let zeek = zeek_websocket::test::MockServer::default();
90//! # let uri = zeek.endpoint().clone();
91//!
92//! let service = Service::new(|outbox| Client {
93//!     outbox: Some(outbox),
94//! });
95//!
96//! service
97//!     .serve(
98//!         "my-client",
99//!         uri,
100//!         Subscriptions::from(&["/topic"]),
101//!     )
102//!     .await.unwrap();
103//! # });
104//! ```
105
106use std::num::NonZeroUsize;
107
108use futures_util::{SinkExt, StreamExt};
109use tokio::sync::mpsc::{self};
110use tokio_tungstenite::{
111    connect_async,
112    tungstenite::{self, http::Uri},
113};
114use tungstenite::{
115    ClientRequestBuilder, Utf8Bytes,
116    protocol::{CloseFrame, frame::coding::CloseCode},
117};
118use zeek_websocket_types::{DeserializationError, Event, Message, Subscriptions};
119
120use crate::{
121    Binding,
122    protocol::{self},
123};
124
125/// Runtime for a [`ZeekClient`].
126pub struct Service<S> {
127    client: S,
128    rx: mpsc::Receiver<(String, Event)>,
129}
130
131impl<C: ZeekClient> Service<C> {
132    /// Construct a new service which the given configuration. The returned `Service` needs to be
133    /// started with [`Service::serve`].
134    #[allow(clippy::needless_pass_by_value)]
135    pub fn new_with_config<F>(config: ServiceConfig, init: F) -> Self
136    where
137        F: FnOnce(Outbox) -> C,
138    {
139        let (tx, rx) = mpsc::channel(config.outbox_size.into());
140        let client = init(Outbox(tx));
141        Self { client, rx }
142    }
143
144    /// Constructs a new service with the default configuration. See
145    /// [`Service::new_with_config`] and [`ServiceConfig::default`] for more details.
146    pub fn new<F>(init: F) -> Self
147    where
148        F: FnOnce(Outbox) -> C,
149    {
150        // We give the client a channel of size `1` for publishing. This prevents the client from
151        // overwhelming the service loop with too much data. We could probably also pick a slightly
152        // bigger number for less backpressure.
153
154        Self::new_with_config(ServiceConfig::default(), init)
155    }
156
157    /// Run the client against the server until either
158    ///
159    /// - the client drops its event sender, or
160    /// - we encounter a fatal error.
161    ///
162    /// # Errors
163    ///
164    /// We return errors for
165    ///
166    /// - transport-related issues which are not recoverable
167    /// - errors to deserialize messages
168    pub async fn serve<S, T>(mut self, app_name: S, uri: Uri, subscriptions: T) -> Result<(), Error>
169    where
170        S: Into<String>,
171        T: Into<Subscriptions>,
172    {
173        let request = ClientRequestBuilder::new(uri).with_header("X-Application-Name", app_name);
174
175        let (mut stream, ..) = connect_async(request)
176            .await
177            .map_err(|e| Error::Transport(e.to_string()))?;
178
179        let mut binding = Binding::new(subscriptions);
180
181        // Handle subscription.
182        while let Some(x) = binding.outgoing() {
183            stream.send(x.into()).await?;
184        }
185
186        loop {
187            let Some(ack) = stream.next().await else {
188                // The server closed the connection.
189                return Ok(());
190            };
191
192            let ack = ack.map_err(|e| Error::Transport(e.to_string()))?;
193            if ack.is_ping() {
194                continue;
195            }
196
197            match ack.try_into()? {
198                zeek_websocket_types::Message::Ack { endpoint, version } => {
199                    self.client.connected(endpoint, version).await;
200                    break;
201                }
202                message => {
203                    return Err(Error::Transport(format!("expected ACK, got '{message:?}'")))?;
204                }
205            }
206        }
207
208        loop {
209            tokio::select! {
210                s = self.rx.recv() => {
211                    let Some((topic, event)) = s else {
212                        // Sender closed, graceful exit.
213                        stream
214                            .send(tungstenite::Message::Close(Some(CloseFrame {
215                                code: CloseCode::Normal,
216                                reason: Utf8Bytes::default(),
217                            })))
218                            .await?;
219                        break;
220                    };
221
222                    binding.publish_event(topic, event);
223
224                    while let Some(x) = binding.outgoing() {
225                        stream.send(x.into()).await?;
226                    }
227                }
228
229                r = stream.next() => {
230                    let Some(r) = r else {
231                        // Connection closed, graceful exit.
232                        break;
233                    };
234
235                    let r = r.map_err(|e| Error::Transport(e.to_string()))?;
236                    if r.is_ping() {
237                        continue;
238                    }
239
240                    let m: Message = match r.try_into() {
241                        Ok(m) => m,
242                        Err(e) => {
243                            self.client.error(e.into()).await;
244                            continue;
245                        }
246                    };
247
248                    binding.handle_incoming(m)?;
249
250                    while let Some(received) = binding.receive_event().transpose() {
251                        match received {
252                            Ok((topic, event)) => self.client.event(topic, event).await,
253                            Err(e) => self.client.error(e).await,
254                        }
255                    }
256                }
257            }
258        }
259
260        Ok(())
261    }
262}
263
264/// Handle for publishing into a [`Service`].
265///
266/// This is intended to be held by implementers of [`ZeekClient`] to publish events to Zeek, and
267/// is created during e.g., [`Service::new`]. `Service` holds on to the receiving side and will keep
268/// checking it. Dropping the `Outbox` indicates to the `Service` that the client is done and will
269/// cause it to terminate, so clients should hold the `Outbox` for as long as they intend to stay
270/// connected, and explicitly `drop` it.
271#[derive(Clone)]
272pub struct Outbox(mpsc::Sender<(String, Event)>);
273
274impl Outbox {
275    /// Enqueue an event on the given topic.
276    ///
277    /// # Errors
278    ///
279    /// Returns back the enqueued event when the outbox has been closed.
280    pub async fn send(&self, topic: String, event: Event) -> Result<(), (String, Event)> {
281        self.0.send((topic, event)).await.map_err(|e| e.0)
282    }
283}
284
285/// Configuration for a [`Service`].
286#[derive(Debug, PartialEq)]
287pub struct ServiceConfig {
288    /// The number of entries which can be enqueue in the outbox.
289    pub outbox_size: NonZeroUsize,
290}
291
292impl Default for ServiceConfig {
293    /// Constructs a default service configuration.
294    ///
295    /// ```
296    /// # use std::num::NonZeroUsize;
297    /// # use zeek_websocket::client::ServiceConfig;
298    /// assert_eq!(
299    ///     ServiceConfig::default(),
300    ///     ServiceConfig {
301    ///         outbox_size: NonZeroUsize::new(256).unwrap(),
302    ///     },
303    /// );
304    /// ```
305    fn default() -> Self {
306        Self {
307            outbox_size: unsafe { NonZeroUsize::new_unchecked(256) },
308        }
309    }
310}
311
312pub trait ZeekClient {
313    /// Callback invoked when we have finished the handshake with the server.
314    fn connected(
315        &mut self,
316        endpoint: String,
317        version: String,
318    ) -> impl std::future::Future<Output = ()> + Send;
319
320    /// Callback invoked when an event is received.
321    fn event(
322        &mut self,
323        topic: String,
324        event: Event,
325    ) -> impl std::future::Future<Output = ()> + Send;
326
327    /// Callback invoked when an error is received.
328    fn error(
329        &mut self,
330        error: protocol::ProtocolError,
331    ) -> impl std::future::Future<Output = ()> + Send;
332}
333
334/// Error enum for client-related errors.
335#[derive(thiserror::Error, Debug, PartialEq)]
336pub enum Error {
337    #[error("failure in websocket transport: {0}")]
338    Transport(String),
339
340    #[error("protocol-related error: {0}")]
341    ProtocolError(#[from] protocol::ProtocolError),
342}
343
344impl From<tungstenite::Error> for Error {
345    fn from(value: tungstenite::Error) -> Self {
346        Self::Transport(value.to_string())
347    }
348}
349
350impl From<DeserializationError> for Error {
351    fn from(value: DeserializationError) -> Self {
352        Self::ProtocolError(value.into())
353    }
354}
355
356#[cfg(test)]
357mod test {
358    use tokio::sync::mpsc::{self};
359    use zeek_websocket_types::{Event, Subscriptions};
360
361    use crate::{
362        client::{Error, Outbox, Service, ZeekClient},
363        protocol::ProtocolError,
364        test::MockServer,
365    };
366
367    #[tokio::test]
368    async fn unreachable_remote() {
369        struct Client {
370            _outbox: Outbox,
371        }
372
373        impl ZeekClient for Client {
374            async fn connected(&mut self, _endpoint: String, _version: String) {}
375            async fn event(&mut self, _topic: String, _event: Event) {}
376            async fn error(&mut self, _error: ProtocolError) {}
377        }
378
379        let service = Service::new(|_outbox| Client { _outbox });
380
381        let status = service
382            .serve(
383                "foo",
384                "ws://localhost:1".try_into().unwrap(),
385                Subscriptions::default(),
386            )
387            .await;
388        assert!(matches!(status, Err(Error::Transport(_))), "{status:?}");
389    }
390
391    #[tokio::test]
392    async fn echo() {
393        static TOPIC: &str = "/topic";
394
395        struct C {
396            _outbox: Outbox,
397            seen_events: mpsc::Sender<(String, Event)>,
398        }
399
400        impl C {
401            fn new(outbox: Outbox, seen_events: mpsc::Sender<(String, Event)>) -> Self {
402                Self {
403                    seen_events,
404                    _outbox: outbox,
405                }
406            }
407        }
408
409        impl ZeekClient for C {
410            async fn connected(&mut self, _endpoint: String, _version: String) {
411                self._outbox
412                    .send(TOPIC.into(), Event::new("echo", [42]))
413                    .await
414                    .unwrap();
415            }
416
417            async fn event(&mut self, topic: String, event: Event) {
418                eprintln!("Event {topic:?}: {event:?}");
419                self.seen_events.send((topic, event)).await.unwrap();
420            }
421
422            async fn error(&mut self, error: ProtocolError) {
423                eprintln!("Error: {error:?}");
424            }
425        }
426
427        let zeek = MockServer::default();
428
429        let (seen, mut events) = mpsc::channel(1);
430
431        let service = Service::new(|sender| C::new(sender, seen));
432
433        tokio::select! {
434            Some((topic, event)) = events.recv() => {
435                eprintln!("Event {topic:?}: {event:?}");
436            }
437            s = service.serve("foo", zeek.endpoint().clone(), Subscriptions::from(&[TOPIC])) => {
438                unreachable!("We should have received an event but instead the service returned with {s:?}");
439            }
440        }
441    }
442}