transition to async stream & try to get some other protocol stuff to work

This commit is contained in:
OMGeeky
2024-11-13 21:36:01 +01:00
parent 621249d50d
commit 002e8e8e9b
10 changed files with 645 additions and 150 deletions

View File

@@ -2,47 +2,34 @@ pub mod protocols;
pub mod types;
pub mod utils;
use crate::protocols::handle;
use crate::types::string::McString;
use crate::types::var_int::VarInt;
use crate::types::{McRead, McRustRepr};
use crate::utils::RWStreamWithLimit;
use num_derive::FromPrimitive;
use num_traits::{FromPrimitive, ToPrimitive};
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream;
fn main() {
#[tokio::main]
async fn main() -> Result<(), ()> {
println!("Hello, world!");
let listener = TcpListener::bind("127.0.0.1:25565").unwrap();
// let listener = TcpListener::bind("127.0.0.1:25565").unwrap();
let listener = tokio::net::TcpListener::bind("127.0.0.1:25565")
.await
.unwrap();
println!("Listening started.");
for stream in listener.incoming() {
match stream {
Ok(stream) => {
thread::spawn(|| {
println!("===============START=====================");
loop {
let (stream, socket) = listener.accept().await.map_err(|x| {
dbg!(x);
})?;
stream
.set_read_timeout(Some(Duration::from_secs(3)))
.unwrap();
stream
.set_write_timeout(Some(Duration::from_secs(3)))
.unwrap();
println!(
"Timeout for connection: {:?}/{:?}",
stream.read_timeout(),
stream.write_timeout()
);
handle_connection(stream);
println!("===============DONE======================");
});
}
Err(err) => {
dbg!(err);
}
}
tokio::spawn(async move {
println!("===============START=====================");
dbg!(&socket);
handle_connection(stream).await;
println!("===============DONE======================");
});
}
}
#[derive(FromPrimitive, Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
@@ -59,12 +46,11 @@ struct Connection {
compression_active: bool,
}
impl Connection {
fn handle(&mut self) -> Result<(), String> {
async fn handle(&mut self) -> Result<(), String> {
while self.connection_state != ConnectionState::Closed {
let x = self.tcp_stream.peek(&mut [0]); //see if we have at least one byte available
let x = self.tcp_stream.peek(&mut [0]).await; //see if we have at least one byte available
match x {
Ok(size) => {
println!("we should have 1 here: {size}");
if size == 0 {
println!("Reached end of stream.");
self.connection_state = ConnectionState::Closed;
@@ -76,7 +62,7 @@ impl Connection {
}
}
let length = VarInt::read_stream(&mut self.tcp_stream)?;
let length = VarInt::read_stream(&mut self.tcp_stream).await?;
println!("packet length: {}", length.as_rs());
let bytes_left_in_package = length.to_rs();
@@ -88,7 +74,8 @@ impl Connection {
&mut package_stream,
self.connection_state,
self.compression_active,
);
)
.await;
match result {
Ok(new_connection_state) => {
assert_eq!(
@@ -99,10 +86,7 @@ impl Connection {
self.connection_state = new_connection_state;
}
Err(e) => {
//discard rest of package for failed ones
discard_read(&mut self.tcp_stream, bytes_left_in_package.to_u8().unwrap())
.map_err(|x| x.to_string())?;
self.connection_state = ConnectionState::Closed;
println!("Got an error during package handling: {e}");
}
}
@@ -110,26 +94,21 @@ impl Connection {
Ok(())
}
fn handshake<T: Read + Write>(
async fn handshake<T: AsyncRead + AsyncWrite + Unpin>(
stream: &mut T,
_compression: bool,
// bytes_left_in_package: &mut i32,
) -> Result<ConnectionState, String> {
// println!("bytes left:{}", bytes_left_in_package);
let protocol_version = VarInt::read_stream(stream)?;
// *bytes_left_in_package -= read as i32;
println!("Handshake");
let protocol_version = VarInt::read_stream(stream).await?;
println!("protocol version: {}", protocol_version.as_rs());
// println!("bytes left:{}", bytes_left_in_package);
let address =
McString::read_stream(stream).map_err(|_| "Could not read string".to_string())?;
// *bytes_left_in_package -= read as i32;
let address: McString<255> = McString::read_stream(stream)
.await
.map_err(|_| "Could not read string".to_string())?;
println!("address: '{}'", address.as_rs());
stream.read_exact(&mut [0, 2]).unwrap(); //server port. Unused
// *bytes_left_in_package -= 2;
let next_state_id = VarInt::read_stream(stream)?;
// *bytes_left_in_package -= read as i32;
stream.read_exact(&mut [0, 2]).await.unwrap(); //server port. Unused
let next_state_id = VarInt::read_stream(stream).await?;
println!("next state: {}", next_state_id.as_rs());
// println!("bytes left:{}", bytes_left_in_package);
let next_state = FromPrimitive::from_i32(next_state_id.to_rs());
match next_state {
Some(next_state) => Ok(next_state),
@@ -139,42 +118,36 @@ impl Connection {
)),
}
}
fn handle_package<T: Read + Write>(
stream: &mut RWStreamWithLimit<T>,
async fn handle_package<T: AsyncRead + AsyncWrite + Unpin>(
stream: &mut RWStreamWithLimit<'_, T>,
connection_state: ConnectionState,
compression: bool,
// bytes_left_in_package: usize,
) -> Result<ConnectionState, String> {
// let mut stream = RWStreamWithLimit::new(stream, bytes_left_in_package);
// let stream = &mut stream;
let packet_id = VarInt::read_stream(stream)?;
// *bytes_left_in_package = i32::max(*bytes_left_in_package - read as i32, 0);
let packet_id = VarInt::read_stream(stream).await?;
println!("id: {:0>2x}", packet_id.as_rs());
if connection_state == ConnectionState::NotConnected && packet_id.to_rs() == 0x00 {
return Self::handshake(stream, compression);
return Self::handshake(stream, compression).await;
}
match FromPrimitive::from_i32(packet_id.to_rs()) {
Some(protocol) => {
// println!("bytes left:{}", bytes_left_in_package);
let res = handle(protocol, stream);
// println!("bytes left:{}", bytes_left_in_package);
let res = protocols::handle(protocol, stream).await;
match res {
Ok(_) => {
// println!("bytes left:{}", bytes_left_in_package);
println!("Success!");
}
Err(_) => {
stream.discard_unread().map_err(|x| x.to_string())?;
// println!("bytes left:{}", bytes_left_in_package);
// *bytes_left_in_package -= discard_read(stream, *bytes_left_in_package as u8)
// as i32;
Err(terminate_connection) => {
if terminate_connection {
return Err("Something terrible has happened!".to_string());
} else {
stream.discard_unread().await.map_err(|x| x.to_string())?;
}
println!("Failure :(");
}
}
}
None => {
stream.discard_unread().map_err(|x| x.to_string())?;
stream.discard_unread().await.map_err(|x| x.to_string())?;
// *bytes_left_in_package -= discard_read(stream, *bytes_left_in_package as u8)
// .map_err(|x| x.to_string())? as i32;
println!("I don't know this protocol yet, so Im gonna ignore it...");
@@ -183,18 +156,14 @@ impl Connection {
Ok(connection_state)
}
}
fn handle_connection(stream: TcpStream) {
async fn handle_connection(stream: TcpStream) {
let mut connection = Connection {
connection_state: ConnectionState::NotConnected,
tcp_stream: stream,
compression_active: false,
};
let result = connection.handle();
let result = connection.handle().await;
if let Err(e) = result {
dbg!(e);
}
}
fn discard_read<T: Read>(stream: &mut T, bytes: u8) -> Result<usize, std::io::Error> {
stream.read_exact(&mut [0, bytes])?;
Ok(bytes as usize)
}

View File

@@ -1,23 +1,24 @@
use crate::protocols::status::StatusProtocol;
use crate::utils::RWStreamWithLimit;
use num_derive::FromPrimitive;
use std::io::{Read, Write};
// use num_traits::FromPrimitive;
use tokio::io::{AsyncRead, AsyncWrite};
#[derive(FromPrimitive)]
pub enum Protocols {
Status = 0x00,
Ping = 0x01,
CustomReportDetails = 0x7a,
}
pub fn handle<T: Read + Write>(
pub async fn handle<T: AsyncRead + AsyncWrite + Unpin>(
protocol: Protocols,
stream: &mut RWStreamWithLimit<T>,
stream: &mut RWStreamWithLimit<'_, T>,
// bytes_left_in_package: &mut i32,
) -> Result<(), ()> {
) -> Result<(), bool> {
match protocol {
Protocols::Status => StatusProtocol::handle(stream)?,
Protocols::Status => status::Protocol::handle(stream).await?,
Protocols::Ping => {}
Protocols::CustomReportDetails => custom_report_details::Protocol::handle(stream).await?,
};
Ok(())
}
mod custom_report_details;
mod status;

View File

@@ -0,0 +1,42 @@
use crate::types::var_int::VarInt;
use crate::types::McRead;
use crate::utils::RWStreamWithLimit;
use tokio::io::{AsyncRead, AsyncWrite};
pub struct Protocol {}
impl Protocol {
pub async fn handle<T: AsyncRead + AsyncWrite + Unpin>(
stream: &mut RWStreamWithLimit<'_, T>,
// bytes_left_in_package: &mut i32,
) -> Result<(), bool> {
let count = VarInt::read_stream(stream).await.map_err(|x| {
dbg!(x);
true
})?;
dbg!(&count);
let string_size = VarInt::read_stream(stream).await.map_err(|x| {
dbg!(x);
true
})?;
dbg!(&string_size);
stream.discard_unread().await.map_err(|x| {
dbg!(x);
true
})?;
// for i in 0..*count {
// let title = McString::read_stream(stream).await.map_err(|x| {
// dbg!(x);
// })?;
// let description = McString::read_stream(stream).await.map_err(|x| {
// dbg!(x);
// })?;
// println!(
// "Read title & description fo some custom report ({i}): {}\n{}",
// title.as_rs(),
// description.as_rs()
// );
// }
Ok(())
}
}

View File

@@ -1,23 +1,33 @@
use crate::types::string::McString;
use crate::types::var_int::VarInt;
use crate::types::McWrite;
use crate::utils::RWStreamWithLimit;
use std::io::{Read, Write};
use tokio::io::{AsyncRead, AsyncWrite};
pub struct StatusProtocol {}
pub struct Protocol {}
impl StatusProtocol {
pub fn handle<T: Read + Write>(
stream: &mut RWStreamWithLimit<T>,
impl Protocol {
pub async fn handle<T: AsyncRead + AsyncWrite + Unpin>(
stream: &mut RWStreamWithLimit<'_, T>,
// bytes_left_in_package: &mut i32,
) -> Result<(), ()> {
McString(Self::get_sample_result())
) -> Result<(), bool> {
println!("Status");
VarInt(0x01).write_stream(stream).await.map_err(|x| {
dbg!(x);
false
})?;
McString::<32767>::from_string(Self::get_sample_result())
.write_stream(stream)
.await
.map_err(|x| {
dbg!(x);
false
})?;
stream.discard_unread().map_err(|x| {
dbg!(x);
})?;
// stream.discard_unread().await.map_err(|x| {
// dbg!(x);
// false
// })?;
// *bytes_left_in_package = 0;
Ok(())
}

View File

@@ -1,14 +1,17 @@
use std::io::{Read, Write};
use tokio::io::{AsyncRead, AsyncWrite};
pub trait McRead {
pub(crate) trait McRead {
type Error;
fn read_stream<T: Read>(stream: &mut T) -> Result<Self, Self::Error>
async fn read_stream<T: AsyncRead + Unpin>(stream: &mut T) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait McWrite {
pub(crate) trait McWrite {
type Error;
fn write_stream<T: Write>(&self, stream: &mut T) -> Result<usize, Self::Error>
async fn write_stream<T: AsyncWrite + Unpin>(
&self,
stream: &mut T,
) -> Result<usize, Self::Error>
where
Self: Sized;
}

View File

@@ -1,58 +1,77 @@
use crate::types::var_int::VarInt;
use crate::types::{McRead, McRustRepr, McWrite};
use std::io::{Read, Write};
pub struct McString(pub String);
impl McRead for McString {
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub struct McString<const MAX_SIZE: usize> {
pub value: String,
}
impl<const MAX_SIZE: usize> McString<MAX_SIZE> {
fn measure_size(s: &str) -> usize {
s.len()
}
pub fn from_string(s: String) -> Self {
Self { value: s }
}
}
impl<const MAX_SIZE: usize> McRead for McString<MAX_SIZE> {
type Error = ();
fn read_stream<T: Read>(b: &mut T) -> Result<Self, Self::Error>
async fn read_stream<T: AsyncRead + Unpin>(b: &mut T) -> Result<Self, Self::Error>
where
Self: Sized,
{
let size = VarInt::read_stream(b).map_err(|x| {
let max_size = VarInt::read_stream(b).await.map_err(|x| {
dbg!(x);
})?;
let size = *size as usize;
let size = *max_size as usize;
// Check if the size exceeds the maximum allowed length (n)
if size > (MAX_SIZE * 3) + 3 {
return Err(()); // Or a more specific error type
}
let mut bytes = vec![0u8; size];
let actual_size = b.read(&mut bytes).map_err(|x| {
let actual_size = b.read(&mut bytes).await.map_err(|x| {
dbg!(x);
})?;
assert_eq!(size, actual_size);
let value = String::from_utf8(bytes).map_err(|x| {
dbg!(x);
})?;
Ok(Self(value))
Ok(Self { value })
}
}
impl McWrite for McString {
impl<const MAX_SIZE: usize> McWrite for McString<MAX_SIZE> {
type Error = std::io::Error;
fn write_stream<T: Write>(&self, stream: &mut T) -> Result<usize, Self::Error>
async fn write_stream<T: AsyncWrite + Unpin>(
&self,
stream: &mut T,
) -> Result<usize, Self::Error>
where
Self: Sized,
{
let buf = self.0.as_bytes();
let length = buf.len(); //This does not actually count right (see https://wiki.vg/Protocol#Type:String)
VarInt(length as i32).write_stream(stream)?;
let buf = self.value.as_bytes();
let length = Self::measure_size(&self.value);
VarInt(length as i32).write_stream(stream).await?;
stream.write_all(buf)?;
stream.write_all(buf).await?;
Ok(length)
}
}
impl McRustRepr for McString {
impl<const MAX_SIZE: usize> McRustRepr for McString<MAX_SIZE> {
type RustRepresentation = String;
fn into_rs(self) -> Self::RustRepresentation {
self.0
self.value
}
fn to_rs(&self) -> Self::RustRepresentation {
self.0.to_owned()
self.value.to_owned()
}
fn as_rs(&self) -> &Self::RustRepresentation {
&self.0
&self.value
}
}

View File

@@ -1,6 +1,7 @@
use crate::types::{McRead, McRustRepr, McWrite};
use std::io::{Read, Write};
use std::ops::Deref;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[derive(Debug, Copy, Clone)]
pub struct VarInt(pub i32);
@@ -14,27 +15,30 @@ impl Deref for VarInt {
impl McWrite for VarInt {
type Error = std::io::Error;
fn write_stream<T: Write>(&self, stream: &mut T) -> Result<usize, Self::Error>
async fn write_stream<T: AsyncWrite + Unpin>(
&self,
stream: &mut T,
) -> Result<usize, Self::Error>
where
Self: Sized,
{
let mut value = self.0 as u32;
loop {
if (value & Self::SEGMENT_BITS as u32) == 0 {
let _ = stream.write(&[value.to_le_bytes()[0]])?;
let _ = stream.write(&[value.to_le_bytes()[0]]).await?;
return Ok(1);
}
let x = value & Self::SEGMENT_BITS as u32 | Self::CONTINUE_BIT as u32;
let x = x.to_le_bytes()[0];
let _ = stream.write(&[x])?;
let _ = stream.write(&[x]).await?;
value >>= 7;
}
}
}
impl McRead for VarInt {
type Error = String;
fn read_stream<T: Read>(b: &mut T) -> Result<Self, Self::Error> {
async fn read_stream<T: AsyncRead + Unpin>(b: &mut T) -> Result<Self, Self::Error> {
let mut value = 0i32;
let mut position = 0;
// println!("CONTINUE bit: {:0>32b}", Self::CONTINUE_BIT);
@@ -43,6 +47,7 @@ impl McRead for VarInt {
loop {
let mut current_byte = 0u8;
b.read_exact(std::slice::from_mut(&mut current_byte))
.await
.map_err(|x| x.to_string())?;
// println!(
// "b: {:0>32b}\nm: {:0>32b}\nr: {:0>32b}\n>: {:0>32b} ({position})\nv: {:0>32b}\nr2:{:0>32b}",

View File

@@ -1,22 +1,44 @@
use std::io::{ErrorKind, Read, Write};
use std::io::ErrorKind;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
pub struct RWStreamWithLimit<'a, T: Read + Write> {
pub struct RWStreamWithLimit<'a, T: AsyncRead + AsyncWrite> {
stream: &'a mut T,
read_bytes_left: usize,
}
impl<'a, T: Read + Write> RWStreamWithLimit<'a, T> {
impl<'a, T: AsyncRead + AsyncWrite + Unpin> RWStreamWithLimit<'a, T> {
pub(crate) fn new(stream: &'a mut T, read_limit: usize) -> Self {
Self {
stream,
read_bytes_left: read_limit,
}
}
pub(crate) fn discard_unread(&mut self) -> std::io::Result<usize> {
pub(crate) async fn discard_unread(&mut self) -> std::io::Result<usize> {
let mut total_read = 0;
while self.read_bytes_left > 0 {
let read = self.stream.read(&mut vec![0; self.read_bytes_left])?;
println!("Discarding {} bytes...", self.read_bytes_left);
let read = self.stream.read(&mut vec![0; self.read_bytes_left]).await?;
total_read += read;
self.read_bytes_left -= read;
println!(
"Discarded {read}/{} remaining bytes ({total_read}/{} total)",
self.read_bytes_left + read,
self.read_bytes_left + total_read
);
if read == 0 {
const ERROR: &str = "Could not read a single byte";
println!("{}", ERROR);
return Err(std::io::Error::new(ErrorKind::Other, ERROR));
}
if self.read_bytes_left > 0 {
//IDK if this makes sense to just throw an error if we don't read all in one go?
const ERROR: &str = "Couldnt read all bytes in one go";
println!("{}", ERROR);
return Err(std::io::Error::new(ErrorKind::Other, ERROR));
}
println!("Done Discarding {total_read} bytes");
}
Ok(total_read)
}
@@ -24,35 +46,59 @@ impl<'a, T: Read + Write> RWStreamWithLimit<'a, T> {
self.read_bytes_left
}
}
impl<'a, T: Read + Write> Read for RWStreamWithLimit<'a, T> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes_read;
if self.read_bytes_left > 0 {
if self.read_bytes_left >= buf.len() {
bytes_read = self.stream.read(buf)?;
} else {
println!("wants to read more than in the readable part of the stream");
bytes_read = self.stream.read(&mut buf[0..self.read_bytes_left])?;
//TODO: decide if we wanna throw an error here or nah
}
self.read_bytes_left -= bytes_read; //TODO: maybe check if we read to much?
} else {
return Err(std::io::Error::new(
impl<'a, T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RWStreamWithLimit<'a, T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.read_bytes_left == 0 {
return Poll::Ready(Err(std::io::Error::new(
ErrorKind::Other,
"There is nothing more to read in this package",
));
//TODO: maybe throw an error since there is no way anything gets read anymore?
)));
}
let self_mut = self.get_mut();
let stream = &mut self_mut.stream;
if self_mut.read_bytes_left < buf.remaining() {
println!("wants to read more than in the readable part of the stream. Only read readable part, to not screw up the next few parts");
}
let bytes_to_read = std::cmp::min(self_mut.read_bytes_left, buf.remaining());
let mut inner_buf = buf.take(bytes_to_read);
let read = Pin::new(stream).poll_read(cx, &mut inner_buf);
if let Poll::Ready(Ok(())) = read {
let bytes_read = inner_buf.filled().len();
self_mut.read_bytes_left -= bytes_read;
buf.advance(bytes_read); // Important: Advance the buffer
Poll::Ready(Ok(()))
} else {
read
}
Ok(bytes_read)
}
}
impl<'a, T: Read + Write> Write for RWStreamWithLimit<'a, T> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.read_bytes_left = 0;
self.stream.write(buf)
impl<'a, T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RWStreamWithLimit<'a, T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let self_mut = self.get_mut();
self_mut.read_bytes_left = 0;
let stream = &mut self_mut.stream;
Pin::new(stream).poll_write(cx, buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.stream.flush()
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let stream = &mut self.get_mut().stream;
Pin::new(stream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let stream = &mut self.get_mut().stream;
Pin::new(stream).poll_shutdown(cx)
}
}