Compare commits

...

4 Commits

8 changed files with 295 additions and 12 deletions

2
.gitmodules vendored
View File

@@ -1,3 +1,3 @@
[submodule "vendor/async-tftp-rs"]
path = vendor/async-tftp-rs
url = git@github.com:thequux/async-tftp-rs.git
url = https://github.com/thequux/async-tftp-rs.git

7
Cargo.lock generated
View File

@@ -635,6 +635,12 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
version = "0.14.27"
@@ -1289,6 +1295,7 @@ dependencies = [
"async-tftp",
"fern",
"futures",
"humantime",
"libc",
"listenfd",
"log",

View File

@@ -19,4 +19,5 @@ mlua = { version = "0.9.1", features = ["luau-jit", "vendored", "async", "send"]
reqwest = { version = "0.11.22", features = ["stream"] }
listenfd = "1.0.1"
libc = "0.2.150"
users = "0.11.0"
users = "0.11.0"
humantime = "2.1.0"

View File

@@ -5,7 +5,7 @@
#env.GREET = "devenv";
# https://devenv.sh/packages/
packages = [ pkgs.git pkgs.openssl ];
packages = [ pkgs.git pkgs.openssl pkgs.lua ];
# https://devenv.sh/scripts/
#scripts.hello.exec = "echo hello from $GREET";

View File

@@ -1,10 +1,14 @@
use std::fmt::{Display, Formatter};
use std::future::Future;
use mlua::{FromLua, IntoLua, Lua, UserData, UserDataFields, Value};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use mlua::{FromLua, FromLuaMulti, IntoLua, Lua, UserData, UserDataFields, UserDataMethods, Value};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
use anyhow::anyhow;
use async_tftp::packet::Error;
mod ipaddr;
#[derive(Clone, Debug)]
pub struct Client {
pub address: SocketAddr,
@@ -26,6 +30,17 @@ pub enum Resource {
Error(Error)
}
impl Display for Resource {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Resource::Http(url) => write!(f, "HTTP {url}"),
Resource::File(path) => write!(f, "FILE {path}"),
Resource::Data(data) => write!(f, "DATA ({} bytes)", data.len()),
Resource::Error(err) => write!(f, "ERR {err:?}")
}
}
}
impl Resource {
}
@@ -34,7 +49,7 @@ impl UserData for Resource {
}
impl<'lua> FromLua<'lua> for Resource {
fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result<Self> {
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take())
}
}
@@ -67,16 +82,12 @@ pub struct EngineImpl {
chan: Option<tokio::sync::mpsc::Receiver<EngineReq>>,
}
impl EngineImpl {
pub(crate) fn init(&mut self) -> anyhow::Result<()> {
let lua = &* self.lua;
let lua = self.lua;
lua.load_from_std_lib(mlua::StdLib::ALL)?;
lua.register_userdata_type::<IpAddr>(|registry| {
registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6}));
})?;
{
// prepare resource types...
let resources = lua.create_table()?;
@@ -98,11 +109,16 @@ impl EngineImpl {
err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?;
err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?;
err_tbl.set("Message", err_fn)?;
err_tbl.set_readonly(true);
resources.set("ERROR", err_tbl)?;
resources.set_readonly(true);
lua.globals().set("resource", resources)?;
lua.globals().set("state", lua.create_table()?)?;
// Construct data types
ipaddr::register(lua)?;
//
}

238
src/engine/ipaddr.rs Normal file
View File

@@ -0,0 +1,238 @@
use std::fmt::Display;
use std::net::{IpAddr, Ipv6Addr};
use std::str::FromStr;
use std::fmt::Formatter;
use mlua::{FromLua, FromLuaMulti, Lua, UserData, UserDataFields, UserDataMethods, Value};
#[derive(Copy, Clone)]
struct Cidr {
addr: IpAddr,
prefix: u8,
}
impl Display for Cidr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}/{}", self.addr, self.prefix)
}
}
impl Cidr {
fn to_v6(&self) -> Self {
match self.addr {
IpAddr::V4(ip) => Cidr {
addr: IpAddr::V6(ip.to_ipv6_mapped()),
prefix: self.prefix + 96,
},
IpAddr::V6(_) => *self,
}
}
}
#[repr(transparent)]
#[derive(Copy, Clone)]
struct IpAddrWrapper(IpAddr);
impl IpAddrWrapper {
fn octet_len(&self) -> usize {
match &self.0 {
IpAddr::V4(_) => 4,
IpAddr::V6(_) => 16,
}
}
fn to_v6(&self) -> Ipv6Addr {
match self.0 {
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
IpAddr::V6(v6) => v6
}
}
}
impl UserData for IpAddrWrapper {
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("version", |_, ip| Ok(if ip.0.is_ipv4() { 4} else {6}));
fields.add_field_method_get("is_ipv4", |_, ip| Ok(ip.0.is_ipv4()));
fields.add_field_method_get("is_ipv6", |_, ip| Ok(ip.0.is_ipv6()));
fields.add_field_method_get("bytes", |lua, ip| match &ip.0 {
IpAddr::V4(v4) => lua.create_string(v4.octets()),
IpAddr::V6(v6) => lua.create_string(v6.octets()),
});
fields.add_meta_field("__name", "IpAddr");
}
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("to_v6", |_, ip, ()| match &ip.0 {
IpAddr::V4(v4) => Ok(IpAddrWrapper(IpAddr::V6(v4.to_ipv6_mapped()))),
IpAddr::V6(v6) => Ok(IpAddrWrapper(IpAddr::V6(*v6))),
});
methods.add_meta_method("__index", |_, ip, index: usize| match &ip.0 {
IpAddr::V4(ip) if 1 <= index && index <= 4 => Ok(ip.octets()[index-1]),
IpAddr::V6(ip) if 1 <= index && index <= 16 => Ok(ip.octets()[index-1]),
_ => Err(mlua::Error::runtime("Index out of range")),
});
methods.add_meta_method_mut("__newindex", |_, ip, (index, value): (usize, u8)| {
match &mut ip.0 {
IpAddr::V4(ip) if 1 <= index && index <= 4 => {
let mut octets = ip.octets();
octets[index-1] = value;
IpAddr::V4(octets.into());
},
IpAddr::V6(ip) if 1 <= index && index <= 16 => {
let mut octets = ip.octets();
octets[index-1] = value;
IpAddr::V6(octets.into());
},
_ => return Err(mlua::Error::runtime("Index out of range")),
};
Ok(())
});
methods.add_meta_method("__tostring", |_, ip, ()| Ok(ip.0.to_string()));
methods.add_meta_function("__call", |lua, args: mlua::MultiValue| Ok(IpAddrWrapper({
eprintln!("IpAddr __call Received {} args", args.len());
if args.len() == 4 {
// IPV4 direct bytes
let (a,b,c,d) = <(u8,u8,u8,u8)>::from_lua_args(args, 0, Some("sftpd.IpAddr"), lua)?;
IpAddr::V4([a,b,c,d].into())
} else if args.len() == 16 {
let mut octets = [0u8;16];
for i in 0..16 {
octets[i] = u8::from_lua(args.get(i).unwrap().clone(), lua)?;
}
IpAddr::V6(octets.into())
} else if args.len() == 1 {
return IpAddrWrapper::from_lua(args[0].clone(), lua);
} else {
return Err(mlua::Error::runtime("Invalid arguments to stfptd.IpAddr"))
}
})));
methods.add_meta_method("__eq", |_, me, other: IpAddrWrapper| Ok(me.to_v6() == other.to_v6()))
}
}
impl<'lua> FromLua<'lua> for IpAddrWrapper {
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
return if let Some(ud) = value.as_userdata() {
Ok(ud.borrow::<IpAddrWrapper>()?.clone())
} else if let Some(s) = value.as_str() {
IpAddr::from_str(s)
.map(IpAddrWrapper)
.map_err(mlua::Error::external)
} else {
Err(mlua::Error::FromLuaConversionError {
from: value.type_name(),
to: "IpAddr",
message: None,
})
}
}
}
impl UserData for Cidr {
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("addr", |_lua, cidr| Ok(IpAddrWrapper(cidr.addr)));
fields.add_field_method_set("addr", |_lua, cidr, value: IpAddrWrapper| {
if cidr.addr.is_ipv4() ^ value.0.is_ipv4() {
return Err(mlua::Error::runtime(format!(
"Cannot assign v{} addr to v{} CIDR",
if value.0.is_ipv4() { 4 } else { 6 },
if cidr.addr.is_ipv4() { 4 } else { 6 },
)));
}
cidr.addr = value.0;
Ok(())
});
fields.add_field_method_get("prefix", |_lua, cidr| Ok(cidr.prefix));
fields.add_field_method_set("prefix", |_lua, cidr, value: u8| {
if value >= if cidr.addr.is_ipv4() { 32 } else { 128 } {
return Err(mlua::Error::runtime("Invalid prefix length"));
}
cidr.prefix = value;
Ok(())
});
}
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("contains", |_, cidr, addr: IpAddrWrapper| {
let addr = u128::from_be_bytes(addr.to_v6().octets());
let cidr = cidr.to_v6();
if cidr.prefix == 0 {
// Global network
return Ok(true)
}
let mask = (!0u128) << (128 - cidr.prefix);
let network = u128::from_be_bytes(IpAddrWrapper(cidr.addr).to_v6().octets());
Ok(addr & mask == network & mask)
});
methods.add_meta_method("__tostring", |_, cidr, ()| Ok(cidr.to_string()));
methods.add_meta_method("__eq", |_, cidr, other: Cidr| {
let cidr = cidr.to_v6();
let other = other.to_v6();
Ok(cidr.prefix == other.prefix && cidr.addr == other.addr)
});
methods.add_meta_function("__call", |lua, args: mlua::MultiValue| {
if args.len() == 0 {
Ok(Cidr{addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), prefix: 0})
} else if args.len() == 1 {
Cidr::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)
} else {
let (IpAddrWrapper(addr), prefix) = <(IpAddrWrapper, u8)>::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)?;
let max_prefix = if addr.is_ipv4() { 32 } else { 64 };
if prefix > max_prefix {
return Err(mlua::Error::runtime("Invalid prefix"));
}
Ok(Cidr{addr, prefix})
}
});
}
}
impl<'lua> FromLua<'lua> for Cidr {
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
return if let Some(ud) = value.as_userdata() {
Ok(ud.borrow::<Cidr>()?.clone())
} else if let Some(s) = value.as_str() {
let (addr, prefix) = s.split_once("/")
.map(|(a,p)| (a,Some(p)))
.unwrap_or((s, None));
let addr = IpAddr::from_str(addr)?;
let prefix = prefix.map(u8::from_str)
.transpose()
.map_err(mlua::Error::external)?;
let max_prefix = if addr.is_ipv4() { 32 } else { 128 };
let prefix = prefix.unwrap_or(max_prefix);
if prefix > max_prefix {
return Err(mlua::Error::runtime("CIDR prefix to long"))
}
Ok(Cidr {
addr, prefix
})
} else {
Err(mlua::Error::FromLuaConversionError {
from: value.type_name(),
to: "Cidr",
message: None,
})
}
}
}
pub fn register(lua: &'static Lua) -> anyhow::Result<()> {
let globals = lua.globals();
let stftpd: Value = globals.get("stftpd")?;
let stftpd = if stftpd.is_nil() {
let newtab = lua.create_table()?;
globals.set("stftpd", newtab.clone())?;
newtab
} else {
stftpd.as_table().unwrap().clone()
};
stftpd.set("Cidr", lua.create_proxy::<Cidr>()?)?;
stftpd.set("IpAddr", lua.create_proxy::<IpAddrWrapper>()?)?;
Ok(())
}

