2525"""
2626
2727import os
28+ import shutil
2829import socket
2930import ssl
31+ import tempfile
3032import time
3133from datetime import datetime , timedelta , timezone
3234from pathlib import Path
3335from select import select
3436from itertools import count
35- from typing import Any , Final
37+ from typing import Any , Final , cast
3638
3739import cffi # noqa # required for cryptography
3840from cryptography import x509
3941from cryptography .hazmat .primitives import hashes , serialization
4042from cryptography .hazmat .primitives .asymmetric import rsa
4143from cryptography .x509 .oid import NameOID
44+ from NVDAState import WritePaths , shouldWriteToDisk
4245from logHandler import log
4346
4447from . import configuration
4548from .protocol import RemoteMessageType
46- from .secureDesktop import getProgramDataTempPath
4749from .serializer import JSONSerializer
4850
4951
@@ -56,50 +58,47 @@ class RemoteCertificateManager:
5658 :ivar fingerprintPath: Path to the fingerprint file
5759 """
5860
59- CERT_FILE = "NvdaRemoteRelay.pem"
60- KEY_FILE = "NvdaRemoteRelay.key"
61- FINGERPRINT_FILE = "NvdaRemoteRelay.fingerprint"
62- CERT_DURATION_DAYS = 365
63- CERT_RENEWAL_THRESHOLD_DAYS = 30
61+ CERT_DIR : Final [Path ] = Path (WritePaths .remoteAccessDir , "localRelay" )
62+ CERT_PATH : Final [Path ] = CERT_DIR / "NvdaRemoteRelay.pem"
63+ KEY_PATH : Final [Path ] = CERT_DIR / "NvdaRemoteRelay.key"
64+ FINGERPRINT_PATH : Final [Path ] = CERT_DIR / "NvdaRemoteRelay.fingerprint"
65+ CERT_DURATION_DAYS : Final [int ] = 365
66+ CERT_RENEWAL_THRESHOLD_DAYS : Final [int ] = 30
6467
65- def __init__ (self , certDir : Path | None = None ):
66- """Initialize the certificate manager.
67-
68- :param certDir: Directory to store certificate files, defaults to program data temp path
69- """
70- self .certDir : Path = certDir or getProgramDataTempPath ()
71- self .certPath : Path = self .certDir / self .CERT_FILE
72- self .keyPath : Path = self .certDir / self .KEY_FILE
73- self .fingerprintPath : Path = self .certDir / self .FINGERPRINT_FILE
68+ def __init__ (self ):
69+ """Initialize the certificate manager."""
70+ self .__cert : bytes | None = None
71+ self .__key : bytes | None = None
72+ self .__fingerprint : str | None = None
7473
7574 def ensureValidCertExists (self ) -> None :
7675 """Ensures a valid certificate and key exist, regenerating if needed."""
7776 log .info ("Checking certificate validity" )
78- os .makedirs (self .certDir , exist_ok = True )
79-
8077 if self ._filesExist ():
8178 try :
8279 self ._validateCertificate ()
8380 return
8481 except Exception as e :
85- log .warning (f"Certificate validation failed: { e } " , exc_info = True )
82+ log .debug (f"Certificate validation failed: { e } " , exc_info = True )
83+ else :
84+ log .debug ("No certificate exists." )
8685
8786 self ._generateSelfSignedCert ()
8887
8988 def _filesExist (self ) -> bool :
90- """Check if both certificate and key files exist."""
91- return self .certPath .is_file () and self .keyPath .is_file ()
89+ """Check if certificate, key and fingerprint files exist."""
90+ return self .CERT_PATH .is_file () and self .KEY_PATH . is_file () and self . FINGERPRINT_PATH .is_file ()
9291
9392 def _validateCertificate (self ) -> None :
94- """Validates the existing certificate and key .
93+ """Validates the existing certificate, key and fingerprint .
9594
9695 :raises ValueError: If the current date/time is outside the certificate's validity period, or if the certificate is approaching expiration.
9796 :raises OSError: If the certificate or private key files cannot be opened.
9897 :raises ValueError: If the private key data cannot be decoded.
9998 :raises TypeError: If the private key is encrypted.
10099 """
101100 # Load and validate certificate
102- with open (self .certPath , "rb" ) as f :
101+ with open (self .CERT_PATH , "rb" ) as f :
103102 certData = f .read ()
104103 cert = x509 .load_pem_x509_certificate (certData )
105104
@@ -114,8 +113,28 @@ def _validateCertificate(self) -> None:
114113 raise ValueError ("Certificate is approaching expiration" )
115114
116115 # Verify private key can be loaded
117- with open (self .keyPath , "rb" ) as f :
118- serialization .load_pem_private_key (f .read (), password = None )
116+ with open (self .KEY_PATH , "rb" ) as f :
117+ keyData = f .read ()
118+ privKey = cast (rsa .RSAPrivateKey , serialization .load_pem_private_key (keyData , password = None ))
119+ pubKey = cast (rsa .RSAPublicKey , cert .public_key ())
120+
121+ # Verify that the private key and certificate match
122+ privNumbers = privKey .private_numbers ()
123+ if pubKey .public_numbers ().n != privNumbers .p * privNumbers .q :
124+ raise ValueError ("Invalid key: n != pq" )
125+ if privKey .public_key () != pubKey :
126+ raise ValueError ("The certificate and private keys do not match." )
127+
128+ with open (self .FINGERPRINT_PATH , "r" ) as f :
129+ fingerprintData = f .read ().strip ()
130+
131+ # Check that fingerprints match
132+ if cert .fingerprint (hashes .SHA256 ()).hex () != fingerprintData :
133+ raise ValueError ("Fingerprints do not match." )
134+
135+ self .__cert = certData
136+ self .__key = keyData
137+ self .__fingerprint = fingerprintData
119138
120139 def _generateSelfSignedCert (self ) -> None :
121140 """Generates a self-signed certificate and private key."""
@@ -169,23 +188,21 @@ def _generateSelfSignedCert(self) -> None:
169188
170189 # Calculate fingerprint
171190 fingerprint = cert .fingerprint (hashes .SHA256 ()).hex ()
172- # Write private key
173- with open (self .keyPath , "wb" ) as f :
174- f .write (
175- privateKey .private_bytes (
176- encoding = serialization .Encoding .PEM ,
177- format = serialization .PrivateFormat .PKCS8 ,
178- encryption_algorithm = serialization .NoEncryption (),
179- ),
180- )
181-
182- # Write certificate
183- with open (self .certPath , "wb" ) as f :
184- f .write (cert .public_bytes (serialization .Encoding .PEM ))
191+ # Calculate private key data
192+ keyData = privateKey .private_bytes (
193+ encoding = serialization .Encoding .PEM ,
194+ format = serialization .PrivateFormat .PKCS8 ,
195+ encryption_algorithm = serialization .NoEncryption (),
196+ )
197+ # Calculate certificate data
198+ certData = cert .public_bytes (serialization .Encoding .PEM )
185199
186- # Save fingerprint
187- with open (self .fingerprintPath , "w" ) as f :
188- f .write (fingerprint )
200+ # Store data on self
201+ self .__key = keyData
202+ self .__cert = certData
203+ self .__fingerprint = fingerprint
204+ # Attempt to persist
205+ self ._persistCertificate ()
189206
190207 # Add to trusted certificates in config
191208 config = configuration .getRemoteConfig ()
@@ -196,26 +213,60 @@ def _generateSelfSignedCert(self) -> None:
196213
197214 log .info (f"Generated new self-signed certificate for NVDA Remote. Fingerprint: { fingerprint } " )
198215
216+ def _persistCertificate (self ) -> None :
217+ if self .__key is None or self .__cert is None or self .__fingerprint is None :
218+ raise RuntimeError ("A certificate must be loaded in order to persist it." )
219+ if not shouldWriteToDisk ():
220+ log .debug ("Not persisting certificate, as shouldWriteToDisk returned False." )
221+ return
222+ try :
223+ os .makedirs (self .CERT_DIR , exist_ok = True )
224+ except Exception :
225+ log .debug ("Unable to create {self.CIRT_DIR}. Not persisting certificates." , exc_info = True )
226+ return
227+ for path in self .KEY_PATH , self .CERT_PATH , self .FINGERPRINT_PATH :
228+ if path .is_dir ():
229+ try :
230+ shutil .rmtree (path )
231+ except Exception :
232+ log .debug ("Unable to remove {path}. Not persisting certificates." , exc_info = True )
233+ return
234+ try :
235+ with (
236+ open (self .KEY_PATH , "wb" ) as keyFile ,
237+ open (self .CERT_PATH , "wb" ) as certFile ,
238+ open (self .FINGERPRINT_PATH , "wt" ) as fingerprintFile ,
239+ ):
240+ keyFile .write (self .__key )
241+ certFile .write (self .__cert )
242+ fingerprintFile .write (self .__fingerprint )
243+ except Exception :
244+ log .debug ("Unable to persist certificate." , exc_info = True )
245+
199246 def getCurrentFingerprint (self ) -> str | None :
200247 """Get the fingerprint of the current certificate."""
201- try :
202- if self .fingerprintPath .is_file ():
203- with open (self .fingerprintPath , "r" ) as f :
204- return f .read ().strip ()
205- except Exception as e :
206- log .warning (f"Error reading fingerprint: { e } " , exc_info = True )
207- return None
248+ return self .__fingerprint
208249
209250 def createSSLContext (self ) -> ssl .SSLContext :
210251 """Creates an SSL context using the certificate and key."""
252+ if self .__key is None or self .__cert is None :
253+ raise RuntimeError ("A certificate must be loaded to create an SSL context." )
211254 context = ssl .SSLContext (ssl .PROTOCOL_TLS_SERVER )
212255 # Load our certificate and private key
213- context .load_cert_chain (
214- certfile = str (self .certPath ),
215- keyfile = str (self .keyPath ),
216- )
217- # Trust our own CA for server verification
218- context .load_verify_locations (cafile = str (self .certPath ))
256+ with tempfile .NamedTemporaryFile ("w+b" , delete = False ) as f :
257+ f .write (self .__key )
258+ if not self .__key .endswith (b"\n " ):
259+ f .write (b"\n " )
260+ f .write (self .__cert )
261+ # OpenSSL will choke if the file is open, so close it manually
262+ # We don't exit the context manager, as that would (potentially) delete the file
263+ f .close ()
264+ context .load_cert_chain (f .name )
265+ # Trust our own CA for server verification
266+ context .load_verify_locations (cafile = f .name )
267+ # Explicitly delete the file, just to be sure
268+ # Exiting the context manager should do this, but it may be left up to the OS to decide when to delete it
269+ os .unlink (f .name )
219270 # Require client cert verification
220271 context .verify_mode = ssl .CERT_NONE # Don't require client certificates
221272 context .check_hostname = False # Don't verify hostname since we're using self-signed certs
@@ -249,19 +300,17 @@ def __init__(
249300 password : str ,
250301 bindHost : str = "" ,
251302 bindHost6 : str = "[::]:" ,
252- certDir : Path | None = None ,
253303 ):
254304 """Initialize the relay server.
255305
256306 :param port: Port number to listen on
257307 :param password: Channel password for client authentication
258308 :param bindHost: IPv4 address to bind to, defaults to all interfaces
259309 :param bindHost6: IPv6 address to bind to, defaults to all interfaces
260- :param certDir: Directory to store certificate files, defaults to None
261310 """
262311 self .port = port
263312 self .password = password
264- self .certManager = RemoteCertificateManager (certDir )
313+ self .certManager = RemoteCertificateManager ()
265314 self .certManager .ensureValidCertExists ()
266315
267316 # Initialize other server components
0 commit comments