from drig.networks import TinyGoogLeNet
from drig.callbacks import LossAccuracyTracker, AlphaSchedulers
from drig.utils import display_image_data, plot_training_metrics, visualize_network, plot_network
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from keras.callbacks import LearningRateScheduler
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.models import load_model
import numpy as np
import os
train_image_datum_path=os.path.abspath(os.path.join(os.path.pardir,"datasets/CIFAR-10/train"))
model_dir=os.path.abspath(os.path.join(os.path.pardir,"models/TinyGoogLeNet"))
os.makedirs(model_dir, exist_ok=True)
model_save_path=os.path.join(model_dir,"tinygooglenet_cifar10.hdf5")
loss_acc_plot_path=os.path.join(model_dir,f"loss_acc_plot_{os.getpid()}.png")
json_path=os.path.join(model_dir,f"model_history.json")
display_image_data(train_image_datum_path, image_dim=(256,256))
((train_x,train_y),(test_x,test_y))=cifar10.load_data()
train_x=train_x.astype("float")
test_x=test_x.astype("float")
mean=np.mean(train_x, axis=0)
train_x-=mean
test_x-=mean
label_binarizer=LabelBinarizer()
train_y=label_binarizer.fit_transform(train_y)
test_y=label_binarizer.transform(test_y)
classes=["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
image_aug=ImageDataGenerator(width_shift_range=0.1,height_shift_range=0.1,horizontal_flip=True,fill_mode="nearest")
epochs = 70
base_alpha = 5e-3
batch_size = 64
alpha_schedulers = AlphaSchedulers(base_alpha, epochs)
callbacks=[LossAccuracyTracker(loss_acc_plot_path,json_path=json_path),
LearningRateScheduler(alpha_schedulers.polynomial_decay)]
optimizer=SGD(learning_rate=base_alpha, momentum=0.9)
net=TinyGoogLeNet.compose(height=32,width=32,depth=3,classes=10)
visualize_network(net, scale_xy=10)
plot_network(net)
net.compile(loss="categorical_crossentropy",optimizer=optimizer, metrics=["accuracy"])
Z=net.fit(image_aug.flow(train_x,train_y, batch_size=batch_size), validation_data=(test_x,test_y),
steps_per_epoch=len(train_x)//batch_size,
epochs=epochs,callbacks=callbacks, verbose=1)
cifar10_predictions=net.predict(test_x, batch_size=batch_size)
print(classification_report(test_y.argmax(axis=1), cifar10_predictions.argmax(axis=1), target_names=classes))
precision recall f1-score support
airplane 0.91 0.92 0.92 1000
automobile 0.95 0.96 0.96 1000
bird 0.87 0.86 0.86 1000
cat 0.82 0.81 0.81 1000
deer 0.88 0.90 0.89 1000
dog 0.88 0.82 0.85 1000
frog 0.88 0.94 0.91 1000
horse 0.93 0.91 0.92 1000
ship 0.95 0.95 0.95 1000
truck 0.94 0.94 0.94 1000
accuracy 0.90 10000
macro avg 0.90 0.90 0.90 10000
weighted avg 0.90 0.90 0.90 10000
plot_training_metrics(model_training_history=Z,epochs=epochs,inline=True)