create macro to not duplicate as much code and improve the PackageId parsing with it

This commit is contained in:
OMGeeky
2024-11-17 17:22:45 +01:00
parent ef8f7e8383
commit a5f94bc4ab
13 changed files with 287 additions and 183 deletions

10
Cargo.lock generated
View File

@@ -149,6 +149,7 @@ dependencies = [
name = "mc-rust-server"
version = "0.1.0"
dependencies = [
"mc-rust-server-macros",
"num-derive",
"num-traits",
"serde",
@@ -158,6 +159,15 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "mc-rust-server-macros"
version = "0.1.0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "memchr"
version = "2.7.4"

View File

@@ -4,6 +4,8 @@ version = "0.1.0"
edition = "2021"
[dependencies]
mc-rust-server-macros = { path = "mc-rust-server-macros" }
num-derive = "0.4.2"
num-traits = "0.2.19"
tokio = { version = "1.41.1", features = ["rt", "rt-multi-thread", "macros", "full", "net"] }

47
mc-rust-server-macros/Cargo.lock generated Normal file
View File

@@ -0,0 +1,47 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "mc-rust-server-macros"
version = "0.1.0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [
"proc-macro2",
]
[[package]]
name = "syn"
version = "2.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"

View File

@@ -0,0 +1,14 @@
[package]
name = "mc-rust-server-macros"
version = "0.1.0"
edition = "2021"
[lib]
name = "mc_rust_server_macros"
path = "src/lib.rs"
proc-macro = true
[dependencies]
quote = "*"
syn = { version = "*", features = ["extra-traits"] }
proc-macro2 = "*"

View File

@@ -0,0 +1,9 @@
extern crate proc_macro;
use syn::{parse_macro_input, DeriveInput};
#[proc_macro_derive(McProtocol, attributes(protocol_read, protocol_write))]
pub fn proc_macro_protocol(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
mc_protocol::proc_macro_protocol(parse_macro_input!(input as DeriveInput)).into()
}
mod mc_protocol;

View File

@@ -0,0 +1,78 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::parse::Parse;
use syn::{Data, DeriveInput, Lit, Token};
pub fn proc_macro_protocol(input: DeriveInput) -> TokenStream {
match input.data {
Data::Enum(e) => {
let variants = e.variants.into_iter();
let mut variant_data = vec![];
for v in variants {
let attr = v.attrs.into_iter().find_map(|a| {
if a.path().is_ident("protocol_read") {
Some(a.parse_args::<ProtocolAttribute>().ok()?)
} else {
None
}
});
if let Some(attr) = attr {
variant_data.push(ProtocolVariant {
attr,
ident: v.ident,
})
}
}
let variant_match = variant_data.into_iter().map(|v| {
let state = v.attr.state;
let id = v.attr.packet_id;
let name = v.ident;
quote! {(#state, #id)=>Self::#name (read_protocol_data(stream).await?)}
});
let variants_match = quote! {#(#variant_match),*};
let enum_name = input.ident;
quote! {
#[automatically_derived]
impl #enum_name {
pub async fn read_protocol_data<T:AsyncRead + AsyncWrite + Unpin>(
protocol_id: VarInt,
connection_state: ConnectionState,
stream: &mut RWStreamWithLimit<'_, T>,
)-> Result<IncomingPackageContent, String> {
Ok(match (connection_state, protocol_id.to_rs()){
#variants_match,
(other_state, other_id) => {
return Err(format!("Unrecognized protocol+state combination: {:?}/{}", other_state, other_id));
}
})
}
}
}
}
_ => {
panic!("This macro is only supported for Enums")
}
}
}
#[derive(Debug)]
struct ProtocolVariant {
attr: ProtocolAttribute,
ident: Ident,
}
#[derive(Debug)]
struct ProtocolAttribute {
state: syn::ExprPath,
packet_id: Lit,
}
impl Parse for ProtocolAttribute {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let state = input.parse()?;
input.parse::<Token![,]>()?;
let packet_id = input.parse()?;
Ok(ProtocolAttribute { state, packet_id })
}
}

View File

@@ -3,7 +3,7 @@ pub mod protocols;
pub mod types;
pub mod utils;
use crate::protocols::ProtocolId;
use crate::types::package::IncomingPackage;
use crate::types::string::McString;
use crate::types::var_int::VarInt;
use crate::types::{McRead, McRustRepr};
@@ -118,36 +118,9 @@ impl Connection {
connection_state: ConnectionState,
compression: bool,
) -> Result<ConnectionState, String> {
let packet_id = VarInt::read_stream(stream).await?;
println!(
"Handling new Package with id: {:0>2x} =======================",
packet_id.as_rs()
);
match ProtocolId::from_id_and_state(packet_id.to_rs(), connection_state) {
Some(protocol) => {
let res = types::package::Package::handle(protocol, stream).await;
match res {
Ok(connection_state_change) => {
println!("Success!");
if let Some(connection_state_change) = connection_state_change {
return Ok(connection_state_change);
}
}
Err(terminate_connection) => {
if terminate_connection {
return Err("Something terrible has happened!".to_string());
} else {
stream.discard_unread().await.map_err(|x| x.to_string())?;
}
println!("Failure :(");
}
}
}
None => {
stream.discard_unread().await.map_err(|x| x.to_string())?;
println!("I don't know this protocol yet, so Im gonna ignore it...");
}
let res = IncomingPackage::handle_incoming(stream, connection_state).await?;
if let Some(new_connection_state) = res {
return Ok(new_connection_state);
}
Ok(connection_state)
}

View File

@@ -1,59 +1,27 @@
use crate::utils::RWStreamWithLimit;
use crate::types::McRead;
use crate::utils::{MyAsyncReadExt, RWStreamWithLimit};
use crate::ConnectionState;
use num_derive::{FromPrimitive, ToPrimitive};
use num_traits::FromPrimitive;
use tokio::io::{AsyncRead, AsyncWrite};
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum NotConnectedProtocolIds {
Handshake = 0x00,
}
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum StatusProtocolIds {
Status = 0x00,
Ping = 0x01,
}
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum LoginProtocolIds {
LoginStart = 0x00,
EncryptionResponse = 0x01,
LoginPluginResponse = 0x02,
CookieResponse = 0x04,
}
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum TransferProtocolIds {}
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum ConfigurationProtocolIds {}
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum PlayProtocolIds {}
#[derive(Debug, Copy, Clone)]
pub enum ProtocolId {
NotConnected(NotConnectedProtocolIds),
Status(StatusProtocolIds),
Login(LoginProtocolIds),
Transfer(TransferProtocolIds),
Configuration(ConfigurationProtocolIds),
Play(PlayProtocolIds),
// CustomReportDetails = 0x7a,
}
#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)]
pub enum ProtocolResponseId {
Status = 0x00,
Ping = 0x01,
}
impl ProtocolId {
pub(crate) fn from_id_and_state(id: i32, state: ConnectionState) -> Option<Self> {
Some(match state {
ConnectionState::NotConnected => Self::NotConnected(FromPrimitive::from_i32(id)?),
ConnectionState::Status => Self::Status(FromPrimitive::from_i32(id)?),
ConnectionState::Login => Self::Login(FromPrimitive::from_i32(id)?),
ConnectionState::Transfer => Self::Transfer(FromPrimitive::from_i32(id)?),
ConnectionState::Configuration => Self::Configuration(FromPrimitive::from_i32(id)?),
ConnectionState::Play => Self::Play(FromPrimitive::from_i32(id)?),
ConnectionState::Closed => return None,
})
#[derive(Debug, Clone)]
pub struct NotImplementedData {}
impl McRead for NotImplementedData {
async fn read_stream<T: AsyncRead + Unpin>(stream: &mut T) -> Result<Self, String>
where
Self: Sized,
{
Err("Did not implement this protocol yet".to_string())
}
}
impl crate::types::package::ProtocolDataMarker for NotImplementedData {}
pub(crate) mod custom_report_details;
pub(crate) mod handshake;
pub(crate) mod ping;

View File

@@ -11,6 +11,8 @@ pub struct Protocol {}
pub struct Data {
details: Vec<(McString<128>, McString<4096>)>,
}
impl crate::types::package::ProtocolDataMarker for Data {}
impl McRead for Data {
async fn read_stream<T: AsyncRead + Unpin>(stream: &mut T) -> Result<Self, String>
where

View File

@@ -15,6 +15,8 @@ pub struct Data {
server_port: u16,
pub(crate) next_state: ConnectionState,
}
impl crate::types::package::ProtocolDataMarker for Data {}
impl McRead for Data {
async fn read_stream<T: AsyncRead + Unpin>(b: &mut T) -> Result<Self, String> {
println!("Reading Handshake");

View File

@@ -1,4 +1,5 @@
use crate::types::long::Long;
use crate::types::package::ProtocolData;
use crate::types::string::McString;
use crate::types::var_int::VarInt;
use crate::types::var_long::VarLong;
@@ -23,6 +24,7 @@ impl McRead for Data {
}
}
impl crate::types::package::ProtocolDataMarker for Data {}
#[derive(Debug, Clone)]
pub struct ResponseData {
pub(crate) timespan: Long,

View File

@@ -1,5 +1,5 @@
use crate::types::long::Long;
use crate::types::package::{OutgoingPackage, OutgoingPackageContent, Package};
use crate::types::package::{OutgoingPackage, OutgoingPackageContent, Package, ProtocolData};
use crate::types::string::McString;
use crate::types::var_int::VarInt;
use crate::types::var_long::VarLong;
@@ -21,6 +21,7 @@ impl McRead for Data {
Ok(Self {})
}
}
impl crate::types::package::ProtocolDataMarker for Data {}
#[derive(Debug, Clone)]
pub struct ResponseData {

View File

@@ -1,12 +1,10 @@
use crate::protocols::{
self, LoginProtocolIds, NotConnectedProtocolIds, ProtocolId, ProtocolResponseId,
StatusProtocolIds,
};
use crate::protocols::{self, NotImplementedData, ProtocolResponseId};
use crate::types::string::McString;
use crate::types::var_int::VarInt;
use crate::types::{McRead, McWrite};
use crate::types::{McRead, McRustRepr, McWrite};
use crate::utils::RWStreamWithLimit;
use crate::ConnectionState;
use crate::{types, ConnectionState};
use mc_rust_server_macros::McProtocol;
use num_traits::ToPrimitive;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter};
@@ -17,7 +15,6 @@ pub enum Package {
}
#[derive(Debug, Clone)]
pub struct IncomingPackage {
pub(crate) protocol: ProtocolId,
pub(crate) content: IncomingPackageContent,
}
#[derive(Debug, Clone)]
@@ -25,19 +22,24 @@ pub struct OutgoingPackage {
pub(crate) protocol: ProtocolResponseId,
pub(crate) content: OutgoingPackageContent,
}
impl OutgoingPackage {
pub fn empty() {}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, McProtocol)]
pub enum IncomingPackageContent {
#[protocol_read(ConnectionState::NotConnected, 0x00)]
Handshake(protocols::handshake::Data),
#[protocol_read(ConnectionState::Status, 0x00)]
Status(protocols::status::Data),
#[protocol_read(ConnectionState::Status, 0x01)]
Ping(protocols::ping::Data),
#[protocol_read(ConnectionState::Play, 0x7a)]
CustomReportDetails(protocols::custom_report_details::Data),
LoginStart(),
EncryptionResponse(),
LoginPluginResponse(),
CookieResponse(),
#[protocol_read(ConnectionState::Login, 0x00)]
LoginStart(NotImplementedData),
#[protocol_read(ConnectionState::Login, 0x01)]
EncryptionResponse(NotImplementedData),
#[protocol_read(ConnectionState::Login, 0x02)]
LoginPluginResponse(NotImplementedData),
#[protocol_read(ConnectionState::Login, 0x03)]
CookieResponse(NotImplementedData),
}
#[derive(Debug, Clone)]
pub enum OutgoingPackageContent {
@@ -109,48 +111,76 @@ impl McWrite for OutgoingPackage {
}
}
impl IncomingPackage {
pub(crate) async fn handle_incoming<T: AsyncRead + AsyncWrite + Unpin>(
stream: &mut RWStreamWithLimit<'_, T>,
connection_state: ConnectionState,
) -> Result<Option<ConnectionState>, String> {
let packet_id = VarInt::read_stream(stream)
.await
.map_err(|e| e.to_string())?;
println!(
"Handling new Package with id: {:0>2x} =======================",
packet_id.as_rs()
);
let incoming_content =
IncomingPackageContent::read_protocol_data(packet_id, connection_state, stream).await;
let incoming = IncomingPackage {
content: match incoming_content {
Ok(incoming) => incoming,
Err(e) => {
stream.discard_unread().await.map_err(|x| x.to_string())?;
return Err(e);
}
},
};
let res = incoming.answer(stream).await;
match res {
Ok(connection_state_change) => {
println!("Success!");
return Ok(connection_state_change);
}
Err(terminate_connection) => {
if terminate_connection {
return Err("Something terrible has happened!".to_string());
} else {
stream.discard_unread().await.map_err(|x| x.to_string())?;
}
println!("Failure :(");
}
}
Ok(None)
}
async fn answer<T: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: &mut RWStreamWithLimit<'_, T>,
) -> Result<Option<ConnectionState>, bool> {
let (answer, changed_connection_state) = match self.protocol {
ProtocolId::NotConnected(protocol_id) => match &self.content {
(IncomingPackageContent::Handshake(handshake_data)) => {
(None, Some(handshake_data.next_state))
}
_ => (None, None), //Ignore all packets that do not belong here
},
ProtocolId::Status(protocol_id) => match &self.content {
IncomingPackageContent::Status(_) => (
Some(OutgoingPackage {
protocol: ProtocolResponseId::Status,
content: OutgoingPackageContent::StatusResponse(
protocols::status::ResponseData::default(),
),
let (answer, changed_connection_state) = match &self.content {
(IncomingPackageContent::Handshake(handshake_data)) => {
(None, Some(handshake_data.next_state))
}
IncomingPackageContent::Status(_) => (
Some(OutgoingPackage {
protocol: ProtocolResponseId::Status,
content: OutgoingPackageContent::StatusResponse(
protocols::status::ResponseData::default(),
),
}),
None,
),
(IncomingPackageContent::Ping(ping_data)) => (
Some(OutgoingPackage {
protocol: ProtocolResponseId::Ping,
content: OutgoingPackageContent::PingResponse(protocols::ping::ResponseData {
timespan: ping_data.timespan,
}),
None,
),
(IncomingPackageContent::Ping(ping_data)) => (
Some(OutgoingPackage {
protocol: ProtocolResponseId::Ping,
content: OutgoingPackageContent::PingResponse(
protocols::ping::ResponseData {
timespan: ping_data.timespan,
},
),
}),
None,
),
_ => (None, None), //Ignore all packets that do not belong here
},
// ProtocolId::Login(protocol_id) => match (protocol_id, &self.content) {
// (LoginProtocolIds::LoginStart, _) => (None, None),
// (LoginProtocolIds::EncryptionResponse, _) => (None, None),
// (LoginProtocolIds::LoginPluginResponse, _) => (None, None),
// (LoginProtocolIds::CookieResponse, _) => (None, None),
// },
_ => (None, None), //TODO: implement the other ProtocolId variants (based on current ConnectionState)
}),
None,
),
_ => (None, None), //TODO: implement the other IncomingPackageContent variants
};
if let Some(outgoing_package) = answer {
outgoing_package.write_stream(stream).await.map_err(|e| {
@@ -161,61 +191,27 @@ impl IncomingPackage {
Ok(changed_connection_state)
}
}
impl Package {
pub(crate) async fn handle<T: AsyncRead + AsyncWrite + Unpin>(
protocol_id: ProtocolId,
stream: &mut RWStreamWithLimit<'_, T>,
) -> Result<Option<ConnectionState>, bool> {
let incoming_content = read_data(protocol_id, stream).await.map_err(|e| {
dbg!(e);
true
})?;
let incoming = IncomingPackage {
protocol: protocol_id,
content: incoming_content,
};
incoming.answer(stream).await
pub(crate) trait ProtocolDataMarker {}
pub(crate) trait ProtocolData {
async fn read_data<T: AsyncRead + Unpin>(stream: &mut T) -> Result<Self, String>
where
Self: Sized;
}
impl<T> ProtocolData for T
where
T: ProtocolDataMarker + McRead,
{
async fn read_data<Stream: AsyncRead + Unpin>(stream: &mut Stream) -> Result<Self, String>
where
Self: Sized,
{
Self::read_stream(stream).await
}
}
pub async fn read_data<T: AsyncRead + AsyncWrite + Unpin>(
protocol_id: ProtocolId,
stream: &mut RWStreamWithLimit<'_, T>,
) -> Result<IncomingPackageContent, String> {
Ok(match protocol_id {
ProtocolId::NotConnected(protocol_id) => match protocol_id {
NotConnectedProtocolIds::Handshake => IncomingPackageContent::Handshake(
protocols::handshake::Data::read_stream(stream).await?,
),
},
ProtocolId::Status(protocol_id) => match protocol_id {
StatusProtocolIds::Status => {
IncomingPackageContent::Status(protocols::status::Data::read_stream(stream).await?)
}
StatusProtocolIds::Ping => {
IncomingPackageContent::Ping(protocols::ping::Data::read_stream(stream).await?)
}
},
ProtocolId::Login(protocol_id) => match protocol_id {
LoginProtocolIds::LoginStart => {
stream.discard_unread().await.map_err(|e| e.to_string())?;
IncomingPackageContent::LoginStart()
}
LoginProtocolIds::EncryptionResponse => {
stream.discard_unread().await.map_err(|e| e.to_string())?;
IncomingPackageContent::EncryptionResponse()
}
LoginProtocolIds::LoginPluginResponse => {
stream.discard_unread().await.map_err(|e| e.to_string())?;
IncomingPackageContent::LoginPluginResponse()
}
LoginProtocolIds::CookieResponse => {
stream.discard_unread().await.map_err(|e| e.to_string())?;
IncomingPackageContent::CookieResponse()
}
},
ProtocolId::Transfer(protocol_id) => match protocol_id {},
ProtocolId::Configuration(protocol_id) => match protocol_id {},
ProtocolId::Play(protocol_id) => match protocol_id {},
})
async fn read_protocol_data<S, T: AsyncRead + Unpin>(stream: &mut T) -> Result<S, String>
where
S: Sized + ProtocolData,
{
S::read_data::<T>(stream).await
}