Skip to content

Commit

Permalink
Merge pull request tensorflow#39525 from tensorflow/mm-fix-major-api-…
Browse files Browse the repository at this point in the history
…version

Fix tests when `tf._major_api_version` does not exist
  • Loading branch information
mihaimaruseac authored May 14, 2020
2 parents 90d1854 + d0106f7 commit 3ffdb91
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions tensorflow/tools/api/tests/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,22 @@ def testSummaryMerged(self):
tf.summary.image
# If we use v2 API, check for create_file_writer,
# otherwise check for FileWriter.
if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
tf.summary.create_file_writer
else:
tf.summary.FileWriter
if hasattr(tf, '_major_api_version'):
if tf._major_api_version == 2:
tf.summary.create_file_writer
else:
tf.summary.FileWriter
# pylint: enable=pointless-statement

def testInternalKerasImport(self):
normalization_parent = layers.BatchNormalization.__module__.split('.')[-1]
if tf._major_api_version == 2:
self.assertEqual('normalization_v2', normalization_parent)
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
else:
self.assertEqual('normalization', normalization_parent)
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)
if hasattr(tf, '_major_api_version'):
if tf._major_api_version == 2:
self.assertEqual('normalization_v2', normalization_parent)
self.assertTrue(layers.BatchNormalization._USE_V2_BEHAVIOR)
else:
self.assertEqual('normalization', normalization_parent)
self.assertFalse(layers.BatchNormalization._USE_V2_BEHAVIOR)


if __name__ == '__main__':
Expand Down

0 comments on commit 3ffdb91

Please sign in to comment.