1use std::{
3 collections::{HashMap, HashSet},
4 sync::Arc,
5};
6
7use futures_util::{SinkExt, StreamExt};
8use tokio::{
9 sync::{
10 RwLock,
11 mpsc::{Receiver, Sender, channel},
12 },
13 task::JoinHandle,
14};
15use tokio_tungstenite::{
16 connect_async,
17 tungstenite::{self, http::Uri},
18};
19use tungstenite::ClientRequestBuilder;
20use typed_builder::TypedBuilder;
21use zeek_websocket_types::Event;
22
23use crate::{
24 Binding,
25 protocol::{self, ProtocolError},
26};
27
28#[derive(TypedBuilder)]
30#[builder(
31 doc,
32 build_method(into = Result<Client, Error>),
33 mutators(
34 pub fn subscribe<S: Into<String>>(&mut self, topic: S) {
36 self.subscriptions.insert(topic.into());
37 }
38
39))]
40pub struct ClientConfig {
41 #[builder(via_mutators)]
42 subscriptions: HashSet<String>,
43
44 endpoint: Uri,
46
47 #[builder(setter(into))]
49 app_name: String,
50
51 #[builder(default = 1024)]
53 buffer_capacity: usize,
54}
55
56impl From<ClientConfig> for Result<Client, Error> {
57 fn from(
58 ClientConfig {
59 subscriptions,
60 endpoint,
61 buffer_capacity,
62 app_name,
63 }: ClientConfig,
64 ) -> Self {
65 let (events_sender, events) = channel(buffer_capacity);
66
67 let bindings: Result<_, _> = subscriptions
68 .into_iter()
69 .map(|topic| {
70 let client = TopicHandler::new(
71 app_name.clone(),
72 Some(topic.clone()),
73 endpoint.clone(),
74 events_sender.clone(),
75 );
76 Ok::<(Option<String>, TopicHandler), Error>((Some(topic), client))
77 })
78 .collect();
79 let bindings = Arc::new(RwLock::<HashMap<_, _>>::new(bindings?));
80
81 Ok(Client {
82 app_name,
83 endpoint,
84 bindings,
85 events,
86 events_sender,
87 })
88 }
89}
90
91pub struct Client {
128 app_name: String,
129
130 endpoint: Uri,
131 bindings: Arc<RwLock<HashMap<Option<String>, TopicHandler>>>,
132
133 events: Receiver<Result<(String, Event), ProtocolError>>,
134 events_sender: Sender<Result<(String, Event), ProtocolError>>,
135}
136
137impl Client {
138 pub async fn publish_event<S: Into<String>>(&mut self, topic: S, event: Event) {
140 let topic = topic.into();
141
142 let mut bindings = self.bindings.write().await;
143
144 let client = if let Some(client) = bindings.get(&Some(topic.clone())) {
145 client
147 } else {
148 bindings.entry(None).or_insert_with(|| {
150 TopicHandler::new(
151 self.app_name.clone(),
152 None,
153 self.endpoint.clone(),
154 self.events_sender.clone(),
155 )
156 })
157 };
158
159 let _ = client.publish_sink.send((topic.clone(), event)).await;
160 }
161
162 pub async fn receive_event(&mut self) -> Result<Option<(String, Event)>, Error> {
171 Ok(self.events.recv().await.transpose()?)
172 }
173
174 pub async fn subscribe<S: Into<String>>(&mut self, topic: S) {
178 let topic = topic.into();
179
180 let mut bindings = self.bindings.write().await;
181 bindings.entry(Some(topic.clone())).or_insert_with(|| {
182 TopicHandler::new(
183 self.app_name.clone(),
184 Some(topic),
185 self.endpoint.clone(),
186 self.events_sender.clone(),
187 )
188 });
189 }
190
191 pub async fn unsubscribe(&mut self, topic: &str) -> bool {
195 let mut bindings = self.bindings.write().await;
196 bindings.remove(&Some(topic.to_string())).is_some()
197 }
198}
199
200struct TopicHandler {
201 _loop: JoinHandle<Result<(), Error>>,
202 publish_sink: Sender<(String, Event)>,
203}
204
205impl TopicHandler {
206 fn new(
207 app_name: String,
208 topic: Option<String>,
209 endpoint: Uri,
210 events_sender: Sender<Result<(String, Event), ProtocolError>>,
211 ) -> Self {
212 let (publish_sink, mut publish) = channel(1);
213
214 let loop_ = tokio::spawn(async move {
215 let topics = if let Some(topic) = &topic {
216 vec![topic.clone()]
217 } else {
218 vec![]
219 };
220 let mut binding = Binding::new(topics);
221
222 let endpoint =
223 ClientRequestBuilder::new(endpoint).with_header("X-Application-Name", app_name);
224
225 let (mut stream, ..) = connect_async(endpoint).await?;
226
227 loop {
228 tokio::select! {
229 r = publish.recv() => {
230 let Some((topic, event)) = r else { return Ok(()); };
231 binding.publish_event::<String>(topic, event);
232 }
233 s = stream.next() => {
234 if let Ok(message) = match s {
235 Some(payload) => match payload {
236 Ok(p) => p.try_into(),
237 Err(e) => {
238 if let Some(topic) = &topic {
239 handle_transport_error(e, &mut binding, topic)?;
240 }
241 continue;
242 },
243 },
244
245 None => continue,
246 } {
247 binding.handle_incoming(message)?;
248 }
249 }
250 };
251
252 while let Some(bin) = binding.outgoing() {
253 if let Err(e) = stream.send(tungstenite::Message::binary(bin)).await
254 && let Some(topic) = &topic
255 {
256 handle_transport_error(e, &mut binding, topic)?;
257 }
258 }
259
260 while let Some(payload) = binding.receive_event().transpose() {
261 let _ = events_sender.send(payload).await;
262 }
263 }
264 });
265
266 TopicHandler {
267 _loop: loop_,
268 publish_sink,
269 }
270 }
271}
272
273#[derive(thiserror::Error, Debug, PartialEq)]
275pub enum Error {
276 #[error("failure in websocket transport: {0}")]
277 Transport(String),
278 #[error("protocol-related error: {0}")]
279 ProtocolError(#[from] protocol::ProtocolError),
280}
281
282impl From<tungstenite::Error> for Error {
283 fn from(value: tungstenite::Error) -> Self {
284 Self::Transport(value.to_string())
285 }
286}
287
288fn handle_transport_error(
292 e: tungstenite::Error,
293 binding: &mut Binding,
294 topic: &str,
295) -> Result<(), Error> {
296 match e {
297 tungstenite::Error::AttackAttempt
299 | tungstenite::Error::AlreadyClosed
300 | tungstenite::Error::ConnectionClosed
301 | tungstenite::Error::Io(_) => {
302 *binding = Binding::new(vec![topic]);
303 Ok(())
304 }
305
306 tungstenite::Error::Protocol(_)
308 | tungstenite::Error::WriteBufferFull(_)
309 | tungstenite::Error::Capacity(_)
310 | tungstenite::Error::Tls(_)
311 | tungstenite::Error::Url(_)
312 | tungstenite::Error::Http(_)
313 | tungstenite::Error::HttpFormat(_)
314 | tungstenite::Error::Utf8(_) => Err(Error::from(e)),
315 }
316}
317
318#[cfg(test)]
319mod test {
320 use std::time::Duration;
321
322 use zeek_websocket_types::Event;
323
324 use crate::client::ClientConfig;
325
326 #[tokio::test]
327 async fn basic() {
328 let endpoint = "ws://127.0.0.1";
329 let mut client = ClientConfig::builder()
330 .endpoint(endpoint.try_into().unwrap())
331 .app_name("foo")
332 .subscribe("/info")
333 .build()
334 .unwrap();
335
336 client
337 .publish_event("/info", Event::new("info", ["hi!"]))
338 .await;
339
340 client
341 .publish_event("/not-yet-subscribed", Event::new("info", ["hi!"]))
342 .await;
343
344 tokio::select! {
345 _e = client.receive_event() => {}
346 _timeout = tokio::time::sleep(Duration::from_millis(10)) => {}
347 };
348
349 client.subscribe("/info").await;
350 client.subscribe("/foo").await;
351
352 client.unsubscribe("/info").await;
353 client.unsubscribe("/foo").await;
354 }
355}