1use std::collections::VecDeque;
49
50use thiserror::Error;
51use tungstenite::Bytes;
52use zeek_websocket_types::{Data, DeserializationError, Event, Message, Value};
53
54use crate::types::Subscriptions;
55
56pub struct Binding {
60 state: State,
61
62 inbox: Inbox,
63 outbox: Outbox,
64}
65
66enum State {
67 Subscribing,
68 Subscribed,
69}
70
71impl Binding {
72 #[must_use]
79 pub fn new<S>(subscriptions: S) -> Self
80 where
81 S: Into<Subscriptions>,
82 {
83 let subscriptions = subscriptions.into();
84 Self {
85 state: State::Subscribing,
86 inbox: Inbox(VecDeque::new()),
87 outbox: Outbox(VecDeque::from([subscriptions.into()])),
88 }
89 }
90
91 pub fn handle_incoming(&mut self, message: Message) -> Result<(), ProtocolError> {
98 match &message {
99 Message::Ack { .. } => match self.state {
100 State::Subscribing => {
101 self.state = State::Subscribed;
102 }
103 State::Subscribed => return Err(ProtocolError::AlreadySubscribed),
104 },
105 Message::DataMessage {
106 data: Data::Other(unexpected),
107 ..
108 } => {
109 return Err(ProtocolError::UnexpectedEventPayload(unexpected.clone()));
110 }
111 _ => {
112 self.inbox.handle(message);
113 }
114 }
115
116 Ok(())
117 }
118
119 pub fn outgoing(&mut self) -> Option<Bytes> {
121 self.outbox.next_data()
122 }
123
124 pub fn receive_event(&mut self) -> Result<Option<(String, Event)>, ProtocolError> {
130 if let Some(message) = self.inbox.next_message() {
131 match message {
132 Message::DataMessage { topic, data } => {
133 let event = match data {
134 Data::Event(event) => event,
135 Data::Other(..) => unreachable!(), };
137 return Ok(Some((topic, event)));
138 }
139 Message::Error { code, context } => {
140 return Err(ProtocolError::ZeekError { code, context });
141 }
142 Message::Ack { .. } => {
143 unreachable!() }
145 }
146 }
147
148 Ok(None)
149 }
150
151 fn enqueue(&mut self, message: Message) {
153 match message {
154 Message::DataMessage { topic, data } => {
155 self.outbox.enqueue(Message::DataMessage { topic, data });
156 }
157 _ => self.outbox.enqueue(message),
158 }
159 }
160
161 pub fn publish_event<S>(&mut self, topic: S, event: Event)
163 where
164 S: Into<String>,
165 {
166 self.enqueue(Message::new_data(topic.into(), event));
167 }
168
169 #[must_use]
175 pub fn split(self) -> (Inbox, Outbox) {
176 (self.inbox, self.outbox)
177 }
178}
179
180pub struct Inbox(VecDeque<Message>);
182
183impl Inbox {
184 pub fn handle(&mut self, message: Message) {
186 self.0.push_back(message);
187 }
188
189 #[must_use]
191 pub fn next_message(&mut self) -> Option<Message> {
192 self.0.pop_front()
193 }
194
195 pub fn next_event(&mut self) -> Option<(String, Event)> {
199 while let Some(message) = self.next_message() {
200 if let Message::DataMessage {
201 topic,
202 data: Data::Event(event),
203 } = message
204 {
205 return Some((topic, event));
206 }
207 }
208
209 None
210 }
211}
212
213pub struct Outbox(VecDeque<tungstenite::Message>);
215
216impl Outbox {
217 pub fn next_data(&mut self) -> Option<Bytes> {
219 self.0.pop_front().map(tungstenite::Message::into_data)
220 }
221
222 pub fn enqueue<M>(&mut self, message: M)
224 where
225 M: Into<tungstenite::Message>,
226 {
227 self.0.push_back(message.into());
228 }
229
230 pub fn enqueue_event<S>(&mut self, topic: S, event: Event)
232 where
233 S: Into<String>,
234 {
235 self.enqueue(Message::new_data(topic.into(), event));
236 }
237}
238
239#[derive(Error, Debug, PartialEq)]
241pub enum ProtocolError {
242 #[error("received an ACK while already subscribed")]
243 AckExpected,
244
245 #[error("received an ACK while already subscribed")]
246 AlreadySubscribed,
247
248 #[error("Zeek error {code}: {context}")]
249 ZeekError { code: String, context: String },
250
251 #[error("unexpected event payload received")]
252 UnexpectedEventPayload(Value),
253
254 #[error("could not deserialize message: {0}")]
255 DeserializationError(#[from] DeserializationError),
256}
257
258#[cfg(test)]
259mod test {
260 use crate::{
261 protocol::{Binding, ProtocolError},
262 types::{Data, Event, Message, Subscriptions, Value},
263 };
264
265 fn ack() -> Message {
266 Message::Ack {
267 endpoint: "mock".into(),
268 version: "0.1".into(),
269 }
270 }
271
272 #[test]
273 fn recv() {
274 let topic = "foo";
275
276 let mut conn = Binding::new(&[topic]);
277
278 assert_eq!(conn.inbox.next_message(), None);
280
281 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
283 conn.handle_incoming(ack().into()).unwrap();
284
285 assert_eq!(conn.inbox.next_message(), None);
287 assert_eq!(conn.receive_event(), Ok(None));
288
289 conn.handle_incoming(Message::new_data(topic, Event::new("ping", [(); 0])).into())
291 .unwrap();
292
293 assert!(matches!(
294 conn.inbox.next_message(),
295 Some(Message::DataMessage {
296 data: Data::Event(..),
297 ..
298 })
299 ));
300 }
301
302 #[test]
303 fn send() {
304 let mut conn = Binding::new(&["foo"]);
305
306 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
308 conn.handle_incoming(ack().into()).unwrap();
309
310 conn.publish_event("foo", Event::new("ping", [(); 0]));
312
313 let msg =
315 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
316 assert!(matches!(
317 msg,
318 Message::DataMessage {
319 data: Data::Event(..),
320 ..
321 }
322 ));
323 }
324
325 #[test]
326 fn split() {
327 let (mut inbox, mut outbox) = Binding::new(&["foo"]).split();
328
329 Subscriptions::try_from(tungstenite::Message::binary(outbox.next_data().unwrap())).unwrap();
331 inbox.handle(ack().into());
332
333 assert!(matches!(inbox.next_message(), Some(Message::Ack { .. })));
334 }
335
336 #[test]
337 fn duplicate_ack() {
338 let mut conn = Binding::new(&["foo"]);
339
340 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
342 conn.handle_incoming(ack().into()).unwrap();
343
344 assert_eq!(
346 conn.handle_incoming(ack().into()),
347 Err(ProtocolError::AlreadySubscribed)
348 );
349 }
350
351 #[test]
352 fn other_event_payload() {
353 let mut conn = Binding::new(&["foo"]);
354 conn.handle_incoming(ack()).unwrap();
355
356 let other = Message::new_data("foo", Value::Count(42));
357 assert_eq!(
358 conn.handle_incoming(other),
359 Err(ProtocolError::UnexpectedEventPayload(Value::Count(42)))
360 );
361 }
362
363 #[test]
364 fn next_incoming() {
365 let mut conn = Binding::new(Subscriptions::default());
366
367 let _ = conn.handle_incoming(ack());
369 let _ = conn.handle_incoming(Message::new_data("topic", Event::new("ping", [(); 0])));
370
371 let (topic, event) = conn.receive_event().unwrap().unwrap();
374 assert_eq!(topic, "topic");
375 assert_eq!(event.name, "ping");
376
377 assert_eq!(conn.inbox.next_message(), None);
378 }
379
380 #[test]
381 fn error() {
382 let mut conn = Binding::new(&["foo"]);
383 conn.handle_incoming(ack()).unwrap();
384
385 conn.handle_incoming(Message::Error {
386 code: "code".to_string(),
387 context: "context".to_string(),
388 })
389 .unwrap();
390
391 assert_eq!(
392 conn.receive_event(),
393 Err(ProtocolError::ZeekError {
394 code: "code".to_string(),
395 context: "context".to_string()
396 })
397 );
398 }
399
400 #[test]
401 fn publish_event() {
402 let mut conn = Binding::new(&["foo"]);
403 conn.outgoing().unwrap();
405
406 conn.publish_event("foo", Event::new("ping", [(); 0]));
407 let message =
408 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
409 let Message::DataMessage {
410 topic,
411 data: Data::Event(event),
412 } = message
413 else {
414 panic!()
415 };
416 assert_eq!(topic, "foo");
417 assert_eq!(event.name, "ping");
418 }
419}