動態時間扭曲(dynamic time warping, DTW)不同於線性伸縮法,主要是允許查詢輸入的節奏變化可能會不太均勻,因此不能用單純的伸縮來調整節奏,而是改用動態規劃(dynamic programming, DP)的方式來處理。本篇將假設各位對 DP 已經有基本的瞭解,例如已經修過演算法的課程,或者知道怎樣解決「最長共同序列(longest common subsequence, LCS)」的問題。以下內容,將主要介紹如何將 DP 運用在哼唱選歌當中。
以下是一個 DTW 的範例,此範例會將查詢輸入與資料庫中的某個音高向量之間進行 DTW 對位,並展示直接對位和 DTW 對位的結果,以及 DTW 運算中的過程:
import numpy as np import matplotlib.pyplot as plt def dtw(input_pitch, db_pitch): len_in = len(input_pitch) len_db = len(db_pitch) table_dp = np.zeros((len_db, len_in)) table_dir = np.zeros((len_db, len_in)) table_dist = np.zeros((len_db, len_in)) for i, db in enumerate(db_pitch): for j, inp in enumerate(input_pitch): table_dist[i, j] = np.abs(db-inp) # First element table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0]) # First row for j in range(1, len_in): table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j]) table_dir[0, j] = 1 # First column for i in range(1, len_db): table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0]) table_dir[i, 0] = 2 # Remaining for i in range(1, len_db): for j in range(1, len_in): vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([ table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j], ]) table_dp[i, j] = np.min(vec) table_dir[i, j] = np.argmin(vec) # Backtrack start_i = np.argmin(table_dp[:, -1]) start_j = len_in - 1 ret_list = [[start_i, start_j]] while start_i >= 0 or start_j >= 0: if table_dir[start_i, start_j] == 0: start_i -= 1 start_j -= 1 elif table_dir[start_i, start_j] == 1: start_j -= 1 elif table_dir[start_i, start_j] == 2: start_i -= 1 ret_list = [[start_i, start_j]] + ret_list return ret_list[1:], table_dp, table_dist inputPitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14]) dbPitch = 1.0 * np.array([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60]) dbPitch -= np.mean(dbPitch) inputPitch -= np.mean(inputPitch) alignment, table_dp, table_dist = dtw(inputPitch, dbPitch) plt.subplot(2, 2, 1) plt.plot(dbPitch) plt.plot(inputPitch-10) for i, (db, inp) in enumerate(zip(dbPitch, inputPitch)): if i % 5 == 0: plt.plot([i, i], [db, inp-10], 'r--') plt.yticks([-10, 0], ['Input pitch', 'Target pitch']) plt.title('Direct Alignment') plt.subplot(2, 2, 3) plt.plot(dbPitch) for i, (idx_db, idx_qry) in enumerate(alignment): if i % 5 == 0: plt.plot([idx_db, idx_qry], [dbPitch[idx_db], inputPitch[idx_qry]-10], 'r--') plt.plot(inputPitch-10) plt.yticks([-10, 0], ['Input pitch', 'Target pitch']) plt.title('DTW Alignment') plt.subplot(1, 2, 2) plt.imshow(table_dp) plt.plot([a[1] for a in alignment], [a[0] for a in alignment], 'r-') plt.xlabel('Input pitch Index') plt.ylabel('Target pitch Index') plt.colorbar() plt.title('DTW Path') plt.show()在上述的程式碼中:
- 將查詢輸入的音高序列和資料庫的音高序列都平移到平均為 0 的原因,在此範例中主要是顯示上的方便;在實際比對上我們也需要將音高向量做平移,原因及做法請見下段探討。
- 範例中的填表的方式為索引值 [i, j-1]、索引值 [i-1, j-1],以及索引值 [i-1, j] 取最小值,此種方法稱為 0°-45°-90° 的填表;若你對允許對位的路徑有不同的想法,也可以使用其他對位方式,例如用索引值 [i-1, j-2]、索引值 [i-1, j-1],以及索引值 [i-2, j-1] 取最小值,即 27°-45°-63° 的填表方式,但請記得適當的修改邊界條件,即一開始的幾個 rows 和 columns 的處理方式。
- 因為本範例假設查詢輸入是從頭開始唱歌,但是不知道唱到哪裡結束,所以希望的比對方式是開頭要對齊,而尾巴不必(簡稱「頭對頭,尾自由」),所以在進行路徑回溯的時候,是從最後一個 column 找最小值。若你對允許對位的方式有不同的想法,也可以從其他地方來尋找回溯的起點。
- 填表時處理到的較靠近邊角的區域,或者路徑回溯時經過了比較邊角的區域,都代表了輸入查詢的節奏,發生了由很快到很慢或者反過來的變化。如果你認為不需要處理這種變化,可以修改程式實作,以便對填表的範圍進行限制,例如只限制在對角線附近。
- 實務上,除非你有如本範例的研究觀察用途等特殊需求,或者系統設計上希望讓使用者知道是比對到哪個片段,否則只需要從最後一個 column,或者任何你根據自身想法定義的起點來找最小值做為距離的計算結果即可,不需要進行回溯求取對位路徑。
DTW 在移調上的做法跟 LS 比較不一樣。LS 因為是在移調之前,就很明確的知道查詢輸入跟資料庫中的某片段是相同位置一對一的對應,所以可以直接計算出一個平移方式;而在 DTW 當中,因為在比對之前不知道查詢輸入跟資料庫中的某片段是如何對應,因此通常會在音高平均為 0 附近的某個(例如加減兩個半音)範圍內進行搜尋,來找出能造成最小距離的平移方式。下面的範例,會對查詢輸入以二元搜尋的原則進行三次搜尋,且每次都縮小搜尋範圍,來求取其跟資料庫中某片段的最小 DTW 距離:
import numpy as np def dtw(input_pitch, db_pitch): len_in = len(input_pitch) len_db = len(db_pitch) table_dp = np.zeros((len_db, len_in)) table_dir = np.zeros((len_db, len_in)) table_dist = np.zeros((len_db, len_in)) for i, db in enumerate(db_pitch): for j, inp in enumerate(input_pitch): table_dist[i, j] = np.abs(db-inp) # First element table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0]) # First row for j in range(1, len_in): table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j]) table_dir[0, j] = 1 # First column for i in range(1, len_db): table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0]) table_dir[i, 0] = 2 # Remaining for i in range(1, len_db): for j in range(1, len_in): vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([ table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j], ]) table_dp[i, j] = np.min(vec) table_dir[i, j] = np.argmin(vec) # Return return np.min(table_dp[:, -1]) def dtw_for_one_song(input_pitch, db_pitch): db_pitch -= np.mean(db_pitch) input_pitch -= np.mean(input_pitch) shift_base = 0 shift_offset = 2 dist = dtw(input_pitch + shift_base, db_pitch) new_base = 0 print(shift_base, 0, dist) for _ in range(3): dist_plus = dtw(input_pitch + shift_base + shift_offset, db_pitch) print(shift_base, shift_offset, dist_plus) if dist_plus < dist: dist = dist_plus new_base = -1 dist_minus = dtw(input_pitch + shift_base - shift_offset, db_pitch) print(shift_base, -shift_offset, dist_minus) if dist_minus < dist: dist = dist_minus new_base = 1 shift_base = new_base shift_offset /= 2 return dist input_pitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14]) db_pitch = 1.0 * np.array([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60]) dist_one_song = dtw_for_one_song(input_pitch, db_pitch) print(dist_one_song)在上述的程式碼中:
- 如先前所述,此範例因為不需要知道路徑,因此將回溯相關的程式碼移除。
- 此範例需要進行多次的 DTW,因此你可以觀察到 DTW 的運算明顯會花費較多時間,因此在搜尋最佳音高偏移值時,要避免重複去計算基準位置的 DTW 距離。
- 各位可以改變音高偏移值的搜尋範圍,觀察看看結果是否會有變化。
對於一個輸入查詢,要做出系統該回傳給使用者的結果候選清單的方法,則跟 LS 相同,把輸入查詢跟資料庫中的所有歌曲都比對過一遍,並依照它與資料庫中每首歌的最小距離進行遞增排序即可:
import os import numpy as np import pretty_midi def read_database(path, fs=31.25): with open(os.path.join(path, 'songList.txt'), 'r') as fin: cnt = fin.read().splitlines() db_song_names = [' '.join(line.split('\t')[1:3]) for line in cnt] midi_files = sorted(os.listdir(path)) db_pitches = [] for mf in midi_files: if not mf.endswith('.mid'): continue midi = pretty_midi.PrettyMIDI(os.path.join(path, mf)) piano_roll = midi.get_piano_roll(fs=fs) # Shape: (semitone, time_step) pitches = np.argmax(piano_roll, axis=0) db_pitches.append(pitches) return db_pitches, db_song_names def dtw(input_pitch, db_pitch): len_in = len(input_pitch) len_db = len(db_pitch) table_dp = np.zeros((len_db, len_in)) + np.inf table_dir = np.zeros((len_db, len_in)) table_dist = np.zeros((len_db, len_in)) for i, db in enumerate(db_pitch): for j, inp in enumerate(input_pitch): table_dist[i, j] = np.abs(db-inp) # First element table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0]) # First row for j in range(1, len_in): table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j]) table_dir[0, j] = 1 # First column for i in range(1, len_db): table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0]) table_dir[i, 0] = 2 # Remaining for i in range(1, len_db): for j in range(max(0, i-len_in//2), min(len_in, i+len_in//2)): vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([ table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j], ]) table_dp[i, j] = np.min(vec) table_dir[i, j] = np.argmin(vec) # Return return np.min(table_dp[:, -1]) def dtw_for_one_song(input_pitch, db_pitch): db_pitch -= np.mean(db_pitch) shift_base = 0 shift_offset = 2 dist = dtw(input_pitch + shift_base, db_pitch) new_base = 0 for _ in range(2): dist_plus = dtw(input_pitch + shift_base + shift_offset, db_pitch) if dist_plus < dist: dist = dist_plus new_base = -1 dist_minus = dtw(input_pitch + shift_base - shift_offset, db_pitch) if dist_minus < dist: dist = dist_minus new_base = 1 shift_base = new_base shift_offset /= 2 return dist def compare_to_whole_db(input_pitch, db_pitches): input_pitch -= np.mean(input_pitch) all_dists = [] for s_idx, db_pitch in enumerate(db_pitches): print('Comparing song {}/{}'.format(s_idx, len(db_pitches))) all_dists.append(dtw_for_one_song(input_pitch, db_pitch.astype('float'))) return np.argsort(all_dists) input_pitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14]) db_pitches, db_song_names = read_database('MIR-QBSH/midiFile') sorted_idx = compare_to_whole_db(input_pitch, db_pitches) for idx in sorted_idx[:10]: print(db_song_names[idx])在上述的程式碼中,因為 DTW 的執行時間顯著的比較久,因此跟之前的範例相比,做了一些或大或小的改變:
- 為了顯示目前正在比對哪首歌曲,以做為比對進度的表示,此處改用一般的迴圈寫法,而非 LS 篇章範例中的 list comprehension。
- 此處的 DTW 在比對時,做了一個簡單的範圍限制,即比對範圍不會超過對角線左右各一半長度的查詢輸入這麼長。如果你對查詢輸入的節奏變化有進一步的假設,可以調整範圍限制的方式。
- 音高偏移值搜尋時,只搜尋兩輪。
與 LS 篇章相同,你可以把上述固定寫死的 input_pitch,換成自己從音檔抽取出來的音高向量,就可以是一個簡單的哼唱選歌系統。如果有需要評估系統的辨識效果,也可以把簡介中提到的 top-n accuracy 或 MRR 實作出來,加到你的程式碼當中。