View File

@@ -6,6 +6,7 @@ use async_tftp::async_trait;
use async_tftp::packet::Error;
use async_tftp::server::handlers::{DirHandler, DirHandlerMode};
use futures::{AsyncRead, AsyncWrite, TryStreamExt};
use log::info;
use reqwest::{Body, StatusCode};
use tokio_util::compat::TokioAsyncWriteCompatExt;
use crate::engine::{Client, Engine, Resource};
@@ -44,6 +45,8 @@ impl async_tftp::server::Handler for Handler {
// .to_str().ok_or(Error::FileNotFound)?.to_owned()
let resource: Resource = self.engine.resolve(path.to_owned(), &mut lua_client, None).await?;
info!("GET {path:?} from {client:?} -> {resource}");
match resource {
Resource::Http(url) => {
// TODO: Add headers describing client
@@ -83,6 +86,7 @@ impl async_tftp::server::Handler for Handler {
for_write: false,
};
let resource: Resource = self.engine.resolve(path.to_owned(), &mut lua_client, size).await?;
info!("PUT {path:?} from {client:?} -> {resource}");
match resource {
Resource::Http(url) => {

View File

@@ -3,6 +3,7 @@ use std::ffi::c_void;
use std::net;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::path::{Path, PathBuf};
use std::time::SystemTime;
use anyhow::anyhow;
use structopt::StructOpt;
use tokio::signal::unix::SignalKind;
@@ -36,6 +37,22 @@ struct Options {
async fn main() -> anyhow::Result<()> {
let opts = Options::from_args();
// Configure logging
fern::Dispatch::new()
.format(|out, message, record| {
out.finish(format_args!(
"[{} {} {}] {}",
humantime::format_rfc3339_seconds(SystemTime::now()),
record.level(),
record.target(),
message
))
})
.level(log::LevelFilter::Debug)
.chain(std::io::stdout())
.chain(fern::log_file("output.log")?)
.apply()?;
let group = opts.group.map(|name| {
if let Ok(gid) = libc::gid_t::from_str_radix(name.as_str(), 10) {
Ok(gid)