From 30b479534c6f8ac91695ecd1bc41bf793d2ede57 Mon Sep 17 00:00:00 2001 From: TQ Hirsch Date: Sun, 3 Mar 2024 14:58:54 +0100 Subject: [PATCH] Split out ipaddr module --- devenv.nix | 2 +- src/engine.rs | 225 +--------------------------------------- src/engine/ipaddr.rs | 238 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 243 insertions(+), 222 deletions(-) create mode 100644 src/engine/ipaddr.rs diff --git a/devenv.nix b/devenv.nix index b750cf6..d5811e9 100644 --- a/devenv.nix +++ b/devenv.nix @@ -5,7 +5,7 @@ #env.GREET = "devenv"; # https://devenv.sh/packages/ - packages = [ pkgs.git pkgs.openssl ]; + packages = [ pkgs.git pkgs.openssl pkgs.lua ]; # https://devenv.sh/scripts/ #scripts.hello.exec = "echo hello from $GREET"; diff --git a/src/engine.rs b/src/engine.rs index 345f8e2..38652ed 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -7,6 +7,8 @@ use std::str::FromStr; use anyhow::anyhow; use async_tftp::packet::Error; +mod ipaddr; + #[derive(Clone, Debug)] pub struct Client { pub address: SocketAddr, @@ -80,226 +82,10 @@ pub struct EngineImpl { chan: Option>, } -#[derive(Copy, Clone)] -struct Cidr { - addr: IpAddr, - prefix: u8, -} - -impl Display for Cidr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}/{}", self.addr, self.prefix) - } -} - -impl Cidr { - fn to_v6(&self) -> Self { - match self.addr { - IpAddr::V4(ip) => Cidr { - addr: IpAddr::V6(ip.to_ipv6_mapped()), - prefix: self.prefix + 96, - }, - IpAddr::V6(_) => *self, - } - } -} - -#[repr(transparent)] -#[derive(Copy, Clone)] -struct IpAddrWrapper(IpAddr); - -impl IpAddrWrapper { - fn octet_len(&self) -> usize { - match &self.0 { - IpAddr::V4(_) => 4, - IpAddr::V6(_) => 16, - } - } - - fn to_v6(&self) -> Ipv6Addr { - match self.0 { - IpAddr::V4(v4) => v4.to_ipv6_mapped(), - IpAddr::V6(v6) => v6 - } - } -} - -impl UserData for IpAddrWrapper { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { - fields.add_field_method_get("version", |_, ip| Ok(if ip.0.is_ipv4() { 4} else {6})); - fields.add_field_method_get("is_ipv4", |_, ip| Ok(ip.0.is_ipv4())); - fields.add_field_method_get("is_ipv6", |_, ip| Ok(ip.0.is_ipv6())); - fields.add_field_method_get("bytes", |lua, ip| match &ip.0 { - IpAddr::V4(v4) => lua.create_string(v4.octets()), - IpAddr::V6(v6) => lua.create_string(v6.octets()), - }); - fields.add_meta_field("__name", "IpAddr"); - } - - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("to_v6", |_, ip, ()| match &ip.0 { - IpAddr::V4(v4) => Ok(IpAddrWrapper(IpAddr::V6(v4.to_ipv6_mapped()))), - IpAddr::V6(v6) => Ok(IpAddrWrapper(IpAddr::V6(*v6))), - }); - methods.add_meta_method("__index", |_, ip, index: usize| match &ip.0 { - IpAddr::V4(ip) if 1 <= index && index <= 4 => Ok(ip.octets()[index-1]), - IpAddr::V6(ip) if 1 <= index && index <= 16 => Ok(ip.octets()[index-1]), - _ => Err(mlua::Error::runtime("Index out of range")), - }); - methods.add_meta_method_mut("__newindex", |_, ip, (index, value): (usize, u8)| { - match &mut ip.0 { - IpAddr::V4(ip) if 1 <= index && index <= 4 => { - let mut octets = ip.octets(); - octets[index-1] = value; - IpAddr::V4(octets.into()); - }, - IpAddr::V6(ip) if 1 <= index && index <= 16 => { - let mut octets = ip.octets(); - octets[index-1] = value; - IpAddr::V6(octets.into()); - }, - _ => return Err(mlua::Error::runtime("Index out of range")), - }; - Ok(()) - }); - methods.add_meta_method("__tostring", |_, ip, ()| Ok(ip.0.to_string())); - methods.add_meta_function("__call", |lua, args: mlua::MultiValue| Ok(IpAddrWrapper({ - eprintln!("IpAddr __call Received {} args", args.len()); - if args.len() == 4 { - // IPV4 direct bytes - let (a,b,c,d) = <(u8,u8,u8,u8)>::from_lua_args(args, 0, Some("sftpd.IpAddr"), lua)?; - IpAddr::V4([a,b,c,d].into()) - } else if args.len() == 16 { - let mut octets = [0u8;16]; - for i in 0..16 { - octets[i] = u8::from_lua(args.get(i).unwrap().clone(), lua)?; - } - IpAddr::V6(octets.into()) - } else if args.len() == 1 { - return IpAddrWrapper::from_lua(args[0].clone(), lua); - } else { - return Err(mlua::Error::runtime("Invalid arguments to stfptd.IpAddr")) - } - }))); - methods.add_meta_method("__eq", |_, me, other: IpAddrWrapper| Ok(me.to_v6() == other.to_v6())) - } -} - -impl<'lua> FromLua<'lua> for IpAddrWrapper { - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { - return if let Some(ud) = value.as_userdata() { - Ok(ud.borrow::()?.clone()) - } else if let Some(s) = value.as_str() { - IpAddr::from_str(s) - .map(IpAddrWrapper) - .map_err(mlua::Error::external) - } else { - Err(mlua::Error::FromLuaConversionError { - from: value.type_name(), - to: "IpAddr", - message: None, - }) - } - } -} - -impl UserData for Cidr { - fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { - fields.add_field_method_get("addr", |_lua, cidr| Ok(IpAddrWrapper(cidr.addr))); - fields.add_field_method_set("addr", |_lua, cidr, value: IpAddrWrapper| { - if cidr.addr.is_ipv4() ^ value.0.is_ipv4() { - return Err(mlua::Error::runtime(format!( - "Cannot assign v{} addr to v{} CIDR", - if value.0.is_ipv4() { 4 } else { 6 }, - if cidr.addr.is_ipv4() { 4 } else { 6 }, - ))); - } - cidr.addr = value.0; - Ok(()) - }); - fields.add_field_method_get("prefix", |_lua, cidr| Ok(cidr.prefix)); - fields.add_field_method_set("prefix", |_lua, cidr, value: u8| { - if value >= if cidr.addr.is_ipv4() { 32 } else { 128 } { - return Err(mlua::Error::runtime("Invalid prefix length")); - } - cidr.prefix = value; - Ok(()) - }); - } - - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method("contains", |_, cidr, addr: IpAddrWrapper| { - let addr = u128::from_be_bytes(addr.to_v6().octets()); - let cidr = cidr.to_v6(); - if cidr.prefix == 0 { - // Global network - return Ok(true) - } - let mask = (!0u128) << (128 - cidr.prefix); - let network = u128::from_be_bytes(IpAddrWrapper(cidr.addr).to_v6().octets()); - - Ok(addr & mask == network & mask) - }); - - methods.add_meta_method("__tostring", |_, cidr, ()| Ok(cidr.to_string())); - methods.add_meta_method("__eq", |_, cidr, other: Cidr| { - let cidr = cidr.to_v6(); - let other = other.to_v6(); - Ok(cidr.prefix == other.prefix && cidr.addr == other.addr) - }); - - methods.add_meta_function("__call", |lua, args: mlua::MultiValue| { - if args.len() == 0 { - Ok(Cidr{addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), prefix: 0}) - } else if args.len() == 1 { - Cidr::from_lua_args(args, 0, Some("stftpd.Cidr"), lua) - } else { - let (IpAddrWrapper(addr), prefix) = <(IpAddrWrapper, u8)>::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)?; - let max_prefix = if addr.is_ipv4() { 32 } else { 64 }; - if prefix > max_prefix { - return Err(mlua::Error::runtime("Invalid prefix")); - } - Ok(Cidr{addr, prefix}) - } - }); - } -} - -impl<'lua> FromLua<'lua> for Cidr { - fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { - return if let Some(ud) = value.as_userdata() { - Ok(ud.borrow::()?.clone()) - } else if let Some(s) = value.as_str() { - let (addr, prefix) = s.split_once("/") - .map(|(a,p)| (a,Some(p))) - .unwrap_or((s, None)); - let addr = IpAddr::from_str(addr)?; - let prefix = prefix.map(u8::from_str) - .transpose() - .map_err(mlua::Error::external)?; - - let max_prefix = if addr.is_ipv4() { 32 } else { 128 }; - let prefix = prefix.unwrap_or(max_prefix); - if prefix > max_prefix { - return Err(mlua::Error::runtime("CIDR prefix to long")) - } - - Ok(Cidr { - addr, prefix - }) - } else { - Err(mlua::Error::FromLuaConversionError { - from: value.type_name(), - to: "Cidr", - message: None, - }) - } - } -} impl EngineImpl { pub(crate) fn init(&mut self) -> anyhow::Result<()> { - let lua = &* self.lua; + let lua = self.lua; lua.load_from_std_lib(mlua::StdLib::ALL)?; { @@ -332,10 +118,7 @@ impl EngineImpl { lua.globals().set("state", lua.create_table()?)?; // Construct data types - let stftpd = lua.create_table()?; - stftpd.set("Cidr", lua.create_proxy::()?)?; - stftpd.set("IpAddr", lua.create_proxy::()?)?; - + ipaddr::register(lua)?; // } diff --git a/src/engine/ipaddr.rs b/src/engine/ipaddr.rs new file mode 100644 index 0000000..f6aa7bc --- /dev/null +++ b/src/engine/ipaddr.rs @@ -0,0 +1,238 @@ +use std::fmt::Display; +use std::net::{IpAddr, Ipv6Addr}; +use std::str::FromStr; +use std::fmt::Formatter; +use mlua::{FromLua, FromLuaMulti, Lua, UserData, UserDataFields, UserDataMethods, Value}; + +#[derive(Copy, Clone)] +struct Cidr { + addr: IpAddr, + prefix: u8, +} + +impl Display for Cidr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/{}", self.addr, self.prefix) + } +} + +impl Cidr { + fn to_v6(&self) -> Self { + match self.addr { + IpAddr::V4(ip) => Cidr { + addr: IpAddr::V6(ip.to_ipv6_mapped()), + prefix: self.prefix + 96, + }, + IpAddr::V6(_) => *self, + } + } +} + +#[repr(transparent)] +#[derive(Copy, Clone)] +struct IpAddrWrapper(IpAddr); + +impl IpAddrWrapper { + fn octet_len(&self) -> usize { + match &self.0 { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, + } + } + + fn to_v6(&self) -> Ipv6Addr { + match self.0 { + IpAddr::V4(v4) => v4.to_ipv6_mapped(), + IpAddr::V6(v6) => v6 + } + } +} + +impl UserData for IpAddrWrapper { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("version", |_, ip| Ok(if ip.0.is_ipv4() { 4} else {6})); + fields.add_field_method_get("is_ipv4", |_, ip| Ok(ip.0.is_ipv4())); + fields.add_field_method_get("is_ipv6", |_, ip| Ok(ip.0.is_ipv6())); + fields.add_field_method_get("bytes", |lua, ip| match &ip.0 { + IpAddr::V4(v4) => lua.create_string(v4.octets()), + IpAddr::V6(v6) => lua.create_string(v6.octets()), + }); + fields.add_meta_field("__name", "IpAddr"); + } + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("to_v6", |_, ip, ()| match &ip.0 { + IpAddr::V4(v4) => Ok(IpAddrWrapper(IpAddr::V6(v4.to_ipv6_mapped()))), + IpAddr::V6(v6) => Ok(IpAddrWrapper(IpAddr::V6(*v6))), + }); + methods.add_meta_method("__index", |_, ip, index: usize| match &ip.0 { + IpAddr::V4(ip) if 1 <= index && index <= 4 => Ok(ip.octets()[index-1]), + IpAddr::V6(ip) if 1 <= index && index <= 16 => Ok(ip.octets()[index-1]), + _ => Err(mlua::Error::runtime("Index out of range")), + }); + methods.add_meta_method_mut("__newindex", |_, ip, (index, value): (usize, u8)| { + match &mut ip.0 { + IpAddr::V4(ip) if 1 <= index && index <= 4 => { + let mut octets = ip.octets(); + octets[index-1] = value; + IpAddr::V4(octets.into()); + }, + IpAddr::V6(ip) if 1 <= index && index <= 16 => { + let mut octets = ip.octets(); + octets[index-1] = value; + IpAddr::V6(octets.into()); + }, + _ => return Err(mlua::Error::runtime("Index out of range")), + }; + Ok(()) + }); + methods.add_meta_method("__tostring", |_, ip, ()| Ok(ip.0.to_string())); + methods.add_meta_function("__call", |lua, args: mlua::MultiValue| Ok(IpAddrWrapper({ + eprintln!("IpAddr __call Received {} args", args.len()); + if args.len() == 4 { + // IPV4 direct bytes + let (a,b,c,d) = <(u8,u8,u8,u8)>::from_lua_args(args, 0, Some("sftpd.IpAddr"), lua)?; + IpAddr::V4([a,b,c,d].into()) + } else if args.len() == 16 { + let mut octets = [0u8;16]; + for i in 0..16 { + octets[i] = u8::from_lua(args.get(i).unwrap().clone(), lua)?; + } + IpAddr::V6(octets.into()) + } else if args.len() == 1 { + return IpAddrWrapper::from_lua(args[0].clone(), lua); + } else { + return Err(mlua::Error::runtime("Invalid arguments to stfptd.IpAddr")) + } + }))); + methods.add_meta_method("__eq", |_, me, other: IpAddrWrapper| Ok(me.to_v6() == other.to_v6())) + } +} + +impl<'lua> FromLua<'lua> for IpAddrWrapper { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { + return if let Some(ud) = value.as_userdata() { + Ok(ud.borrow::()?.clone()) + } else if let Some(s) = value.as_str() { + IpAddr::from_str(s) + .map(IpAddrWrapper) + .map_err(mlua::Error::external) + } else { + Err(mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: "IpAddr", + message: None, + }) + } + } +} + +impl UserData for Cidr { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("addr", |_lua, cidr| Ok(IpAddrWrapper(cidr.addr))); + fields.add_field_method_set("addr", |_lua, cidr, value: IpAddrWrapper| { + if cidr.addr.is_ipv4() ^ value.0.is_ipv4() { + return Err(mlua::Error::runtime(format!( + "Cannot assign v{} addr to v{} CIDR", + if value.0.is_ipv4() { 4 } else { 6 }, + if cidr.addr.is_ipv4() { 4 } else { 6 }, + ))); + } + cidr.addr = value.0; + Ok(()) + }); + fields.add_field_method_get("prefix", |_lua, cidr| Ok(cidr.prefix)); + fields.add_field_method_set("prefix", |_lua, cidr, value: u8| { + if value >= if cidr.addr.is_ipv4() { 32 } else { 128 } { + return Err(mlua::Error::runtime("Invalid prefix length")); + } + cidr.prefix = value; + Ok(()) + }); + } + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("contains", |_, cidr, addr: IpAddrWrapper| { + let addr = u128::from_be_bytes(addr.to_v6().octets()); + let cidr = cidr.to_v6(); + if cidr.prefix == 0 { + // Global network + return Ok(true) + } + let mask = (!0u128) << (128 - cidr.prefix); + let network = u128::from_be_bytes(IpAddrWrapper(cidr.addr).to_v6().octets()); + + Ok(addr & mask == network & mask) + }); + + methods.add_meta_method("__tostring", |_, cidr, ()| Ok(cidr.to_string())); + methods.add_meta_method("__eq", |_, cidr, other: Cidr| { + let cidr = cidr.to_v6(); + let other = other.to_v6(); + Ok(cidr.prefix == other.prefix && cidr.addr == other.addr) + }); + + methods.add_meta_function("__call", |lua, args: mlua::MultiValue| { + if args.len() == 0 { + Ok(Cidr{addr: IpAddr::V6(Ipv6Addr::UNSPECIFIED), prefix: 0}) + } else if args.len() == 1 { + Cidr::from_lua_args(args, 0, Some("stftpd.Cidr"), lua) + } else { + let (IpAddrWrapper(addr), prefix) = <(IpAddrWrapper, u8)>::from_lua_args(args, 0, Some("stftpd.Cidr"), lua)?; + let max_prefix = if addr.is_ipv4() { 32 } else { 64 }; + if prefix > max_prefix { + return Err(mlua::Error::runtime("Invalid prefix")); + } + Ok(Cidr{addr, prefix}) + } + }); + } +} + +impl<'lua> FromLua<'lua> for Cidr { + fn from_lua(value: Value<'lua>, _: &'lua Lua) -> mlua::Result { + return if let Some(ud) = value.as_userdata() { + Ok(ud.borrow::()?.clone()) + } else if let Some(s) = value.as_str() { + let (addr, prefix) = s.split_once("/") + .map(|(a,p)| (a,Some(p))) + .unwrap_or((s, None)); + let addr = IpAddr::from_str(addr)?; + let prefix = prefix.map(u8::from_str) + .transpose() + .map_err(mlua::Error::external)?; + + let max_prefix = if addr.is_ipv4() { 32 } else { 128 }; + let prefix = prefix.unwrap_or(max_prefix); + if prefix > max_prefix { + return Err(mlua::Error::runtime("CIDR prefix to long")) + } + + Ok(Cidr { + addr, prefix + }) + } else { + Err(mlua::Error::FromLuaConversionError { + from: value.type_name(), + to: "Cidr", + message: None, + }) + } + } +} + +pub fn register(lua: &'static Lua) -> anyhow::Result<()> { + let globals = lua.globals(); + let stftpd: Value = globals.get("stftpd")?; + let stftpd = if stftpd.is_nil() { + let newtab = lua.create_table()?; + globals.set("stftpd", newtab.clone())?; + newtab + } else { + stftpd.as_table().unwrap().clone() + }; + stftpd.set("Cidr", lua.create_proxy::()?)?; + stftpd.set("IpAddr", lua.create_proxy::()?)?; + + Ok(()) +}