Skip to content

Commit

Permalink
allow programs to bring their own transport to the connection
Browse files Browse the repository at this point in the history
  • Loading branch information
icewind1991 committed Nov 2, 2024
1 parent 1b5e73c commit 4f16b82
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 25 deletions.
57 changes: 57 additions & 0 deletions examples/custom_transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use futures_util::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use rustls::{ClientConfig, KeyLogFile, RootCertStore};
use std::error::Error;
use std::future::ready;
use std::sync::Arc;
use steam_vent::connection::UnAuthenticatedConnection;
use steam_vent::message::flatten_multi;
use steam_vent::{NetworkError, RawNetMessage, ServerList};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::{connect_async_tls_with_config, Connector};

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::fmt::init();

let server_list = ServerList::discover().await?;
let (sender, receiver) = connect(&server_list.pick_ws()).await?;
let connection = UnAuthenticatedConnection::from_sender_receiver(sender, receiver).await?;
let _connection = connection.anonymous().await?;

Ok(())
}

// this is just a copy of the standard websocket transport implementation, functioning as an example
// how to implement a websocket transport
pub async fn connect(
addr: &str,
) -> Result<
(
impl Sink<RawNetMessage, Error = NetworkError>,
impl Stream<Item = Result<RawNetMessage, NetworkError>>,
),
NetworkError,
> {
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.ok(); // can only be once called
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut tls_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
tls_config.key_log = Arc::new(KeyLogFile::new());
let tls_config = Connector::Rustls(Arc::new(tls_config));
let (stream, _) = connect_async_tls_with_config(addr, None, false, Some(tls_config)).await?;
let (raw_write, raw_read) = stream.split();

Ok((
raw_write.with(|msg: RawNetMessage| ready(Ok(WsMessage::binary(msg.into_bytes())))),
flatten_multi(
raw_read
.map_err(NetworkError::from)
.map_ok(|raw| raw.into_data())
.map(|res| res.and_then(RawNetMessage::read)),
),
))
}
5 changes: 5 additions & 0 deletions protobuf/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub trait RpcMessageWithKind: RpcMessage {
const KIND: Self::KindEnum;
}

/// A generic wrapper for "kind" constants used by network messages
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct MsgKind(pub i32);

