1use std::collections::VecDeque;
49
50use thiserror::Error;
51use tungstenite::Bytes;
52use zeek_websocket_types::{Data, Event, Message, Value};
53
54use crate::types::Subscriptions;
55
56pub struct Binding {
60 state: State,
61 subscriptions: Subscriptions,
62
63 inbox: Inbox,
64 outbox: Outbox,
65}
66
67enum State {
68 Subscribing,
69 Subscribed,
70}
71
72impl Binding {
73 #[must_use]
80 pub fn new<S>(subscriptions: S) -> Self
81 where
82 S: Into<Subscriptions>,
83 {
84 let subscriptions = subscriptions.into();
85 Self {
86 state: State::Subscribing,
87 inbox: Inbox(VecDeque::new()),
88 outbox: Outbox(VecDeque::from([subscriptions.clone().into()])),
89 subscriptions,
90 }
91 }
92
93 pub fn handle_incoming(&mut self, message: Message) -> Result<(), ProtocolError> {
103 match &message {
104 Message::Ack { .. } => match self.state {
105 State::Subscribing => {
106 self.state = State::Subscribed;
107 }
108 State::Subscribed => return Err(ProtocolError::AlreadySubscribed),
109 },
110 Message::DataMessage {
111 data: Data::Other(unexpected),
112 ..
113 } => {
114 return Err(ProtocolError::UnexpectedEventPayload(unexpected.clone()));
115 }
116 _ => {
117 self.inbox.handle(message);
118 }
119 }
120
121 Ok(())
122 }
123
124 pub fn outgoing(&mut self) -> Option<Bytes> {
126 self.outbox.next_data()
127 }
128
129 pub fn receive_event(&mut self) -> Result<Option<(String, Event)>, ProtocolError> {
135 if let Some(message) = self.inbox.next_message() {
136 match message {
137 Message::DataMessage { topic, data } => {
138 let event = match data {
139 Data::Event(event) => event,
140 Data::Other(..) => unreachable!(), };
142 return Ok(Some((topic, event)));
143 }
144 Message::Error { code, context } => {
145 return Err(ProtocolError::ZeekError { code, context });
146 }
147 Message::Ack { .. } => {
148 unreachable!() }
150 }
151 }
152
153 Ok(None)
154 }
155
156 fn enqueue(&mut self, message: Message) -> Result<(), ProtocolError> {
158 match message {
159 Message::DataMessage { topic, data } => {
160 let is_subscribed = self
161 .subscriptions
162 .0
163 .iter()
164 .any(|s| s.as_str() == topic.as_str());
165
166 if is_subscribed {
167 self.outbox.enqueue(Message::DataMessage { topic, data });
168 } else {
169 return Err(ProtocolError::SendOnNonSubscribed(
170 topic,
171 self.subscriptions.clone(),
172 data,
173 ))?;
174 }
175 }
176 _ => self.outbox.enqueue(message),
177 }
178
179 Ok(())
180 }
181
182 pub fn publish_event<S>(&mut self, topic: S, event: Event) -> Result<(), ProtocolError>
189 where
190 S: Into<String>,
191 {
192 self.enqueue(Message::new_data(topic.into(), event))
193 }
194
195 #[must_use]
201 pub fn split(self) -> (Inbox, Outbox) {
202 (self.inbox, self.outbox)
203 }
204}
205
206pub struct Inbox(VecDeque<Message>);
208
209impl Inbox {
210 pub fn handle(&mut self, message: Message) {
212 self.0.push_back(message);
213 }
214
215 #[must_use]
217 pub fn next_message(&mut self) -> Option<Message> {
218 self.0.pop_front()
219 }
220
221 pub fn next_event(&mut self) -> Option<(String, Event)> {
225 while let Some(message) = self.next_message() {
226 if let Message::DataMessage {
227 topic,
228 data: Data::Event(event),
229 } = message
230 {
231 return Some((topic, event));
232 }
233 }
234
235 None
236 }
237}
238
239pub struct Outbox(VecDeque<tungstenite::Message>);
241
242impl Outbox {
243 pub fn next_data(&mut self) -> Option<Bytes> {
245 self.0.pop_front().map(tungstenite::Message::into_data)
246 }
247
248 pub fn enqueue<M>(&mut self, message: M)
250 where
251 M: Into<tungstenite::Message>,
252 {
253 self.0.push_back(message.into());
254 }
255
256 pub fn enqueue_event<S>(&mut self, topic: S, event: Event)
258 where
259 S: Into<String>,
260 {
261 self.enqueue(Message::new_data(topic.into(), event));
262 }
263}
264
265#[derive(Error, Debug, PartialEq)]
267pub enum ProtocolError {
268 #[error("received an ACK while already subscribed")]
270 AlreadySubscribed,
271
272 #[error("attempted to send on topic '{0}' but only subscribed to '{1}'")]
273 SendOnNonSubscribed(String, Subscriptions, Data),
274
275 #[error("Zeek error {code}: {context}")]
276 ZeekError { code: String, context: String },
277
278 #[error("unexpected event payload received")]
279 UnexpectedEventPayload(Value),
280}
281
282#[cfg(test)]
283mod test {
284 use crate::{
285 protocol::{Binding, ProtocolError},
286 types::{Data, Event, Message, Subscriptions, Value},
287 };
288
289 fn ack() -> Message {
290 Message::Ack {
291 endpoint: "mock".into(),
292 version: "0.1".into(),
293 }
294 }
295
296 #[test]
297 fn recv() {
298 let topic = "foo";
299
300 let mut conn = Binding::new(&[topic]);
301
302 assert_eq!(conn.inbox.next_message(), None);
304
305 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
307 conn.handle_incoming(ack().into()).unwrap();
308
309 assert_eq!(conn.inbox.next_message(), None);
311 assert_eq!(conn.receive_event(), Ok(None));
312
313 conn.handle_incoming(Message::new_data(topic, Event::new("ping", [(); 0])).into())
315 .unwrap();
316
317 assert!(matches!(
318 conn.inbox.next_message(),
319 Some(Message::DataMessage {
320 data: Data::Event(..),
321 ..
322 })
323 ));
324 }
325
326 #[test]
327 fn send() {
328 let mut conn = Binding::new(&["foo"]);
329
330 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
332 conn.handle_incoming(ack().into()).unwrap();
333
334 conn.publish_event("foo", Event::new("ping", [(); 0]))
336 .unwrap();
337
338 let msg =
340 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
341 assert!(matches!(
342 msg,
343 Message::DataMessage {
344 data: Data::Event(..),
345 ..
346 }
347 ));
348 }
349
350 #[test]
351 fn split() {
352 let (mut inbox, mut outbox) = Binding::new(&["foo"]).split();
353
354 Subscriptions::try_from(tungstenite::Message::binary(outbox.next_data().unwrap())).unwrap();
356 inbox.handle(ack().into());
357
358 assert!(matches!(inbox.next_message(), Some(Message::Ack { .. })));
359 }
360
361 #[test]
362 fn send_on_non_subscribed() {
363 let mut conn = Binding::new(&["foo"]);
364
365 let message = tungstenite::Message::binary(conn.outgoing().unwrap());
367 let subscription = Subscriptions::try_from(message).unwrap();
368 assert_eq!(subscription, Subscriptions::from(&["foo"]));
369
370 let event = Event::new("ping", ["ping on 'bar'"]);
372 assert_eq!(
373 conn.publish_event("bar", event.clone()),
374 Err(ProtocolError::SendOnNonSubscribed(
375 "bar".to_string(),
376 Subscriptions::from(&["foo"]),
377 Data::Event(event),
378 ))
379 );
380 }
381
382 #[test]
383 fn duplicate_ack() {
384 let mut conn = Binding::new(&["foo"]);
385
386 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
388 conn.handle_incoming(ack().into()).unwrap();
389
390 assert_eq!(
392 conn.handle_incoming(ack().into()),
393 Err(ProtocolError::AlreadySubscribed)
394 );
395 }
396
397 #[test]
398 fn other_event_payload() {
399 let mut conn = Binding::new(&["foo"]);
400 conn.handle_incoming(ack()).unwrap();
401
402 let other = Message::new_data("foo", Value::Count(42));
403 assert_eq!(
404 conn.handle_incoming(other),
405 Err(ProtocolError::UnexpectedEventPayload(Value::Count(42)))
406 );
407 }
408
409 #[test]
410 fn next_incoming() {
411 let mut conn = Binding::new(Subscriptions(Vec::new()));
412
413 let _ = conn.handle_incoming(ack());
415 let _ = conn.handle_incoming(Message::new_data("topic", Event::new("ping", [(); 0])));
416
417 let (topic, event) = conn.receive_event().unwrap().unwrap();
420 assert_eq!(topic, "topic");
421 assert_eq!(event.name, "ping");
422
423 assert_eq!(conn.inbox.next_message(), None);
424 }
425
426 #[test]
427 fn error() {
428 let mut conn = Binding::new(&["foo"]);
429 conn.handle_incoming(ack()).unwrap();
430
431 conn.handle_incoming(Message::Error {
432 code: "code".to_string(),
433 context: "context".to_string(),
434 })
435 .unwrap();
436
437 assert_eq!(
438 conn.receive_event(),
439 Err(ProtocolError::ZeekError {
440 code: "code".to_string(),
441 context: "context".to_string()
442 })
443 );
444 }
445
446 #[test]
447 fn publish_event() {
448 let mut conn = Binding::new(&["foo"]);
449 conn.outgoing().unwrap();
451
452 conn.publish_event("foo", Event::new("ping", [(); 0]))
453 .unwrap();
454 let message =
455 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
456 let Message::DataMessage {
457 topic,
458 data: Data::Event(event),
459 } = message
460 else {
461 panic!()
462 };
463 assert_eq!(topic, "foo");
464 assert_eq!(event.name, "ping");
465 }
466}