Much progress
This commit is contained in:
215
src/handler.rs
Normal file
215
src/handler.rs
Normal file
@@ -0,0 +1,215 @@
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::future::Future;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use async_lock::Mutex;
|
||||
use async_tftp::async_trait;
|
||||
use async_tftp::packet::Error;
|
||||
use async_tftp::server::handlers::{DirHandler, DirHandlerMode};
|
||||
use futures::{AsyncRead, AsyncWrite, TryStreamExt};
|
||||
use mlua::{FromLua, Lua, UserData, UserDataFields, Value};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Handler {
|
||||
lua: Arc<async_lock::Mutex<mlua::Lua>>,
|
||||
call_key: Arc<mlua::RegistryKey>,
|
||||
dir_handler: Arc<DirHandler>,
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Resource {
|
||||
Http(String), // Parameter is URL
|
||||
File(String), // Parameter is content
|
||||
Data(Vec<u8>),
|
||||
Error(Error)
|
||||
}
|
||||
|
||||
impl Resource {
|
||||
}
|
||||
|
||||
impl UserData for Resource {
|
||||
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for Resource {
|
||||
fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result<Self> {
|
||||
value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Client {
|
||||
address: SocketAddr,
|
||||
for_write: bool,
|
||||
}
|
||||
|
||||
impl UserData for Client {
|
||||
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("mode", |_lua, client| Ok(if client.for_write { "w" } else { "r" }));
|
||||
fields.add_field_method_get("address", |lua, client| lua.create_any_userdata(client.address.ip()));
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler {
|
||||
pub fn new(srv_path: impl AsRef<Path>) -> Result<Self, anyhow::Error> {
|
||||
let lua = mlua::Lua::new();
|
||||
|
||||
let http = reqwest::Client::builder()
|
||||
.user_agent(concat!("sftpd/", env!("CARGO_PKG_VERSION")))
|
||||
.build()?;
|
||||
|
||||
lua.register_userdata_type::<IpAddr>(|registry| {
|
||||
registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6}));
|
||||
})?;
|
||||
|
||||
{
|
||||
// prepare resource types...
|
||||
let resources = lua.create_table()?;
|
||||
resources.set("HTTP", lua.create_function(|_lua, url: String| Ok(Resource::Http(url)))?)?;
|
||||
resources.set("FILE", lua.create_function(|_lua, path: String| Ok(Resource::File(path)))?)?;
|
||||
resources.set("DATA", lua.create_function(|_lua, url: mlua::String| Ok(Resource::Data(url.as_bytes().to_vec())))?)?;
|
||||
|
||||
let err_tbl = lua.create_table()?;
|
||||
let err_mtbl = lua.create_table()?;
|
||||
err_tbl.set_metatable(Some(err_mtbl.clone()));
|
||||
|
||||
let err_fn = lua.create_function(|_, msg: String| Ok(Resource::Error(Error::Msg(msg))))?;
|
||||
err_mtbl.set("__call", err_fn.clone())?;
|
||||
err_tbl.set("FileNotFound", Resource::Error(Error::FileNotFound))?;
|
||||
err_tbl.set("Unknown", Resource::Error(Error::UnknownError))?;
|
||||
err_tbl.set("PermissionDenied", Resource::Error(Error::PermissionDenied))?;
|
||||
err_tbl.set("DiskFull", Resource::Error(Error::DiskFull))?;
|
||||
err_tbl.set("IllegalOperation", Resource::Error(Error::IllegalOperation))?;
|
||||
err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?;
|
||||
err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?;
|
||||
err_tbl.set("Message", err_fn)?;
|
||||
|
||||
resources.set("ERROR", err_tbl)?;
|
||||
lua.globals().set("resource", resources)?;
|
||||
lua.globals().set("state", lua.create_table()?)?;
|
||||
|
||||
//
|
||||
|
||||
}
|
||||
let handler_fn = lua.create_registry_value(0)?;
|
||||
|
||||
Ok(Self {
|
||||
lua: Arc::new(Mutex::new(lua)),
|
||||
call_key: Arc::new(handler_fn),
|
||||
http,
|
||||
dir_handler: Arc::new(DirHandler::new(srv_path, DirHandlerMode::ReadWrite)?),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_script(&mut self, data: PathBuf) -> impl Future<Output=anyhow::Result<()>> {
|
||||
let lua = self.lua.clone();
|
||||
let key = self.call_key.clone();
|
||||
async move {
|
||||
let lua = lua.lock_arc().await;
|
||||
let chunk = lua.load(data);
|
||||
let script_fn = chunk.into_function()?;
|
||||
|
||||
// Prepare a fake client to determine whether the script should just be run in full for each request
|
||||
// or if it returns a function
|
||||
let client = Client {
|
||||
address: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
for_write: false,
|
||||
};
|
||||
|
||||
let result: mlua::Value = script_fn.call_async(("", client, None as Option<u64>)).await?;
|
||||
if result.is_function() {
|
||||
lua.replace_registry_value(&*key, result)?
|
||||
} else if Resource::from_lua(result, &*lua).is_ok() {
|
||||
// We must run the script, it seems
|
||||
lua.replace_registry_value(&*key, script_fn)?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_handler_int(lua: Arc<Mutex<mlua::Lua>>, key: Arc<mlua::RegistryKey>, path: String, client: Client, size: Option<u64>) -> Result<Resource, Error> {
|
||||
lua.lock().await
|
||||
.registry_value::<mlua::Function>(&*key).map_err(|_| Error::UnknownError)?
|
||||
.call_async((path, client, size)).await
|
||||
.map_err(|_| Error::UnknownError)
|
||||
|
||||
}
|
||||
|
||||
async fn call_handler(&mut self, path: String, client: Client, size: Option<u64>) -> Result<Resource, Error> {
|
||||
|
||||
let handle = tokio::task::spawn_local(Self::call_handler_int(self.lua.clone(), self.call_key.clone(), path, client, size));
|
||||
handle.await.map_err(|err| Error::Msg(err.to_string()))
|
||||
.and_then(|x| x)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl async_tftp::server::Handler for Handler {
|
||||
type Reader = Box<dyn AsyncRead + Send + Unpin + 'static>;
|
||||
type Writer = Box<dyn AsyncWrite + Send + Unpin + 'static>;
|
||||
|
||||
async fn read_req_open(&mut self, client: &SocketAddr, path: &Path) -> Result<(Self::Reader, Option<u64>), Error> {
|
||||
let lua_client = Client {
|
||||
address: client.clone(),
|
||||
for_write: false,
|
||||
};
|
||||
let resource: Resource = self.call_handler(path.to_str().ok_or(Error::FileNotFound)?.to_owned(), lua_client, None).await?;
|
||||
|
||||
match resource {
|
||||
Resource::Http(url) => {
|
||||
// TODO: Add headers describing client
|
||||
let req = self.http.get(url).send().await.map_err(|err| Error::Msg(err.to_string()))?;
|
||||
let size = req.content_length();
|
||||
let stream = req.bytes_stream()
|
||||
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
|
||||
.into_async_read();
|
||||
|
||||
Ok((Box::new(stream), size))
|
||||
}
|
||||
Resource::File(path) => {
|
||||
let (rdr, size) = (*self.dir_handler).clone().read_req_open(client, Path::new(path.as_str())).await?;
|
||||
Ok((Box::new(rdr), size))
|
||||
}
|
||||
Resource::Data(data) => {
|
||||
let len = data.len() as u64;
|
||||
Ok((
|
||||
Box::new(futures::io::Cursor::new(data)),
|
||||
Some(len),
|
||||
))
|
||||
}
|
||||
Resource::Error(err) => { return Err(err) }
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_req_open(&mut self, _client: &SocketAddr, _path: &Path, _size: Option<u64>) -> Result<Self::Writer, Error> {
|
||||
todo!();
|
||||
|
||||
#[cfg(ignore)]
|
||||
{
|
||||
let lua_client = Client {
|
||||
address: client.clone(),
|
||||
for_write: true,
|
||||
};
|
||||
|
||||
let resource: Resource = {
|
||||
let mut lua = self.lua.lock_arc().await;
|
||||
let lc = lua_client.clone().into_lua(&*lua).map_err(|_| Error::UnknownError)?;
|
||||
let handle_fn: mlua::Function = lua.registry_value(&self.call_key).map_err(|_| Error::UnknownError)?;
|
||||
|
||||
let result = handle_fn.call_async((path.to_str().ok_or(Error::FileNotFound)?, lc, size)).await.map_err(|_| Error::UnknownError)?;
|
||||
|
||||
// let lc = lc.as_userdata().and_then(|ud| ud.take().ok()).unwrap_or(lua_client);
|
||||
result
|
||||
};
|
||||
|
||||
match resource {
|
||||
Resource::Http(_) => { todo!() }
|
||||
Resource::File(_) => { todo!() }
|
||||
Resource::Data(_) => { todo!() }
|
||||
Resource::Error(_) => { todo!() }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
114
src/main.rs
114
src/main.rs
@@ -1,3 +1,113 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
use std::ffi::c_void;
|
||||
use std::net;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use anyhow::anyhow;
|
||||
use structopt::StructOpt;
|
||||
|
||||
mod handler;
|
||||
|
||||
#[derive(StructOpt, Debug)]
|
||||
struct Options {
|
||||
#[structopt(short="s", long="script", env="STFTPD_SCRIPT")]
|
||||
/// The lua script to determine how to handle requests
|
||||
script: PathBuf,
|
||||
/// Systemd socket activated mode. Can also be used for inetd activation
|
||||
#[structopt(long)]
|
||||
systemd: bool,
|
||||
/// The address and port to listen on
|
||||
#[structopt(short="l", env="STFTPD_LISTEN", default_value=":69")]
|
||||
listen: String,
|
||||
#[structopt(short="u")]
|
||||
/// User to drop privileges to
|
||||
user: Option<String>,
|
||||
#[structopt(short="g")]
|
||||
/// User to drop privileges to
|
||||
group: Option<String>,
|
||||
#[structopt(short="d")]
|
||||
/// Directory to serve files from
|
||||
serve: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let opts = Options::from_args();
|
||||
|
||||
|
||||
let mut handler = handler::Handler::new(
|
||||
opts.serve.as_ref()
|
||||
.map(PathBuf::as_path)
|
||||
.unwrap_or_else(|| Path::new(""))
|
||||
)?;
|
||||
let load_fut = tokio::task::spawn_local(handler.load_script(opts.script.clone()));
|
||||
load_fut.await??;
|
||||
|
||||
|
||||
let sock = if opts.systemd {
|
||||
let mut lfds = listenfd::ListenFd::from_env();
|
||||
if lfds.len() > 0 {
|
||||
lfds.take_udp_socket(0)?.ok_or(anyhow!("Failed to receive socket from systemd"))?
|
||||
} else {
|
||||
// inetd activation
|
||||
let sock_fd = 0;
|
||||
let sock = unsafe {
|
||||
// validate the socket
|
||||
let mut sockaddr : libc::sockaddr = std::mem::zeroed();
|
||||
let mut sockaddr_sz = std::mem::size_of_val(&sockaddr) as libc::socklen_t;
|
||||
let mut ty : libc::c_int = 0;
|
||||
let mut ty_sz = std::mem::size_of_val(&ty) as libc::socklen_t;
|
||||
let ret = libc::getsockname(sock_fd, &mut sockaddr, &mut sockaddr_sz);
|
||||
if ret != 0 {
|
||||
return Err(std::io::Error::last_os_error().into())
|
||||
}
|
||||
let ret = libc::getsockopt(
|
||||
sock_fd,
|
||||
libc::SOL_SOCKET,
|
||||
libc::SO_TYPE,
|
||||
&mut ty as *mut libc::c_int as *mut c_void,
|
||||
&mut ty_sz
|
||||
);
|
||||
if ret != 0 {
|
||||
return Err(std::io::Error::last_os_error().into())
|
||||
}
|
||||
if sockaddr.sa_family as libc::c_int != libc::AF_INET && sockaddr.sa_family as libc::c_int != libc::AF_INET6 || ty != libc::SOCK_DGRAM {
|
||||
return Err(anyhow!("Can only listen on inet or inet6 UDP sockets"))
|
||||
}
|
||||
let owned = std::os::fd::BorrowedFd::borrow_raw(sock_fd).try_clone_to_owned()?;
|
||||
// Putz around with stdin and stdout so that we don't accidentally write to the socket.
|
||||
let mut pipe_fds = [0 as libc::c_int; 2];
|
||||
if libc::pipe(&mut pipe_fds[0]) < 0 {
|
||||
return Err(std::io::Error::last_os_error().into())
|
||||
}
|
||||
if libc::close(pipe_fds[1]) < 0 {
|
||||
return Err(std::io::Error::last_os_error().into())
|
||||
}
|
||||
if libc::dup2(pipe_fds[0], 0) < 0 || libc::dup2(2, 1) < 0 {
|
||||
return Err(std::io::Error::last_os_error().into())
|
||||
}
|
||||
net::UdpSocket::from(owned)
|
||||
};
|
||||
sock
|
||||
}
|
||||
} else {
|
||||
let (host, port) = opts.listen.split_once(':').ok_or(anyhow!("Invalid listen address"))?;
|
||||
let port = u16::from_str_radix(port, 10).map_err(|_| anyhow!("Invalid listen address"))?;
|
||||
if host.is_empty() {
|
||||
net::UdpSocket::bind((Ipv6Addr::UNSPECIFIED, port)).or_else(
|
||||
|_| net::UdpSocket::bind((Ipv4Addr::UNSPECIFIED, port))
|
||||
)?
|
||||
} else {
|
||||
net::UdpSocket::bind((host, port))?
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
let server = async_tftp::server::TftpServerBuilder::with_handler(handler.clone())
|
||||
.std_socket(sock)?
|
||||
.build().await?;
|
||||
|
||||
|
||||
server.serve().await?;
|
||||
println!("{opts:#?}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user