Expand All @@ -59,6 +60,10 @@ impl From<MsgKind> for i32 {

pub const PROTO_MASK: u32 = 0x80000000;

/// An enum containing "kind" constants used by the network messages
///
/// Though it is possible to use the generic [`MsgKind`] struct. Applications shipping their own protobufs
/// are encouraged to create their own enums containing the constants in use for ease of use.
pub trait MsgKindEnum: Enum + Debug {
fn enum_value(&self) -> i32 {
<Self as Enum>::value(self)
Expand Down
8 changes: 4 additions & 4 deletions src/connection/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::net::{JobId, RawNetMessage};
use dashmap::DashMap;
use futures_util::Stream;
use std::collections::VecDeque;
use std::pin::pin;
use std::sync::{Arc, Mutex};
use steam_vent_proto::enums_clientserver::EMsg;
use steam_vent_proto::MsgKind;
Expand Down Expand Up @@ -59,10 +60,8 @@ pub struct MessageFilter {
}

impl MessageFilter {
pub fn new<
Input: Stream<Item = crate::connection::Result<RawNetMessage>> + Send + Unpin + 'static,
>(
mut source: Input,
pub fn new<Input: Stream<Item = crate::connection::Result<RawNetMessage>> + Send + 'static>(
source: Input,
) -> Self {
let filter = MessageFilter {
job_id_filters: Default::default(),
Expand All @@ -75,6 +74,7 @@ impl MessageFilter {

let filter_send = filter.clone();
spawn(async move {
let mut source = pin!(source);
while let Some(res) = source.next().await {
match res {
Ok(message) => {
Expand Down
19 changes: 15 additions & 4 deletions src/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::message::EncodableMessage;
use crate::net::{NetMessageHeader, RawNetMessage};
use crate::session::{hello, Session};
use crate::transport::websocket::connect;
use crate::{ConnectionError, ServerList};
use crate::{ConnectionError, NetworkError, ServerList};
use futures_util::{Sink, Stream};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -34,14 +35,24 @@ impl Debug for RawConnection {

impl RawConnection {
pub async fn connect(server_list: &ServerList) -> Result<Self, ConnectionError> {
let (read, write) = connect(&server_list.pick_ws()).await?;
let filter = MessageFilter::new(read);
let (sender, receiver) = connect(&server_list.pick_ws()).await?;
Self::from_sender_receiver(sender, receiver).await
}

pub async fn from_sender_receiver<
Sender: Sink<RawNetMessage, Error = NetworkError> + Send + 'static,
Receiver: Stream<Item = Result<RawNetMessage>> + Send + 'static,
>(
sender: Sender,
receiver: Receiver,
) -> Result<Self, ConnectionError> {
let filter = MessageFilter::new(receiver);
let heartbeat_cancellation_token = CancellationToken::new();
let mut connection = RawConnection {
session: Session::default(),
filter,
sender: MessageSender {
write: Arc::new(Mutex::new(write)),
write: Arc::new(Mutex::new(Box::pin(sender))),
},
timeout: Duration::from_secs(10),
heartbeat_cancellation_token: heartbeat_cancellation_token.clone(),
Expand Down
22 changes: 21 additions & 1 deletion src/connection/unauthenticated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::service_method::ServiceMethodRequest;
use crate::session::{anonymous, login};
use crate::{Connection, ConnectionError, NetMessage, NetworkError, ServerList};
use futures_util::future::{select, Either};
use futures_util::FutureExt;
use futures_util::Stream;
use futures_util::{FutureExt, Sink};
use std::future::Future;
use std::pin::pin;
use steam_vent_proto::enums_clientserver::EMsg;
Expand All @@ -22,12 +22,30 @@ use tracing::{debug, error};
pub struct UnAuthenticatedConnection(RawConnection);

impl UnAuthenticatedConnection {
/// Create a connection from a sender, receiver pair.
///
/// This allows customizing the transport used by the connection. For example to customize the
/// TLS configuration, use an existing websocket client or use a proxy.
pub async fn from_sender_receiver<
Sender: Sink<RawNetMessage, Error = NetworkError> + Send + 'static,
Receiver: Stream<Item = Result<RawNetMessage>> + Send + 'static,
>(
sender: Sender,
receiver: Receiver,
) -> Result<Self, ConnectionError> {
Ok(UnAuthenticatedConnection(
RawConnection::from_sender_receiver(sender, receiver).await?,
))
}

/// Connect to a server from the server list using the default websocket transport
pub async fn connect(server_list: &ServerList) -> Result<Self, ConnectionError> {
Ok(UnAuthenticatedConnection(
RawConnection::connect(server_list).await?,
))
}

/// Start an anonymous client session with this connection
pub async fn anonymous(self) -> Result<Connection, ConnectionError> {
let mut raw = self.0;
raw.session = anonymous(&raw, AccountType::AnonUser).await?;
Expand All @@ -37,6 +55,7 @@ impl UnAuthenticatedConnection {
Ok(connection)
}

/// Start an anonymous server session with this connection
pub async fn anonymous_server(self) -> Result<Connection, ConnectionError> {
let mut raw = self.0;
raw.session = anonymous(&raw, AccountType::AnonGameServer).await?;
Expand All @@ -46,6 +65,7 @@ impl UnAuthenticatedConnection {
Ok(connection)
}

/// Start a client session with this connection
pub async fn login<H: AuthConfirmationHandler, G: GuardDataStore>(
self,
account: &str,
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod auth;
pub mod connection;
mod eresult;
mod game_coordinator;
mod message;
pub mod message;
mod net;
mod serverlist;
mod service_method;
Expand All @@ -15,6 +15,6 @@ pub use connection::{Connection, ConnectionTrait, ReadonlyConnection};
pub use eresult::EResult;
pub use game_coordinator::GameCoordinator;
pub use message::NetMessage;
pub use net::NetworkError;
pub use net::{NetworkError, RawNetMessage};
pub use serverlist::{DiscoverOptions, ServerDiscoveryError, ServerList};
pub use session::{ConnectionError, LoginError};
21 changes: 15 additions & 6 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use thiserror::Error;
use tokio_stream::Stream;
use tracing::{debug, trace};

/// Malformed message body
#[derive(Error, Debug)]
#[error("Malformed message body for {0:?}: {1}")]
pub struct MalformedBody(MsgKind, MessageBodyError);
Expand All @@ -32,6 +33,7 @@ impl MalformedBody {
}
}

/// Error while parsing the message body
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum MessageBodyError {
Expand All @@ -55,6 +57,11 @@ impl From<String> for MessageBodyError {
}
}

/// A message which can be encoded and/or decoded
///
/// Applications can implement this trait on a struct to allow sending it using
/// [`raw_send_with_kind`](crate::ConnectionTrait::raw_send_with_kind). To use the higher level messages a struct also needs to implement
/// [`NetMessage`]
pub trait EncodableMessage: Sized + Debug + Send {
fn read_body(_data: BytesMut, _header: &NetMessageHeader) -> Result<Self, MalformedBody> {
panic!("Reading not implemented for {}", type_name::<Self>())
Expand All @@ -71,14 +78,15 @@ pub trait EncodableMessage: Sized + Debug + Send {
fn process_header(&self, _header: &mut NetMessageHeader) {}
}

/// A message with associated kind
pub trait NetMessage: EncodableMessage {
type KindEnum: MsgKindEnum;
const KIND: Self::KindEnum;
const IS_PROTOBUF: bool = false;
}

#[derive(Debug, BinRead)]
pub struct ChannelEncryptRequest {
pub(crate) struct ChannelEncryptRequest {
pub protocol: u32,
#[allow(dead_code)]
pub universe: u32,
Expand All @@ -99,7 +107,7 @@ impl NetMessage for ChannelEncryptRequest {
}

#[derive(Debug, BinRead)]
pub struct ChannelEncryptResult {
pub(crate) struct ChannelEncryptResult {
pub result: u32,
}

Expand All @@ -117,7 +125,7 @@ impl NetMessage for ChannelEncryptResult {
}

#[derive(Debug)]
pub struct ClientEncryptResponse {
pub(crate) struct ClientEncryptResponse {
pub protocol: u32,
pub encrypted_key: Vec<u8>,
}
Expand Down Expand Up @@ -164,6 +172,7 @@ impl Read for MaybeZipReader {
}
}

/// Flatten any "multi" messages in a stream of raw messages
pub fn flatten_multi<S: Stream<Item = Result<RawNetMessage, NetworkError>>>(
source: S,
) -> impl Stream<Item = Result<RawNetMessage, NetworkError>> {
Expand Down Expand Up @@ -226,7 +235,7 @@ impl<R: Read> Iterator for MultiBodyIter<R> {
}

#[derive(Debug)]
pub struct ServiceMethodMessage<Request: Debug>(pub Request);
pub(crate) struct ServiceMethodMessage<Request: Debug>(pub Request);

impl<Request: ServiceMethodRequest + Debug> EncodableMessage for ServiceMethodMessage<Request> {
fn read_body(data: BytesMut, _header: &NetMessageHeader) -> Result<Self, MalformedBody> {
Expand Down Expand Up @@ -259,7 +268,7 @@ impl<Request: ServiceMethodRequest + Debug> NetMessage for ServiceMethodMessage<
}

#[derive(Debug)]
pub struct ServiceMethodResponseMessage {
pub(crate) struct ServiceMethodResponseMessage {
job_name: String,
body: BytesMut,
}
Expand Down Expand Up @@ -301,7 +310,7 @@ impl NetMessage for ServiceMethodResponseMessage {
}

#[derive(Debug, Clone)]
pub struct ServiceMethodNotification {
pub(crate) struct ServiceMethodNotification {
pub(crate) job_name: String,
body: BytesMut,
}
Expand Down
7 changes: 7 additions & 0 deletions src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ impl RawNetMessage {
header_buffer,
})
}

/// Return a buffer containing the raw message bytes
pub fn into_bytes(self) -> BytesMut {
let mut body = self.header_buffer;
body.unsplit(self.data);
body
}
}

impl RawNetMessage {
Expand Down
4 changes: 2 additions & 2 deletions src/transport/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ pub async fn encode_message<T: NetMessage, S: Sink<BytesMut, Error = NetworkErro
pub async fn connect<A: ToSocketAddrs + Debug>(
addr: A,
) -> Result<(
impl Stream<Item = Result<RawNetMessage>>,
impl Sink<RawNetMessage, Error = NetworkError>,
impl Stream<Item = Result<RawNetMessage>>,
)> {
let stream = TcpStream::connect(addr).await?;
debug!("connected to server");
Expand Down Expand Up @@ -190,6 +190,7 @@ pub async fn connect<A: ToSocketAddrs + Debug>(
let key = key.plain;

Ok((
FramedWrite::new(raw_writer.into_inner(), RawMessageEncoder { key }),
flatten_multi(
raw_reader
.and_then(move |encrypted| {
Expand All @@ -201,6 +202,5 @@ pub async fn connect<A: ToSocketAddrs + Debug>(
})
.and_then(|raw| ready(RawNetMessage::read(raw))),
),
FramedWrite::new(raw_writer.into_inner(), RawMessageEncoder { key }),
))
}
8 changes: 2 additions & 6 deletions src/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type Result<T, E = NetworkError> = std::result::Result<T, E>;
pub async fn connect(
addr: &str,
) -> Result<(
impl Stream<Item = Result<RawNetMessage>>,
impl Sink<RawNetMessage, Error = NetworkError>,
impl Stream<Item = Result<RawNetMessage>>,
)> {
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
Expand All @@ -33,16 +33,12 @@ pub async fn connect(
let (raw_write, raw_read) = stream.split();

Ok((
raw_write.with(|msg: RawNetMessage| ready(Ok(WsMessage::binary(msg.into_bytes())))),
flatten_multi(
raw_read
.map_err(NetworkError::from)
.map_ok(|raw| raw.into_data())
.map(|res| res.and_then(RawNetMessage::read)),
),
raw_write.with(|msg: RawNetMessage| {
let mut body = msg.header_buffer;
body.unsplit(msg.data);
ready(Ok(WsMessage::binary(body)))
}),
))
}

0 comments on commit 4f16b82

Please sign in to comment.