use std::fmt::Display; use std::net::{IpAddr, Ipv6Addr}; use std::str::FromStr; use std::fmt::Formatter; use mlua::{FromLua, FromLuaMulti, Lua, UserData, UserDataFields, UserDataMethods, Value}; #[derive(Copy, Clone)] struct Cidr { addr: IpAddr, prefix: u8, } impl Display for Cidr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}/{}", self.addr, self.prefix) } } impl Cidr { fn to_v6(&self) -> Self { match self.addr { IpAddr::V4(ip) => Cidr { addr: IpAddr::V6(ip.to_ipv6_mapped()), prefix: self.prefix + 96, }, IpAddr::V6(_) => *self, } } } #[repr(transparent)] #[derive(Copy, Clone)] struct IpAddrWrapper(IpAddr); impl IpAddrWrapper { fn octet_len(&self) -> usize { match &self.0 { IpAddr::V4(_) => 4, IpAddr::V6(_) => 16, } } fn to_v6(&self) -> Ipv6Addr { match self.0 { IpAddr::V4(v4) => v4.to_ipv6_mapped(), IpAddr::V6(v6) => v6 } } } impl UserData for IpAddrWrapper { fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { fields.add_field_method_get("version", |_, ip| Ok(if ip.0.is_ipv4() { 4} else {6})); fields.add_field_method_get("is_ipv4", |_, ip| Ok(ip.0.is_ipv4())); fields.add_field_method_get("is_ipv6", |_, ip| Ok(ip.0.is_ipv6())); fields.add_field_method_get("bytes", |lua, ip| match &ip.0 { IpAddr::V4(v4) => lua.create_string(v4.octets()), IpAddr::V6(v6) => lua.create_string(v6.octets()), }); fields.add_meta_field("__name", "IpAddr"); } fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("to_v6", |_, ip, ()| match &ip.0 { IpAddr::V4(v4) => Ok(IpAddrWrapper(IpAddr::V6(v4.to_ipv6_mapped()))), IpAddr::V6(v6) => Ok(IpAddrWrapper(IpAddr::V6(*v6))), }); methods.add_meta_method("__index", |_, ip, index: usize| match &ip.0 { IpAddr::V4(ip) if 1 <= index && index <= 4 => Ok(ip.octets()[index-1]), IpAddr::V6(ip) if 1 <= index && index <= 16 => Ok(ip.octets()[index-1]), _ => Err(mlua::Error::runtime("Index out of range")), }); methods.add_meta_method_mut("__newindex", |_, ip, (index, value): (usize, u8)| { match &mut ip.0 { IpAddr::V4(ip) if 1 <= index && index <= 4 => { let mut octets = ip.octets(); octets[index-1] = value; IpAddr::V4(octets.into()); }, IpAddr::V6(ip) if 1 <= index && index <= 16 => { let mut octets = ip.octets(); octets[index-1] = value; IpAddr::V6(octets.into()); }, _ => return Err(mlua::Error::runtime("Index out of range")), }; Ok(()) }); methods.add_meta_method("__tostring", |_, ip, ()| Ok(ip.0.to_string())); methods.add_meta_function("__call", |lua, args: mlua::MultiValue| Ok(IpAddrWrapper({ eprintln!("IpAddr __call Received {} args", args.len()); if args.len() == 4 { // IPV4 direct bytes let (a,b,c,d) = <(u8,u8,u8,u8)>::from_lua_args(args, 0, Some("sftpd.IpAddr"), lua)?; IpAddr::V4([a,b,c,d].into()) } else if args.len() == 16 { let mut octets = [0u8;16]; for i in 0..16 { octets[i] = u8::from_lua(args.get(i).unwrap().clone(), lua)?; } IpAddr::V6(octets.into()) } else if args.len() == 1 { return IpAddrWrapper::from_lua(args[0].clone(), lua); } else { return Err(mlua::Error::runtime("Invalid arguments to stfptd.IpAddr")) } }))); methods.add_meta_method("__eq", |_, me, other: IpAddrWrapper| Ok(me.to_v6() == other.to_v6())) } } impl<'lua> FromLua<'lua> for IpAddrWrapper { fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { return if let Some(ud) = value.as_userdata() { Ok(ud.borrow::()?.clone()) } else if let Some(s) = value.as_str() { IpAddr::from_str(s) .map(IpAddrWrapper) .map_err(mlua::Error::external) } else { Err(mlua::Error::FromLuaConversionError { from: value.type_name(), to: "IpAddr", message: None, }) } } } impl UserData for Cidr { fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { fields.add_field_method_get("addr", |_lua, cidr| Ok(IpAddrWrapper(cidr.addr))); fields.add_field_method_set("addr", |_lua, cidr, value: IpAddrWrapper| { if cidr.addr.is_ipv4() ^ value.0.is_ipv4() { return Err(mlua::Error::runtime(format!( "Cannot assign v{} addr to v{} CIDR", if value.0.is_ipv4() { 4 } else { 6 }, if cidr.addr.is_ipv4() { 4 } else { 6 }, ))); } cidr.addr = value.0; Ok(()) }); fields.add_field_method_get("prefix", |_lua, cidr| Ok(cidr.prefix)); fields.add_field_method_set("prefix", |_lua, cidr, value: u8| { if value >= if cidr.addr.is_ipv4() { 32 } else { 128 } { return Err(mlua::Error::runtime("Invalid prefix length")); } cidr.prefix = value; Ok(()) }); } fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { methods.add_method("contains", |_, cidr, addr: IpAddrWrapper| { let addr = u128::from_be_bytes(addr.to_v6().octets()); let cidr = cidr.to_v6(); if cidr.prefix == 0 { // Global network return Ok(true) } let mask = (!0u128) << (128 - cidr.prefix); let network = u128::from_be_bytes(IpAddrWrapper(cidr.addr).to_v6().octets()); Ok(addr & mask == network & mask) }); methods.add_meta_method("__tostring", |_, cidr, ()| Ok(cidr.to_string())); methods.add_meta_method("__eq", |_, cidr, other: Cidr| { let cidr = cidr.to_v6(); let other = other.to_v6(); Ok(cidr.prefix == other.prefix && cidr.addr == other.addr) }); methods.add_meta_function("__call", |lua, args: mlua::MultiValue| { if args.len() == 0 { Ok(Cidr{addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), prefix: 0}) } else if args.len() == 1 { Cidr::from_lua_args(args, 0, Some("stftpd.Cidr"), lua) } else { let (IpAddrWrapper(addr), prefix) = <(IpAddrWrapper, u8)>::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)?; let max_prefix = if addr.is_ipv4() { 32 } else { 64 }; if prefix > max_prefix { return Err(mlua::Error::runtime("Invalid prefix")); } Ok(Cidr{addr, prefix}) } }); } } impl<'lua> FromLua<'lua> for Cidr { fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { return if let Some(ud) = value.as_userdata() { Ok(ud.borrow::()?.clone()) } else if let Some(s) = value.as_str() { let (addr, prefix) = s.split_once("/") .map(|(a,p)| (a,Some(p))) .unwrap_or((s, None)); let addr = IpAddr::from_str(addr)?; let prefix = prefix.map(u8::from_str) .transpose() .map_err(mlua::Error::external)?; let max_prefix = if addr.is_ipv4() { 32 } else { 128 }; let prefix = prefix.unwrap_or(max_prefix); if prefix > max_prefix { return Err(mlua::Error::runtime("CIDR prefix to long")) } Ok(Cidr { addr, prefix }) } else { Err(mlua::Error::FromLuaConversionError { from: value.type_name(), to: "Cidr", message: None, }) } } } pub fn register(lua: &'static Lua) -> anyhow::Result<()> { let globals = lua.globals(); let stftpd: Value = globals.get("stftpd")?; let stftpd = if stftpd.is_nil() { let newtab = lua.create_table()?; globals.set("stftpd", newtab.clone())?; newtab } else { stftpd.as_table().unwrap().clone() }; stftpd.set("Cidr", lua.create_proxy::()?)?; stftpd.set("IpAddr", lua.create_proxy::()?)?; Ok(()) }