Skip to content

Commit d9752ce

Browse files
authored
Merge pull request #32 from yuchen0cc/main
avoid client init race in threading
2 parents 06c02aa + 85f8b76 commit d9752ce

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

oss-torch-connector/osstorchconnector/_oss_client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from typing import Iterator, Iterable, Any
33
import logging
4+
import threading
45

56
log = logging.getLogger(__name__)
67

@@ -30,15 +31,18 @@ def __init__(self, endpoint: str, cred_path: str = "", config_path: str = "", uu
3031
self._total = total
3132
self._cred_provider = cred_provider
3233

34+
_lock = threading.Lock()
3335
@property
3436
def _client(self) -> DataSet:
35-
if self._client_pid is None or self._client_pid != os.getpid():
36-
# does OSS client survive forking ? NO
37-
if self._client_pid != os.getpid() and self._real_client is not None:
38-
log.info("OssClient delete dataset")
39-
# del self._real_client
40-
self._client_pid = os.getpid()
41-
self._real_client = self._client_builder()
37+
with OssClient._lock:
38+
if self._client_pid is None or self._client_pid != os.getpid() :
39+
# does OSS client survive forking ? NO
40+
if self._client_pid != os.getpid() and self._real_client is not None:
41+
log.info("OssClient delete dataset")
42+
# del self._real_client
43+
self._client_pid = os.getpid()
44+
self._real_client = self._client_builder()
45+
4246
return self._real_client
4347

4448
def _client_builder(self) -> DataSet:

oss-torch-connector/osstorchconnector/oss_filesystem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
173173
return False
174174
return True
175175

176-
def writer(self, path : str) -> OssStorageWriter:
177-
return OssStorageWriter(self, path)
176+
def writer(self, path : str, **kwargs) -> OssStorageWriter:
177+
return OssStorageWriter(self, path, **kwargs)
178178

179179
def reader(self, path : str) -> OssStorageReader:
180180
return OssStorageReader(self, path)

0 commit comments

Comments
 (0)