1313# limitations under the License.
1414
1515"""Integration tests for firebase_admin.ml module."""
16- import re
1716import os
18- import shutil
1917import random
18+ import re
19+ import shutil
20+ import string
2021import tempfile
2122import pytest
2223
2324
24- from firebase_admin import ml
2525from firebase_admin import exceptions
26+ from firebase_admin import ml
2627from tests import testutils
2728
2829
3435 _TF_ENABLED = False
3536
3637
38+ def _random_identifier (prefix ):
39+ #pylint: disable=unused-variable
40+ suffix = '' .join ([random .choice (string .ascii_letters + string .digits ) for n in range (8 )])
41+ return '{0}_{1}' .format (prefix , suffix )
42+
43+
3744NAME_ONLY_ARGS = {
38- 'display_name' : 'TestModel123_{0}' . format ( random . randint ( 1111 , 9999 ) )
45+ 'display_name' : _random_identifier ( 'TestModel123_' )
3946}
4047NAME_ONLY_ARGS_UPDATED = {
41- 'display_name' : 'TestModel123_updated_{0}' . format ( random . randint ( 1111 , 9999 ) )
48+ 'display_name' : _random_identifier ( 'TestModel123_updated_' )
4249}
4350NAME_AND_TAGS_ARGS = {
44- 'display_name' : 'TestModel123_tags_{0}' . format ( random . randint ( 1111 , 9999 ) ),
51+ 'display_name' : _random_identifier ( 'TestModel123_tags_' ),
4552 'tags' : ['test_tag123' ]
4653}
4754FULL_MODEL_ARGS = {
48- 'display_name' : 'TestModel123_full_{0}' . format ( random . randint ( 1111 , 9999 ) ),
55+ 'display_name' : _random_identifier ( 'TestModel123_full_' ),
4956 'tags' : ['test_tag567' ],
5057 'file_name' : 'model1.tflite'
5158}
5259INVALID_FULL_MODEL_ARGS = {
53- 'display_name' : 'TestModel123_invalid_full_{0}' . format ( random . randint ( 1111 , 9999 ) ),
60+ 'display_name' : _random_identifier ( 'TestModel123_invalid_full_' ),
5461 'tags' : ['test_tag890' ],
5562 'file_name' : 'invalid_model.tflite'
5663}
5764
65+
5866@pytest .fixture
5967def firebase_model (request ):
6068 args = request .param
@@ -76,10 +84,11 @@ def firebase_model(request):
7684
7785@pytest .fixture
7886def model_list ():
79- ml_model_1 = ml .Model (display_name = "TestModel123" )
87+ ml_model_1 = ml .Model (display_name = _random_identifier ( 'TestModel123_list1_' ) )
8088 model_1 = ml .create_model (model = ml_model_1 )
8189
82- ml_model_2 = ml .Model (display_name = "TestModel123_tags" , tags = ['test_tag123' ])
90+ ml_model_2 = ml .Model (display_name = _random_identifier ('TestModel123_list2_' ),
91+ tags = ['test_tag123' ])
8392 model_2 = ml .create_model (model = ml_model_2 )
8493
8594 yield [model_1 , model_2 ]
@@ -124,7 +133,7 @@ def check_model(model, args):
124133 assert model .etag is not None
125134
126135
127- def check_model_format (model , has_model_format , validation_error ):
136+ def check_model_format (model , has_model_format = False , validation_error = None ):
128137 if has_model_format :
129138 assert model .validation_error == validation_error
130139 assert model .published is False
@@ -145,13 +154,13 @@ def check_model_format(model, has_model_format, validation_error):
145154@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
146155def test_create_simple_model (firebase_model ):
147156 check_model (firebase_model , NAME_AND_TAGS_ARGS )
148- check_model_format (firebase_model , False , None )
157+ check_model_format (firebase_model )
149158
150159
151160@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
152161def test_create_full_model (firebase_model ):
153162 check_model (firebase_model , FULL_MODEL_ARGS )
154- check_model_format (firebase_model , True , None )
163+ check_model_format (firebase_model , True )
155164
156165
157166@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
@@ -173,7 +182,7 @@ def test_create_invalid_model(firebase_model):
173182def test_get_model (firebase_model ):
174183 get_model = ml .get_model (firebase_model .model_id )
175184 check_model (get_model , NAME_AND_TAGS_ARGS )
176- check_model_format (get_model , False , None )
185+ check_model_format (get_model )
177186
178187
179188@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -192,12 +201,12 @@ def test_update_model(firebase_model):
192201 firebase_model .display_name = new_model_name
193202 updated_model = ml .update_model (firebase_model )
194203 check_model (updated_model , NAME_ONLY_ARGS_UPDATED )
195- check_model_format (updated_model , False , None )
204+ check_model_format (updated_model )
196205
197206 # Second call with same model does not cause error
198207 updated_model2 = ml .update_model (updated_model )
199208 check_model (updated_model2 , NAME_ONLY_ARGS_UPDATED )
200- check_model_format (updated_model2 , False , None )
209+ check_model_format (updated_model2 )
201210
202211
203212@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -304,10 +313,13 @@ def keras_model():
304313@pytest .fixture
305314def saved_model_dir (keras_model ):
306315 assert _TF_ENABLED
307- # different versions have different model conversion capability
308- # pick something that works for each version
316+ # Make a new parent directory. The child directory must not exist yet.
317+ # The child directory gets created by tf. If it exists, the tf call fails.
309318 parent = tempfile .mkdtemp ()
310319 save_dir = os .path .join (parent , 'child' )
320+
321+ # different versions have different model conversion capability
322+ # pick something that works for each version
311323 if tf .version .VERSION .startswith ('1.' ):
312324 tf .reset_default_graph ()
313325 x_var = tf .placeholder (tf .float32 , (None , 3 ), name = "x" )
@@ -331,12 +343,12 @@ def test_from_keras_model(keras_model):
331343
332344 # Validate the conversion by creating a model
333345 model_format = ml .TFLiteFormat (model_source = source )
334- model = ml .Model (display_name = "KerasModel1" , model_format = model_format )
346+ model = ml .Model (display_name = _random_identifier ( 'KerasModel_' ) , model_format = model_format )
335347 created_model = ml .create_model (model )
336348
337349 try :
338- assert created_model . model_id is not None
339- assert created_model . validation_error is None
350+ check_model ( created_model , { 'display_name' : model . display_name })
351+ check_model_format ( created_model , True )
340352 finally :
341353 _clean_up_model (created_model )
342354
@@ -351,7 +363,7 @@ def test_from_saved_model(saved_model_dir):
351363
352364 # Validate the conversion by creating a model
353365 model_format = ml .TFLiteFormat (model_source = source )
354- model = ml .Model (display_name = "SavedModel1" , model_format = model_format )
366+ model = ml .Model (display_name = _random_identifier ( 'SavedModel_' ) , model_format = model_format )
355367 created_model = ml .create_model (model )
356368
357369 try :
0 commit comments