如果要使用類神經網路來進行音樂檢索,目前(2023)常見的方法之一是對比式學習(contrastive learning)。使用這個方法的好處是,檢索時的資料庫在有所變動,特別是新增歌曲時,不需要像分類問題要新增一個類別時一樣重新訓練。

我們首先將 wav 檔案轉換成頻譜,因為抽取頻譜也需要花一些時間(儘管本範例的參數設定,跑起來可能只需要數秒鐘),因此我們將相關步驟獨立出來,如下:

import os

import librosa
import numpy as np
import scipy.io.wavfile

DATASET_PATH = 'MIR-QBSH/waveFile'
TRAIN_DIRS = ['year2003', 'year2004a', 'year2004b', 'year2005']
TEST_DIRS = ['year2006a']
FRAME_SIZE = 256
HOP_SIZE = 256

data_dirs = os.listdir(DATASET_PATH)
mags_tr = []
mags_te = []
ans_tr = []
ans_te = []
for dd in data_dirs:
	if dd not in TRAIN_DIRS + TEST_DIRS:
		continue
	people = os.listdir(os.path.join(DATASET_PATH, dd))
	for person in people:
		if person.endswith('.txt'):
			continue
		wav_files = os.listdir(os.path.join(DATASET_PATH, dd, person))
		print('Processing {}/{}...'.format(dd, person))
		for wf in wav_files:
			if not wf.endswith('.wav'):
				continue
			wav_path = os.path.join(DATASET_PATH, dd, person, wf)
			fs, y = scipy.io.wavfile.read(wav_path)
			y = y / 32768
			mag = np.abs(librosa.stft(y, n_fft=FRAME_SIZE, hop_length=HOP_SIZE, center=False)).T
			if dd in TRAIN_DIRS:
				mags_tr.append(mag)
				ans_tr.append(int(wf[:-4]) - 1)
			else:
				mags_te.append(mag)
				ans_te.append(int(wf[:-4]) - 1)

mags_tr = np.stack(mags_tr)
mags_te = np.stack(mags_te)
ans_tr = np.array(ans_tr)
ans_te = np.array(ans_te)
print('TR shapes:', mags_tr.shape, ans_tr.shape)
print('TE shapes:', mags_te.shape, ans_te.shape)
np.save('mags_tr', mags_tr)
np.save('mags_te', mags_te)
np.save('ans_tr', ans_tr)
np.save('ans_te', ans_te)

在上述的程式碼中:

訓練用的完整程式碼如下:

import os
import time

import numpy as np
import torch
from pytorch_metric_learning import losses
from torch import nn, optim
from torch.autograd import Variable

import util
import model

SAVE_MODEL_DIR = 'model'
NUM_EPOCH = 20
OUT_NUM = 4
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

db_pitches, db_song_names = util.read_database('MIR-QBSH/midiFile')
min_len = min([len(p) for p in db_pitches])
db_pitches = np.array([p[:min_len] for p in db_pitches])
print('DB shape:', db_pitches.shape)

mags_tr = np.load('mags_tr.npy')
ans_tr = np.load('ans_tr.npy')[:, np.newaxis]
print('Training data shapes:', mags_tr.shape, ans_tr.shape)

print('Setting network')
midi_encoder = model.MIDIENCODER(OUT_NUM)
spec_encoder = model.SPECENCODER(OUT_NUM)
midi_encoder.to(device)
spec_encoder.to(device)
optimizer_midi = optim.Adam(list(midi_encoder.parameters()), lr=0.0005)
optimizer_spec = optim.Adam(list(spec_encoder.parameters()), lr=0.0005)

loss_func = losses.NTXentLoss()

if not os.path.exists(SAVE_MODEL_DIR):
	os.makedirs(SAVE_MODEL_DIR, 0o755)

data_loader_train = torch.utils.data.DataLoader(
	util.Data2Torch(mags_tr, db_pitches, ans_tr),
	shuffle=True,
	batch_size=64,
)

