// SPDX-License-Identifier: MIT use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::future::Future; use std::path::{Path, PathBuf}; use std::sync::Arc; use async_lock::Mutex; use async_tftp::async_trait; use async_tftp::packet::Error; use async_tftp::server::handlers::{DirHandler, DirHandlerMode}; use futures::{AsyncRead, AsyncWrite, TryStreamExt}; use mlua::{FromLua, Lua, UserData, UserDataFields, Value}; #[derive(Clone)] pub struct Handler { lua: Arc>, call_key: Arc, dir_handler: Arc, http: reqwest::Client, } #[derive(Debug)] pub enum Resource { Http(String), // Parameter is URL File(String), // Parameter is content Data(Vec), Error(Error) } impl Resource { } impl UserData for Resource { } impl<'lua> FromLua<'lua> for Resource { fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result { value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take()) } } #[derive(Clone, Debug)] struct Client { address: SocketAddr, for_write: bool, } impl UserData for Client { fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { fields.add_field_method_get("mode", |_lua, client| Ok(if client.for_write { "w" } else { "r" })); fields.add_field_method_get("address", |lua, client| lua.create_any_userdata(client.address.ip())); } } impl Handler { pub fn new(srv_path: impl AsRef) -> Result { let lua = mlua::Lua::new(); let http = reqwest::Client::builder() .user_agent(concat!("sftpd/", env!("CARGO_PKG_VERSION"))) .build()?; lua.register_userdata_type::(|registry| { registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6})); })?; { // prepare resource types... let resources = lua.create_table()?; resources.set("HTTP", lua.create_function(|_lua, url: String| Ok(Resource::Http(url)))?)?; resources.set("FILE", lua.create_function(|_lua, path: String| Ok(Resource::File(path)))?)?; resources.set("DATA", lua.create_function(|_lua, url: mlua::String| Ok(Resource::Data(url.as_bytes().to_vec())))?)?; let err_tbl = lua.create_table()?; let err_mtbl = lua.create_table()?; err_tbl.set_metatable(Some(err_mtbl.clone())); let err_fn = lua.create_function(|_, msg: String| Ok(Resource::Error(Error::Msg(msg))))?; err_mtbl.set("__call", err_fn.clone())?; err_tbl.set("FileNotFound", Resource::Error(Error::FileNotFound))?; err_tbl.set("Unknown", Resource::Error(Error::UnknownError))?; err_tbl.set("PermissionDenied", Resource::Error(Error::PermissionDenied))?; err_tbl.set("DiskFull", Resource::Error(Error::DiskFull))?; err_tbl.set("IllegalOperation", Resource::Error(Error::IllegalOperation))?; err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?; err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?; err_tbl.set("Message", err_fn)?; resources.set("ERROR", err_tbl)?; lua.globals().set("resource", resources)?; lua.globals().set("state", lua.create_table()?)?; // } let handler_fn = lua.create_registry_value(0)?; Ok(Self { lua: Arc::new(Mutex::new(lua)), call_key: Arc::new(handler_fn), http, dir_handler: Arc::new(DirHandler::new(srv_path, DirHandlerMode::ReadWrite)?), }) } pub fn load_script(&mut self, data: PathBuf) -> impl Future> { let lua = self.lua.clone(); let key = self.call_key.clone(); async move { let lua = lua.lock_arc().await; let chunk = lua.load(data); let script_fn = chunk.into_function()?; // Prepare a fake client to determine whether the script should just be run in full for each request // or if it returns a function let client = Client { address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), for_write: false, }; let result: mlua::Value = script_fn.call_async(("", client, None as Option)).await?; if result.is_function() { lua.replace_registry_value(&*key, result)? } else if Resource::from_lua(result, &*lua).is_ok() { // We must run the script, it seems lua.replace_registry_value(&*key, script_fn)? } Ok(()) } } async fn call_handler_int(lua: Arc>, key: Arc, path: String, client: Client, size: Option) -> Result { lua.lock().await .registry_value::(&*key).map_err(|_| Error::UnknownError)? .call_async((path, client, size)).await .map_err(|_| Error::UnknownError) } async fn call_handler(&mut self, path: String, client: Client, size: Option) -> Result { let handle = tokio::task::spawn_local(Self::call_handler_int(self.lua.clone(), self.call_key.clone(), path, client, size)); handle.await.map_err(|err| Error::Msg(err.to_string())) .and_then(|x| x) } } #[async_trait] impl async_tftp::server::Handler for Handler { type Reader = Box; type Writer = Box; async fn read_req_open(&mut self, client: &SocketAddr, path: &Path) -> Result<(Self::Reader, Option), Error> { let lua_client = Client { address: client.clone(), for_write: false, }; let resource: Resource = self.call_handler(path.to_str().ok_or(Error::FileNotFound)?.to_owned(), lua_client, None).await?; match resource { Resource::Http(url) => { // TODO: Add headers describing client let req = self.http.get(url).send().await.map_err(|err| Error::Msg(err.to_string()))?; let size = req.content_length(); let stream = req.bytes_stream() .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e)) .into_async_read(); Ok((Box::new(stream), size)) } Resource::File(path) => { let (rdr, size) = (*self.dir_handler).clone().read_req_open(client, Path::new(path.as_str())).await?; Ok((Box::new(rdr), size)) } Resource::Data(data) => { let len = data.len() as u64; Ok(( Box::new(futures::io::Cursor::new(data)), Some(len), )) } Resource::Error(err) => { return Err(err) } } } async fn write_req_open(&mut self, _client: &SocketAddr, _path: &Path, _size: Option) -> Result { todo!(); #[cfg(ignore)] { let lua_client = Client { address: client.clone(), for_write: true, }; let resource: Resource = { let mut lua = self.lua.lock_arc().await; let lc = lua_client.clone().into_lua(&*lua).map_err(|_| Error::UnknownError)?; let handle_fn: mlua::Function = lua.registry_value(&self.call_key).map_err(|_| Error::UnknownError)?; let result = handle_fn.call_async((path.to_str().ok_or(Error::FileNotFound)?, lc, size)).await.map_err(|_| Error::UnknownError)?; // let lc = lc.as_userdata().and_then(|ud| ud.take().ok()).unwrap_or(lua_client); result }; match resource { Resource::Http(_) => { todo!() } Resource::File(_) => { todo!() } Resource::Data(_) => { todo!() } Resource::Error(_) => { todo!() } } } } }