Added support for IP address computation
This commit is contained in:
249
src/engine.rs
249
src/engine.rs
@@ -1,7 +1,9 @@
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::future::Future;
|
||||
use mlua::{FromLua, IntoLua, Lua, UserData, UserDataFields, Value};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use mlua::{FromLua, FromLuaMulti, IntoLua, Lua, UserData, UserDataFields, UserDataMethods, Value};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use anyhow::anyhow;
|
||||
use async_tftp::packet::Error;
|
||||
|
||||
@@ -26,6 +28,17 @@ pub enum Resource {
|
||||
Error(Error)
|
||||
}
|
||||
|
||||
impl Display for Resource {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Resource::Http(url) => write!(f, "HTTP {url}"),
|
||||
Resource::File(path) => write!(f, "FILE {path}"),
|
||||
Resource::Data(data) => write!(f, "DATA ({} bytes)", data.len()),
|
||||
Resource::Error(err) => write!(f, "ERR {err:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Resource {
|
||||
}
|
||||
|
||||
@@ -34,7 +47,7 @@ impl UserData for Resource {
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for Resource {
|
||||
fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result<Self> {
|
||||
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
|
||||
value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take())
|
||||
}
|
||||
}
|
||||
@@ -67,16 +80,228 @@ pub struct EngineImpl {
|
||||
chan: Option<tokio::sync::mpsc::Receiver<EngineReq>>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct Cidr {
|
||||
addr: IpAddr,
|
||||
prefix: u8,
|
||||
}
|
||||
|
||||
impl Display for Cidr {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}/{}", self.addr, self.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
impl Cidr {
|
||||
fn to_v6(&self) -> Self {
|
||||
match self.addr {
|
||||
IpAddr::V4(ip) => Cidr {
|
||||
addr: IpAddr::V6(ip.to_ipv6_mapped()),
|
||||
prefix: self.prefix + 96,
|
||||
},
|
||||
IpAddr::V6(_) => *self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone)]
|
||||
struct IpAddrWrapper(IpAddr);
|
||||
|
||||
impl IpAddrWrapper {
|
||||
fn octet_len(&self) -> usize {
|
||||
match &self.0 {
|
||||
IpAddr::V4(_) => 4,
|
||||
IpAddr::V6(_) => 16,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_v6(&self) -> Ipv6Addr {
|
||||
match self.0 {
|
||||
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
|
||||
IpAddr::V6(v6) => v6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserData for IpAddrWrapper {
|
||||
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("version", |_, ip| Ok(if ip.0.is_ipv4() { 4} else {6}));
|
||||
fields.add_field_method_get("is_ipv4", |_, ip| Ok(ip.0.is_ipv4()));
|
||||
fields.add_field_method_get("is_ipv6", |_, ip| Ok(ip.0.is_ipv6()));
|
||||
fields.add_field_method_get("bytes", |lua, ip| match &ip.0 {
|
||||
IpAddr::V4(v4) => lua.create_string(v4.octets()),
|
||||
IpAddr::V6(v6) => lua.create_string(v6.octets()),
|
||||
});
|
||||
fields.add_meta_field("__name", "IpAddr");
|
||||
}
|
||||
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_method("to_v6", |_, ip, ()| match &ip.0 {
|
||||
IpAddr::V4(v4) => Ok(IpAddrWrapper(IpAddr::V6(v4.to_ipv6_mapped()))),
|
||||
IpAddr::V6(v6) => Ok(IpAddrWrapper(IpAddr::V6(*v6))),
|
||||
});
|
||||
methods.add_meta_method("__index", |_, ip, index: usize| match &ip.0 {
|
||||
IpAddr::V4(ip) if 1 <= index && index <= 4 => Ok(ip.octets()[index-1]),
|
||||
IpAddr::V6(ip) if 1 <= index && index <= 16 => Ok(ip.octets()[index-1]),
|
||||
_ => Err(mlua::Error::runtime("Index out of range")),
|
||||
});
|
||||
methods.add_meta_method_mut("__newindex", |_, ip, (index, value): (usize, u8)| {
|
||||
match &mut ip.0 {
|
||||
IpAddr::V4(ip) if 1 <= index && index <= 4 => {
|
||||
let mut octets = ip.octets();
|
||||
octets[index-1] = value;
|
||||
IpAddr::V4(octets.into());
|
||||
},
|
||||
IpAddr::V6(ip) if 1 <= index && index <= 16 => {
|
||||
let mut octets = ip.octets();
|
||||
octets[index-1] = value;
|
||||
IpAddr::V6(octets.into());
|
||||
},
|
||||
_ => return Err(mlua::Error::runtime("Index out of range")),
|
||||
};
|
||||
Ok(())
|
||||
});
|
||||
methods.add_meta_method("__tostring", |_, ip, ()| Ok(ip.0.to_string()));
|
||||
methods.add_meta_function("__call", |lua, args: mlua::MultiValue| Ok(IpAddrWrapper({
|
||||
eprintln!("IpAddr __call Received {} args", args.len());
|
||||
if args.len() == 4 {
|
||||
// IPV4 direct bytes
|
||||
let (a,b,c,d) = <(u8,u8,u8,u8)>::from_lua_args(args, 0, Some("sftpd.IpAddr"), lua)?;
|
||||
IpAddr::V4([a,b,c,d].into())
|
||||
} else if args.len() == 16 {
|
||||
let mut octets = [0u8;16];
|
||||
for i in 0..16 {
|
||||
octets[i] = u8::from_lua(args.get(i).unwrap().clone(), lua)?;
|
||||
}
|
||||
IpAddr::V6(octets.into())
|
||||
} else if args.len() == 1 {
|
||||
return IpAddrWrapper::from_lua(args[0].clone(), lua);
|
||||
} else {
|
||||
return Err(mlua::Error::runtime("Invalid arguments to stfptd.IpAddr"))
|
||||
}
|
||||
})));
|
||||
methods.add_meta_method("__eq", |_, me, other: IpAddrWrapper| Ok(me.to_v6() == other.to_v6()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for IpAddrWrapper {
|
||||
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
|
||||
return if let Some(ud) = value.as_userdata() {
|
||||
Ok(ud.borrow::<IpAddrWrapper>()?.clone())
|
||||
} else if let Some(s) = value.as_str() {
|
||||
IpAddr::from_str(s)
|
||||
.map(IpAddrWrapper)
|
||||
.map_err(mlua::Error::external)
|
||||
} else {
|
||||
Err(mlua::Error::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "IpAddr",
|
||||
message: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserData for Cidr {
|
||||
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("addr", |_lua, cidr| Ok(IpAddrWrapper(cidr.addr)));
|
||||
fields.add_field_method_set("addr", |_lua, cidr, value: IpAddrWrapper| {
|
||||
if cidr.addr.is_ipv4() ^ value.0.is_ipv4() {
|
||||
return Err(mlua::Error::runtime(format!(
|
||||
"Cannot assign v{} addr to v{} CIDR",
|
||||
if value.0.is_ipv4() { 4 } else { 6 },
|
||||
if cidr.addr.is_ipv4() { 4 } else { 6 },
|
||||
)));
|
||||
}
|
||||
cidr.addr = value.0;
|
||||
Ok(())
|
||||
});
|
||||
fields.add_field_method_get("prefix", |_lua, cidr| Ok(cidr.prefix));
|
||||
fields.add_field_method_set("prefix", |_lua, cidr, value: u8| {
|
||||
if value >= if cidr.addr.is_ipv4() { 32 } else { 128 } {
|
||||
return Err(mlua::Error::runtime("Invalid prefix length"));
|
||||
}
|
||||
cidr.prefix = value;
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_method("contains", |_, cidr, addr: IpAddrWrapper| {
|
||||
let addr = u128::from_be_bytes(addr.to_v6().octets());
|
||||
let cidr = cidr.to_v6();
|
||||
if cidr.prefix == 0 {
|
||||
// Global network
|
||||
return Ok(true)
|
||||
}
|
||||
let mask = (!0u128) << (128 - cidr.prefix);
|
||||
let network = u128::from_be_bytes(IpAddrWrapper(cidr.addr).to_v6().octets());
|
||||
|
||||
Ok(addr & mask == network & mask)
|
||||
});
|
||||
|
||||
methods.add_meta_method("__tostring", |_, cidr, ()| Ok(cidr.to_string()));
|
||||
methods.add_meta_method("__eq", |_, cidr, other: Cidr| {
|
||||
let cidr = cidr.to_v6();
|
||||
let other = other.to_v6();
|
||||
Ok(cidr.prefix == other.prefix && cidr.addr == other.addr)
|
||||
});
|
||||
|
||||
methods.add_meta_function("__call", |lua, args: mlua::MultiValue| {
|
||||
if args.len() == 0 {
|
||||
Ok(Cidr{addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), prefix: 0})
|
||||
} else if args.len() == 1 {
|
||||
Cidr::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)
|
||||
} else {
|
||||
let (IpAddrWrapper(addr), prefix) = <(IpAddrWrapper, u8)>::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)?;
|
||||
let max_prefix = if addr.is_ipv4() { 32 } else { 64 };
|
||||
if prefix > max_prefix {
|
||||
return Err(mlua::Error::runtime("Invalid prefix"));
|
||||
}
|
||||
Ok(Cidr{addr, prefix})
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for Cidr {
|
||||
fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result<Self> {
|
||||
return if let Some(ud) = value.as_userdata() {
|
||||
Ok(ud.borrow::<Cidr>()?.clone())
|
||||
} else if let Some(s) = value.as_str() {
|
||||
let (addr, prefix) = s.split_once("/")
|
||||
.map(|(a,p)| (a,Some(p)))
|
||||
.unwrap_or((s, None));
|
||||
let addr = IpAddr::from_str(addr)?;
|
||||
let prefix = prefix.map(u8::from_str)
|
||||
.transpose()
|
||||
.map_err(mlua::Error::external)?;
|
||||
|
||||
let max_prefix = if addr.is_ipv4() { 32 } else { 128 };
|
||||
let prefix = prefix.unwrap_or(max_prefix);
|
||||
if prefix > max_prefix {
|
||||
return Err(mlua::Error::runtime("CIDR prefix to long"))
|
||||
}
|
||||
|
||||
Ok(Cidr {
|
||||
addr, prefix
|
||||
})
|
||||
} else {
|
||||
Err(mlua::Error::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "Cidr",
|
||||
message: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EngineImpl {
|
||||
pub(crate) fn init(&mut self) -> anyhow::Result<()> {
|
||||
let lua = &* self.lua;
|
||||
lua.load_from_std_lib(mlua::StdLib::ALL)?;
|
||||
|
||||
|
||||
lua.register_userdata_type::<IpAddr>(|registry| {
|
||||
registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6}));
|
||||
})?;
|
||||
|
||||
{
|
||||
// prepare resource types...
|
||||
let resources = lua.create_table()?;
|
||||
@@ -98,11 +323,19 @@ impl EngineImpl {
|
||||
err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?;
|
||||
err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?;
|
||||
err_tbl.set("Message", err_fn)?;
|
||||
err_tbl.set_readonly(true);
|
||||
|
||||
resources.set("ERROR", err_tbl)?;
|
||||
resources.set_readonly(true);
|
||||
|
||||
lua.globals().set("resource", resources)?;
|
||||
lua.globals().set("state", lua.create_table()?)?;
|
||||
|
||||
// Construct data types
|
||||
let stftpd = lua.create_table()?;
|
||||
stftpd.set("Cidr", lua.create_proxy::<Cidr>()?)?;
|
||||
stftpd.set("IpAddr", lua.create_proxy::<IpAddrWrapper>()?)?;
|
||||
|
||||
//
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user