Engine knock detection AI part 3/5
Training a convolutional neural network to identify knocking engine noises
Background
With the training dataset created the model can be trained. This is copying almost verbatim from the second lesson in Practical Deep Learning for Coders, found at the github repo for the third installment of the course. The steps are described quite well in the linked notebook.
!curl -s https://course.fast.ai/setup/colab | bash
from google.colab import drive
drive.mount('/content/drive')
from fastai.vision import *
classes = ['knocking','normal']
path = Path('/content/drive/My Drive/Colab Notebooks/fast.ai/KnockKnock/data')
for c in classes:
print(c)
verify_images(path/c)
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, size=224, num_workers=4).normalize(imagenet_stats)
data.classes
Display images from dataset
As demonstrated by these images the knocking displays as vertical spikes in the middle of the spectrum. Some spectrograms for non knocking engines show rythmic components in the lower frequencies (top row middle). It will be interesting how well the model will be able to distinguish from these.
data.show_batch(rows=3, figsize=(7,8))
data.classes, data.c, len(data.train_ds), len(data.valid_ds)
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
learn.save('stage-1')
learn.unfreeze()
learn.lr_find()
# learn.lr_find(start_lr=1e-5, end_lr=1e-1)
learn.recorder.plot()
learn.fit_one_cycle(2, max_lr=slice(4e-6,4e-4))
learn.save('stage-2')
learn.load('stage-2');
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
Plot top losses
Looking at the cases where the model was the most unsure also gives the impression that it's quite good. The top row would be hard for me to classify correctly based on the spectrograms. Looking closely at the third one (top right) you can suspect somehing is going on with the vertical lines in the top middle of the image.
losses,idxs = interp.top_losses(10)
len(data.valid_ds)==len(losses)==len(idxs)
interp.plot_top_losses(9)
import IPython.display as ipd
import os
for img_path in data.valid_ds.items[idxs]:
filepath, extension = os.path.splitext(img_path)
audio_slice_path = filepath + '.wav'
print(filepath)
ipd.display(ipd.Audio(audio_slice_path))
learn.export()