va_not_imporved_continue_count = 0
totalTime = 0
fout = open(os.path.join(SAVE_MODEL_DIR, 'train_report.txt'), 'w')
for epoch in range(NUM_EPOCH):
	util.print_and_write_file(fout, 'epoch {}/{}...'.format(epoch + 1, NUM_EPOCH))
	tic = time.time()
	# --- Batch training
	midi_encoder.train()
	spec_encoder.train()
	training_loss = 0
	n_batch = 0
	optimizer_midi.zero_grad()
	optimizer_spec.zero_grad()
	for idx, data in enumerate(data_loader_train):
		emb_midi = midi_encoder(Variable(data['database'].to(device)))
		emb_spec = spec_encoder(Variable(data['query'].to(device)))
		ans = Variable(data['ans'].to(device))
		loss = loss_func(torch.cat((emb_midi, emb_spec)), torch.cat((ans, ans)).flatten())
		optimizer_midi.zero_grad()
		optimizer_spec.zero_grad()
		loss.backward()
		optimizer_midi.step()
		optimizer_spec.step()
		training_loss += loss.data
		n_batch += 1
	# --- Training loss
	training_loss_avg = training_loss / n_batch
	util.print_and_write_file(
		fout, '\tTraining loss (avg over batch): {}, {}, {}'.format(
			training_loss_avg, training_loss, n_batch
		)
	)
	# --- Time
	toc = time.time()
	totalTime += toc - tic
	util.print_and_write_file(fout, '\tTime: {:.3f} sec, estimated remaining: {:.3} hr'.format(
		toc - tic,
		1.0 * totalTime / (epoch + 1) * (NUM_EPOCH - (epoch + 1)) / 3600
	))
	fout.flush()

torch.save(midi_encoder.state_dict(), os.path.join(SAVE_MODEL_DIR, 'midi_encoder'))
torch.save(spec_encoder.state_dict(), os.path.join(SAVE_MODEL_DIR, 'spec_encoder'))
fout.close()
print('Model saved in {}'.format(SAVE_MODEL_DIR))

其中 import 的 model.py,內容如下:

import torch
import torch.nn.functional as F
from torch import nn


class MIDIENCODER(nn.Module):
	def __init__(self, out_num):
		super(MIDIENCODER, self).__init__()
		self.cnn = nn.Sequential(
			nn.Conv1d(1, 16, 5),
			nn.BatchNorm1d(16),
			nn.ReLU(),
			nn.MaxPool1d(4),
			nn.Conv1d(16, 32, 5),
			nn.BatchNorm1d(32),
			nn.ReLU(),
			nn.MaxPool1d(4),
			nn.Conv1d(32, 32, 5),
			nn.BatchNorm1d(32),
			nn.ReLU(),
			nn.MaxPool1d(4),
		)
		self.mlp = nn.Sequential(
			nn.Linear(in_features=32, out_features=out_num),
			nn.BatchNorm1d(out_num),
		)

	def forward(self, x):
		h = torch.unsqueeze(x, 1) # (batch, time) to (batch, ch=1, time)
		h = self.cnn(h)
		h = F.max_pool1d(h, kernel_size=h.size()[2])[:, :, 0]
		return self.mlp(h)


