1use std::collections::VecDeque;
49
50use thiserror::Error;
51use tungstenite::Bytes;
52use zeek_websocket_types::{Data, Event, Message, Value};
53
54use crate::types::Subscriptions;
55
56pub struct Binding {
60 state: State,
61
62 inbox: Inbox,
63 outbox: Outbox,
64}
65
66enum State {
67 Subscribing,
68 Subscribed,
69}
70
71impl Binding {
72 #[must_use]
79 pub fn new<S>(subscriptions: S) -> Self
80 where
81 S: Into<Subscriptions>,
82 {
83 let subscriptions = subscriptions.into();
84 Self {
85 state: State::Subscribing,
86 inbox: Inbox(VecDeque::new()),
87 outbox: Outbox(VecDeque::from([subscriptions.into()])),
88 }
89 }
90
91 pub fn handle_incoming(&mut self, message: Message) -> Result<(), ProtocolError> {
101 match &message {
102 Message::Ack { .. } => match self.state {
103 State::Subscribing => {
104 self.state = State::Subscribed;
105 }
106 State::Subscribed => return Err(ProtocolError::AlreadySubscribed),
107 },
108 Message::DataMessage {
109 data: Data::Other(unexpected),
110 ..
111 } => {
112 return Err(ProtocolError::UnexpectedEventPayload(unexpected.clone()));
113 }
114 _ => {
115 self.inbox.handle(message);
116 }
117 }
118
119 Ok(())
120 }
121
122 pub fn outgoing(&mut self) -> Option<Bytes> {
124 self.outbox.next_data()
125 }
126
127 pub fn receive_event(&mut self) -> Result<Option<(String, Event)>, ProtocolError> {
133 if let Some(message) = self.inbox.next_message() {
134 match message {
135 Message::DataMessage { topic, data } => {
136 let event = match data {
137 Data::Event(event) => event,
138 Data::Other(..) => unreachable!(), };
140 return Ok(Some((topic, event)));
141 }
142 Message::Error { code, context } => {
143 return Err(ProtocolError::ZeekError { code, context });
144 }
145 Message::Ack { .. } => {
146 unreachable!() }
148 }
149 }
150
151 Ok(None)
152 }
153
154 fn enqueue(&mut self, message: Message) {
156 match message {
157 Message::DataMessage { topic, data } => {
158 self.outbox.enqueue(Message::DataMessage { topic, data });
159 }
160 _ => self.outbox.enqueue(message),
161 }
162 }
163
164 pub fn publish_event<S>(&mut self, topic: S, event: Event)
166 where
167 S: Into<String>,
168 {
169 self.enqueue(Message::new_data(topic.into(), event));
170 }
171
172 #[must_use]
178 pub fn split(self) -> (Inbox, Outbox) {
179 (self.inbox, self.outbox)
180 }
181}
182
183pub struct Inbox(VecDeque<Message>);
185
186impl Inbox {
187 pub fn handle(&mut self, message: Message) {
189 self.0.push_back(message);
190 }
191
192 #[must_use]
194 pub fn next_message(&mut self) -> Option<Message> {
195 self.0.pop_front()
196 }
197
198 pub fn next_event(&mut self) -> Option<(String, Event)> {
202 while let Some(message) = self.next_message() {
203 if let Message::DataMessage {
204 topic,
205 data: Data::Event(event),
206 } = message
207 {
208 return Some((topic, event));
209 }
210 }
211
212 None
213 }
214}
215
216pub struct Outbox(VecDeque<tungstenite::Message>);
218
219impl Outbox {
220 pub fn next_data(&mut self) -> Option<Bytes> {
222 self.0.pop_front().map(tungstenite::Message::into_data)
223 }
224
225 pub fn enqueue<M>(&mut self, message: M)
227 where
228 M: Into<tungstenite::Message>,
229 {
230 self.0.push_back(message.into());
231 }
232
233 pub fn enqueue_event<S>(&mut self, topic: S, event: Event)
235 where
236 S: Into<String>,
237 {
238 self.enqueue(Message::new_data(topic.into(), event));
239 }
240}
241
242#[derive(Error, Debug, PartialEq)]
244pub enum ProtocolError {
245 #[error("received an ACK while already subscribed")]
247 AlreadySubscribed,
248
249 #[error("Zeek error {code}: {context}")]
250 ZeekError { code: String, context: String },
251
252 #[error("unexpected event payload received")]
253 UnexpectedEventPayload(Value),
254}
255
256#[cfg(test)]
257mod test {
258 use crate::{
259 protocol::{Binding, ProtocolError},
260 types::{Data, Event, Message, Subscriptions, Value},
261 };
262
263 fn ack() -> Message {
264 Message::Ack {
265 endpoint: "mock".into(),
266 version: "0.1".into(),
267 }
268 }
269
270 #[test]
271 fn recv() {
272 let topic = "foo";
273
274 let mut conn = Binding::new(&[topic]);
275
276 assert_eq!(conn.inbox.next_message(), None);
278
279 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
281 conn.handle_incoming(ack().into()).unwrap();
282
283 assert_eq!(conn.inbox.next_message(), None);
285 assert_eq!(conn.receive_event(), Ok(None));
286
287 conn.handle_incoming(Message::new_data(topic, Event::new("ping", [(); 0])).into())
289 .unwrap();
290
291 assert!(matches!(
292 conn.inbox.next_message(),
293 Some(Message::DataMessage {
294 data: Data::Event(..),
295 ..
296 })
297 ));
298 }
299
300 #[test]
301 fn send() {
302 let mut conn = Binding::new(&["foo"]);
303
304 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
306 conn.handle_incoming(ack().into()).unwrap();
307
308 conn.publish_event("foo", Event::new("ping", [(); 0]));
310
311 let msg =
313 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
314 assert!(matches!(
315 msg,
316 Message::DataMessage {
317 data: Data::Event(..),
318 ..
319 }
320 ));
321 }
322
323 #[test]
324 fn split() {
325 let (mut inbox, mut outbox) = Binding::new(&["foo"]).split();
326
327 Subscriptions::try_from(tungstenite::Message::binary(outbox.next_data().unwrap())).unwrap();
329 inbox.handle(ack().into());
330
331 assert!(matches!(inbox.next_message(), Some(Message::Ack { .. })));
332 }
333
334 #[test]
335 fn duplicate_ack() {
336 let mut conn = Binding::new(&["foo"]);
337
338 Subscriptions::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
340 conn.handle_incoming(ack().into()).unwrap();
341
342 assert_eq!(
344 conn.handle_incoming(ack().into()),
345 Err(ProtocolError::AlreadySubscribed)
346 );
347 }
348
349 #[test]
350 fn other_event_payload() {
351 let mut conn = Binding::new(&["foo"]);
352 conn.handle_incoming(ack()).unwrap();
353
354 let other = Message::new_data("foo", Value::Count(42));
355 assert_eq!(
356 conn.handle_incoming(other),
357 Err(ProtocolError::UnexpectedEventPayload(Value::Count(42)))
358 );
359 }
360
361 #[test]
362 fn next_incoming() {
363 let mut conn = Binding::new(Subscriptions(Vec::new()));
364
365 let _ = conn.handle_incoming(ack());
367 let _ = conn.handle_incoming(Message::new_data("topic", Event::new("ping", [(); 0])));
368
369 let (topic, event) = conn.receive_event().unwrap().unwrap();
372 assert_eq!(topic, "topic");
373 assert_eq!(event.name, "ping");
374
375 assert_eq!(conn.inbox.next_message(), None);
376 }
377
378 #[test]
379 fn error() {
380 let mut conn = Binding::new(&["foo"]);
381 conn.handle_incoming(ack()).unwrap();
382
383 conn.handle_incoming(Message::Error {
384 code: "code".to_string(),
385 context: "context".to_string(),
386 })
387 .unwrap();
388
389 assert_eq!(
390 conn.receive_event(),
391 Err(ProtocolError::ZeekError {
392 code: "code".to_string(),
393 context: "context".to_string()
394 })
395 );
396 }
397
398 #[test]
399 fn publish_event() {
400 let mut conn = Binding::new(&["foo"]);
401 conn.outgoing().unwrap();
403
404 conn.publish_event("foo", Event::new("ping", [(); 0]));
405 let message =
406 Message::try_from(tungstenite::Message::binary(conn.outgoing().unwrap())).unwrap();
407 let Message::DataMessage {
408 topic,
409 data: Data::Event(event),
410 } = message
411 else {
412 panic!()
413 };
414 assert_eq!(topic, "foo");
415 assert_eq!(event.name, "ping");
416 }
417}