zeek_websocket_derive/lib.rs
1use proc_macro_error2::{abort, proc_macro_error};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::{Data, DeriveInput, Field, Fields, Ident, Type, parse_macro_input, spanned::Spanned};
5
6/// # Derive macro to convert a type from and to a `zeek_websocket_types::Value`
7///
8/// Zeek's WebSocket API encodes Zeek `record` values as vectors. This derive macro adds support
9/// for automatically converting Rust types to and from the encoding. It supports `struct`s
10/// made up of fields which implement `TryFrom<Value>` and `Into<Value>`.
11///
12/// ```
13/// # use zeek_websocket_derive::ZeekType;
14/// # use zeek_websocket_types::Value;
15/// #[derive(Debug, PartialEq)] // Not required.
16/// #[derive(ZeekType)]
17/// struct Record {
18/// a: i64,
19/// b: u64,
20/// }
21///
22/// let r = Record { a: -32, b: 1024 };
23///
24/// let value = Value::from(r);
25/// assert_eq!(
26/// value,
27/// Value::Vector(vec![Value::Integer(-32), Value::Count(1024)]));
28///
29/// let r = Record::try_from(value).unwrap();
30/// assert_eq!(r, Record { a: -32, b: 1024 });
31/// ```
32///
33/// If more than the expected number of fields are received they are silently discarded.
34///
35/// ```
36/// # use zeek_websocket_derive::ZeekType;
37/// # use zeek_websocket_types::Value;
38/// # #[derive(Debug, PartialEq)] // Not required.
39/// # #[derive(ZeekType)]
40/// # struct Record {
41/// # a: i64,
42/// # b: u64,
43/// # }
44/// #
45/// let v = Value::Vector(vec![Value::Integer(1), Value::Count(2), Value::Count(3)]);
46/// let r: Record = v.try_into().unwrap();
47/// assert_eq!(r, Record { a: 1, b: 2 });
48///
49/// // Unknown fields are not magically added back when encoding. This is supported by Zeek.
50/// let v2 = Value::from(r);
51/// assert_eq!(v2, Value::Vector(vec![Value::Integer(1), Value::Count(2)]));
52/// ```
53///
54/// ## Optional fields
55///
56/// Zeek `record` fields which can be unset are marked `&optional`, e.g.,
57///
58/// ```zeek
59/// type X: record {
60/// a: count;
61/// b: int &optional;
62/// };
63/// ```
64///
65/// This is used to evolve Zeek `record` types so that users do not need to be updated if
66/// more fields are added.
67///
68/// The WebSocket API encodes unset fields as `Value::None`. To work with such types the Rust type
69/// should be an `Option`, e.g.,
70///
71/// ```
72/// # use zeek_websocket_types::Value;
73/// # use zeek_websocket_derive::ZeekType;
74/// #[derive(ZeekType)]
75/// struct X {
76/// a: u64,
77/// b: Option<i64>,
78/// }
79/// ```
80///
81/// `Value::None` maps onto `Option::None`.
82///
83/// ```
84/// # use zeek_websocket_types::Value;
85/// # use zeek_websocket_derive::ZeekType;
86/// # #[derive(Debug, PartialEq)] // Not required.
87/// # #[derive(ZeekType)]
88/// # struct X {
89/// # a: u64,
90/// # b: Option<i64>,
91/// # }
92/// let v = Value::Vector(vec![Value::Count(1), Value::None]);
93/// let x: X = v.try_into().unwrap();
94/// assert_eq!(x, X { a: 1, b: None });
95/// ```
96///
97/// Anything else maps onto `Option::Some`.
98/// ```
99/// # use zeek_websocket_types::Value;
100/// # use zeek_websocket_derive::ZeekType;
101/// # #[derive(Debug, PartialEq)] // Not required.
102/// # #[derive(ZeekType)]
103/// # struct X {
104/// # a: u64,
105/// # b: Option<i64>,
106/// # }
107/// let v = Value::Vector(vec![Value::Count(1), Value::Integer(2)]);
108/// let x: X = v.try_into().unwrap();
109/// assert_eq!(x, X { a: 1, b: Some(2) });
110/// ```
111///
112/// If no value was received for an optional field it is set to `None`. Non-`Option` fields are
113/// always required.
114///
115/// ```
116/// # use zeek_websocket_types::Value;
117/// # use zeek_websocket_derive::ZeekType;
118/// # #[derive(Debug, PartialEq)] // Not required.
119/// # #[derive(ZeekType)]
120/// # struct X {
121/// # a: u64,
122/// # b: Option<i64>,
123/// # }
124/// let v = Value::Vector(vec![Value::Count(1)]);
125/// let x: X = v.try_into().unwrap();
126/// assert_eq!(x, X { a: 1, b: None });
127///
128/// // Error for non-`Option` fields.
129/// let x: Result<X, _> = Value::Vector(vec![]).try_into();
130/// assert!(x.is_err());
131/// ```
132///
133#[proc_macro_error]
134#[proc_macro_derive(ZeekType)]
135pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
136 let ast = parse_macro_input!(input as DeriveInput);
137
138 let name = &ast.ident;
139
140 let Data::Struct(struct_) = &ast.data else {
141 abort!(ast.span(), "only structs can derive ZeekType");
142 };
143
144 let Fields::Named(fields) = &struct_.fields else {
145 abort!(
146 ast.span(),
147 "only structs with named fields can derive ZeekType"
148 );
149 };
150 let fields: Vec<_> = fields.named.iter().cloned().collect();
151
152 let value_from = impl_value_from(name, &fields);
153 let from_value = impl_from_value(name, &fields);
154
155 quote! {
156 #value_from
157
158 #from_value
159 }
160 .into()
161}
162
163fn impl_value_from(name: &Ident, fields: &[Field]) -> TokenStream {
164 let fields = fields.iter().map(|f| {
165 let Some(field_name) = &f.ident else {
166 abort!(f.span(), "unsupported field name");
167 };
168
169 if looks_like_option(&f.ty) {
170 let x = format_ident!("__zeek_websocket_derive__impl_value_from_{name}__{field_name}");
171 quote! {
172 match value.#field_name {
173 Some(#x) => ::zeek_websocket_types::Value::from(#x),
174 None => ::zeek_websocket_types::Value::None,
175 }
176 }
177 } else {
178 quote! { ::zeek_websocket_types::Value::from(value.#field_name) }
179 }
180 });
181
182 quote! {
183 impl From<#name> for ::zeek_websocket_types::Value {
184 fn from(value: #name) -> Self {
185 Self::from(::std::vec::Vec::<::zeek_websocket_types::Value>::from([#(#fields), *]))
186 }
187 }
188 }
189}
190
191fn impl_from_value(name: &Ident, fields: &[Field]) -> TokenStream {
192 let xs = format_ident!("__zeek_websocket_derive__impl_from_value_{name}");
193
194 // Validate that all fields are named so we can cleanly unwrap below.
195 if fields.iter().any(|f| f.ident.is_none()) {
196 abort!(name.span(), "unnamed fields are unsupported");
197 }
198
199 let fields = fields.iter().map(|f| {
200 let field_name = f.ident.as_ref().unwrap();
201
202 let init = quote! {
203 #xs
204 .next()
205 .unwrap_or(::zeek_websocket_types::Value::None)
206 };
207
208 let x = format_ident!("__zeek_websocket_derive__impl_from_value_{name}__{field_name}");
209
210 (
211 field_name.clone(),
212 if looks_like_option(&f.ty) {
213 quote! {
214 let #field_name = match #init {
215 ::zeek_websocket_types::Value::None => None,
216 #x => Some(#x.try_into()?),
217 };
218 }
219 } else {
220 quote! { let #field_name = #init.try_into()?; }
221 },
222 )
223 });
224
225 let (names, inits): (Vec<_>, Vec<_>) = fields.unzip();
226
227 quote! {
228 impl TryFrom<::zeek_websocket_types::Value> for #name {
229 type Error = ::zeek_websocket_types::ConversionError;
230
231 fn try_from(value: ::zeek_websocket_types::Value) -> Result<Self, Self::Error> {
232 #[allow(non_snake_case)]
233 let ::zeek_websocket_types::Value::Vector(#xs) = value else {
234 return Err(::zeek_websocket_types::ConversionError::MismatchedTypes);
235 };
236 let mut #xs = #xs.into_iter();
237
238 #(#inits)*
239
240 Ok(#name { #(#names),* })
241 }
242 }
243 }
244}
245
246// Helper function to detect whether a type is likely an instance of `Option`.
247fn looks_like_option(ty: &Type) -> bool {
248 let Type::Path(p) = ty else {
249 abort!(ty.span(), "unsupported type");
250 };
251
252 p.path
253 .segments
254 .last()
255 .map(|ty| ty.ident == "Option")
256 .unwrap_or_default()
257}