zeek_websocket/client.rs
1//! Client implementation
2//!
3//! # Tokio-based clients for the Zeek WebSocket API
4//!
5//! This module provides a trait [`ZeekClient`] and a [`Service`] which can be used to create full
6//! asynchronous clients for the Zeek WebSocket API under [`tokio`]. Users implement `ZeekClient`
7//! to specify the runtime behavior of the client. After implementing `ZeekClient` for a client
8//! type it needs to be wrapped in a `Service`, e.g.,
9//!
10//! ```
11//! # use zeek_websocket::client::*;
12//! struct Client {
13//! outbox: Option<Outbox>
14//! }
15//!
16//! impl ZeekClient for Client {
17//! async fn connected(&mut self, _endpoint: String, _version: String) {}
18//! async fn event(&mut self, _topic: String, _event: zeek_websocket::Event) {}
19//! async fn error(&mut self, _error: zeek_websocket::protocol::ProtocolError) {}
20//! }
21//!
22//! let service = Service::new(|outbox| Client {
23//! outbox: Some(outbox)
24//! });
25//! ```
26//!
27//! [`Service::new`] passes along an [`Outbox`] which can be used to publish (topic, [`Event`])
28//! tuples to Zeek with `Outbox::send`. Clients should store the `Outbox` since after it is
29//! dropped the `Service` will close the connection to Zeek; one way to control the lifetime of the
30//! API connection is to store an `Option<Outbox>` in the client so it can explicitly be reset to
31//! `None`.
32//!
33//! The service needs to explicitly be started with [`Service::serve`] which will return a `Future`
34//! which will become ready once the service has terminated, either due to connection shutdown or a
35//! fatal error.
36//!
37//! ## Example
38//!
39//! This example implements a client which publishes an event to Zeek and waits for the response
40//! before exiting. The hypothetical event here is `echo`,
41//!
42//! ```zeek
43//! global echo: event(message: string);
44//! ```
45//!
46//! and the server will publish back the event on the same topic. Since the client is subscribed on
47//! the topic it publishes to it will see the response, and can then reset its internally held
48//! `Outbox` to signal to the `Service` that the connection should be closed.
49//!
50//! ```
51//! # use zeek_websocket::client::*;
52//! # use zeek_websocket::*;
53//! struct Client {
54//! outbox: Option<Outbox>,
55//! };
56//!
57//! impl ZeekClient for Client {
58//! async fn connected(&mut self, endpoint: String, version: String) {
59//! // Once connected send a single echo event. The server will send
60//! // the event back to us.
61//! if let Some(outbox) = &self.outbox {
62//! outbox
63//! .send("/topic".to_owned(), Event::new("echo", ["hello!"]))
64//! .await
65//! .unwrap();
66//! }
67//! }
68//!
69//! async fn event(&mut self, topic: String, event: Event) {
70//! // If we see the `echo` event from the server drop our `outbox`.
71//! // This will cause the service to terminate.
72//! if &event.name == "echo" {
73//! self.outbox.take();
74//! }
75//! }
76//!
77//! async fn error(&mut self, error: protocol::ProtocolError) {
78//! todo!()
79//! }
80//! }
81//!
82//! # let rt = tokio::runtime::Builder::new_multi_thread()
83//! # .enable_io()
84//! # .build()
85//! # .unwrap();
86//! # rt.block_on(async move {
87//! let uri = "ws://localhost:8080/v1/messages/json".try_into().unwrap();
88//! # let uri: tungstenite::http::Uri = uri;
89//! # let zeek = zeek_websocket::test::MockServer::default();
90//! # let uri = zeek.endpoint().clone();
91//!
92//! let service = Service::new(|outbox| Client {
93//! outbox: Some(outbox),
94//! });
95//!
96//! service
97//! .serve(
98//! "my-client",
99//! uri,
100//! Subscriptions::from(&["/topic"]),
101//! )
102//! .await.unwrap();
103//! # });
104//! ```
105
106use std::num::NonZeroUsize;
107
108use futures_util::{SinkExt, StreamExt};
109use tokio::sync::mpsc::{self};
110use tokio_tungstenite::{
111 connect_async,
112 tungstenite::{self, http::Uri},
113};
114use tungstenite::{
115 ClientRequestBuilder, Utf8Bytes,
116 protocol::{CloseFrame, frame::coding::CloseCode},
117};
118use zeek_websocket_types::{DeserializationError, Event, Message, Subscriptions};
119
120use crate::{
121 Binding,
122 protocol::{self},
123};
124
125/// Runtime for a [`ZeekClient`].
126pub struct Service<S> {
127 client: S,
128 rx: mpsc::Receiver<(String, Event)>,
129}
130
131impl<C: ZeekClient> Service<C> {
132 /// Construct a new service which the given configuration. The returned `Service` needs to be
133 /// started with [`Service::serve`].
134 #[allow(clippy::needless_pass_by_value)]
135 pub fn new_with_config<F>(config: ServiceConfig, init: F) -> Self
136 where
137 F: FnOnce(Outbox) -> C,
138 {
139 let (tx, rx) = mpsc::channel(config.outbox_size.into());
140 let client = init(Outbox(tx));
141 Self { client, rx }
142 }
143
144 /// Constructs a new service with the default configuration. See
145 /// [`Service::new_with_config`] and [`ServiceConfig::default`] for more details.
146 pub fn new<F>(init: F) -> Self
147 where
148 F: FnOnce(Outbox) -> C,
149 {
150 // We give the client a channel of size `1` for publishing. This prevents the client from
151 // overwhelming the service loop with too much data. We could probably also pick a slightly
152 // bigger number for less backpressure.
153
154 Self::new_with_config(ServiceConfig::default(), init)
155 }
156
157 /// Run the client against the server until either
158 ///
159 /// - the client drops its event sender, or
160 /// - we encounter a fatal error.
161 ///
162 /// # Errors
163 ///
164 /// We return errors for
165 ///
166 /// - transport-related issues which are not recoverable
167 /// - errors to deserialize messages
168 pub async fn serve<S, T>(mut self, app_name: S, uri: Uri, subscriptions: T) -> Result<(), Error>
169 where
170 S: Into<String>,
171 T: Into<Subscriptions>,
172 {
173 let request = ClientRequestBuilder::new(uri).with_header("X-Application-Name", app_name);
174
175 let (mut stream, ..) = connect_async(request)
176 .await
177 .map_err(|e| Error::Transport(e.to_string()))?;
178
179 let mut binding = Binding::new(subscriptions);
180
181 // Handle subscription.
182 while let Some(x) = binding.outgoing() {
183 stream.send(x.into()).await?;
184 }
185
186 loop {
187 let Some(ack) = stream.next().await else {
188 // The server closed the connection.
189 return Ok(());
190 };
191
192 let ack = ack.map_err(|e| Error::Transport(e.to_string()))?;
193 if ack.is_ping() {
194 continue;
195 }
196
197 match ack.try_into()? {
198 zeek_websocket_types::Message::Ack { endpoint, version } => {
199 self.client.connected(endpoint, version).await;
200 break;
201 }
202 message => {
203 return Err(Error::Transport(format!("expected ACK, got '{message:?}'")))?;
204 }
205 }
206 }
207
208 loop {
209 tokio::select! {
210 s = self.rx.recv() => {
211 let Some((topic, event)) = s else {
212 // Sender closed, graceful exit.
213 stream
214 .send(tungstenite::Message::Close(Some(CloseFrame {
215 code: CloseCode::Normal,
216 reason: Utf8Bytes::default(),
217 })))
218 .await?;
219 break;
220 };
221
222 binding.publish_event(topic, event);
223
224 while let Some(x) = binding.outgoing() {
225 stream.send(x.into()).await?;
226 }
227 }
228
229 r = stream.next() => {
230 let Some(r) = r else {
231 // Connection closed, graceful exit.
232 break;
233 };
234
235 let r = r.map_err(|e| Error::Transport(e.to_string()))?;
236 if r.is_ping() {
237 continue;
238 }
239
240 let m: Message = match r.try_into() {
241 Ok(m) => m,
242 Err(e) => {
243 self.client.error(e.into()).await;
244 continue;
245 }
246 };
247
248 binding.handle_incoming(m)?;
249
250 while let Some(received) = binding.receive_event().transpose() {
251 match received {
252 Ok((topic, event)) => self.client.event(topic, event).await,
253 Err(e) => self.client.error(e).await,
254 }
255 }
256 }
257 }
258 }
259
260 Ok(())
261 }
262}
263
264/// Handle for publishing into a [`Service`].
265///
266/// This is intended to be held by implementers of [`ZeekClient`] to publish events to Zeek, and
267/// is created during e.g., [`Service::new`]. `Service` holds on to the receiving side and will keep
268/// checking it. Dropping the `Outbox` indicates to the `Service` that the client is done and will
269/// cause it to terminate, so clients should hold the `Outbox` for as long as they intend to stay
270/// connected, and explicitly `drop` it.
271#[derive(Clone)]
272pub struct Outbox(mpsc::Sender<(String, Event)>);
273
274impl Outbox {
275 /// Enqueue an event on the given topic.
276 ///
277 /// # Errors
278 ///
279 /// Returns back the enqueued event when the outbox has been closed.
280 pub async fn send(&self, topic: String, event: Event) -> Result<(), (String, Event)> {
281 self.0.send((topic, event)).await.map_err(|e| e.0)
282 }
283}
284
285/// Configuration for a [`Service`].
286#[derive(Debug, PartialEq)]
287pub struct ServiceConfig {
288 /// The number of entries which can be enqueue in the outbox.
289 pub outbox_size: NonZeroUsize,
290}
291
292impl Default for ServiceConfig {
293 /// Constructs a default service configuration.
294 ///
295 /// ```
296 /// # use std::num::NonZeroUsize;
297 /// # use zeek_websocket::client::ServiceConfig;
298 /// assert_eq!(
299 /// ServiceConfig::default(),
300 /// ServiceConfig {
301 /// outbox_size: NonZeroUsize::new(256).unwrap(),
302 /// },
303 /// );
304 /// ```
305 fn default() -> Self {
306 Self {
307 outbox_size: unsafe { NonZeroUsize::new_unchecked(256) },
308 }
309 }
310}
311
312pub trait ZeekClient {
313 /// Callback invoked when we have finished the handshake with the server.
314 fn connected(
315 &mut self,
316 endpoint: String,
317 version: String,
318 ) -> impl std::future::Future<Output = ()> + Send;
319
320 /// Callback invoked when an event is received.
321 fn event(
322 &mut self,
323 topic: String,
324 event: Event,
325 ) -> impl std::future::Future<Output = ()> + Send;
326
327 /// Callback invoked when an error is received.
328 fn error(
329 &mut self,
330 error: protocol::ProtocolError,
331 ) -> impl std::future::Future<Output = ()> + Send;
332}
333
334/// Error enum for client-related errors.
335#[derive(thiserror::Error, Debug, PartialEq)]
336pub enum Error {
337 #[error("failure in websocket transport: {0}")]
338 Transport(String),
339
340 #[error("protocol-related error: {0}")]
341 ProtocolError(#[from] protocol::ProtocolError),
342}
343
344impl From<tungstenite::Error> for Error {
345 fn from(value: tungstenite::Error) -> Self {
346 Self::Transport(value.to_string())
347 }
348}
349
350impl From<DeserializationError> for Error {
351 fn from(value: DeserializationError) -> Self {
352 Self::ProtocolError(value.into())
353 }
354}
355
356#[cfg(test)]
357mod test {
358 use tokio::sync::mpsc::{self};
359 use zeek_websocket_types::{Event, Subscriptions};
360
361 use crate::{
362 client::{Error, Outbox, Service, ZeekClient},
363 protocol::ProtocolError,
364 test::MockServer,
365 };
366
367 #[tokio::test]
368 async fn unreachable_remote() {
369 struct Client {
370 _outbox: Outbox,
371 }
372
373 impl ZeekClient for Client {
374 async fn connected(&mut self, _endpoint: String, _version: String) {}
375 async fn event(&mut self, _topic: String, _event: Event) {}
376 async fn error(&mut self, _error: ProtocolError) {}
377 }
378
379 let service = Service::new(|_outbox| Client { _outbox });
380
381 let status = service
382 .serve(
383 "foo",
384 "ws://localhost:1".try_into().unwrap(),
385 Subscriptions::default(),
386 )
387 .await;
388 assert!(matches!(status, Err(Error::Transport(_))), "{status:?}");
389 }
390
391 #[tokio::test]
392 async fn echo() {
393 static TOPIC: &str = "/topic";
394
395 struct C {
396 _outbox: Outbox,
397 seen_events: mpsc::Sender<(String, Event)>,
398 }
399
400 impl C {
401 fn new(outbox: Outbox, seen_events: mpsc::Sender<(String, Event)>) -> Self {
402 Self {
403 seen_events,
404 _outbox: outbox,
405 }
406 }
407 }
408
409 impl ZeekClient for C {
410 async fn connected(&mut self, _endpoint: String, _version: String) {
411 self._outbox
412 .send(TOPIC.into(), Event::new("echo", [42]))
413 .await
414 .unwrap();
415 }
416
417 async fn event(&mut self, topic: String, event: Event) {
418 eprintln!("Event {topic:?}: {event:?}");
419 self.seen_events.send((topic, event)).await.unwrap();
420 }
421
422 async fn error(&mut self, error: ProtocolError) {
423 eprintln!("Error: {error:?}");
424 }
425 }
426
427 let zeek = MockServer::default();
428
429 let (seen, mut events) = mpsc::channel(1);
430
431 let service = Service::new(|sender| C::new(sender, seen));
432
433 tokio::select! {
434 Some((topic, event)) = events.recv() => {
435 eprintln!("Event {topic:?}: {event:?}");
436 }
437 s = service.serve("foo", zeek.endpoint().clone(), Subscriptions::from(&[TOPIC])) => {
438 unreachable!("We should have received an event but instead the service returned with {s:?}");
439 }
440 }
441 }
442}