Skip to content

Commit

Permalink
Move tensorflow lite python calls to ai-edge-litert.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688768903
  • Loading branch information
pak-laura authored and copybara-github committed Oct 23, 2024
1 parent 1b654c1 commit c89d05c
Show file tree
Hide file tree
Showing 4 changed files with 964 additions and 795 deletions.
4 changes: 1 addition & 3 deletions chirp/projects/zoo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
import tensorflow.compat.v1 as tf1
import tensorflow_hub as hub

from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


@dataclasses.dataclass
class SeparateEmbedModel(zoo_interface.EmbeddingModel):
Expand Down Expand Up @@ -328,7 +326,7 @@ def from_config(cls, config: config_dict.ConfigDict) -> 'BirdNet':
with tempfile.NamedTemporaryFile() as tmpf:
model_file = epath.Path(config.model_path)
model_file.copy(tmpf.name, overwrite=True)
model = tfl_interpreter.Interpreter(
model = tf.lite.Interpreter(
tmpf.name, num_threads=config.num_tflite_threads
)
model.allocate_tensors()
Expand Down
5 changes: 2 additions & 3 deletions chirp/train_tests/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from absl.testing import absltest
from absl.testing import parameterized
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


class FrontendTest(parameterized.TestCase):
Expand Down Expand Up @@ -201,7 +200,7 @@ def test_tflite_stft_export(
tflite_float_model = converter.convert()

# Use the converted TFLite model.
interpreter = tfl_interpreter.Interpreter(model_content=tflite_float_model)
interpreter = tf.lite.Interpreter(model_content=tflite_float_model)
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
Expand Down Expand Up @@ -248,7 +247,7 @@ def test_simple_melspec(self):
tflite_float_model = converter.convert()

# Use the converted TFLite model.
interpreter = tfl_interpreter.Interpreter(model_content=tflite_float_model)
interpreter = tf.lite.Interpreter(model_content=tflite_float_model)
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
Expand Down
Loading

0 comments on commit c89d05c

Please sign in to comment.