Skip to content

Commit

Permalink
Update python scripts from ai-edge-litert package to tf.lite
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681624902
  • Loading branch information
ecalubaquib authored and copybara-github committed Oct 2, 2024
1 parent 46fcd97 commit 9084c2e
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# pylint: disable=g-import-not-at-top
try:
# Import TFLite interpreter from tflite_runtime package if it's available.
from tflite_runtime.interpreter import Interpreter
from ai_edge_litert.interpreter import Interpreter
except ImportError:
# If not, fallback to use the TFLite interpreter from the full TF package.
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion lite/examples/pose_estimation/raspberry_pi/ml/movenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# pylint: disable=g-import-not-at-top
try:
# Import TFLite interpreter from tflite_runtime package if it's available.
from tflite_runtime.interpreter import Interpreter
from ai_edge_litert.interpreter import Interpreter
except ImportError:
# If not, fallback to use the TFLite interpreter from the full TF package.
import tensorflow as tf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# pylint: disable=g-import-not-at-top
try:
# Import TFLite interpreter from tflite_runtime package if it's available.
from tflite_runtime.interpreter import Interpreter
from ai_edge_litert.interpreter import Interpreter
except ImportError:
# If not, fallback to use the TFLite interpreter from the full TF package.
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion lite/examples/pose_estimation/raspberry_pi/ml/posenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# pylint: disable=g-import-not-at-top
try:
# Import TFLite interpreter from tflite_runtime package if it's available.
from tflite_runtime.interpreter import Interpreter
from ai_edge_litert.interpreter import Interpreter
except ImportError:
# If not, fallback to use the TFLite interpreter from the full TF package.
import tensorflow as tf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from model import input_pipeline
from model import recommendation_model_launcher as launcher
from google.protobuf import text_format
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -192,7 +195,7 @@ def testModelTrainEvalExport(self):
tflite_model_path = os.path.join(export_dir, 'model.tflite')
self.assertTrue(os.path.exists(tflite_model_path))
f = open(tflite_model_path, 'rb')
interpreter = tf.lite.Interpreter(model_content=f.read())
interpreter = tfl_interpreter.Interpreter(model_content=f.read())
interpreter.allocate_tensors()
inference_signature = interpreter.get_signature_list()['serving_default']
self.assertAllEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# pylint: disable=g-import-not-at-top
try:
# Import TFLite interpreter from tflite_runtime package if it's available.
from tflite_runtime.interpreter import Interpreter
from ai_edge_litert.interpreter import Interpreter
except ImportError:
# If not, fallback to use the TFLite interpreter from the full TF package.
import tensorflow as tf
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_examples/lite/model_maker/core/task/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

from tensorflow_examples.lite.model_maker.core import compat
from tensorflowjs.converters import converter as tfjs_converter
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import
from tflite_support import metadata as _metadata

DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
Expand Down Expand Up @@ -222,7 +225,7 @@ def __init__(self,
"""
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter = tfl_interpreter.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()

# Gets the indexed of the input tensors.
Expand Down
1 change: 1 addition & 0 deletions tensorflow_examples/lite/model_maker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ neural-structured-learning>=1.3.1
tensorflow-model-optimization>=0.5
Cython>=0.29.13
scann==1.2.6
ai-edge-litert>=1.0.1
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.model import input_pipeline
from tensorflow_examples.lite.model_maker.third_party.recommendation.ml.model import recommendation_model_launcher as launcher
from google.protobuf import text_format
# pylint: disable=g-direct-tensorflow-import
from ai_edge_litert import interpreter as tfl_interpreter
# pylint: enable=g-direct-tensorflow-import

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -192,7 +195,7 @@ def testModelTrainEvalExport(self):
tflite_model_path = os.path.join(export_dir, 'model.tflite')
self.assertTrue(os.path.exists(tflite_model_path))
f = open(tflite_model_path, 'rb')
interpreter = tf.lite.Interpreter(model_content=f.read())
interpreter = tfl_interpreter.Interpreter(model_content=f.read())
interpreter.allocate_tensors()
inference_signature = interpreter.get_signature_list()['serving_default']
self.assertAllEqual(
Expand Down

0 comments on commit 9084c2e

Please sign in to comment.