I really like this model and Im wandering how to train a modal like this.
#2
by
kikouousya
- opened
I will be appreciated if there are more imformation(paper, tutorial, project, or anything other) about that. Tanks!
I tried write one but I wander is Sequential modal, Dense layer, sigmoid activation, adam optimizer, categorical_crossentropy loss function
ok?
if there's any advice, I would be appreciated.
import tensorflow as tf
import tf2onnx
def generate_mixed_modal(
model_path,
feature_extraction_layer,
batch_size, # count of images in each batch
# a list of (448, 448, 3) image in shape of (batch_size, 448, 448, 3)
# can also use Horizontal Flipping, Rotation, Cropping, Color Jittering, Gaussian Noise, etc. to increase the dataset?
new_data,
# a tag list in (batch_size, num_classes) (like [[0,0,1], [0,1,0],...] when it is a 3 tag problem)
# maybe float value is also ok? using original prediction confidence value?
new_labels,
num_classes, # total tag number
num_epochs, # train epochs, each epoch will train all the data
new_modal_name,
):
full_model = tf.keras.models.load_model(model_path)
# get the feature extractor of the layer before last layer
feature_extractor = tf.keras.models.Model(
full_model.inputs, full_model.get_layer(feature_extraction_layer[-2]).output
)
features = feature_extractor.predict(new_data)
print(features.shape)
# dense layer & use sigmoid to predict multiple tags?
classifier = tf.keras.models.Sequential([
tf.keras.layers.Dense(num_classes, activation='sigmoid')
])
# train the classifier
# use adam to automatically adjust learning rate
classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
classifier.fit(features, new_labels, epochs=num_epochs, batch_size=batch_size)
# combine the feature extractor and classifier
new_model = tf.keras.models.Sequential([
feature_extractor,
classifier
])
new_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
tf.keras.models.save_model(new_model, new_modal_name + '.pd')
# also save to onnx
tf2onnx.convert.from_keras(new_model, output_path=new_modal_name + '.onnx')