如何创建一个可轻松转换为TensorFlow Lite的模型?

最后发布: 2020-08-03


问题

如何创建一个可以转换为TensorFlow Lite(tflite)并能在Android应用中使用的TensorFlow模型?

按照Google ML Crash Course中的例子,我已经创建了一个分类器并训练了一个模型。我已经将模型导出为 保存模型. 我想把模型转换为 .tflite 文件,并将其用于 推断.

很快(其实是后来)我就明白,我的模型使用了 不支持的操作 - ParseExampleV2.

这是我用来训练模型的分类器。

classifier = tf.estimator.DNNClassifier(
        feature_columns=[tf.feature_column.numeric_column('pixels', shape=WIDTH * HEIGHT)],
        n_classes=NUMBER_OF_CLASSES,
        hidden_units=[40, 40],
        optimizer=my_optimizer,
        config=tf.estimator.RunConfig(keep_checkpoint_max=1),
        model_dir=MODEL_DIR)

有没有一种方法可以训练一个不用这个的模型 tf.ParseExampleV2 操作员?

TensorFlow Lite Model

tensorflow tensorflow2.0 tensorflow-lite
回答

使用方法 Keras 序列式 API 而不是 估算器API.

如果你的模型比较复杂,可以尝试 Keras功能API.

估算器是一个高级的API,它为模型增加了额外的复杂性。

这里是一个顺序模型。

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1024, input_dim=WIDTH*HEIGHT, activation='relu'))
model.add(tf.keras.layers.Dense(1024, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

optimizer = tf.keras.optimizers.Adam(learning_rate=rate)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

和它的模式。将其与问题中的模型进行比较。

enter image description here

对于完整的例子如何转换模型 萤火虫 看我的项目 零八分类.