from drig.datum_config import CALTECH101Config as config
from drig.config import VGGNetImage
from drig.feature import FeatureExtractor
from drig.utils import display_image_data, visualize_network, display_prediction, plot_network, \
grab_random_image, display_image, grab_confusion_mesh, plot_confusion_mesh, grab_image_class_names
from keras.applications import VGG16, imagenet_utils
from keras.applications import vgg16 as vgg16_preprocessor
from keras.preprocessing.image import img_to_array, load_img
from imutils import paths
import numpy as np
import os
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
import pickle
import h5py
from sklearn.model_selection import RandomizedSearchCV
display_image_data(config.DATASET_PATH)
vgg16=VGG16(weights="imagenet")
visualize_network(vgg16)
plot_network(vgg16)
input_cast=(VGGNetImage.HEIGHT,VGGNetImage.WIDTH,VGGNetImage.DEPTH)
vgg16_net=VGG16(weights="imagenet", input_shape=input_cast ,include_top=False)
visualize_network(vgg16_net)
plot_network(vgg16_net)
feature_extractor=FeatureExtractor(feature_datum_path=config.VGG16_FEATURE_DATUM_PATH,
class_index=config.CLASS_INDEX,
network=vgg16_net,
net_input_cast=input_cast,
image_datum_path=config.DATASET_PATH,
preprocessor=vgg16_preprocessor,
batch_size=config.BATCH_SIZE,
buffer_size=config.BUFFER_SIZE,
image_net=True
)
feature_extractor.extract_features()
INFO:root:ENCODING LABELS INFO:root:CLASSES : 101 INFO:root:INITIALZING FEATURE CONDENSER Extracting Features: 100% |❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆❆| 8677 of 8677 Time: 0:05:31
vgg16_caltech_features=h5py.File(config.VGG16_FEATURE_DATUM_PATH,mode="r")
classes=vgg16_caltech_features["class_names"]
training_set_index=int(vgg16_caltech_features["classes"].shape[0]*config.TRAIN_SIZE)
from sklearn.svm import SVC
svc=SVC(verbose=1)
svc.fit(vgg16_caltech_features["features"][:training_set_index],
vgg16_caltech_features["classes"][:training_set_index],)
[LibSVM]
SVC(verbose=1)
predictions=svc.predict(vgg16_caltech_features["features"][training_set_index:])
print(classification_report(vgg16_caltech_features["classes"][training_set_index:],
predictions,
target_names=classes))
precision recall f1-score support Faces 1.00 1.00 1.00 121 Faces_easy 1.00 1.00 1.00 105 Leopards 0.92 1.00 0.96 49 Motorbikes 1.00 1.00 1.00 202 accordion 1.00 1.00 1.00 14 airplanes 0.99 1.00 0.99 186 anchor 0.67 0.80 0.73 5 ant 1.00 1.00 1.00 8 barrel 1.00 1.00 1.00 8 bass 1.00 0.92 0.96 13 beaver 0.75 0.60 0.67 10 binocular 1.00 0.80 0.89 5 bonsai 0.72 1.00 0.84 38 brain 0.92 0.96 0.94 25 brontosaurus 1.00 0.75 0.86 8 buddha 1.00 0.93 0.97 15 butterfly 0.50 1.00 0.67 21 camera 1.00 0.82 0.90 11 cannon 1.00 0.70 0.82 10 car_side 1.00 1.00 1.00 33 ceiling_fan 0.92 0.85 0.88 13 cellphone 0.95 0.90 0.92 20 chair 0.71 0.86 0.77 14 chandelier 0.77 0.95 0.85 21 cougar_body 0.62 0.73 0.67 11 cougar_face 0.83 0.71 0.77 14 crab 0.81 1.00 0.89 25 crayfish 0.58 0.88 0.70 16 crocodile 0.55 0.79 0.65 14 crocodile_head 0.83 0.91 0.87 11 cup 1.00 0.87 0.93 15 dalmatian 1.00 1.00 1.00 13 dollar_bill 1.00 1.00 1.00 10 dolphin 0.83 1.00 0.91 15 dragonfly 1.00 1.00 1.00 21 electric_guitar 0.96 1.00 0.98 26 elephant 0.94 0.89 0.91 18 emu 1.00 0.86 0.92 14 euphonium 1.00 1.00 1.00 16 ewer 1.00 0.96 0.98 24 ferry 1.00 1.00 1.00 17 flamingo 1.00 1.00 1.00 13 flamingo_head 1.00 0.92 0.96 13 garfield 1.00 0.82 0.90 11 gerenuk 1.00 0.18 0.31 11 gramophone 1.00 0.93 0.96 14 grand_piano 1.00 1.00 1.00 17 hawksbill 0.94 0.97 0.95 32 headphone 0.89 0.89 0.89 9 hedgehog 0.93 0.72 0.81 18 helicopter 0.95 0.95 0.95 22 ibis 1.00 1.00 1.00 20 inline_skate 1.00 1.00 1.00 7 joshua_tree 0.68 0.93 0.79 14 kangaroo 0.75 0.94 0.83 16 ketch 0.80 0.97 0.88 29 lamp 1.00 0.80 0.89 20 laptop 1.00 1.00 1.00 20 llama 0.67 0.82 0.74 17 lobster 1.00 0.33 0.50 12 lotus 0.74 0.93 0.82 15 mandolin 1.00 0.92 0.96 13 mayfly 0.90 1.00 0.95 9 menorah 0.96 0.96 0.96 25 metronome 1.00 0.69 0.82 13 minaret 0.96 0.96 0.96 25 nautilus 1.00 0.77 0.87 13 octopus 1.00 0.40 0.57 10 okapi 1.00 0.85 0.92 13 pagoda 1.00 0.85 0.92 13 panda 1.00 0.64 0.78 11 pigeon 1.00 0.86 0.92 7 pizza 0.86 0.75 0.80 8 platypus 1.00 0.62 0.77 8 pyramid 0.81 0.87 0.84 15 revolver 1.00 0.95 0.97 19 rhino 0.93 0.88 0.90 16 rooster 1.00 0.75 0.86 12 saxophone 1.00 0.60 0.75 5 schooner 1.00 0.60 0.75 20 scissors 1.00 0.77 0.87 13 scorpion 1.00 0.89 0.94 18 sea_horse 0.64 0.88 0.74 16 snoopy 0.82 0.82 0.82 11 soccer_ball 1.00 0.93 0.96 14 stapler 1.00 0.88 0.94 17 starfish 0.92 1.00 0.96 22 stegosaurus 1.00 0.64 0.78 11 stop_sign 1.00 0.81 0.90 16 strawberry 1.00 0.70 0.82 10 sunflower 1.00 1.00 1.00 15 tick 1.00 0.67 0.80 15 trilobite 1.00 1.00 1.00 19 umbrella 1.00 0.92 0.96 13 watch 0.74 1.00 0.85 67 water_lilly 1.00 0.12 0.22 8 wheelchair 1.00 0.95 0.97 19 wild_cat 1.00 0.30 0.46 10 windsor_chair 0.91 0.77 0.83 13 wrench 0.88 0.70 0.78 10 yin_yang 0.92 0.85 0.88 13 accuracy 0.92 2170 macro avg 0.92 0.85 0.87 2170 weighted avg 0.93 0.92 0.91 2170
test_image,class_name,test_image_path=grab_random_image(dataset_path=config.DATASET_PATH, class_index=config.CLASS_INDEX,
return_image_path=True)
display_image(image=test_image)
class_name
'cougar_face'
test_predict=svc.predict(FeatureExtractor.unit_image_feature(vgg16_net,test_image_path,
net_input_cast=input_cast,
image_net=True))
display_prediction(test_predict,classes,font_scale=0.7,image_path=test_image_path)
class_names, encoded_class_names=grab_image_class_names(dataset_path=config.DATASET_PATH,
class_index=config.CLASS_INDEX,
encode_classes=True)
mesh, class_name,_=grab_confusion_mesh(vgg16_caltech_features["classes"][training_set_index:],predictions,
encoded_class_names , class_name=class_name,classes=class_names.tolist())
plot_confusion_mesh(mesh, class_name)