如果要使用類神經網路來進行音樂檢索,目前(2023)常見的方法之一是對比式學習(contrastive learning)。使用這個方法的好處是,檢索時的資料庫在有所變動,特別是新增歌曲時,不需要像分類問題要新增一個類別時一樣重新訓練。
我們首先將 wav 檔案轉換成頻譜,因為抽取頻譜也需要花一些時間(儘管本範例的參數設定,跑起來可能只需要數秒鐘),因此我們將相關步驟獨立出來,如下:
import os import librosa import numpy as np import scipy.io.wavfile DATASET_PATH = 'MIR-QBSH/waveFile' TRAIN_DIRS = ['year2003', 'year2004a', 'year2004b', 'year2005'] TEST_DIRS = ['year2006a'] FRAME_SIZE = 256 HOP_SIZE = 256 data_dirs = os.listdir(DATASET_PATH) mags_tr = [] mags_te = [] ans_tr = [] ans_te = [] for dd in data_dirs: if dd not in TRAIN_DIRS + TEST_DIRS: continue people = os.listdir(os.path.join(DATASET_PATH, dd)) for person in people: if person.endswith('.txt'): continue wav_files = os.listdir(os.path.join(DATASET_PATH, dd, person)) print('Processing {}/{}...'.format(dd, person)) for wf in wav_files: if not wf.endswith('.wav'): continue wav_path = os.path.join(DATASET_PATH, dd, person, wf) fs, y = scipy.io.wavfile.read(wav_path) y = y / 32768 mag = np.abs(librosa.stft(y, n_fft=FRAME_SIZE, hop_length=HOP_SIZE, center=False)).T if dd in TRAIN_DIRS: mags_tr.append(mag) ans_tr.append(int(wf[:-4]) - 1) else: mags_te.append(mag) ans_te.append(int(wf[:-4]) - 1) mags_tr = np.stack(mags_tr) mags_te = np.stack(mags_te) ans_tr = np.array(ans_tr) ans_te = np.array(ans_te) print('TR shapes:', mags_tr.shape, ans_tr.shape) print('TE shapes:', mags_te.shape, ans_te.shape) np.save('mags_tr', mags_tr) np.save('mags_te', mags_te) np.save('ans_tr', ans_tr) np.save('ans_te', ans_te)在上述的程式碼中:
- FRAME_SIZE 和 HOP_SIZE 的值,是為了配合作為資料庫的 MIDI 檔案在解析時的取樣率(每秒鐘取 31.25 個點)而設定,但此處實際上沒有完全配合的絕對必要,你可以試試看其他的值。
- MIDI 和 wav 檔案的命名,原本為 00001 至 00048,程式中將其減去 1,以使其成為 0 開頭的索引值。
- 資料集的音檔恰好都是八秒鐘,因此在呼叫 np.stack 前,不需要另外對頻譜長度截長或補短。
- 為了使範例較為簡短等緣故,本篇的範例不切分出驗證集。若有興趣進行實驗,你可以自行切分。
訓練用的完整程式碼如下:
import os import time import numpy as np import torch from pytorch_metric_learning import losses from torch import nn, optim from torch.autograd import Variable import util import model SAVE_MODEL_DIR = 'model' NUM_EPOCH = 20 OUT_NUM = 4 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') db_pitches, db_song_names = util.read_database('MIR-QBSH/midiFile') min_len = min([len(p) for p in db_pitches]) db_pitches = np.array([p[:min_len] for p in db_pitches]) print('DB shape:', db_pitches.shape) mags_tr = np.load('mags_tr.npy') ans_tr = np.load('ans_tr.npy')[:, np.newaxis] print('Training data shapes:', mags_tr.shape, ans_tr.shape) print('Setting network') midi_encoder = model.MIDIENCODER(OUT_NUM) spec_encoder = model.SPECENCODER(OUT_NUM) midi_encoder.to(device) spec_encoder.to(device) optimizer_midi = optim.Adam(list(midi_encoder.parameters()), lr=0.0005) optimizer_spec = optim.Adam(list(spec_encoder.parameters()), lr=0.0005) loss_func = losses.NTXentLoss() if not os.path.exists(SAVE_MODEL_DIR): os.makedirs(SAVE_MODEL_DIR, 0o755) data_loader_train = torch.utils.data.DataLoader( util.Data2Torch(mags_tr, db_pitches, ans_tr), shuffle=True, batch_size=64, ) va_not_imporved_continue_count = 0 totalTime = 0 fout = open(os.path.join(SAVE_MODEL_DIR, 'train_report.txt'), 'w') for epoch in range(NUM_EPOCH): util.print_and_write_file(fout, 'epoch {}/{}...'.format(epoch + 1, NUM_EPOCH)) tic = time.time() # --- Batch training midi_encoder.train() spec_encoder.train() training_loss = 0 n_batch = 0 optimizer_midi.zero_grad() optimizer_spec.zero_grad() for idx, data in enumerate(data_loader_train): emb_midi = midi_encoder(Variable(data['database'].to(device))) emb_spec = spec_encoder(Variable(data['query'].to(device))) ans = Variable(data['ans'].to(device)) loss = loss_func(torch.cat((emb_midi, emb_spec)), torch.cat((ans, ans)).flatten()) optimizer_midi.zero_grad() optimizer_spec.zero_grad() loss.backward() optimizer_midi.step() optimizer_spec.step() training_loss += loss.data n_batch += 1 # --- Training loss training_loss_avg = training_loss / n_batch util.print_and_write_file( fout, '\tTraining loss (avg over batch): {}, {}, {}'.format( training_loss_avg, training_loss, n_batch ) ) # --- Time toc = time.time() totalTime += toc - tic util.print_and_write_file(fout, '\tTime: {:.3f} sec, estimated remaining: {:.3} hr'.format( toc - tic, 1.0 * totalTime / (epoch + 1) * (NUM_EPOCH - (epoch + 1)) / 3600 )) fout.flush() torch.save(midi_encoder.state_dict(), os.path.join(SAVE_MODEL_DIR, 'midi_encoder')) torch.save(spec_encoder.state_dict(), os.path.join(SAVE_MODEL_DIR, 'spec_encoder')) fout.close() print('Model saved in {}'.format(SAVE_MODEL_DIR))其中 import 的 model.py,內容如下:
import torch import torch.nn.functional as F from torch import nn class MIDIENCODER(nn.Module): def __init__(self, out_num): super(MIDIENCODER, self).__init__() self.cnn = nn.Sequential( nn.Conv1d(1, 16, 5), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(4), nn.Conv1d(16, 32, 5), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(4), nn.Conv1d(32, 32, 5), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(4), ) self.mlp = nn.Sequential( nn.Linear(in_features=32, out_features=out_num), nn.BatchNorm1d(out_num), ) def forward(self, x): h = torch.unsqueeze(x, 1) # (batch, time) to (batch, ch=1, time) h = self.cnn(h) h = F.max_pool1d(h, kernel_size=h.size()[2])[:, :, 0] return self.mlp(h) class SPECENCODER(nn.Module): def __init__(self, out_num): super(SPECENCODER, self).__init__() self.cnn = nn.Sequential( nn.Conv2d(1, 16, 3), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), ) self.mlp = nn.Sequential( nn.Linear(in_features=128, out_features=out_num), nn.BatchNorm1d(out_num), ) def forward(self, x): h = torch.unsqueeze(x, 1) # (batch, time, freq) to (batch, ch=1, time, freq) h = self.cnn(h) h = F.max_pool2d(h, kernel_size=h.size()[2:])[:, :, 0, 0] return self.mlp(h)util.py 的內容則為如下:
import os import numpy as np import pretty_midi import torch from sklearn.metrics.pairwise import cosine_similarity from torch.utils.data import Dataset class Data2Torch(Dataset): def __init__(self, query, database, ans): self.query = query self.database = database self.ans = ans def __getitem__(self, index): ans = self.ans[index] return { 'query': torch.from_numpy(self.query[index]).float(), 'database': torch.from_numpy(self.database[ans[0]]).float(), 'ans': torch.from_numpy(ans).long(), } def __len__(self): return self.ans.shape[0] class Data2TorchDB(Dataset): def __init__(self, database): self.database = database def __getitem__(self, index): return { 'database': torch.from_numpy(self.database[index]).float(), } def __len__(self): return self.database.shape[0] class Data2TorchQUERY(Dataset): def __init__(self, query): self.query = query def __getitem__(self, index): return { 'query': torch.from_numpy(self.query[index]).float(), } def __len__(self): return self.query.shape[0] def print_and_write_file(fout, cnt, fout_end='\n'): print(cnt) if fout is not None: fout.write(cnt + fout_end) fout.flush() def read_database(path, fs=31.25): with open(os.path.join(path, 'songList.txt'), 'r') as fin: cnt = fin.read().splitlines() db_song_names = [' '.join(line.split('\t')[1:3]) for line in cnt] midi_files = sorted(os.listdir(path)) db_pitches = [] for mf in midi_files: if not mf.endswith('.mid'): continue midi = pretty_midi.PrettyMIDI(os.path.join(path, mf)) piano_roll = midi.get_piano_roll(fs=fs) # Shape: (semitone, time_step) pitches = np.argmax(piano_roll, axis=0) db_pitches.append(pitches) return db_pitches, db_song_names def get_eval_results(qry_emb_all, db_emb_all, ans_all): cos_sim_mat = cosine_similarity(qry_emb_all, db_emb_all) r1_all = [] r5_all = [] r10_all = [] mrr10_all = [] for qry_idx, row in enumerate(cos_sim_mat): db_idx = ans_all[qry_idx] idx_sort = np.argsort(row)[::-1] rank = np.where(idx_sort == db_idx)[0][0] r1_all.append(rank < 1) r5_all.append(rank < 5) r10_all.append(rank < 10) mrr10_all.append(1/(rank+1) if rank < 10 else 0) return np.mean(r1_all), np.mean(r5_all), np.mean(r10_all), np.mean(mrr10_all)在上述的程式碼中:
- 如同傳統方法中的介紹一樣,此處仍假設使用者會從頭哼唱,因此我們在處理資料庫時,是透過截斷尾端的方式來讓所有歌曲的長度相同。
- 對比式學習的原理是,讓兩邊網路各自輸出嵌入向量,並透過損失函數的控制,讓這兩個向量在空間中的夾角盡可能地接近。因此,我們才會對做為查詢的 wav 檔案,以及作為資料庫的 MIDI 檔案,各自設計編碼器。
- 網路結構、輸出嵌入向量的維度、epoch 數目,以及學習率等,是一組能讓測試時的 top 10 辨識率達到約 60% 的參數。若有興趣進行實驗,你可以自己進行喜歡的調整。
- util.get_eval_results 為評估測試效果用,其說明請參考下一段。
測試用的程式碼如下:
import os import time import numpy as np import torch from mir_eval.melody import evaluate from mir_eval.util import midi_to_hz from torch.autograd import Variable import util import model MODLE_DIR = 'model' OUT_NUM = 4 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') db_pitches, db_song_names = util.read_database('MIR-QBSH/midiFile') min_len = min([len(p) for p in db_pitches]) db_pitches = np.array([p[:min_len] for p in db_pitches]) print('DB shape:', db_pitches.shape) mags_te = np.load('mags_te.npy') ans_te = np.load('ans_te.npy') print('Test data shapes:', mags_te.shape, ans_te.shape) print('Setting network') save_dic_midi = torch.load(os.path.join(MODLE_DIR, 'midi_encoder')) save_dic_spec = torch.load(os.path.join(MODLE_DIR, 'spec_encoder')) midi_encoder = model.MIDIENCODER(OUT_NUM) spec_encoder = model.SPECENCODER(OUT_NUM) midi_encoder.load_state_dict(save_dic_midi) spec_encoder.load_state_dict(save_dic_spec) midi_encoder.to(device) spec_encoder.to(device) midi_encoder.eval() spec_encoder.eval() data_loader_test_db = torch.utils.data.DataLoader( util.Data2TorchDB(db_pitches), batch_size=32, ) db_emb_all = [] for idx, data in enumerate(data_loader_test_db): with torch.no_grad(): pred = midi_encoder(Variable(data['database'].to(device))) db_emb_all.append(pred.detach().cpu().numpy()) db_emb_all = np.vstack(db_emb_all) print('db_emb_shape:', db_emb_all.shape) data_loader_test_query = torch.utils.data.DataLoader( util.Data2TorchQUERY(mags_te), batch_size=64, ) qry_emb_all = [] for idx, data in enumerate(data_loader_test_query): with torch.no_grad(): pred = spec_encoder(Variable(data['query'].to(device))) qry_emb_all.append(pred.detach().cpu().numpy()) qry_emb_all = np.vstack(qry_emb_all) print('query_emb_shape:', qry_emb_all.shape) r1, r5, r10, mrr10 = util.get_eval_results(qry_emb_all, db_emb_all, ans_te) print(r1, r5, r10, mrr10)在上述的程式碼中:
- 測試時的嵌入向量維度,當然必須與訓練時的相同。如果你希望可以不必在訓練和測試的程式碼中重複出現相同的參數,則可以將其利用 pyyaml 等函式庫,來幫你將各種參數存檔。
- 測試時,當輸入固定則輸出的嵌入向量也會固定,並且用於資料庫和查詢兩個編碼器是各自獨立的,因此我們可以先將其嵌入向量各自計算完畢,再來進行後續的相似度比對。
- 由於我們使用 cosine similarity 來評估向量之間的相似度,因此對於每一則查詢來說,我們將它與資料庫每首歌的 cosine similarity 做遞減排序,就可以用標準答案所在的位置做為其名次。