From a5f94bc4abf0afd10390ca2151aefd56e9a2af9d Mon Sep 17 00:00:00 2001 From: OMGeeky Date: Sun, 17 Nov 2024 17:22:45 +0100 Subject: [PATCH] create macro to not duplicate as much code and improve the PackageId parsing with it --- Cargo.lock | 10 ++ Cargo.toml | 2 + mc-rust-server-macros/Cargo.lock | 47 +++++ mc-rust-server-macros/Cargo.toml | 14 ++ mc-rust-server-macros/src/lib.rs | 9 + mc-rust-server-macros/src/mc_protocol.rs | 78 +++++++++ src/main.rs | 35 +--- src/protocols.rs | 56 ++---- src/protocols/custom_report_details.rs | 2 + src/protocols/handshake.rs | 2 + src/protocols/ping.rs | 2 + src/protocols/status.rs | 3 +- src/types/package.rs | 210 +++++++++++------------ 13 files changed, 287 insertions(+), 183 deletions(-) create mode 100644 mc-rust-server-macros/Cargo.lock create mode 100644 mc-rust-server-macros/Cargo.toml create mode 100644 mc-rust-server-macros/src/lib.rs create mode 100644 mc-rust-server-macros/src/mc_protocol.rs diff --git a/Cargo.lock b/Cargo.lock index 5e8a59e..efe85d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -149,6 +149,7 @@ dependencies = [ name = "mc-rust-server" version = "0.1.0" dependencies = [ + "mc-rust-server-macros", "num-derive", "num-traits", "serde", @@ -158,6 +159,15 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "mc-rust-server-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "memchr" version = "2.7.4" diff --git a/Cargo.toml b/Cargo.toml index 23daa29..a439589 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,8 @@ version = "0.1.0" edition = "2021" [dependencies] +mc-rust-server-macros = { path = "mc-rust-server-macros" } + num-derive = "0.4.2" num-traits = "0.2.19" tokio = { version = "1.41.1", features = ["rt", "rt-multi-thread", "macros", "full", "net"] } diff --git a/mc-rust-server-macros/Cargo.lock b/mc-rust-server-macros/Cargo.lock new file mode 100644 index 0000000..53cc6d0 --- /dev/null +++ b/mc-rust-server-macros/Cargo.lock @@ -0,0 +1,47 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "mc-rust-server-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" diff --git a/mc-rust-server-macros/Cargo.toml b/mc-rust-server-macros/Cargo.toml new file mode 100644 index 0000000..f583329 --- /dev/null +++ b/mc-rust-server-macros/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "mc-rust-server-macros" +version = "0.1.0" +edition = "2021" + +[lib] +name = "mc_rust_server_macros" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +quote = "*" +syn = { version = "*", features = ["extra-traits"] } +proc-macro2 = "*" diff --git a/mc-rust-server-macros/src/lib.rs b/mc-rust-server-macros/src/lib.rs new file mode 100644 index 0000000..fbf58fb --- /dev/null +++ b/mc-rust-server-macros/src/lib.rs @@ -0,0 +1,9 @@ +extern crate proc_macro; + +use syn::{parse_macro_input, DeriveInput}; + +#[proc_macro_derive(McProtocol, attributes(protocol_read, protocol_write))] +pub fn proc_macro_protocol(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + mc_protocol::proc_macro_protocol(parse_macro_input!(input as DeriveInput)).into() +} +mod mc_protocol; diff --git a/mc-rust-server-macros/src/mc_protocol.rs b/mc-rust-server-macros/src/mc_protocol.rs new file mode 100644 index 0000000..6befe1d --- /dev/null +++ b/mc-rust-server-macros/src/mc_protocol.rs @@ -0,0 +1,78 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::quote; +use syn::parse::Parse; +use syn::{Data, DeriveInput, Lit, Token}; + +pub fn proc_macro_protocol(input: DeriveInput) -> TokenStream { + match input.data { + Data::Enum(e) => { + let variants = e.variants.into_iter(); + let mut variant_data = vec![]; + for v in variants { + let attr = v.attrs.into_iter().find_map(|a| { + if a.path().is_ident("protocol_read") { + Some(a.parse_args::().ok()?) + } else { + None + } + }); + if let Some(attr) = attr { + variant_data.push(ProtocolVariant { + attr, + ident: v.ident, + }) + } + } + + let variant_match = variant_data.into_iter().map(|v| { + let state = v.attr.state; + let id = v.attr.packet_id; + let name = v.ident; + quote! {(#state, #id)=>Self::#name (read_protocol_data(stream).await?)} + }); + let variants_match = quote! {#(#variant_match),*}; + + let enum_name = input.ident; + quote! { + #[automatically_derived] + impl #enum_name { + pub async fn read_protocol_data( + protocol_id: VarInt, + connection_state: ConnectionState, + stream: &mut RWStreamWithLimit<'_, T>, + )-> Result { + Ok(match (connection_state, protocol_id.to_rs()){ + #variants_match, + (other_state, other_id) => { + return Err(format!("Unrecognized protocol+state combination: {:?}/{}", other_state, other_id)); + } + }) + } + } + } + } + _ => { + panic!("This macro is only supported for Enums") + } + } +} + +#[derive(Debug)] +struct ProtocolVariant { + attr: ProtocolAttribute, + ident: Ident, +} +#[derive(Debug)] +struct ProtocolAttribute { + state: syn::ExprPath, + packet_id: Lit, +} +impl Parse for ProtocolAttribute { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let state = input.parse()?; + input.parse::()?; + let packet_id = input.parse()?; + + Ok(ProtocolAttribute { state, packet_id }) + } +} diff --git a/src/main.rs b/src/main.rs index 440cc36..c37a488 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ pub mod protocols; pub mod types; pub mod utils; -use crate::protocols::ProtocolId; +use crate::types::package::IncomingPackage; use crate::types::string::McString; use crate::types::var_int::VarInt; use crate::types::{McRead, McRustRepr}; @@ -118,36 +118,9 @@ impl Connection { connection_state: ConnectionState, compression: bool, ) -> Result { - let packet_id = VarInt::read_stream(stream).await?; - - println!( - "Handling new Package with id: {:0>2x} =======================", - packet_id.as_rs() - ); - match ProtocolId::from_id_and_state(packet_id.to_rs(), connection_state) { - Some(protocol) => { - let res = types::package::Package::handle(protocol, stream).await; - match res { - Ok(connection_state_change) => { - println!("Success!"); - if let Some(connection_state_change) = connection_state_change { - return Ok(connection_state_change); - } - } - Err(terminate_connection) => { - if terminate_connection { - return Err("Something terrible has happened!".to_string()); - } else { - stream.discard_unread().await.map_err(|x| x.to_string())?; - } - println!("Failure :("); - } - } - } - None => { - stream.discard_unread().await.map_err(|x| x.to_string())?; - println!("I don't know this protocol yet, so Im gonna ignore it..."); - } + let res = IncomingPackage::handle_incoming(stream, connection_state).await?; + if let Some(new_connection_state) = res { + return Ok(new_connection_state); } Ok(connection_state) } diff --git a/src/protocols.rs b/src/protocols.rs index 5d9e723..ae2d308 100644 --- a/src/protocols.rs +++ b/src/protocols.rs @@ -1,59 +1,27 @@ -use crate::utils::RWStreamWithLimit; +use crate::types::McRead; +use crate::utils::{MyAsyncReadExt, RWStreamWithLimit}; use crate::ConnectionState; use num_derive::{FromPrimitive, ToPrimitive}; use num_traits::FromPrimitive; use tokio::io::{AsyncRead, AsyncWrite}; -#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum NotConnectedProtocolIds { - Handshake = 0x00, -} -#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum StatusProtocolIds { - Status = 0x00, - Ping = 0x01, -} -#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum LoginProtocolIds { - LoginStart = 0x00, - EncryptionResponse = 0x01, - LoginPluginResponse = 0x02, - CookieResponse = 0x04, -} -#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum TransferProtocolIds {} -#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum ConfigurationProtocolIds {} -#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum PlayProtocolIds {} -#[derive(Debug, Copy, Clone)] -pub enum ProtocolId { - NotConnected(NotConnectedProtocolIds), - Status(StatusProtocolIds), - Login(LoginProtocolIds), - Transfer(TransferProtocolIds), - Configuration(ConfigurationProtocolIds), - Play(PlayProtocolIds), - // CustomReportDetails = 0x7a, -} #[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] pub enum ProtocolResponseId { Status = 0x00, Ping = 0x01, } -impl ProtocolId { - pub(crate) fn from_id_and_state(id: i32, state: ConnectionState) -> Option { - Some(match state { - ConnectionState::NotConnected => Self::NotConnected(FromPrimitive::from_i32(id)?), - ConnectionState::Status => Self::Status(FromPrimitive::from_i32(id)?), - ConnectionState::Login => Self::Login(FromPrimitive::from_i32(id)?), - ConnectionState::Transfer => Self::Transfer(FromPrimitive::from_i32(id)?), - ConnectionState::Configuration => Self::Configuration(FromPrimitive::from_i32(id)?), - ConnectionState::Play => Self::Play(FromPrimitive::from_i32(id)?), - ConnectionState::Closed => return None, - }) +#[derive(Debug, Clone)] +pub struct NotImplementedData {} +impl McRead for NotImplementedData { + async fn read_stream(stream: &mut T) -> Result + where + Self: Sized, + { + Err("Did not implement this protocol yet".to_string()) } } + +impl crate::types::package::ProtocolDataMarker for NotImplementedData {} pub(crate) mod custom_report_details; pub(crate) mod handshake; pub(crate) mod ping; diff --git a/src/protocols/custom_report_details.rs b/src/protocols/custom_report_details.rs index 88a1203..b0ab449 100644 --- a/src/protocols/custom_report_details.rs +++ b/src/protocols/custom_report_details.rs @@ -11,6 +11,8 @@ pub struct Protocol {} pub struct Data { details: Vec<(McString<128>, McString<4096>)>, } +impl crate::types::package::ProtocolDataMarker for Data {} + impl McRead for Data { async fn read_stream(stream: &mut T) -> Result where diff --git a/src/protocols/handshake.rs b/src/protocols/handshake.rs index 86c3c04..ad3bd52 100644 --- a/src/protocols/handshake.rs +++ b/src/protocols/handshake.rs @@ -15,6 +15,8 @@ pub struct Data { server_port: u16, pub(crate) next_state: ConnectionState, } +impl crate::types::package::ProtocolDataMarker for Data {} + impl McRead for Data { async fn read_stream(b: &mut T) -> Result { println!("Reading Handshake"); diff --git a/src/protocols/ping.rs b/src/protocols/ping.rs index c94a126..87b0a90 100644 --- a/src/protocols/ping.rs +++ b/src/protocols/ping.rs @@ -1,4 +1,5 @@ use crate::types::long::Long; +use crate::types::package::ProtocolData; use crate::types::string::McString; use crate::types::var_int::VarInt; use crate::types::var_long::VarLong; @@ -23,6 +24,7 @@ impl McRead for Data { } } +impl crate::types::package::ProtocolDataMarker for Data {} #[derive(Debug, Clone)] pub struct ResponseData { pub(crate) timespan: Long, diff --git a/src/protocols/status.rs b/src/protocols/status.rs index b1ee56e..c5c281a 100644 --- a/src/protocols/status.rs +++ b/src/protocols/status.rs @@ -1,5 +1,5 @@ use crate::types::long::Long; -use crate::types::package::{OutgoingPackage, OutgoingPackageContent, Package}; +use crate::types::package::{OutgoingPackage, OutgoingPackageContent, Package, ProtocolData}; use crate::types::string::McString; use crate::types::var_int::VarInt; use crate::types::var_long::VarLong; @@ -21,6 +21,7 @@ impl McRead for Data { Ok(Self {}) } } +impl crate::types::package::ProtocolDataMarker for Data {} #[derive(Debug, Clone)] pub struct ResponseData { diff --git a/src/types/package.rs b/src/types/package.rs index 6d05c66..0121722 100644 --- a/src/types/package.rs +++ b/src/types/package.rs @@ -1,12 +1,10 @@ -use crate::protocols::{ - self, LoginProtocolIds, NotConnectedProtocolIds, ProtocolId, ProtocolResponseId, - StatusProtocolIds, -}; +use crate::protocols::{self, NotImplementedData, ProtocolResponseId}; use crate::types::string::McString; use crate::types::var_int::VarInt; -use crate::types::{McRead, McWrite}; +use crate::types::{McRead, McRustRepr, McWrite}; use crate::utils::RWStreamWithLimit; -use crate::ConnectionState; +use crate::{types, ConnectionState}; +use mc_rust_server_macros::McProtocol; use num_traits::ToPrimitive; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; @@ -17,7 +15,6 @@ pub enum Package { } #[derive(Debug, Clone)] pub struct IncomingPackage { - pub(crate) protocol: ProtocolId, pub(crate) content: IncomingPackageContent, } #[derive(Debug, Clone)] @@ -25,19 +22,24 @@ pub struct OutgoingPackage { pub(crate) protocol: ProtocolResponseId, pub(crate) content: OutgoingPackageContent, } -impl OutgoingPackage { - pub fn empty() {} -} -#[derive(Debug, Clone)] +#[derive(Debug, Clone, McProtocol)] pub enum IncomingPackageContent { + #[protocol_read(ConnectionState::NotConnected, 0x00)] Handshake(protocols::handshake::Data), + #[protocol_read(ConnectionState::Status, 0x00)] Status(protocols::status::Data), + #[protocol_read(ConnectionState::Status, 0x01)] Ping(protocols::ping::Data), + #[protocol_read(ConnectionState::Play, 0x7a)] CustomReportDetails(protocols::custom_report_details::Data), - LoginStart(), - EncryptionResponse(), - LoginPluginResponse(), - CookieResponse(), + #[protocol_read(ConnectionState::Login, 0x00)] + LoginStart(NotImplementedData), + #[protocol_read(ConnectionState::Login, 0x01)] + EncryptionResponse(NotImplementedData), + #[protocol_read(ConnectionState::Login, 0x02)] + LoginPluginResponse(NotImplementedData), + #[protocol_read(ConnectionState::Login, 0x03)] + CookieResponse(NotImplementedData), } #[derive(Debug, Clone)] pub enum OutgoingPackageContent { @@ -109,48 +111,76 @@ impl McWrite for OutgoingPackage { } } impl IncomingPackage { + pub(crate) async fn handle_incoming( + stream: &mut RWStreamWithLimit<'_, T>, + connection_state: ConnectionState, + ) -> Result, String> { + let packet_id = VarInt::read_stream(stream) + .await + .map_err(|e| e.to_string())?; + + println!( + "Handling new Package with id: {:0>2x} =======================", + packet_id.as_rs() + ); + + let incoming_content = + IncomingPackageContent::read_protocol_data(packet_id, connection_state, stream).await; + + let incoming = IncomingPackage { + content: match incoming_content { + Ok(incoming) => incoming, + Err(e) => { + stream.discard_unread().await.map_err(|x| x.to_string())?; + return Err(e); + } + }, + }; + let res = incoming.answer(stream).await; + match res { + Ok(connection_state_change) => { + println!("Success!"); + return Ok(connection_state_change); + } + Err(terminate_connection) => { + if terminate_connection { + return Err("Something terrible has happened!".to_string()); + } else { + stream.discard_unread().await.map_err(|x| x.to_string())?; + } + println!("Failure :("); + } + } + + Ok(None) + } async fn answer( &self, stream: &mut RWStreamWithLimit<'_, T>, ) -> Result, bool> { - let (answer, changed_connection_state) = match self.protocol { - ProtocolId::NotConnected(protocol_id) => match &self.content { - (IncomingPackageContent::Handshake(handshake_data)) => { - (None, Some(handshake_data.next_state)) - } - _ => (None, None), //Ignore all packets that do not belong here - }, - ProtocolId::Status(protocol_id) => match &self.content { - IncomingPackageContent::Status(_) => ( - Some(OutgoingPackage { - protocol: ProtocolResponseId::Status, - content: OutgoingPackageContent::StatusResponse( - protocols::status::ResponseData::default(), - ), + let (answer, changed_connection_state) = match &self.content { + (IncomingPackageContent::Handshake(handshake_data)) => { + (None, Some(handshake_data.next_state)) + } + IncomingPackageContent::Status(_) => ( + Some(OutgoingPackage { + protocol: ProtocolResponseId::Status, + content: OutgoingPackageContent::StatusResponse( + protocols::status::ResponseData::default(), + ), + }), + None, + ), + (IncomingPackageContent::Ping(ping_data)) => ( + Some(OutgoingPackage { + protocol: ProtocolResponseId::Ping, + content: OutgoingPackageContent::PingResponse(protocols::ping::ResponseData { + timespan: ping_data.timespan, }), - None, - ), - (IncomingPackageContent::Ping(ping_data)) => ( - Some(OutgoingPackage { - protocol: ProtocolResponseId::Ping, - content: OutgoingPackageContent::PingResponse( - protocols::ping::ResponseData { - timespan: ping_data.timespan, - }, - ), - }), - None, - ), - _ => (None, None), //Ignore all packets that do not belong here - }, - - // ProtocolId::Login(protocol_id) => match (protocol_id, &self.content) { - // (LoginProtocolIds::LoginStart, _) => (None, None), - // (LoginProtocolIds::EncryptionResponse, _) => (None, None), - // (LoginProtocolIds::LoginPluginResponse, _) => (None, None), - // (LoginProtocolIds::CookieResponse, _) => (None, None), - // }, - _ => (None, None), //TODO: implement the other ProtocolId variants (based on current ConnectionState) + }), + None, + ), + _ => (None, None), //TODO: implement the other IncomingPackageContent variants }; if let Some(outgoing_package) = answer { outgoing_package.write_stream(stream).await.map_err(|e| { @@ -161,61 +191,27 @@ impl IncomingPackage { Ok(changed_connection_state) } } -impl Package { - pub(crate) async fn handle( - protocol_id: ProtocolId, - stream: &mut RWStreamWithLimit<'_, T>, - ) -> Result, bool> { - let incoming_content = read_data(protocol_id, stream).await.map_err(|e| { - dbg!(e); - true - })?; - let incoming = IncomingPackage { - protocol: protocol_id, - content: incoming_content, - }; - incoming.answer(stream).await + +pub(crate) trait ProtocolDataMarker {} +pub(crate) trait ProtocolData { + async fn read_data(stream: &mut T) -> Result + where + Self: Sized; +} +impl ProtocolData for T +where + T: ProtocolDataMarker + McRead, +{ + async fn read_data(stream: &mut Stream) -> Result + where + Self: Sized, + { + Self::read_stream(stream).await } } - -pub async fn read_data( - protocol_id: ProtocolId, - stream: &mut RWStreamWithLimit<'_, T>, -) -> Result { - Ok(match protocol_id { - ProtocolId::NotConnected(protocol_id) => match protocol_id { - NotConnectedProtocolIds::Handshake => IncomingPackageContent::Handshake( - protocols::handshake::Data::read_stream(stream).await?, - ), - }, - ProtocolId::Status(protocol_id) => match protocol_id { - StatusProtocolIds::Status => { - IncomingPackageContent::Status(protocols::status::Data::read_stream(stream).await?) - } - StatusProtocolIds::Ping => { - IncomingPackageContent::Ping(protocols::ping::Data::read_stream(stream).await?) - } - }, - ProtocolId::Login(protocol_id) => match protocol_id { - LoginProtocolIds::LoginStart => { - stream.discard_unread().await.map_err(|e| e.to_string())?; - IncomingPackageContent::LoginStart() - } - LoginProtocolIds::EncryptionResponse => { - stream.discard_unread().await.map_err(|e| e.to_string())?; - IncomingPackageContent::EncryptionResponse() - } - LoginProtocolIds::LoginPluginResponse => { - stream.discard_unread().await.map_err(|e| e.to_string())?; - IncomingPackageContent::LoginPluginResponse() - } - LoginProtocolIds::CookieResponse => { - stream.discard_unread().await.map_err(|e| e.to_string())?; - IncomingPackageContent::CookieResponse() - } - }, - ProtocolId::Transfer(protocol_id) => match protocol_id {}, - ProtocolId::Configuration(protocol_id) => match protocol_id {}, - ProtocolId::Play(protocol_id) => match protocol_id {}, - }) +async fn read_protocol_data(stream: &mut T) -> Result +where + S: Sized + ProtocolData, +{ + S::read_data::(stream).await }