diff --git a/src/lib.rs b/src/lib.rs index 7d89f97..19026cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -132,10 +132,13 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::{ Ident, + Item, LitStr, Result, Token, + Visibility, parse::{Parse, ParseStream}, + parse_file, parse_macro_input, }; @@ -280,18 +283,58 @@ pub fn folder_router(input: TokenStream) -> TokenStream { expanded.into() } -fn methods_for_route(route_path: &PathBuf) -> Vec<&str> { - let file_content = fs::read_to_string(&route_path).unwrap_or_default(); - let methods = ["get", "post", "put", "delete", "patch", "head", "options"]; +/// parses the file at the specified location using syn +/// and returns a Vec<&'static str> of all used http verb fns +/// e.g. for the file +/// +/// ```rust +/// pub async fn get() {} # ✅ => "get" be added to vec +/// pub fn post() {} # not async +/// async fn delete() {} # not pub +/// fn patch() {} # not pub nor async +/// pub fn non_verb() {} # not a http verb +/// ``` +/// +/// it returns: `vec!["get"]` +/// +fn methods_for_route(route_path: &PathBuf) -> Vec<&'static str> { + // Read the file content + let file_content = match fs::read_to_string(route_path) { + Ok(content) => content, + Err(_) => return Vec::new(), + }; - let mut method_registrations = Vec::new(); - for method in methods { - if file_content.contains(&format!("pub async fn {}(", method)) { - // let method_ident = format_ident!("{}", method); - method_registrations.push(method); + // Parse the file content into a syn syntax tree + let file = match parse_file(&file_content) { + Ok(file) => file, + Err(_) => return Vec::new(), + }; + + // Define HTTP methods we're looking for + let methods = ["get", "post", "put", "delete", "patch", "head", "options"]; + let mut found_methods = Vec::new(); + + // Examine each item in the file + for item in &file.items { + if let Item::Fn(fn_item) = item { + let fn_name = fn_item.sig.ident.to_string(); + + // Check if the function name is one of our HTTP methods + if let Some(&method) = methods.iter().find(|&&m| m == fn_name) { + // Check if the function is public + let is_public = matches!(fn_item.vis, Visibility::Public(_)); + + // Check if the function is async + let is_async = fn_item.sig.asyncness.is_some(); + + if is_public && is_async { + found_methods.push(method); + } + } } } - method_registrations + + found_methods } // Add a route to the module tree