1use std::collections::VecDeque;
49
50use thiserror::Error;
51use tungstenite::Bytes;
52use zeek_websocket_types::{Data, Event, Message, Value};
53
54use crate::types::Subscriptions;
55
56pub struct Binding {
60 state: State,
61 subscriptions: Subscriptions,
62
63 inbox: Inbox,
64 outbox: Outbox,
65}
66
67enum State {
68 Subscribing,
69 Subscribed,
70}
71
72impl Binding {
73 #[must_use]
80 pub fn new<S>(subscriptions: S) -> Self
81 where
82 S: Into<Subscriptions>,
83 {
84 let subscriptions = subscriptions.into();
85 Self {
86 state: State::Subscribing,
87 inbox: Inbox(VecDeque::new()),
88 outbox: Outbox(VecDeque::from([subscriptions.clone().into()])),
89 subscriptions,
90 }
91 }
92
93 pub fn handle_incoming(&mut self, message: Message) -> Result<(), ProtocolError> {
103 match &message {
104 Message::Ack { .. } => match self.state {
105 State::Subscribing => {
106 self.state = State::Subscribed;
107 }
108 State::Subscribed => return Err(ProtocolError::AlreadySubscribed),
109 },
110 Message::DataMessage {
111 data: Data::Other(unexpected),
112 ..
113 } => {
114 return Err(ProtocolError::UnexpectedEventPayload(unexpected.clone()));
115 }
116 _ => {
117 self.inbox.handle(message);
118 }
119 }
120
121 Ok(())
122 }
123
124 pub fn outgoing(&mut self) -> Option<Bytes> {
126 self.outbox.next_data()
127 }
128
129 pub fn receive_event(&mut self) -> Result<Option<(String, Event)>, ProtocolError> {
135 if let Some(message) = self.inbox.next_message() {
136 match message {
137 Message::DataMessage { topic, data } => {
138 let event = match data {
139 Data::Event(event) => event,
140 Data::Other(..) => unreachable!(), };
142 return Ok(Some((topic, event)));
143 }
144 Message::Error { code, context } => {
145 return Err(ProtocolError::ZeekError { code, context });
146 }
147 Message::Ack { .. } => {
148 unreachable!() }
150 }
151 }
152
153 Ok(None)
154 }
155
156 fn enqueue(&mut self, message: Message) -> Result<(), ProtocolError> {
158 match message {
159 Message::DataMessage { topic, data } => {
160 let is_subscribed = self
161 .subscriptions
162 .0
163 .iter()
164 .any(|s| s.as_str() == topic.as_str());
165
166 if is_subscribed {
167 self.outbox.enqueue(Message::DataMessage { topic, data });
168 } else {
169 return Err(ProtocolError::SendOnNonSubscribed(
170 topic,
171 self.subscriptions.clone(),
172 data,
173 ))?;
174 }
175 }
176 _ => self.outbox.enqueue(message),
177 }
178
179 Ok(())
180 }
181
182 pub fn publish_event<S>(&mut self, topic: S, event: Event) -> Result<(), ProtocolError>
189 where
190 S: Into<String>,
191 {
192 self.enqueue(Message::new_data(topic.into(), event))
193 }
194
195 #[must_use]
201 pub fn split(self) -> (Inbox, Outbox) {
202 (self.inbox, self.outbox)
203 }
204}
205
206pub struct Inbox(VecDeque<Message>);
208
209impl Inbox {
210 pub fn handle(&mut self, message: Message) {
212 self.0.push_back(message);
213 }
214
215 #[must_use]
217 pub fn next_message(&mut self) -> Option<Message> {
218 self.0.pop_front()
219 }
220
221 pub fn next_event(&mut self) -> Option<(String, Event)> {
225 while let Some(message) = self.next_message() {
226 if let Message::DataMessage {
227 topic,
228 data: Data::Event(event),
229 } = message
230 {
231 return Some((topic, event));
232 }
233 }
234
235 None
236 }
237}
238
239pub struct Outbox(VecDeque<tungstenite::Message>);
241
242impl Outbox {
243 pub fn next_data(&mut self) -> Option<Bytes> {
245 self.0.pop_front().map(tungstenite::Message::into_data)
246 }
247
248 pub fn enqueue<M>(&mut self, message: M)
250 where
251 M: Into<tungstenite::Message>,
252 {
253 self.0.push_back(message.into());
254 }
255
256 pub fn enqueue_event<S>(&mut self, topic: S, event: Event)
258 where
259 S: Into<String>,
260 {
261 self.enqueue(Message::new_data(topic.into(), event));
262 }
263}
264
265#[derive(Error, Debug, PartialEq)]
267pub enum ProtocolError {
268 #[error("received an ACK while already subscribed")]
270 AlreadySubscribed,
271
272 #[error("attempted to send on topic '{0}' but only subscribed to '{1}'")]
273 SendOnNonSubscribed(String, Subscriptions, Data),
274
275 #[error("Zeek error {code}: {context}")]
276 ZeekError { code: String, context: String },
277
278 #[error("unexpected event payload received")]
279 UnexpectedEventPayload(Value),
280}
281
282#[cfg(test)]
283mod test {
284 use crate::{
285 protocol::{Binding, ProtocolError},
286 types::{Data, Event, Message, Subscriptions, Value},
287 };
288
289 fn ack() -> Message {
290 Message::Ack {
291 endpoint: "mock".into(),
292 version: "0.1".into(),
293 }
294 }
295
296 #[test]
297 fn recv() {
298 let topic = "foo";
299
300 let mut conn = Binding::new(&[topic]);
301
302 assert_eq!(conn.inbox.next_message(), None);
304
305 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
307 conn.handle_incoming(ack().into()).unwrap();
308
309 assert_eq!(conn.inbox.next_message(), None);
311 assert_eq!(conn.receive_event(), Ok(None));
312
313 conn.handle_incoming(
315 Message::new_data(topic, Event::new("ping", Vec::<Value>::new())).into(),
316 )
317 .unwrap();
318
319 assert!(matches!(
320 conn.inbox.next_message(),
321 Some(Message::DataMessage {
322 data: Data::Event(..),
323 ..
324 })
325 ));
326 }
327
328 #[test]
329 fn send() {
330 let mut conn = Binding::new(&["foo"]);
331
332 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
334 conn.handle_incoming(ack().into()).unwrap();
335
336 conn.publish_event("foo", Event::new("ping", Vec::<Value>::new()))
338 .unwrap();
339
340 let msg =
342 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
343 assert!(matches!(
344 msg,
345 Message::DataMessage {
346 data: Data::Event(..),
347 ..
348 }
349 ));
350 }
351
352 #[test]
353 fn split() {
354 let (mut inbox, mut outbox) = Binding::new(&["foo"]).split();
355
356 Subscriptions::try_from(tungstenite::Message::binary(outbox.next_data().unwrap())).unwrap();
358 inbox.handle(ack().into());
359
360 assert!(matches!(inbox.next_message(), Some(Message::Ack { .. })));
361 }
362
363 #[test]
364 fn send_on_non_subscribed() {
365 let mut conn = Binding::new(&["foo"]);
366
367 let message = tungstenite::Message::binary(conn.outgoing().unwrap());
369 let subscription = Subscriptions::try_from(message).unwrap();
370 assert_eq!(subscription, Subscriptions::from(&["foo"]));
371
372 let event = Event::new("ping", vec!["ping on 'bar'"]);
374 assert_eq!(
375 conn.publish_event("bar", event.clone()),
376 Err(ProtocolError::SendOnNonSubscribed(
377 "bar".to_string(),
378 Subscriptions::from(&["foo"]),
379 Data::Event(event),
380 ))
381 );
382 }
383
384 #[test]
385 fn duplicate_ack() {
386 let mut conn = Binding::new(&["foo"]);
387
388 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
390 conn.handle_incoming(ack().into()).unwrap();
391
392 assert_eq!(
394 conn.handle_incoming(ack().into()),
395 Err(ProtocolError::AlreadySubscribed)
396 );
397 }
398
399 #[test]
400 fn other_event_payload() {
401 let mut conn = Binding::new(&["foo"]);
402 conn.handle_incoming(ack()).unwrap();
403
404 let other = Message::new_data("foo", Value::Count(42));
405 assert_eq!(
406 conn.handle_incoming(other),
407 Err(ProtocolError::UnexpectedEventPayload(Value::Count(42)))
408 );
409 }
410
411 #[test]
412 fn next_incoming() {
413 let mut conn = Binding::new(Subscriptions(Vec::new()));
414
415 let _ = conn.handle_incoming(ack());
417 let _ = conn.handle_incoming(Message::new_data(
418 "topic",
419 Event::new("ping", Vec::<Value>::new()),
420 ));
421
422 let (topic, event) = conn.receive_event().unwrap().unwrap();
425 assert_eq!(topic, "topic");
426 assert_eq!(event.name, "ping");
427
428 assert_eq!(conn.inbox.next_message(), None);
429 }
430
431 #[test]
432 fn error() {
433 let mut conn = Binding::new(&["foo"]);
434 conn.handle_incoming(ack()).unwrap();
435
436 conn.handle_incoming(Message::Error {
437 code: "code".to_string(),
438 context: "context".to_string(),
439 })
440 .unwrap();
441
442 assert_eq!(
443 conn.receive_event(),
444 Err(ProtocolError::ZeekError {
445 code: "code".to_string(),
446 context: "context".to_string()
447 })
448 );
449 }
450
451 #[test]
452 fn publish_event() {
453 let mut conn = Binding::new(&["foo"]);
454 conn.outgoing().unwrap();
456
457 conn.publish_event("foo", Event::new("ping", Vec::<Value>::new()))
458 .unwrap();
459 let message =
460 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
461 let Message::DataMessage {
462 topic,
463 data: Data::Event(event),
464 } = message
465 else {
466 panic!()
467 };
468 assert_eq!(topic, "foo");
469 assert_eq!(event.name, "ping");
470 }
471}