Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PgBindIter for encoding and use it as the implementation encoding &[T] #3651

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
154 changes: 154 additions & 0 deletions sqlx-postgres/src/bind_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
use core::cell::Cell;
use sqlx_core::{
database::Database,
encode::{Encode, IsNull},
error::BoxDynError,
types::Type,
};

// not exported but pub because it is used in the extension trait
pub struct PgBindIter<I>(Cell<Option<I>>);

/// Iterator extension trait enabling iterators to encode arrays in Postgres.
///
/// Because of the blanket impl of `PgHasArrayType` for all references
/// we can borrow instead of needing to clone or copy in the iterators
/// and it still works
///
/// Previously, 3 separate arrays would be needed in this example which
/// requires iterating 3 times to collect items into the array and then
/// iterating over them again to encode.
///
/// This now requires only iterating over the array once for each field
/// while using less memory giving both speed and memory usage improvements
/// along with allowing much more flexibility in the underlying collection.
///
/// ```rust,no_run
/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
/// # use sqlx::types::chrono::{DateTime, Utc};
/// # use sqlx::Connection;
/// # fn people() -> &'static [Person] {
/// # &[]
/// # }
/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
/// use sqlx::postgres::PgBindIterExt;
///
/// #[derive(sqlx::FromRow)]
/// struct Person {
/// id: i64,
/// name: String,
/// birthdate: DateTime<Utc>,
/// }
///
/// # let people: &[Person] = people();
/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
/// .bind(people.iter().map(|p| p.id).bind_iter())
/// .bind(people.iter().map(|p| &p.name).bind_iter())
/// .bind(people.iter().map(|p| &p.birthdate).bind_iter())
/// .execute(&mut conn)
/// .await?;
///
/// # Ok(())
/// # }
/// ```
pub trait PgBindIterExt: Iterator + Sized {
fn bind_iter(self) -> PgBindIter<Self>;
}

impl<I: Iterator + Sized> PgBindIterExt for I {
fn bind_iter(self) -> PgBindIter<I> {
PgBindIter(Cell::new(Some(self)))
}
}

impl<I> Type<Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
<I as Iterator>::Item::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<I as Iterator>::Item::array_compatible(ty)
}
}

impl<'q, I> PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
// need ownership to iterate
mut iter: I,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, BoxDynError> {
let lower_size_hint = iter.size_hint().0;
let first = iter.next();
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(<I as Iterator>::Item::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let len_start = buf.len();
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
buf.extend(1_i32.to_be_bytes()); // lower bound

match first {
Some(first) => buf.encode(first)?,
None => return Ok(IsNull::No),
}

let mut count = 1_i32;
const MAX: usize = i32::MAX as usize - 1;

for value in (&mut iter).take(MAX) {
buf.encode(value)?;
count += 1;
}

const OVERFLOW: usize = i32::MAX as usize + 1;
if iter.next().is_some() {
let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
}

// set the length now that we know what it is.
buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());

Ok(IsNull::No)
}
}

impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
}
fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
where
Self: Sized,
{
Self::encode_inner(
self.0.into_inner().expect("PgBindIter is only used once"),
buf,
)
}
}
2 changes: 2 additions & 0 deletions sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::executor::Executor;

mod advisory_lock;
mod arguments;
mod bind_iter;
mod column;
mod connection;
mod copy;
Expand Down Expand Up @@ -44,6 +45,7 @@ pub(crate) use sqlx_core::driver_prelude::*;

pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use bind_iter::PgBindIterExt;
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{PgCopyIn, PgPoolCopyExt};
Expand Down
32 changes: 3 additions & 29 deletions sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::borrow::Cow;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::type_info::PgType;
use crate::types::Oid;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
Expand Down Expand Up @@ -156,39 +155,14 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = self
.first()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

// element type
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let array_len = i32::try_from(self.len()).map_err(|_| {
// do the length check early to avoid doing unnecessary work
i32::try_from(self.len()).map_err(|_| {
format!(
"encoded array length is too large for Postgres: {}",
self.len()
)
})?;

buf.extend(array_len.to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound

for element in self.iter() {
buf.encode(element)?;
}

Ok(IsNull::No)
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
}
}

Expand Down
58 changes: 58 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2042,3 +2042,61 @@ async fn test_issue_3052() {
"expected encode error, got {too_large_error:?}",
);
}

#[sqlx_macros::test]
async fn test_bind_iter() -> anyhow::Result<()> {
use sqlx::postgres::PgBindIterExt;
use sqlx::types::chrono::{DateTime, Utc};

let mut conn = new::<Postgres>().await?;

#[derive(sqlx::FromRow, PartialEq, Debug)]
struct Person {
id: i64,
name: String,
birthdate: DateTime<Utc>,
}

let people: Vec<Person> = vec![
Person {
id: 1,
name: "Alice".into(),
birthdate: "1984-01-01T00:00:00Z".parse().unwrap(),
},
Person {
id: 2,
name: "Bob".into(),
birthdate: "2000-01-01T00:00:00Z".parse().unwrap(),
},
];

sqlx::query(
r#"
create temporary table person(
id int8 primary key,
name text not null,
birthdate timestamptz not null
)"#,
)
.execute(&mut conn)
.await?;

let rows_affected =
sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
// owned value
.bind(people.iter().map(|p| p.id).bind_iter())
// borrowed value
.bind(people.iter().map(|p| &p.name).bind_iter())
.bind(people.iter().map(|p| &p.birthdate).bind_iter())
.execute(&mut conn)
.await?
.rows_affected();
assert_eq!(rows_affected, 2);

let p_query = sqlx::query_as::<_, Person>("select * from person order by id")
.fetch_all(&mut conn)
.await?;

assert_eq!(people, p_query);
Ok(())
}
Loading