Skip to content

TensorFlow Lite

Quantization-Aware Training

from tensorflow_model_optimization.quantization.keras import quantize_model

q_aware_model = quantize_model(model) # untrained model

q_aware_model.compile(
    # ...
)

q_aware_model.fit()
q_aware_model.evaluate()
q_aware_model.predict()

# perform post-training optimization

Post-Training Optimization

tf_lite_converter = tf.lite.TFLiteConverter.from_keras_model(model)

# ... optimization

tflite_model = tf_lite_converter.convert()

with open("model.tflite", "wb") as f:
    f.write(tflite_model)

Quantization

tf_lite_converter.target_spec.supported_types = [
    tf.int8
]
tf_lite_converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
tf_lite_converter.optimizations = [
    # tf.lite.Optimize.DEFAULT,
    tf.lite.Optimize.OPTIMIZE_FOR_SIZE,
    # tf.lite.Optimize.OPTIMIZE_FOR_LATENCY
]

def representative_data_gen():
    for input_value, _ in test_batches.take(100):
        yield [input_value]

tf_lite_converter.representative_dataset = representative_data_gen

Evaluating model

Testing the model without edge device

interpreter = tf.lite.Interpreter(model_content = tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details)
x_new = np.array(
    [
        [10.0, 5.0],
        [5.0, 5.0],
    ],
    dtype=np.float32
)

interpreter.set_tensor(
    input_details[0]["index"],
    x_new
)

interpreter.invoke()

tflite_results = interpreter.get_tensor(output_details[0]["index"])
print(tflite_results)
Last Updated: 2024-12-26 ; Contributors: AhmedThahir, web-flow

Comments