From ef8f7e8383f36a8e76ba9dec1553d844e85a5141 Mon Sep 17 00:00:00 2001 From: OMGeeky Date: Sat, 16 Nov 2024 14:56:42 +0100 Subject: [PATCH] make package_id parsing connection state dependent --- src/main.rs | 57 +++++------------- src/protocols.rs | 55 +++++++++++++----- src/types/package.rs | 135 +++++++++++++++++++++++++++++-------------- 3 files changed, 148 insertions(+), 99 deletions(-) diff --git a/src/main.rs b/src/main.rs index 66cd513..440cc36 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ pub mod protocols; pub mod types; pub mod utils; +use crate::protocols::ProtocolId; use crate::types::string::McString; use crate::types::var_int::VarInt; use crate::types::{McRead, McRustRepr}; @@ -34,11 +35,13 @@ async fn main() -> Result<(), ()> { } } #[derive(FromPrimitive, Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] -enum ConnectionState { +pub(crate) enum ConnectionState { NotConnected = 0, Status = 1, Login = 2, Transfer = 3, + Configuration = 4, + Play = 5, ///Internal use Closed = -1, } @@ -48,6 +51,10 @@ struct Connection { compression_active: bool, } impl Connection { + async fn shutdown(&mut self) { + self.connection_state = ConnectionState::Closed; + self.tcp_stream.shutdown().await; + } async fn handle(&mut self) -> Result<(), String> { while self.connection_state != ConnectionState::Closed { let x = self.tcp_stream.peek(&mut [0]).await; //see if we have at least one byte available @@ -72,11 +79,7 @@ impl Connection { if *length == 0xFE { //Legacy Ping (see https://wiki.vg/Server_List_Ping#1.6) let x = handle_legacy_ping(&mut self.tcp_stream).await; - self.connection_state = ConnectionState::Closed; - self.tcp_stream.shutdown().await.map_err(|e| { - dbg!(e); - "?" - })?; + self.shutdown().await; continue; } println!("packet length: {}", length.as_rs()); @@ -102,8 +105,7 @@ impl Connection { self.connection_state = new_connection_state; } Err(e) => { - self.connection_state = ConnectionState::Closed; - dbg!(&self.tcp_stream.shutdown().await); + self.shutdown().await; println!("Got an error during package handling: {e}"); } } @@ -111,32 +113,6 @@ impl Connection { Ok(()) } - async fn handshake( - stream: &mut T, - _compression: bool, - // bytes_left_in_package: &mut i32, - ) -> Result { - let handshake_data = protocols::handshake::Data::read_stream(stream).await?; - // dbg!(&handshake_data); - Ok(handshake_data.next_state) - // let protocol_version = VarInt::read_stream(stream).await?; - // println!("protocol version: {}", protocol_version.as_rs()); - // let address: McString<255> = McString::read_stream(stream) - // .await - // .map_err(|_| "Could not read string".to_string())?; - // println!("address: '{}'", address.as_rs()); - // stream.discard(2).await.unwrap(); //server port. Unused - // let next_state_id = VarInt::read_stream(stream).await?; - // println!("next state: {}", next_state_id.as_rs()); - // let next_state = FromPrimitive::from_i32(next_state_id.to_rs()); - // match next_state { - // Some(next_state) => Ok(next_state), - // None => Err(format!( - // "Got an unknown next state: {}", - // next_state_id.as_rs() - // )), - // } - } async fn handle_package( stream: &mut RWStreamWithLimit<'_, T>, connection_state: ConnectionState, @@ -148,16 +124,15 @@ impl Connection { "Handling new Package with id: {:0>2x} =======================", packet_id.as_rs() ); - if connection_state == ConnectionState::NotConnected && packet_id.to_rs() == 0x00 { - return Self::handshake(stream, compression).await; - } - match FromPrimitive::from_i32(packet_id.to_rs()) { + match ProtocolId::from_id_and_state(packet_id.to_rs(), connection_state) { Some(protocol) => { let res = types::package::Package::handle(protocol, stream).await; - // let res = protocols::handle(protocol, stream).await; match res { - Ok(_) => { + Ok(connection_state_change) => { println!("Success!"); + if let Some(connection_state_change) = connection_state_change { + return Ok(connection_state_change); + } } Err(terminate_connection) => { if terminate_connection { @@ -171,8 +146,6 @@ impl Connection { } None => { stream.discard_unread().await.map_err(|x| x.to_string())?; - // *bytes_left_in_package -= discard_read(stream, *bytes_left_in_package as u8) - // .map_err(|x| x.to_string())? as i32; println!("I don't know this protocol yet, so Im gonna ignore it..."); } } diff --git a/src/protocols.rs b/src/protocols.rs index 837c174..5d9e723 100644 --- a/src/protocols.rs +++ b/src/protocols.rs @@ -1,29 +1,58 @@ use crate::utils::RWStreamWithLimit; +use crate::ConnectionState; use num_derive::{FromPrimitive, ToPrimitive}; +use num_traits::FromPrimitive; use tokio::io::{AsyncRead, AsyncWrite}; #[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] -pub enum ProtocolId { +pub enum NotConnectedProtocolIds { + Handshake = 0x00, +} +#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] +pub enum StatusProtocolIds { Status = 0x00, Ping = 0x01, - CustomReportDetails = 0x7a, +} +#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] +pub enum LoginProtocolIds { + LoginStart = 0x00, + EncryptionResponse = 0x01, + LoginPluginResponse = 0x02, + CookieResponse = 0x04, +} +#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] +pub enum TransferProtocolIds {} +#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] +pub enum ConfigurationProtocolIds {} +#[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] +pub enum PlayProtocolIds {} +#[derive(Debug, Copy, Clone)] +pub enum ProtocolId { + NotConnected(NotConnectedProtocolIds), + Status(StatusProtocolIds), + Login(LoginProtocolIds), + Transfer(TransferProtocolIds), + Configuration(ConfigurationProtocolIds), + Play(PlayProtocolIds), + // CustomReportDetails = 0x7a, } #[derive(FromPrimitive, ToPrimitive, Debug, Copy, Clone)] pub enum ProtocolResponseId { Status = 0x00, Ping = 0x01, } -pub async fn handle( - protocol: ProtocolId, - stream: &mut RWStreamWithLimit<'_, T>, - // bytes_left_in_package: &mut i32, -) -> Result<(), bool> { - match protocol { - ProtocolId::Status => status::Protocol::handle(stream).await?, - ProtocolId::Ping => ping::Protocol::handle(stream).await?, - ProtocolId::CustomReportDetails => custom_report_details::Protocol::handle(stream).await?, - }; - Ok(()) +impl ProtocolId { + pub(crate) fn from_id_and_state(id: i32, state: ConnectionState) -> Option { + Some(match state { + ConnectionState::NotConnected => Self::NotConnected(FromPrimitive::from_i32(id)?), + ConnectionState::Status => Self::Status(FromPrimitive::from_i32(id)?), + ConnectionState::Login => Self::Login(FromPrimitive::from_i32(id)?), + ConnectionState::Transfer => Self::Transfer(FromPrimitive::from_i32(id)?), + ConnectionState::Configuration => Self::Configuration(FromPrimitive::from_i32(id)?), + ConnectionState::Play => Self::Play(FromPrimitive::from_i32(id)?), + ConnectionState::Closed => return None, + }) + } } pub(crate) mod custom_report_details; pub(crate) mod handshake; diff --git a/src/types/package.rs b/src/types/package.rs index e0f3957..6d05c66 100644 --- a/src/types/package.rs +++ b/src/types/package.rs @@ -1,8 +1,12 @@ -use crate::protocols::{self, ProtocolId, ProtocolResponseId}; +use crate::protocols::{ + self, LoginProtocolIds, NotConnectedProtocolIds, ProtocolId, ProtocolResponseId, + StatusProtocolIds, +}; use crate::types::string::McString; use crate::types::var_int::VarInt; use crate::types::{McRead, McWrite}; use crate::utils::RWStreamWithLimit; +use crate::ConnectionState; use num_traits::ToPrimitive; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; @@ -26,15 +30,19 @@ impl OutgoingPackage { } #[derive(Debug, Clone)] pub enum IncomingPackageContent { - Handshake(crate::protocols::handshake::Data), - Status(crate::protocols::status::Data), - Ping(crate::protocols::ping::Data), - CustomReportDetails(crate::protocols::custom_report_details::Data), + Handshake(protocols::handshake::Data), + Status(protocols::status::Data), + Ping(protocols::ping::Data), + CustomReportDetails(protocols::custom_report_details::Data), + LoginStart(), + EncryptionResponse(), + LoginPluginResponse(), + CookieResponse(), } #[derive(Debug, Clone)] pub enum OutgoingPackageContent { - StatusResponse(crate::protocols::status::ResponseData), - PingResponse(crate::protocols::ping::ResponseData), + StatusResponse(protocols::status::ResponseData), + PingResponse(protocols::ping::ResponseData), } impl McWrite for OutgoingPackageContent { type Error = String; @@ -104,22 +112,45 @@ impl IncomingPackage { async fn answer( &self, stream: &mut RWStreamWithLimit<'_, T>, - ) -> Result<(), bool> { - let answer = match (&self.protocol, &self.content) { - (ProtocolId::Status, _) => Some(OutgoingPackage { - protocol: ProtocolResponseId::Status, - content: OutgoingPackageContent::StatusResponse( - protocols::status::ResponseData::default(), + ) -> Result, bool> { + let (answer, changed_connection_state) = match self.protocol { + ProtocolId::NotConnected(protocol_id) => match &self.content { + (IncomingPackageContent::Handshake(handshake_data)) => { + (None, Some(handshake_data.next_state)) + } + _ => (None, None), //Ignore all packets that do not belong here + }, + ProtocolId::Status(protocol_id) => match &self.content { + IncomingPackageContent::Status(_) => ( + Some(OutgoingPackage { + protocol: ProtocolResponseId::Status, + content: OutgoingPackageContent::StatusResponse( + protocols::status::ResponseData::default(), + ), + }), + None, ), - }), - (ProtocolId::Ping, IncomingPackageContent::Ping(ping_data)) => Some(OutgoingPackage { - protocol: ProtocolResponseId::Ping, - content: OutgoingPackageContent::PingResponse(protocols::ping::ResponseData { - timespan: ping_data.timespan, - }), - }), - (ProtocolId::Ping, _) => unreachable!(), - (ProtocolId::CustomReportDetails, _) => None, + (IncomingPackageContent::Ping(ping_data)) => ( + Some(OutgoingPackage { + protocol: ProtocolResponseId::Ping, + content: OutgoingPackageContent::PingResponse( + protocols::ping::ResponseData { + timespan: ping_data.timespan, + }, + ), + }), + None, + ), + _ => (None, None), //Ignore all packets that do not belong here + }, + + // ProtocolId::Login(protocol_id) => match (protocol_id, &self.content) { + // (LoginProtocolIds::LoginStart, _) => (None, None), + // (LoginProtocolIds::EncryptionResponse, _) => (None, None), + // (LoginProtocolIds::LoginPluginResponse, _) => (None, None), + // (LoginProtocolIds::CookieResponse, _) => (None, None), + // }, + _ => (None, None), //TODO: implement the other ProtocolId variants (based on current ConnectionState) }; if let Some(outgoing_package) = answer { outgoing_package.write_stream(stream).await.map_err(|e| { @@ -127,14 +158,14 @@ impl IncomingPackage { false })?; } - Ok(()) + Ok(changed_connection_state) } } impl Package { - pub async fn handle( + pub(crate) async fn handle( protocol_id: ProtocolId, stream: &mut RWStreamWithLimit<'_, T>, - ) -> Result<(), bool> { + ) -> Result, bool> { let incoming_content = read_data(protocol_id, stream).await.map_err(|e| { dbg!(e); true @@ -143,8 +174,7 @@ impl Package { protocol: protocol_id, content: incoming_content, }; - incoming.answer(stream).await?; - Ok(()) + incoming.answer(stream).await } } @@ -153,22 +183,39 @@ pub async fn read_data( stream: &mut RWStreamWithLimit<'_, T>, ) -> Result { Ok(match protocol_id { - ProtocolId::Status => { - IncomingPackageContent::Status(protocols::status::Data::read_stream(stream).await?) - } - ProtocolId::Ping => { - IncomingPackageContent::Ping(protocols::ping::Data::read_stream(stream).await?) - } - ProtocolId::CustomReportDetails => { - // return Err("Not implemented".to_string()); - let x = IncomingPackageContent::CustomReportDetails( - protocols::custom_report_details::Data::read_stream(stream).await?, - ); - stream.discard_unread().await.map_err(|e| { - dbg!(e); - "Could not discard unused stuff" - })?; - x - } + ProtocolId::NotConnected(protocol_id) => match protocol_id { + NotConnectedProtocolIds::Handshake => IncomingPackageContent::Handshake( + protocols::handshake::Data::read_stream(stream).await?, + ), + }, + ProtocolId::Status(protocol_id) => match protocol_id { + StatusProtocolIds::Status => { + IncomingPackageContent::Status(protocols::status::Data::read_stream(stream).await?) + } + StatusProtocolIds::Ping => { + IncomingPackageContent::Ping(protocols::ping::Data::read_stream(stream).await?) + } + }, + ProtocolId::Login(protocol_id) => match protocol_id { + LoginProtocolIds::LoginStart => { + stream.discard_unread().await.map_err(|e| e.to_string())?; + IncomingPackageContent::LoginStart() + } + LoginProtocolIds::EncryptionResponse => { + stream.discard_unread().await.map_err(|e| e.to_string())?; + IncomingPackageContent::EncryptionResponse() + } + LoginProtocolIds::LoginPluginResponse => { + stream.discard_unread().await.map_err(|e| e.to_string())?; + IncomingPackageContent::LoginPluginResponse() + } + LoginProtocolIds::CookieResponse => { + stream.discard_unread().await.map_err(|e| e.to_string())?; + IncomingPackageContent::CookieResponse() + } + }, + ProtocolId::Transfer(protocol_id) => match protocol_id {}, + ProtocolId::Configuration(protocol_id) => match protocol_id {}, + ProtocolId::Play(protocol_id) => match protocol_id {}, }) }