Skip to content

Commit

Permalink
allow using unmatched messages from connection
Browse files Browse the repository at this point in the history
  • Loading branch information
icewind1991 committed Oct 20, 2024
1 parent 8312f7a commit 1443ae9
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
47 changes: 47 additions & 0 deletions examples/auth_ticket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures_util::StreamExt;
use std::env::args;
use std::error::Error;
use steam_vent::auth::{
AuthConfirmationHandler, ConsoleAuthConfirmationHandler, DeviceConfirmationHandler,
FileGuardDataStore,
};
use steam_vent::{Connection, ConnectionTrait, ServerList};
use steam_vent_proto::steammessages_clientserver::CMsgClientGameConnectTokens;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let mut args = args().skip(1);
let account = args.next().expect("no account");
let password = args.next().expect("no password");

let server_list = ServerList::discover().await?;

let connection = Connection::login(
&server_list,
&account,
&password,
FileGuardDataStore::user_cache(),
ConsoleAuthConfirmationHandler::default().or(DeviceConfirmationHandler),
)
.await?;

let tokens_messages = connection.on::<CMsgClientGameConnectTokens>();

// also process the messages that were received before we registered our filter
let old_token_messages = connection
.take_unprocessed()
.into_iter()
.filter_map(|raw| raw.into_message::<CMsgClientGameConnectTokens>().ok())
.map(Ok);
let mut tokens_messages = futures_util::stream::iter(old_token_messages).chain(tokens_messages);

while let Some(Ok(tokens_message)) = tokens_messages.next().await {
println!("got {} token from message", tokens_message.tokens.len());
for token in tokens_message.tokens.into_iter() {
println!("\t{}", BASE64_STANDARD.encode(token));
}
}
Ok(())
}
49 changes: 46 additions & 3 deletions src/connection/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,59 @@ use crate::message::ServiceMethodNotification;
use crate::net::{JobId, RawNetMessage};
use dashmap::DashMap;
use futures_util::Stream;
use std::sync::Arc;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use steam_vent_proto::enums_clientserver::EMsg;
use steam_vent_proto::MsgKind;
use tokio::spawn;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_stream::StreamExt;
use tracing::{debug, error};

#[derive(Clone)]
pub struct RingBuffer<T>(Arc<Mutex<VecDeque<T>>>);

impl<T> RingBuffer<T> {
pub fn new(capacity: usize) -> Self {
Self(Arc::new(Mutex::new(VecDeque::with_capacity(capacity))))
}

pub fn push(&self, item: T) -> Option<T> {
let mut deque = self.0.lock().unwrap();
if deque.len() == deque.capacity() {
let popped = deque.pop_front();
deque.push_back(item);
debug_assert!(deque.len() == deque.capacity());
popped
} else {
deque.push_back(item);
None
}
}

#[expect(dead_code)]
pub fn pop(&self) -> Option<T> {
self.0.lock().unwrap().pop_front()
}
}

impl<T: Clone> RingBuffer<T> {
pub fn take(&self) -> Vec<T> {
let mut dequeu = self.0.lock().unwrap();
let items = dequeu.make_contiguous().to_vec();
dequeu.clear();
items
}
}

#[derive(Clone)]
pub struct MessageFilter {
job_id_filters: Arc<DashMap<JobId, oneshot::Sender<RawNetMessage>>>,
job_id_multi_filters: Arc<DashMap<JobId, mpsc::Sender<RawNetMessage>>>,
notification_filters: Arc<DashMap<&'static str, broadcast::Sender<ServiceMethodNotification>>>,
kind_filters: Arc<DashMap<MsgKind, broadcast::Sender<RawNetMessage>>>,
oneshot_kind_filters: Arc<DashMap<MsgKind, oneshot::Sender<RawNetMessage>>>,
rest: RingBuffer<RawNetMessage>,
}

impl MessageFilter {
Expand All @@ -31,6 +69,7 @@ impl MessageFilter {
kind_filters: Default::default(),
notification_filters: Default::default(),
oneshot_kind_filters: Default::default(),
rest: RingBuffer::new(32),
};

let filter_send = filter.clone();
Expand Down Expand Up @@ -71,8 +110,8 @@ impl MessageFilter {
}
} else if let Some(tx) = filter_send.kind_filters.get(&message.kind) {
tx.send(message).ok();
} else {
debug!(kind = ?message.kind, "Unhandled message");
} else if let Some(popped) = filter_send.rest.push(message) {
debug!(kind = ?popped.kind, "Unhandled message");
}
}
Err(err) => {
Expand Down Expand Up @@ -124,4 +163,8 @@ impl MessageFilter {
self.oneshot_kind_filters.insert(kind.into(), tx);
rx
}

pub fn unprocessed(&self) -> Vec<RawNetMessage> {
self.rest.take()
}
}
8 changes: 8 additions & 0 deletions src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ impl Connection {
.into_message::<ServiceMethodResponseMessage>()?;
message.into_response::<Msg>()
}

/// Get all messages that haven't been filtered by any of the filters
///
/// Note that at most 32 unprocessed connections are stored and calling
/// this method clears the buffer
pub fn take_unprocessed(&self) -> Vec<RawNetMessage> {
self.filter.unprocessed()
}
}

pub(crate) trait ConnectionImpl: Sync + Debug {
Expand Down

0 comments on commit 1443ae9

Please sign in to comment.