diff --git a/Cargo.lock b/Cargo.lock index 0a3254f..297b20a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,23 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", +] + +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "ansi_term" version = "0.12.1" @@ -546,6 +563,10 @@ name = "hashbrown" version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -1261,10 +1282,12 @@ dependencies = [ "futures", "libc", "listenfd", + "log", "mlua", "reqwest", "structopt", "tokio", + "tokio-util", ] [[package]] @@ -1443,7 +1466,10 @@ checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", + "futures-util", + "hashbrown", "pin-project-lite", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index f33c91e..dacce0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,10 @@ async-tftp = { path = "vendor/async-tftp-rs" } futures = "0.3.29" anyhow = "1" fern = "0.6.2" +log = "0.4.20" structopt = "0.3" tokio = { version = "1.32.0", features = ["rt-multi-thread", "macros", "rt"] } +tokio-util = { version = "0.7.10", features = ["io", "io-util", "rt", "compat"] } mlua = { version = "0.9.1", features = ["luau-jit", "vendored", "async", "send"] } reqwest = { version = "0.11.22", features = ["stream"] } listenfd = "1.0.1" diff --git a/README.md b/README.md index a35b518..b4b8b0a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,25 @@ -Scriptable TFTP server -====================== +# Scriptable TFTP server -This is a TFTP server, designed for use with Installa. +This is a TFTP application server. Potential uses include: +* Integration with a provisioning system to deliver custom boot scripts +* Backing up router configs directly to an HTTP server +* Integration -It supports responding to TFTP requests differently depending on client, possibly contacting a database or HTTP server for instructions, and bridging to a HTTP server or a pipe. +It supports responding to TFTP requests differently depending on client, possibly contacting +a database or HTTP server for instructions, and bridging to a HTTP server or a pipe. -Further, it comes with full support for RFC 7440, which can speed up TFTP downloads by an order of magnitude on low-latency gigabit networks, or even more on high-latency links. +# Status + +* Lua scripting + * [x] Direct file serving + * [x] HTTP proxy + * [x] Lua-computed raw data + * [x] Lua-computed error code + * [ ] Redis support + * [ ] IP Address calculation (CIDR support) + * [ ] Examination of FILE root + * [ ] Lua-driven HTTP requests +* TFTP features + * [ ] RFC 7440 - window size extension (can speed up TFTP downloads by >10x) + +# Installation diff --git a/doc/lua_api.md b/doc/lua_api.md index f8d94c6..260e769 100644 --- a/doc/lua_api.md +++ b/doc/lua_api.md @@ -1,5 +1,9 @@ # Overview +> :warning: +> Not all of this is presently implemented. See the [README](../README.md) for details +> on implementation status. + The lua script must have the following form: ```lua @@ -99,4 +103,24 @@ You can test the version using either `.version`, which returns either 4 or 6, o `.is_v4` and `.is_v6`. If you need the IPv6-mapped version of a v4 address, (i.e., `::ffff:0.0.0.0`), this is available through `.to_v6` -This is probably most useful in conjunction with `stftpd.Cidr` \ No newline at end of file +This is probably most useful in conjunction with `stftpd.Cidr` + +# Examples + +Probably the most useful starting point for your own script will be the following: + +```lua +return function(path, client, size) + if path:match("^/?^/?https?://") then + if path:sub(1,1) == "/" then + path = path:sub(2) + end + return resource.HTTP(path) + end + return resource.FILE(path) +end +``` + +This detects HTTP (or HTTPS) urls, and transparently proxies them to the appropriate HTTP server. +Some TFTP clients automatically insert a leading `/`, so this function strips it if given. +If the given path doesn't look like an HTTP URL, this example tries to handle it as a file. \ No newline at end of file diff --git a/src/engine.rs b/src/engine.rs new file mode 100644 index 0000000..7fe88c4 --- /dev/null +++ b/src/engine.rs @@ -0,0 +1,257 @@ +use std::future::Future; +use mlua::{FromLua, IntoLua, Lua, UserData, UserDataFields, Value}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::{Path, PathBuf}; +use anyhow::anyhow; +use async_tftp::packet::Error; +use futures::TryFutureExt; +use tokio::sync::oneshot::error::RecvError; + +pub struct LuaRunner { + +} + +#[derive(Clone, Debug)] +pub struct Client { + pub address: SocketAddr, + pub for_write: bool, +} + +impl UserData for Client { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("mode", |_lua, client| Ok(if client.for_write { "w" } else { "r" })); + fields.add_field_method_get("address", |lua, client| lua.create_any_userdata(client.address.ip())); + } +} + +#[derive(Debug)] +pub enum Resource { + Http(String), // Parameter is URL + File(String), // Parameter is content + Data(Vec), + Error(Error) +} + +impl Resource { +} + +impl UserData for Resource { + +} + +impl<'lua> FromLua<'lua> for Resource { + fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result { + value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take()) + } +} + +struct EngineResolveReq { + r_chan: tokio::sync::oneshot::Sender>, + path: PathBuf, + client: Client, + size: Option, +} + +struct EngineResolveRsp { + resource: Resource, + client: Client, +} +enum EngineReq { + Resolve(EngineResolveReq), + ScriptChange(PathBuf, tokio::sync::oneshot::Sender>), +} + + +#[derive(Clone)] +pub struct Engine { + chan: tokio::sync::mpsc::Sender +} + +pub struct EngineImpl { + lua: &'static mlua::Lua, + resolver: std::sync::Arc>>, + chan: Option>, +} + +impl EngineImpl { + pub(crate) fn init(&mut self) -> anyhow::Result<()> { + let lua = &* self.lua; + lua.load_from_std_lib(mlua::StdLib::ALL)?; + + + lua.register_userdata_type::(|registry| { + registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6})); + })?; + + { + // prepare resource types... + let resources = lua.create_table()?; + resources.set("HTTP", lua.create_function(|_lua, url: String| Ok(Resource::Http(url)))?)?; + resources.set("FILE", lua.create_function(|_lua, path: String| Ok(Resource::File(path)))?)?; + resources.set("DATA", lua.create_function(|_lua, url: mlua::String| Ok(Resource::Data(url.as_bytes().to_vec())))?)?; + + let err_tbl = lua.create_table()?; + let err_mtbl = lua.create_table()?; + err_tbl.set_metatable(Some(err_mtbl.clone())); + + let err_fn = lua.create_function(|_, msg: String| Ok(Resource::Error(Error::Msg(msg))))?; + err_mtbl.set("__call", err_fn.clone())?; + err_tbl.set("FileNotFound", Resource::Error(Error::FileNotFound))?; + err_tbl.set("Unknown", Resource::Error(Error::UnknownError))?; + err_tbl.set("PermissionDenied", Resource::Error(Error::PermissionDenied))?; + err_tbl.set("DiskFull", Resource::Error(Error::DiskFull))?; + err_tbl.set("IllegalOperation", Resource::Error(Error::IllegalOperation))?; + err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?; + err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?; + err_tbl.set("Message", err_fn)?; + + resources.set("ERROR", err_tbl)?; + lua.globals().set("resource", resources)?; + lua.globals().set("state", lua.create_table()?)?; + + // + } + + // Default resolver + Ok(()) + } + + pub fn load_script(&self, path: PathBuf) -> impl Future> { + + let lua: &'static mlua::Lua = self.lua; + let resolver = self.resolver.clone(); + + async move { + log::info!("Loading new script from {:?}", path); + let chunk = lua.load::<'static, '_>(path.clone()); + + let script_fn = chunk.into_function()?; + + // Prepare a fake client to determine whether the script should just be run in full for each request + // or if it returns a function + let client = Client { + address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + for_write: false, + }; + + let result: mlua::Value<'static> = script_fn.call_async(("", client, None as Option)).await?; + if let Ok(result) = mlua::Function::from_lua(result.clone(), lua) { + log::info!("Successfully loaded standard script from {:?}", path); + *resolver.write().await = result; + } else if Resource::from_lua(result.clone(), lua).is_ok() { + // We must run the script, it seems + log::info!("Successfully loaded simple script from {:?}", path); + *resolver.write().await = script_fn; + } else { + anyhow::bail!("Invalid return type {}", result.type_name()); + } + Ok(()) + } + } + + pub async fn run(mut self) { + let mut chan = self.chan.take().unwrap(); + while let Some(req) = chan.recv().await { + match req { + EngineReq::Resolve(req) => { + let lua = self.lua; + let resolver = self.resolver.read().await.clone(); + tokio::task::spawn_local(Self::resolve(lua, resolver, req)); + } + EngineReq::ScriptChange(path, rsp) => { + let fut = self.load_script(path); + tokio::task::spawn_local(async move { + let result = fut.await; + if let Err(err) = &result { + log::warn!("Failed to load script: {}", err); + } + rsp.send(result).ok(); + }); + } + } + } + } + + async fn resolve(lua: &'static mlua::Lua, resolver: mlua::Function<'static>, req: EngineResolveReq) { + let rsp = Self::resolve_impl(lua, resolver, req.path, req.client, req.size).await; + req.r_chan.send(rsp).ok(); + } + async fn resolve_impl(lua: &'static mlua::Lua, resolver: mlua::Function<'static>, path: PathBuf, client: Client, size: Option) -> Result { + let lua_client = client.clone() + .into_lua(lua) + .map_err(|err| Error::Msg(err.to_string()))?; + let path = path.to_str().ok_or(Error::FileNotFound)?.to_owned(); + let (resource, size): (Resource, Option) = resolver + .call_async((path, lua_client.clone(), size)) + .await + .map_err(|err| Error::Msg(err.to_string()))?; + let client = lua_client.as_userdata() + .and_then(|ud| ud.take().ok()).unwrap_or(client); + + Ok(EngineResolveRsp{ + resource, + client, + }) + + } +} + +impl Engine { + pub fn new() -> anyhow::Result<(Self, EngineImpl)> { + let lua = Box::leak(Box::new(mlua::Lua::new())); + // Add stdlib + let handler_fn = lua.create_registry_value(0)?; + + let (req_snd, req_rcv) = tokio::sync::mpsc::channel(1); + + let engine = Self { + chan: req_snd + }; + + let resolver = lua.create_function(|_, path| Ok(Resource::File(path)))?; + + let engine_impl = EngineImpl { + lua, + resolver: async_lock::RwLock::new(resolver).into(), + chan: Some(req_rcv), + }; + + Ok((engine, engine_impl)) + + } + + pub fn resolve<'a>(&self, path: PathBuf, client: &'a mut Client, size: Option) -> impl Future> + Send + 'a { + let (o_snd, o_rcv) = tokio::sync::oneshot::channel(); + let req = EngineResolveReq{ + r_chan: o_snd, + path, + client: client.clone(), + size, + }; + let r_ch = self.chan.clone(); + async move { + if let Err(err) = r_ch.send(EngineReq::Resolve(req)).await { + return Err(Error::Msg(err.to_string())) + } + match o_rcv.await { + Ok(Ok(EngineResolveRsp{resource, client: r_client})) => { + *client = r_client; + Ok(resource) + } + Ok(Err(err)) => Err(err), + Err(err) => Err(Error::Msg(err.to_string())), + } + } + } + + pub fn load_script(&self, path: PathBuf) -> impl Future> { + let chan = self.chan.clone(); + async move { + let (o_snd, o_rcv) = tokio::sync::oneshot::channel(); + chan.send(EngineReq::ScriptChange(path, o_snd)) + .await.map_err(|_| anyhow!("Failed to send script load command"))?; + o_rcv.await??; + Ok(()) + } + } +} \ No newline at end of file diff --git a/src/handler.rs b/src/handler.rs index 1572784..31ac94d 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -8,142 +8,30 @@ use async_tftp::async_trait; use async_tftp::packet::Error; use async_tftp::server::handlers::{DirHandler, DirHandlerMode}; use futures::{AsyncRead, AsyncWrite, TryStreamExt}; -use mlua::{FromLua, Lua, UserData, UserDataFields, Value}; +use mlua::{FromLua, UserDataFields}; +use reqwest::Body; +use tokio_util::compat::TokioAsyncWriteCompatExt; +use crate::engine::{Client, Engine, Resource}; #[derive(Clone)] pub struct Handler { - lua: Arc>, - call_key: Arc, + engine: Engine, dir_handler: Arc, http: reqwest::Client, } -#[derive(Debug)] -pub enum Resource { - Http(String), // Parameter is URL - File(String), // Parameter is content - Data(Vec), - Error(Error) -} - -impl Resource { -} - -impl UserData for Resource { - -} - -impl<'lua> FromLua<'lua> for Resource { - fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result { - value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take()) - } -} - -#[derive(Clone, Debug)] -struct Client { - address: SocketAddr, - for_write: bool, -} - -impl UserData for Client { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { - fields.add_field_method_get("mode", |_lua, client| Ok(if client.for_write { "w" } else { "r" })); - fields.add_field_method_get("address", |lua, client| lua.create_any_userdata(client.address.ip())); - } -} - impl Handler { - pub fn new(srv_path: impl AsRef) -> Result { - let lua = mlua::Lua::new(); - + pub fn new(engine: Engine, srv_path: impl AsRef) -> Result { let http = reqwest::Client::builder() .user_agent(concat!("sftpd/", env!("CARGO_PKG_VERSION"))) .build()?; - lua.register_userdata_type::(|registry| { - registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6})); - })?; - - { - // prepare resource types... - let resources = lua.create_table()?; - resources.set("HTTP", lua.create_function(|_lua, url: String| Ok(Resource::Http(url)))?)?; - resources.set("FILE", lua.create_function(|_lua, path: String| Ok(Resource::File(path)))?)?; - resources.set("DATA", lua.create_function(|_lua, url: mlua::String| Ok(Resource::Data(url.as_bytes().to_vec())))?)?; - - let err_tbl = lua.create_table()?; - let err_mtbl = lua.create_table()?; - err_tbl.set_metatable(Some(err_mtbl.clone())); - - let err_fn = lua.create_function(|_, msg: String| Ok(Resource::Error(Error::Msg(msg))))?; - err_mtbl.set("__call", err_fn.clone())?; - err_tbl.set("FileNotFound", Resource::Error(Error::FileNotFound))?; - err_tbl.set("Unknown", Resource::Error(Error::UnknownError))?; - err_tbl.set("PermissionDenied", Resource::Error(Error::PermissionDenied))?; - err_tbl.set("DiskFull", Resource::Error(Error::DiskFull))?; - err_tbl.set("IllegalOperation", Resource::Error(Error::IllegalOperation))?; - err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?; - err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?; - err_tbl.set("Message", err_fn)?; - - resources.set("ERROR", err_tbl)?; - lua.globals().set("resource", resources)?; - lua.globals().set("state", lua.create_table()?)?; - - // - - } - let handler_fn = lua.create_registry_value(0)?; - Ok(Self { - lua: Arc::new(Mutex::new(lua)), - call_key: Arc::new(handler_fn), + engine, http, dir_handler: Arc::new(DirHandler::new(srv_path, DirHandlerMode::ReadWrite)?), }) } - - pub fn load_script(&mut self, data: PathBuf) -> impl Future> { - let lua = self.lua.clone(); - let key = self.call_key.clone(); - async move { - let lua = lua.lock_arc().await; - let chunk = lua.load(data); - let script_fn = chunk.into_function()?; - - // Prepare a fake client to determine whether the script should just be run in full for each request - // or if it returns a function - let client = Client { - address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - for_write: false, - }; - - let result: mlua::Value = script_fn.call_async(("", client, None as Option)).await?; - if result.is_function() { - lua.replace_registry_value(&*key, result)? - } else if Resource::from_lua(result, &*lua).is_ok() { - // We must run the script, it seems - lua.replace_registry_value(&*key, script_fn)? - } - - Ok(()) - } - } - - async fn call_handler_int(lua: Arc>, key: Arc, path: String, client: Client, size: Option) -> Result { - lua.lock().await - .registry_value::(&*key).map_err(|_| Error::UnknownError)? - .call_async((path, client, size)).await - .map_err(|_| Error::UnknownError) - - } - - async fn call_handler(&mut self, path: String, client: Client, size: Option) -> Result { - - let handle = tokio::task::spawn_local(Self::call_handler_int(self.lua.clone(), self.call_key.clone(), path, client, size)); - handle.await.map_err(|err| Error::Msg(err.to_string())) - .and_then(|x| x) - } } #[async_trait] @@ -152,11 +40,12 @@ impl async_tftp::server::Handler for Handler { type Writer = Box; async fn read_req_open(&mut self, client: &SocketAddr, path: &Path) -> Result<(Self::Reader, Option), Error> { - let lua_client = Client { + let mut lua_client = Client { address: client.clone(), for_write: false, }; - let resource: Resource = self.call_handler(path.to_str().ok_or(Error::FileNotFound)?.to_owned(), lua_client, None).await?; + // .to_str().ok_or(Error::FileNotFound)?.to_owned() + let resource: Resource = self.engine.resolve(path.to_owned(), &mut lua_client, None).await?; match resource { Resource::Http(url) => { @@ -184,33 +73,36 @@ impl async_tftp::server::Handler for Handler { } } - async fn write_req_open(&mut self, _client: &SocketAddr, _path: &Path, _size: Option) -> Result { - todo!(); + async fn write_req_open(&mut self, client: &SocketAddr, path: &Path, size: Option) -> Result { + let mut lua_client = Client { + address: client.clone(), + for_write: false, + }; + let resource: Resource = self.engine.resolve(path.to_owned(), &mut lua_client, size).await?; - #[cfg(ignore)] - { - let lua_client = Client { - address: client.clone(), - for_write: true, - }; + match resource { + Resource::Http(url) => { + // TODO: Add headers describing client + let mut req = self.http.post(url); + let (body, writer) = tokio::io::duplex(10240); - let resource: Resource = { - let mut lua = self.lua.lock_arc().await; - let lc = lua_client.clone().into_lua(&*lua).map_err(|_| Error::UnknownError)?; - let handle_fn: mlua::Function = lua.registry_value(&self.call_key).map_err(|_| Error::UnknownError)?; - - let result = handle_fn.call_async((path.to_str().ok_or(Error::FileNotFound)?, lc, size)).await.map_err(|_| Error::UnknownError)?; - - // let lc = lc.as_userdata().and_then(|ud| ud.take().ok()).unwrap_or(lua_client); - result - }; - - match resource { - Resource::Http(_) => { todo!() } - Resource::File(_) => { todo!() } - Resource::Data(_) => { todo!() } - Resource::Error(_) => { todo!() } + if let Some(len) = size { + req = req.header("Content-Length", len.to_string()); + } + let req = req + .body(Body::wrap_stream(tokio_util::io::ReaderStream::new(body))) + .send(); + tokio::task::spawn(req); + Ok(Box::new(writer.compat_write())) } + Resource::File(path) => { + let wtr = (*self.dir_handler).clone().write_req_open(client, Path::new(path.as_str()), size).await?; + Ok(Box::new(wtr)) + } + Resource::Data(_) => { + Err(Error::FileAlreadyExists) + } + Resource::Error(err) => { return Err(err) } } } } diff --git a/src/main.rs b/src/main.rs index 1b39e61..f183dac 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ use anyhow::anyhow; use structopt::StructOpt; mod handler; +mod engine; #[derive(StructOpt, Debug)] struct Options { @@ -34,15 +35,20 @@ struct Options { async fn main() -> anyhow::Result<()> { let opts = Options::from_args(); + let local_set = tokio::task::LocalSet::new(); + + let (engine, mut engine_impl) = engine::Engine::new()?; + engine_impl.init()?; + engine_impl.load_script(opts.script.clone()).await?; + + local_set.spawn_local(engine_impl.run()); let mut handler = handler::Handler::new( + engine.clone(), opts.serve.as_ref() .map(PathBuf::as_path) .unwrap_or_else(|| Path::new("")) )?; - let load_fut = tokio::task::spawn_local(handler.load_script(opts.script.clone())); - load_fut.await??; - let sock = if opts.systemd { let mut lfds = listenfd::ListenFd::from_env(); @@ -103,12 +109,22 @@ async fn main() -> anyhow::Result<()> { }; - let server = async_tftp::server::TftpServerBuilder::with_handler(handler.clone()) - .std_socket(sock)? - .build().await?; + let main_task = async move { + let server = async_tftp::server::TftpServerBuilder::with_handler(handler.clone()) + .std_socket(sock)? + .build().await?; - server.serve().await?; - println!("{opts:#?}"); - Ok(()) + server.serve().await?; + Ok::<(), anyhow::Error>(()) + }; + + tokio::select! { + _ = local_set => { + anyhow::bail!("Lua thread exited early") + } + ret = main_task => { + return ret + } + } }