Skip to content

Commit 30a0278

Browse files
authored
Add custom drop for SocketHolder (#746)
1 parent 6e37e1e commit 30a0278

File tree

2 files changed

+48
-21
lines changed

2 files changed

+48
-21
lines changed

src/net.rs

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,51 +164,67 @@ impl UnixListenerSpec {
164164
#[cfg(not(any(windows, target_os = "linux", target_os = "freebsd")))]
165165
#[pyclass(frozen, module = "granian._granian")]
166166
pub struct SocketHolder {
167-
socket: Socket,
167+
socket: Option<Socket>,
168168
uds: bool,
169169
}
170170

171171
#[cfg(not(any(windows, target_os = "linux", target_os = "freebsd")))]
172172
impl SocketHolder {
173173
fn from_spec(spec: &ListenerSpec) -> Result<Self> {
174174
let socket = spec.as_socket()?;
175-
Ok(Self { socket, uds: false })
175+
Ok(Self {
176+
socket: Some(socket),
177+
uds: false,
178+
})
176179
}
177180

178181
fn from_unix_spec(spec: &UnixListenerSpec) -> Result<Self> {
179182
let socket = spec.as_socket()?;
180-
Ok(Self { socket, uds: true })
183+
Ok(Self {
184+
socket: Some(socket),
185+
uds: true,
186+
})
181187
}
182188

183189
#[allow(clippy::unnecessary_wraps)]
184190
pub fn as_tcp_listener(&self) -> Result<TcpListener> {
185-
let listener = unsafe { TcpListener::from_raw_fd(self.socket.as_raw_fd()) };
191+
let listener = unsafe { TcpListener::from_raw_fd(self.socket.as_ref().unwrap().as_raw_fd()) };
186192
Ok(listener)
187193
}
188194

189195
#[allow(clippy::unnecessary_wraps)]
190196
pub fn as_unix_listener(&self) -> Result<UnixListener> {
191-
let listener = unsafe { UnixListener::from_raw_fd(self.socket.as_raw_fd()) };
197+
let listener = unsafe { UnixListener::from_raw_fd(self.socket.as_ref().unwrap().as_raw_fd()) };
192198
Ok(listener)
193199
}
194200
}
195201

202+
#[cfg(not(any(windows, target_os = "linux", target_os = "freebsd")))]
203+
impl Drop for SocketHolder {
204+
fn drop(&mut self) {
205+
std::mem::forget(self.socket.take());
206+
}
207+
}
208+
196209
#[cfg(not(any(windows, target_os = "linux", target_os = "freebsd")))]
197210
#[pymethods]
198211
impl SocketHolder {
199212
#[new]
200213
pub fn new(fd: i32, uds: bool) -> Self {
201214
let socket = unsafe { Socket::from_raw_fd(fd) };
202-
Self { socket, uds }
215+
Self {
216+
socket: Some(socket),
217+
uds,
218+
}
203219
}
204220

205221
pub fn __getstate__(&self, py: Python) -> Py<PyAny> {
206-
let fd = self.socket.as_raw_fd();
222+
let fd = self.socket.as_ref().unwrap().as_raw_fd();
207223
(fd, self.uds).into_py_any(py).unwrap()
208224
}
209225

210226
pub fn get_fd(&self, py: Python) -> Py<PyAny> {
211-
self.socket.as_raw_fd().into_py_any(py).unwrap()
227+
self.socket.as_ref().unwrap().as_raw_fd().into_py_any(py).unwrap()
212228
}
213229

214230
pub fn is_uds(&self) -> bool {
@@ -219,7 +235,7 @@ impl SocketHolder {
219235
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
220236
#[pyclass(frozen, module = "granian._granian")]
221237
pub struct SocketHolder {
222-
socket: Socket,
238+
socket: Option<Socket>,
223239
uds: bool,
224240
backlog: i32,
225241
}
@@ -229,7 +245,7 @@ impl SocketHolder {
229245
fn from_spec(spec: &ListenerSpec) -> Result<Self> {
230246
let socket = spec.as_socket()?;
231247
Ok(Self {
232-
socket,
248+
socket: Some(socket),
233249
uds: false,
234250
backlog: spec.backlog,
235251
})
@@ -238,41 +254,54 @@ impl SocketHolder {
238254
fn from_unix_spec(spec: &UnixListenerSpec) -> Result<Self> {
239255
let socket = spec.as_socket()?;
240256
Ok(Self {
241-
socket,
257+
socket: Some(socket),
242258
uds: true,
243259
backlog: spec.backlog,
244260
})
245261
}
246262

247263
pub fn as_tcp_listener(&self) -> Result<TcpListener> {
248-
self.socket.listen(self.backlog)?;
249-
let listener = unsafe { TcpListener::from_raw_fd(self.socket.as_raw_fd()) };
264+
let socket = self.socket.as_ref().unwrap();
265+
socket.listen(self.backlog)?;
266+
let listener = unsafe { TcpListener::from_raw_fd(socket.as_raw_fd()) };
250267
Ok(listener)
251268
}
252269

253270
pub fn as_unix_listener(&self) -> Result<UnixListener> {
254-
self.socket.listen(self.backlog)?;
255-
let listener = unsafe { UnixListener::from_raw_fd(self.socket.as_raw_fd()) };
271+
let socket = self.socket.as_ref().unwrap();
272+
socket.listen(self.backlog)?;
273+
let listener = unsafe { UnixListener::from_raw_fd(socket.as_raw_fd()) };
256274
Ok(listener)
257275
}
258276
}
259277

278+
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
279+
impl Drop for SocketHolder {
280+
fn drop(&mut self) {
281+
std::mem::forget(self.socket.take());
282+
}
283+
}
284+
260285
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
261286
#[pymethods]
262287
impl SocketHolder {
263288
#[new]
264289
pub fn new(fd: i32, uds: bool, backlog: i32) -> Self {
265290
let socket = unsafe { Socket::from_raw_fd(fd) };
266-
Self { socket, uds, backlog }
291+
Self {
292+
socket: Some(socket),
293+
uds,
294+
backlog,
295+
}
267296
}
268297

269298
pub fn __getstate__(&self, py: Python) -> Py<PyAny> {
270-
let fd = self.socket.as_raw_fd();
299+
let fd = self.socket.as_ref().unwrap().as_raw_fd();
271300
(fd, self.uds, self.backlog).into_py_any(py).unwrap()
272301
}
273302

274303
pub fn get_fd(&self, py: Python) -> Py<PyAny> {
275-
self.socket.as_raw_fd().into_py_any(py).unwrap()
304+
self.socket.as_ref().unwrap().as_raw_fd().into_py_any(py).unwrap()
276305
}
277306

278307
pub fn is_uds(&self) -> bool {

tests/test_embed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32

43
import httpx
54
import pytest
@@ -14,7 +13,7 @@ async def app(scope, protocol):
1413

1514
@pytest.fixture(scope='function')
1615
def loop():
17-
return asyncio.get_event_loop()
16+
return asyncio.new_event_loop()
1817

1918

2019
@pytest.fixture(scope='function')
@@ -23,7 +22,6 @@ def embed_server(server_port):
2322

2423

2524
@pytest.mark.skipif(not BUILD_GIL, reason='free-threaded Python')
26-
@pytest.mark.skipif(bool(os.environ.get('GITHUB_WORKFLOW')), reason='CI')
2725
def test_embed_server(loop, server_port, embed_server):
2826
data = {}
2927

0 commit comments

Comments
 (0)