Skip to content

Commit 6a38f16

Browse files
committed
init
Signed-off-by: chandr-andr (Kiselev Aleksandr) <chandr@chandr.net>
1 parent 12f3c63 commit 6a38f16

File tree

11 files changed

+439
-60
lines changed

11 files changed

+439
-60
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,5 @@ pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git
5959
pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = "psqlpy", features = [
6060
"postgres",
6161
] }
62+
futures-channel = "0.3.31"
63+
futures = "0.3.31"

python/psqlpy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Cursor,
77
IsolationLevel,
88
KeepaliveConfig,
9+
Listener,
910
LoadBalanceHosts,
1011
QueryResult,
1112
ReadVariant,
@@ -25,6 +26,7 @@
2526
"Cursor",
2627
"IsolationLevel",
2728
"KeepaliveConfig",
29+
"Listener",
2830
"LoadBalanceHosts",
2931
"QueryResult",
3032
"ReadVariant",

python/psqlpy/_internal/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,3 +1748,6 @@ class ConnectionPoolBuilder:
17481748
### Returns:
17491749
`ConnectionPoolBuilder`
17501750
"""
1751+
1752+
class Listener:
1753+
"""Result."""

src/driver/connection_pool.rs

Lines changed: 119 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use crate::runtime::tokio_runtime;
22
use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod};
3-
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
4-
use postgres_openssl::MakeTlsConnector;
5-
use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny};
6-
use std::{sync::Arc, vec};
7-
use tokio_postgres::NoTls;
3+
use futures::{stream, FutureExt, StreamExt, TryStreamExt};
4+
use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny, Python};
5+
use std::{sync::Arc, time::Duration, vec};
6+
use tokio::time::sleep;
7+
use tokio_postgres::{Config, NoTls};
88

99
use crate::{
1010
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
@@ -13,9 +13,10 @@ use crate::{
1313
};
1414

1515
use super::{
16-
common_options::{self, ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs},
16+
common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs},
1717
connection::Connection,
18-
utils::build_connection_config,
18+
listener::Listener,
19+
utils::{build_connection_config, build_manager, build_tls, ConfiguredTLS},
1920
};
2021

2122
/// Make new connection pool.
@@ -77,7 +78,6 @@ pub fn connect(
7778
load_balance_hosts: Option<LoadBalanceHosts>,
7879
ssl_mode: Option<SslMode>,
7980
ca_file: Option<String>,
80-
8181
max_db_pool_size: Option<usize>,
8282
conn_recycling_method: Option<ConnRecyclingMethod>,
8383
) -> RustPSQLDriverPyResult<ConnectionPool> {
@@ -126,33 +126,25 @@ pub fn connect(
126126
};
127127
}
128128

129-
let mgr: Manager;
130-
if let Some(ca_file) = ca_file {
131-
let mut builder = SslConnector::builder(SslMethod::tls())?;
132-
builder.set_ca_file(ca_file)?;
133-
let tls_connector = MakeTlsConnector::new(builder.build());
134-
mgr = Manager::from_config(pg_config, tls_connector, mgr_config);
135-
} else if let Some(ssl_mode) = ssl_mode {
136-
if ssl_mode == common_options::SslMode::Require {
137-
let mut builder = SslConnector::builder(SslMethod::tls())?;
138-
builder.set_verify(SslVerifyMode::NONE);
139-
let tls_connector = MakeTlsConnector::new(builder.build());
140-
mgr = Manager::from_config(pg_config, tls_connector, mgr_config);
141-
} else {
142-
mgr = Manager::from_config(pg_config, NoTls, mgr_config);
143-
}
144-
} else {
145-
mgr = Manager::from_config(pg_config, NoTls, mgr_config);
146-
}
129+
let mgr: Manager = build_manager(
130+
mgr_config,
131+
pg_config.clone(),
132+
build_tls(&ca_file, ssl_mode)?,
133+
);
147134

148135
let mut db_pool_builder = Pool::builder(mgr);
149136
if let Some(max_db_pool_size) = max_db_pool_size {
150137
db_pool_builder = db_pool_builder.max_size(max_db_pool_size);
151138
}
152139

153-
let db_pool = db_pool_builder.build()?;
140+
let pool = db_pool_builder.build()?;
154141

155-
Ok(ConnectionPool(db_pool))
142+
Ok(ConnectionPool {
143+
pool,
144+
pg_config,
145+
ca_file,
146+
ssl_mode,
147+
})
156148
}
157149

158150
#[pyclass]
@@ -212,8 +204,31 @@ impl ConnectionPoolStatus {
212204
}
213205
}
214206

207+
// #[pyclass(subclass)]
208+
// pub struct ConnectionPool(pub Pool);
215209
#[pyclass(subclass)]
216-
pub struct ConnectionPool(pub Pool);
210+
pub struct ConnectionPool {
211+
pool: Pool,
212+
pg_config: Config,
213+
ca_file: Option<String>,
214+
ssl_mode: Option<SslMode>,
215+
}
216+
217+
impl ConnectionPool {
218+
pub fn build(
219+
pool: Pool,
220+
pg_config: Config,
221+
ca_file: Option<String>,
222+
ssl_mode: Option<SslMode>,
223+
) -> Self {
224+
ConnectionPool {
225+
pool,
226+
pg_config,
227+
ca_file,
228+
ssl_mode,
229+
}
230+
}
231+
}
217232

218233
#[pymethods]
219234
impl ConnectionPool {
@@ -333,7 +348,7 @@ impl ConnectionPool {
333348

334349
#[must_use]
335350
pub fn status(&self) -> ConnectionPoolStatus {
336-
let inner_status = self.0.status();
351+
let inner_status = self.pool.status();
337352

338353
ConnectionPoolStatus::new(
339354
inner_status.max_size,
@@ -344,7 +359,7 @@ impl ConnectionPool {
344359
}
345360

346361
pub fn resize(&self, new_max_size: usize) {
347-
self.0.resize(new_max_size);
362+
self.pool.resize(new_max_size);
348363
}
349364

350365
/// Execute querystring with parameters.
@@ -361,7 +376,7 @@ impl ConnectionPool {
361376
parameters: Option<pyo3::Py<PyAny>>,
362377
prepared: Option<bool>,
363378
) -> RustPSQLDriverPyResult<PSQLDriverPyQueryResult> {
364-
let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone());
379+
let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone());
365380

366381
let db_pool_manager = tokio_runtime()
367382
.spawn(async move { Ok::<Object, RustPSQLDriverError>(db_pool.get().await?) })
@@ -430,7 +445,7 @@ impl ConnectionPool {
430445
parameters: Option<pyo3::Py<PyAny>>,
431446
prepared: Option<bool>,
432447
) -> RustPSQLDriverPyResult<PSQLDriverPyQueryResult> {
433-
let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone());
448+
let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone());
434449

435450
let db_pool_manager = tokio_runtime()
436451
.spawn(async move { Ok::<Object, RustPSQLDriverError>(db_pool.get().await?) })
@@ -484,15 +499,83 @@ impl ConnectionPool {
484499

485500
#[must_use]
486501
pub fn acquire(&self) -> Connection {
487-
Connection::new(None, Some(self.0.clone()))
502+
Connection::new(None, Some(self.pool.clone()))
503+
}
504+
505+
pub async fn add_listener(
506+
self_: pyo3::Py<Self>,
507+
callback: Py<PyAny>,
508+
) -> RustPSQLDriverPyResult<Listener> {
509+
let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| {
510+
let b_gil = self_.borrow(gil);
511+
(
512+
b_gil.pg_config.clone(),
513+
b_gil.ca_file.clone(),
514+
b_gil.ssl_mode,
515+
)
516+
});
517+
518+
// let tls_ = build_tls(&ca_file, Some(SslMode::Disable)).unwrap();
519+
520+
// match tls_ {
521+
// ConfiguredTLS::NoTls => {
522+
// let a = pg_config.connect(NoTls).await.unwrap();
523+
// },
524+
// ConfiguredTLS::TlsConnector(connector) => {
525+
// let a = pg_config.connect(connector).await.unwrap();
526+
// }
527+
// }
528+
529+
// let (client, mut connection) = tokio_runtime()
530+
// .spawn(async move { pg_config.connect(NoTls).await.unwrap() })
531+
// .await?;
532+
533+
// // Make transmitter and receiver.
534+
// let (tx, mut rx) = futures_channel::mpsc::unbounded();
535+
// let stream =
536+
// stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e));
537+
// let connection = stream.forward(tx).map(|r| r.unwrap());
538+
// tokio_runtime().spawn(connection);
539+
540+
// // Wait for notifications in separate thread.
541+
// tokio_runtime().spawn(async move {
542+
// client
543+
// .batch_execute(
544+
// "LISTEN test_notifications;
545+
// LISTEN test_notifications2;",
546+
// )
547+
// .await
548+
// .unwrap();
549+
550+
// loop {
551+
// let next_element = rx.next().await;
552+
// client.batch_execute("LISTEN test_notifications3;").await.unwrap();
553+
// match next_element {
554+
// Some(n) => {
555+
// match n {
556+
// tokio_postgres::AsyncMessage::Notification(n) => {
557+
// Python::with_gil(|gil| {
558+
// callback.call0(gil);
559+
// });
560+
// println!("Notification {:?}", n);
561+
// },
562+
// _ => {println!("in_in {:?}", n)}
563+
// }
564+
// },
565+
// _ => {println!("in {:?}", next_element)}
566+
// }
567+
// }
568+
// });
569+
570+
Ok(Listener::new(pg_config, ca_file, ssl_mode))
488571
}
489572

490573
/// Return new single connection.
491574
///
492575
/// # Errors
493576
/// May return Err Result if cannot get new connection from the pool.
494577
pub async fn connection(self_: pyo3::Py<Self>) -> RustPSQLDriverPyResult<Connection> {
495-
let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).0.clone());
578+
let db_pool = pyo3::Python::with_gil(|gil| self_.borrow(gil).pool.clone());
496579
let db_connection = tokio_runtime()
497580
.spawn(async move {
498581
Ok::<deadpool_postgres::Object, RustPSQLDriverError>(db_pool.get().await?)
@@ -507,7 +590,7 @@ impl ConnectionPool {
507590
/// # Errors
508591
/// May return Err Result if cannot get new connection from the pool.
509592
pub fn close(&self) {
510-
let db_pool = self.0.clone();
593+
let db_pool = self.pool.clone();
511594

512595
db_pool.close();
513596
}

src/driver/connection_pool_builder.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use std::{net::IpAddr, time::Duration};
22

33
use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
4-
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
5-
use postgres_openssl::MakeTlsConnector;
64
use pyo3::{pyclass, pymethods, Py, Python};
7-
use tokio_postgres::NoTls;
85

96
use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult};
107

11-
use super::{common_options, connection_pool::ConnectionPool};
8+
use super::{
9+
common_options,
10+
connection_pool::ConnectionPool,
11+
utils::{build_manager, build_tls},
12+
};
1213

1314
#[pyclass]
1415
pub struct ConnectionPoolBuilder {
@@ -49,24 +50,11 @@ impl ConnectionPoolBuilder {
4950
};
5051
};
5152

52-
let mgr: Manager;
53-
if let Some(ca_file) = &self.ca_file {
54-
let mut builder = SslConnector::builder(SslMethod::tls())?;
55-
builder.set_ca_file(ca_file)?;
56-
let tls_connector = MakeTlsConnector::new(builder.build());
57-
mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config);
58-
} else if let Some(ssl_mode) = self.ssl_mode {
59-
if ssl_mode == common_options::SslMode::Require {
60-
let mut builder = SslConnector::builder(SslMethod::tls())?;
61-
builder.set_verify(SslVerifyMode::NONE);
62-
let tls_connector = MakeTlsConnector::new(builder.build());
63-
mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config);
64-
} else {
65-
mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config);
66-
}
67-
} else {
68-
mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config);
69-
}
53+
let mgr: Manager = build_manager(
54+
mgr_config,
55+
self.config.clone(),
56+
build_tls(&self.ca_file, self.ssl_mode)?,
57+
);
7058

7159
let mut db_pool_builder = Pool::builder(mgr);
7260
if let Some(max_db_pool_size) = self.max_db_pool_size {
@@ -75,7 +63,12 @@ impl ConnectionPoolBuilder {
7563

7664
let db_pool = db_pool_builder.build()?;
7765

78-
Ok(ConnectionPool(db_pool))
66+
Ok(ConnectionPool::build(
67+
db_pool,
68+
self.config.clone(),
69+
self.ca_file.clone(),
70+
self.ssl_mode,
71+
))
7972
}
8073

8174
/// Set ca_file for ssl_mode in PostgreSQL.

0 commit comments

Comments
 (0)