class SPECENCODER(nn.Module):
	def __init__(self, out_num):
		super(SPECENCODER, self).__init__()
		self.cnn = nn.Sequential(
			nn.Conv2d(1, 16, 3),
			nn.BatchNorm2d(16),
			nn.ReLU(),
			nn.MaxPool2d(2),
			nn.Conv2d(16, 32, 3),
			nn.BatchNorm2d(32),
			nn.ReLU(),
			nn.MaxPool2d(2),
			nn.Conv2d(32, 64, 3),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
			nn.Conv2d(64, 128, 3),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.mlp = nn.Sequential(
			nn.Linear(in_features=128, out_features=out_num),
			nn.BatchNorm1d(out_num),
		)

	def forward(self, x):
		h = torch.unsqueeze(x, 1) # (batch, time, freq) to (batch, ch=1, time, freq)
		h = self.cnn(h)
		h = F.max_pool2d(h, kernel_size=h.size()[2:])[:, :, 0, 0]
		return self.mlp(h)

util.py 的內容則為如下:

import os

import numpy as np
import pretty_midi
import torch
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import Dataset


class Data2Torch(Dataset):
	def __init__(self, query, database, ans):
		self.query = query
		self.database = database
		self.ans = ans

	def __getitem__(self, index):
		ans = self.ans[index]
		return {
			'query': torch.from_numpy(self.query[index]).float(),
			'database': torch.from_numpy(self.database[ans[0]]).float(),
			'ans': torch.from_numpy(ans).long(),
		}

	def __len__(self):
		return self.ans.shape[0]


class Data2TorchDB(Dataset):
	def __init__(self, database):
		self.database = database

	def __getitem__(self, index):
		return {
			'database': torch.from_numpy(self.database[index]).float(),
		}

	def __len__(self):
		return self.database.shape[0]


class Data2TorchQUERY(Dataset):
	def __init__(self, query):
		self.query = query

	def __getitem__(self, index):
		return {
			'query': torch.from_numpy(self.query[index]).float(),
		}

	def __len__(self):
		return self.query.shape[0]


def print_and_write_file(fout, cnt, fout_end='\n'):
	print(cnt)
	if fout is not None:
		fout.write(cnt  + fout_end)
		fout.flush()


def read_database(path, fs=31.25):
	with open(os.path.join(path, 'songList.txt'), 'r') as fin:
		cnt = fin.read().splitlines()
	db_song_names = [' '.join(line.split('\t')[1:3]) for line in cnt]
	midi_files = sorted(os.listdir(path))
	db_pitches = []
	for mf in midi_files:
		if not mf.endswith('.mid'):
			continue
		midi = pretty_midi.PrettyMIDI(os.path.join(path, mf))
		piano_roll = midi.get_piano_roll(fs=fs) # Shape: (semitone, time_step)
		pitches = np.argmax(piano_roll, axis=0)
		db_pitches.append(pitches)
	return db_pitches, db_song_names


def get_eval_results(qry_emb_all, db_emb_all, ans_all):
	cos_sim_mat = cosine_similarity(qry_emb_all, db_emb_all)
	
	r1_all = []
	r5_all = []
	r10_all = []
	mrr10_all = []
	for qry_idx, row in enumerate(cos_sim_mat):
		db_idx = ans_all[qry_idx]
		idx_sort = np.argsort(row)[::-1]
		rank = np.where(idx_sort == db_idx)[0][0]
		r1_all.append(rank < 1)
		r5_all.append(rank < 5)
		r10_all.append(rank < 10)
		mrr10_all.append(1/(rank+1) if rank < 10 else 0)
	return np.mean(r1_all), np.mean(r5_all), np.mean(r10_all), np.mean(mrr10_all)

在上述的程式碼中:

測試用的程式碼如下:

import os
import time

import numpy as np
import torch
from mir_eval.melody import evaluate
from mir_eval.util import midi_to_hz
from torch.autograd import Variable

import util
import model

MODLE_DIR = 'model'
OUT_NUM = 4
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

db_pitches, db_song_names = util.read_database('MIR-QBSH/midiFile')
min_len = min([len(p) for p in db_pitches])
db_pitches = np.array([p[:min_len] for p in db_pitches])
print('DB shape:', db_pitches.shape)

mags_te = np.load('mags_te.npy')
ans_te = np.load('ans_te.npy')
print('Test data shapes:', mags_te.shape, ans_te.shape)

print('Setting network')
save_dic_midi = torch.load(os.path.join(MODLE_DIR, 'midi_encoder'))
save_dic_spec = torch.load(os.path.join(MODLE_DIR, 'spec_encoder'))
midi_encoder = model.MIDIENCODER(OUT_NUM)
spec_encoder = model.SPECENCODER(OUT_NUM)
midi_encoder.load_state_dict(save_dic_midi)
spec_encoder.load_state_dict(save_dic_spec)
midi_encoder.to(device)
spec_encoder.to(device)
midi_encoder.eval()
spec_encoder.eval()

data_loader_test_db = torch.utils.data.DataLoader(
	util.Data2TorchDB(db_pitches), batch_size=32,
)

db_emb_all = []
for idx, data in enumerate(data_loader_test_db):
	with torch.no_grad():
		pred = midi_encoder(Variable(data['database'].to(device)))
	db_emb_all.append(pred.detach().cpu().numpy())
db_emb_all = np.vstack(db_emb_all)
print('db_emb_shape:', db_emb_all.shape)

data_loader_test_query = torch.utils.data.DataLoader(
	util.Data2TorchQUERY(mags_te), batch_size=64,
)

qry_emb_all = []
for idx, data in enumerate(data_loader_test_query):
	with torch.no_grad():
		pred = spec_encoder(Variable(data['query'].to(device)))
	qry_emb_all.append(pred.detach().cpu().numpy())
qry_emb_all = np.vstack(qry_emb_all)
print('query_emb_shape:', qry_emb_all.shape)

r1, r5, r10, mrr10 = util.get_eval_results(qry_emb_all, db_emb_all, ans_te)
print(r1, r5, r10, mrr10)

在上述的程式碼中: