diff --git a/src/engine.rs b/src/engine.rs index da6615b..345f8e2 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,7 +1,9 @@ +use std::fmt::{Display, Formatter}; use std::future::Future; -use mlua::{FromLua, IntoLua, Lua, UserData, UserDataFields, Value}; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use mlua::{FromLua, FromLuaMulti, IntoLua, Lua, UserData, UserDataFields, UserDataMethods, Value}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::path::PathBuf; +use std::str::FromStr; use anyhow::anyhow; use async_tftp::packet::Error; @@ -26,6 +28,17 @@ pub enum Resource { Error(Error) } +impl Display for Resource { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Resource::Http(url) => write!(f, "HTTP {url}"), + Resource::File(path) => write!(f, "FILE {path}"), + Resource::Data(data) => write!(f, "DATA ({} bytes)", data.len()), + Resource::Error(err) => write!(f, "ERR {err:?}") + } + } +} + impl Resource { } @@ -34,7 +47,7 @@ impl UserData for Resource { } impl<'lua> FromLua<'lua> for Resource { - fn from_lua(value: Value<'lua>, _lua: &'lua Lua) -> mlua::Result { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { value.as_userdata().ok_or(mlua::Error::UserDataTypeMismatch).and_then(|value| value.take()) } } @@ -67,16 +80,228 @@ pub struct EngineImpl { chan: Option>, } +#[derive(Copy, Clone)] +struct Cidr { + addr: IpAddr, + prefix: u8, +} + +impl Display for Cidr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/{}", self.addr, self.prefix) + } +} + +impl Cidr { + fn to_v6(&self) -> Self { + match self.addr { + IpAddr::V4(ip) => Cidr { + addr: IpAddr::V6(ip.to_ipv6_mapped()), + prefix: self.prefix + 96, + }, + IpAddr::V6(_) => *self, + } + } +} + +#[repr(transparent)] +#[derive(Copy, Clone)] +struct IpAddrWrapper(IpAddr); + +impl IpAddrWrapper { + fn octet_len(&self) -> usize { + match &self.0 { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, + } + } + + fn to_v6(&self) -> Ipv6Addr { + match self.0 { + IpAddr::V4(v4) => v4.to_ipv6_mapped(), + IpAddr::V6(v6) => v6 + } + } +} + +impl UserData for IpAddrWrapper { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("version", |_, ip| Ok(if ip.0.is_ipv4() { 4} else {6})); + fields.add_field_method_get("is_ipv4", |_, ip| Ok(ip.0.is_ipv4())); + fields.add_field_method_get("is_ipv6", |_, ip| Ok(ip.0.is_ipv6())); + fields.add_field_method_get("bytes", |lua, ip| match &ip.0 { + IpAddr::V4(v4) => lua.create_string(v4.octets()), + IpAddr::V6(v6) => lua.create_string(v6.octets()), + }); + fields.add_meta_field("__name", "IpAddr"); + } + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("to_v6", |_, ip, ()| match &ip.0 { + IpAddr::V4(v4) => Ok(IpAddrWrapper(IpAddr::V6(v4.to_ipv6_mapped()))), + IpAddr::V6(v6) => Ok(IpAddrWrapper(IpAddr::V6(*v6))), + }); + methods.add_meta_method("__index", |_, ip, index: usize| match &ip.0 { + IpAddr::V4(ip) if 1 <= index && index <= 4 => Ok(ip.octets()[index-1]), + IpAddr::V6(ip) if 1 <= index && index <= 16 => Ok(ip.octets()[index-1]), + _ => Err(mlua::Error::runtime("Index out of range")), + }); + methods.add_meta_method_mut("__newindex", |_, ip, (index, value): (usize, u8)| { + match &mut ip.0 { + IpAddr::V4(ip) if 1 <= index && index <= 4 => { + let mut octets = ip.octets(); + octets[index-1] = value; + IpAddr::V4(octets.into()); + }, + IpAddr::V6(ip) if 1 <= index && index <= 16 => { + let mut octets = ip.octets(); + octets[index-1] = value; + IpAddr::V6(octets.into()); + }, + _ => return Err(mlua::Error::runtime("Index out of range")), + }; + Ok(()) + }); + methods.add_meta_method("__tostring", |_, ip, ()| Ok(ip.0.to_string())); + methods.add_meta_function("__call", |lua, args: mlua::MultiValue| Ok(IpAddrWrapper({ + eprintln!("IpAddr __call Received {} args", args.len()); + if args.len() == 4 { + // IPV4 direct bytes + let (a,b,c,d) = <(u8,u8,u8,u8)>::from_lua_args(args, 0, Some("sftpd.IpAddr"), lua)?; + IpAddr::V4([a,b,c,d].into()) + } else if args.len() == 16 { + let mut octets = [0u8;16]; + for i in 0..16 { + octets[i] = u8::from_lua(args.get(i).unwrap().clone(), lua)?; + } + IpAddr::V6(octets.into()) + } else if args.len() == 1 { + return IpAddrWrapper::from_lua(args[0].clone(), lua); + } else { + return Err(mlua::Error::runtime("Invalid arguments to stfptd.IpAddr")) + } + }))); + methods.add_meta_method("__eq", |_, me, other: IpAddrWrapper| Ok(me.to_v6() == other.to_v6())) + } +} + +impl<'lua> FromLua<'lua> for IpAddrWrapper { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { + return if let Some(ud) = value.as_userdata() { + Ok(ud.borrow::()?.clone()) + } else if let Some(s) = value.as_str() { + IpAddr::from_str(s) + .map(IpAddrWrapper) + .map_err(mlua::Error::external) + } else { + Err(mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: "IpAddr", + message: None, + }) + } + } +} + +impl UserData for Cidr { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("addr", |_lua, cidr| Ok(IpAddrWrapper(cidr.addr))); + fields.add_field_method_set("addr", |_lua, cidr, value: IpAddrWrapper| { + if cidr.addr.is_ipv4() ^ value.0.is_ipv4() { + return Err(mlua::Error::runtime(format!( + "Cannot assign v{} addr to v{} CIDR", + if value.0.is_ipv4() { 4 } else { 6 }, + if cidr.addr.is_ipv4() { 4 } else { 6 }, + ))); + } + cidr.addr = value.0; + Ok(()) + }); + fields.add_field_method_get("prefix", |_lua, cidr| Ok(cidr.prefix)); + fields.add_field_method_set("prefix", |_lua, cidr, value: u8| { + if value >= if cidr.addr.is_ipv4() { 32 } else { 128 } { + return Err(mlua::Error::runtime("Invalid prefix length")); + } + cidr.prefix = value; + Ok(()) + }); + } + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("contains", |_, cidr, addr: IpAddrWrapper| { + let addr = u128::from_be_bytes(addr.to_v6().octets()); + let cidr = cidr.to_v6(); + if cidr.prefix == 0 { + // Global network + return Ok(true) + } + let mask = (!0u128) << (128 - cidr.prefix); + let network = u128::from_be_bytes(IpAddrWrapper(cidr.addr).to_v6().octets()); + + Ok(addr & mask == network & mask) + }); + + methods.add_meta_method("__tostring", |_, cidr, ()| Ok(cidr.to_string())); + methods.add_meta_method("__eq", |_, cidr, other: Cidr| { + let cidr = cidr.to_v6(); + let other = other.to_v6(); + Ok(cidr.prefix == other.prefix && cidr.addr == other.addr) + }); + + methods.add_meta_function("__call", |lua, args: mlua::MultiValue| { + if args.len() == 0 { + Ok(Cidr{addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), prefix: 0}) + } else if args.len() == 1 { + Cidr::from_lua_args(args, 0, Some("stftpd.Cidr"), lua) + } else { + let (IpAddrWrapper(addr), prefix) = <(IpAddrWrapper, u8)>::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)?; + let max_prefix = if addr.is_ipv4() { 32 } else { 64 }; + if prefix > max_prefix { + return Err(mlua::Error::runtime("Invalid prefix")); + } + Ok(Cidr{addr, prefix}) + } + }); + } +} + +impl<'lua> FromLua<'lua> for Cidr { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { + return if let Some(ud) = value.as_userdata() { + Ok(ud.borrow::()?.clone()) + } else if let Some(s) = value.as_str() { + let (addr, prefix) = s.split_once("/") + .map(|(a,p)| (a,Some(p))) + .unwrap_or((s, None)); + let addr = IpAddr::from_str(addr)?; + let prefix = prefix.map(u8::from_str) + .transpose() + .map_err(mlua::Error::external)?; + + let max_prefix = if addr.is_ipv4() { 32 } else { 128 }; + let prefix = prefix.unwrap_or(max_prefix); + if prefix > max_prefix { + return Err(mlua::Error::runtime("CIDR prefix to long")) + } + + Ok(Cidr { + addr, prefix + }) + } else { + Err(mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: "Cidr", + message: None, + }) + } + } +} + impl EngineImpl { pub(crate) fn init(&mut self) -> anyhow::Result<()> { let lua = &* self.lua; lua.load_from_std_lib(mlua::StdLib::ALL)?; - - lua.register_userdata_type::(|registry| { - registry.add_field_method_get("version", |_, ip| Ok(if ip.is_ipv4() { 4} else {6})); - })?; - { // prepare resource types... let resources = lua.create_table()?; @@ -98,11 +323,19 @@ impl EngineImpl { err_tbl.set("FileAlreadyExists", Resource::Error(Error::FileAlreadyExists))?; err_tbl.set("NoSuchUser", Resource::Error(Error::NoSuchUser))?; err_tbl.set("Message", err_fn)?; + err_tbl.set_readonly(true); resources.set("ERROR", err_tbl)?; + resources.set_readonly(true); + lua.globals().set("resource", resources)?; lua.globals().set("state", lua.create_table()?)?; + // Construct data types + let stftpd = lua.create_table()?; + stftpd.set("Cidr", lua.create_proxy::()?)?; + stftpd.set("IpAddr", lua.create_proxy::()?)?; + // }