zeek_websocket_derive/
lib.rs

1use proc_macro_error2::{abort, proc_macro_error};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Field, Fields, Ident, Type, parse_macro_input, spanned::Spanned};
5
6/// # Derive macro to convert a type from and to a `zeek_websocket_types::Value`
7///
8/// Zeek's WebSocket API encodes Zeek `record` values as vectors. This derive macro adds support
9/// for automatically converting Rust types to and from the encoding. It supports `struct`s
10/// made up of fields which implement `TryFrom<Value>` and `Into<Value>`.
11///
12/// ```
13/// # use zeek_websocket_derive::ZeekType;
14/// # use zeek_websocket_types::Value;
15/// #[derive(Debug, PartialEq)] // Not required.
16/// #[derive(ZeekType)]
17/// struct Record {
18///     a: i64,
19///     b: u64,
20/// }
21///
22/// let r = Record { a: -32, b: 1024 };
23///
24/// let value = Value::from(r);
25/// assert_eq!(
26///     value,
27///     Value::Vector(vec![Value::Integer(-32), Value::Count(1024)]));
28///
29/// let r = Record::try_from(value).unwrap();
30/// assert_eq!(r, Record { a: -32, b: 1024 });
31/// ```
32///
33/// If more than the expected number of fields are received they are silently discarded.
34///
35/// ```
36/// # use zeek_websocket_derive::ZeekType;
37/// # use zeek_websocket_types::Value;
38/// # #[derive(Debug, PartialEq)] // Not required.
39/// # #[derive(ZeekType)]
40/// # struct Record {
41/// #     a: i64,
42/// #     b: u64,
43/// # }
44/// #
45/// let v = Value::Vector(vec![Value::Integer(1), Value::Count(2), Value::Count(3)]);
46/// let r: Record = v.try_into().unwrap();
47/// assert_eq!(r, Record { a: 1, b: 2 });
48///
49/// // Unknown fields are not magically added back when encoding. This is supported by Zeek.
50/// let v2 = Value::from(r);
51/// assert_eq!(v2, Value::Vector(vec![Value::Integer(1), Value::Count(2)]));
52/// ```
53///
54/// ## Optional fields
55///
56/// Zeek `record` fields which can be unset are marked `&optional`, e.g.,
57///
58/// ```zeek
59/// type X: record {
60///     a: count;
61///     b: int &optional;
62/// };
63/// ```
64///
65/// This is used to evolve Zeek `record` types so that users do not need to be updated if
66/// more fields are added.
67///
68/// The WebSocket API encodes unset fields as `Value::None`. To work with such types the Rust type
69/// should be an `Option`, e.g.,
70///
71/// ```
72/// # use zeek_websocket_types::Value;
73/// # use zeek_websocket_derive::ZeekType;
74/// #[derive(ZeekType)]
75/// struct X {
76///     a: u64,
77///     b: Option<i64>,
78/// }
79/// ```
80///
81/// `Value::None` maps onto `Option::None`.
82///
83/// ```
84/// # use zeek_websocket_types::Value;
85/// # use zeek_websocket_derive::ZeekType;
86/// # #[derive(Debug, PartialEq)] // Not required.
87/// # #[derive(ZeekType)]
88/// # struct X {
89/// #     a: u64,
90/// #     b: Option<i64>,
91/// # }
92/// let v = Value::Vector(vec![Value::Count(1), Value::None]);
93/// let x: X = v.try_into().unwrap();
94/// assert_eq!(x, X { a: 1, b: None });
95/// ```
96///
97/// Anything else maps onto `Option::Some`.
98/// ```
99/// # use zeek_websocket_types::Value;
100/// # use zeek_websocket_derive::ZeekType;
101/// # #[derive(Debug, PartialEq)] // Not required.
102/// # #[derive(ZeekType)]
103/// # struct X {
104/// #     a: u64,
105/// #     b: Option<i64>,
106/// # }
107/// let v = Value::Vector(vec![Value::Count(1), Value::Integer(2)]);
108/// let x: X = v.try_into().unwrap();
109/// assert_eq!(x, X { a: 1, b: Some(2) });
110/// ```
111///
112/// If no value was received for an optional field it is set to `None`. Non-`Option` fields are
113/// always required.
114///
115/// ```
116/// # use zeek_websocket_types::Value;
117/// # use zeek_websocket_derive::ZeekType;
118/// # #[derive(Debug, PartialEq)] // Not required.
119/// # #[derive(ZeekType)]
120/// # struct X {
121/// #     a: u64,
122/// #     b: Option<i64>,
123/// # }
124/// let v = Value::Vector(vec![Value::Count(1)]);
125/// let x: X = v.try_into().unwrap();
126/// assert_eq!(x, X { a: 1, b: None });
127///
128/// // Error for non-`Option` fields.
129/// let x: Result<X, _> = Value::Vector(vec![]).try_into();
130/// assert!(x.is_err());
131/// ```
132///
133#[proc_macro_error]
134#[proc_macro_derive(ZeekType)]
135pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
136    let ast = parse_macro_input!(input as DeriveInput);
137
138    let name = &ast.ident;
139
140    let Data::Struct(struct_) = &ast.data else {
141        abort!(ast.span(), "only structs can derive ZeekType");
142    };
143
144    let Fields::Named(fields) = &struct_.fields else {
145        abort!(
146            ast.span(),
147            "only structs with named fields can derive ZeekType"
148        );
149    };
150    let fields: Vec<_> = fields.named.iter().cloned().collect();
151
152    let value_from = impl_value_from(name, &fields);
153    let from_value = impl_from_value(name, &fields);
154
155    quote! {
156        #value_from
157
158        #from_value
159    }
160    .into()
161}
162
163fn impl_value_from(name: &Ident, fields: &[Field]) -> TokenStream {
164    let fields = fields.iter().map(|f| {
165        let Some(field_name) = &f.ident else {
166            abort!(f.span(), "unsupported field name");
167        };
168
169        if looks_like_option(&f.ty) {
170            let x = format_ident!("__zeek_websocket_derive__impl_value_from_{name}__{field_name}");
171            quote! {
172                match value.#field_name {
173                    Some(#x) => ::zeek_websocket_types::Value::from(#x),
174                    None => ::zeek_websocket_types::Value::None,
175                }
176            }
177        } else {
178            quote! { ::zeek_websocket_types::Value::from(value.#field_name) }
179        }
180    });
181
182    quote! {
183        impl From<#name> for ::zeek_websocket_types::Value {
184            fn from(value: #name) -> Self {
185                Self::from(::std::vec::Vec::<::zeek_websocket_types::Value>::from([#(#fields), *]))
186            }
187        }
188    }
189}
190
191fn impl_from_value(name: &Ident, fields: &[Field]) -> TokenStream {
192    let xs = format_ident!("__zeek_websocket_derive__impl_from_value_{name}");
193
194    // Validate that all fields are named so we can cleanly unwrap below.
195    if fields.iter().any(|f| f.ident.is_none()) {
196        abort!(name.span(), "unnamed fields are unsupported");
197    }
198
199    let fields = fields.iter().map(|f| {
200        let field_name = f.ident.as_ref().unwrap();
201
202        let init = quote! {
203            #xs
204                .next()
205                .unwrap_or(::zeek_websocket_types::Value::None)
206        };
207
208        let x = format_ident!("__zeek_websocket_derive__impl_from_value_{name}__{field_name}");
209
210        (
211            field_name.clone(),
212            if looks_like_option(&f.ty) {
213                quote! {
214                    let #field_name = match #init {
215                        ::zeek_websocket_types::Value::None => None,
216                        #x => Some(#x.try_into()?),
217                    };
218                }
219            } else {
220                quote! { let #field_name = #init.try_into()?; }
221            },
222        )
223    });
224
225    let (names, inits): (Vec<_>, Vec<_>) = fields.unzip();
226
227    quote! {
228        impl TryFrom<::zeek_websocket_types::Value> for #name {
229            type Error = ::zeek_websocket_types::ConversionError;
230
231            fn try_from(value: ::zeek_websocket_types::Value) -> Result<Self, Self::Error> {
232                #[allow(non_snake_case)]
233                let ::zeek_websocket_types::Value::Vector(#xs) = value else {
234                    return Err(::zeek_websocket_types::ConversionError::MismatchedTypes);
235                };
236                let mut #xs = #xs.into_iter();
237
238                #(#inits)*
239
240                Ok(#name { #(#names),* })
241            }
242        }
243    }
244}
245
246// Helper function to detect whether a type is likely an instance of `Option`.
247fn looks_like_option(ty: &Type) -> bool {
248    let Type::Path(p) = ty else {
249        abort!(ty.span(), "unsupported type");
250    };
251
252    p.path
253        .segments
254        .last()
255        .map(|ty| ty.ident == "Option")
256        .unwrap_or_default()
257}