動態時間扭曲(dynamic time warping, DTW)不同於線性伸縮法,主要是允許查詢輸入的節奏變化可能會不太均勻,因此不能用單純的伸縮來調整節奏,而是改用動態規劃(dynamic programming, DP)的方式來處理。本篇將假設各位對 DP 已經有基本的瞭解,例如已經修過演算法的課程,或者知道怎樣解決「最長共同序列(longest common subsequence, LCS)」的問題。以下內容,將主要介紹如何將 DP 運用在哼唱選歌當中。

以下是一個 DTW 的範例,此範例會將查詢輸入與資料庫中的某個音高向量之間進行 DTW 對位,並展示直接對位和 DTW 對位的結果,以及 DTW 運算中的過程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import matplotlib.pyplot as plt
 
 
def dtw(input_pitch, db_pitch):
    len_in = len(input_pitch)
    len_db = len(db_pitch)
    table_dp = np.zeros((len_db, len_in))
    table_dir = np.zeros((len_db, len_in))
    table_dist = np.zeros((len_db, len_in))
    for i, db in enumerate(db_pitch):
        for j, inp in enumerate(input_pitch):
            table_dist[i, j] = np.abs(db-inp)
    # First element
    table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0])
    # First row
    for j in range(1, len_in):
        table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j])
        table_dir[0, j] = 1
    # First column
    for i in range(1, len_db):
        table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0])
        table_dir[i, 0] = 2
    # Remaining
    for i in range(1, len_db):
        for j in range(1, len_in):
            vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([
                table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j],
            ])
            table_dp[i, j] = np.min(vec)
            table_dir[i, j] = np.argmin(vec)
    # Backtrack
    start_i = np.argmin(table_dp[:, -1])
    start_j = len_in - 1
    ret_list = [[start_i, start_j]]
    while start_i >= 0 or start_j >= 0:
        if table_dir[start_i, start_j] == 0:
            start_i -= 1
            start_j -= 1
        elif table_dir[start_i, start_j] == 1:
            start_j -= 1
        elif table_dir[start_i, start_j] == 2:
            start_i -= 1
        ret_list = [[start_i, start_j]] + ret_list
    return ret_list[1:], table_dp, table_dist
 
 
inputPitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14])
dbPitch = 1.0 * np.array([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60])
 
dbPitch -= np.mean(dbPitch)
inputPitch -= np.mean(inputPitch)
 
alignment, table_dp, table_dist = dtw(inputPitch, dbPitch)
 
plt.subplot(2, 2, 1)
plt.plot(dbPitch)
plt.plot(inputPitch-10)
for i, (db, inp) in enumerate(zip(dbPitch, inputPitch)):
    if i % 5 == 0:
        plt.plot([i, i], [db, inp-10], 'r--')
plt.yticks([-10, 0], ['Input pitch', 'Target pitch'])
plt.title('Direct Alignment')
 
plt.subplot(2, 2, 3)
plt.plot(dbPitch)
for i, (idx_db, idx_qry) in enumerate(alignment):
    if i % 5 == 0:
        plt.plot([idx_db, idx_qry], [dbPitch[idx_db], inputPitch[idx_qry]-10], 'r--')
plt.plot(inputPitch-10)
plt.yticks([-10, 0], ['Input pitch', 'Target pitch'])
plt.title('DTW Alignment')
 
plt.subplot(1, 2, 2)
plt.imshow(table_dp)
plt.plot([a[1] for a in alignment], [a[0] for a in alignment], 'r-')
plt.xlabel('Input pitch Index')
plt.ylabel('Target pitch Index')
plt.colorbar()
plt.title('DTW Path')
 
plt.show()

在上述的程式碼中:

DTW 在移調上的做法跟 LS 比較不一樣。LS 因為是在移調之前,就很明確的知道查詢輸入跟資料庫中的某片段是相同位置一對一的對應,所以可以直接計算出一個平移方式;而在 DTW 當中,因為在比對之前不知道查詢輸入跟資料庫中的某片段是如何對應,因此通常會在音高平均為 0 附近的某個(例如加減兩個半音)範圍內進行搜尋,來找出能造成最小距離的平移方式。下面的範例,會對查詢輸入以二元搜尋的原則進行三次搜尋,且每次都縮小搜尋範圍,來求取其跟資料庫中某片段的最小 DTW 距離:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
 
 
def dtw(input_pitch, db_pitch):
    len_in = len(input_pitch)
    len_db = len(db_pitch)
    table_dp = np.zeros((len_db, len_in))
    table_dir = np.zeros((len_db, len_in))
    table_dist = np.zeros((len_db, len_in))
    for i, db in enumerate(db_pitch):
        for j, inp in enumerate(input_pitch):
            table_dist[i, j] = np.abs(db-inp)
    # First element
    table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0])
    # First row
    for j in range(1, len_in):
        table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j])
        table_dir[0, j] = 1
    # First column
    for i in range(1, len_db):
        table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0])
        table_dir[i, 0] = 2
    # Remaining
    for i in range(1, len_db):
        for j in range(1, len_in):
            vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([
                table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j],
            ])
            table_dp[i, j] = np.min(vec)
            table_dir[i, j] = np.argmin(vec)
    # Return
    return np.min(table_dp[:, -1])
 
 
def dtw_for_one_song(input_pitch, db_pitch):
    db_pitch -= np.mean(db_pitch)
    input_pitch -= np.mean(input_pitch)
 
    shift_base = 0
    shift_offset = 2
    dist = dtw(input_pitch + shift_base, db_pitch)
    new_base = 0
    print(shift_base, 0, dist)
    for _ in range(3):
        dist_plus = dtw(input_pitch + shift_base + shift_offset, db_pitch)
        print(shift_base, shift_offset, dist_plus)
        if dist_plus < dist:
            dist = dist_plus
            new_base = -1
        dist_minus = dtw(input_pitch + shift_base - shift_offset, db_pitch)
        print(shift_base, -shift_offset, dist_minus)
        if dist_minus < dist:
            dist = dist_minus
            new_base = 1
        shift_base = new_base
        shift_offset /= 2
    return dist
 
 
input_pitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14])
db_pitch = 1.0 * np.array([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60])
 
dist_one_song = dtw_for_one_song(input_pitch, db_pitch)
print(dist_one_song)

在上述的程式碼中:

對於一個輸入查詢,要做出系統該回傳給使用者的結果候選清單的方法,則跟 LS 相同,把輸入查詢跟資料庫中的所有歌曲都比對過一遍,並依照它與資料庫中每首歌的最小距離進行遞增排序即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
 
import numpy as np
import pretty_midi
 
 
def read_database(path, fs=31.25):
    with open(os.path.join(path, 'songList.txt'), 'r') as fin:
        cnt = fin.read().splitlines()
    db_song_names = [' '.join(line.split('\t')[1:3]) for line in cnt]
    midi_files = sorted(os.listdir(path))
    db_pitches = []
    for mf in midi_files:
        if not mf.endswith('.mid'):
            continue
        midi = pretty_midi.PrettyMIDI(os.path.join(path, mf))
        piano_roll = midi.get_piano_roll(fs=fs) # Shape: (semitone, time_step)
        pitches = np.argmax(piano_roll, axis=0)
        db_pitches.append(pitches)
    return db_pitches, db_song_names
 
 
def dtw(input_pitch, db_pitch):
    len_in = len(input_pitch)
    len_db = len(db_pitch)
    table_dp = np.zeros((len_db, len_in)) + np.inf
    table_dir = np.zeros((len_db, len_in))
    table_dist = np.zeros((len_db, len_in))
    for i, db in enumerate(db_pitch):
        for j, inp in enumerate(input_pitch):
            table_dist[i, j] = np.abs(db-inp)
    # First element
    table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0])
    # First row
    for j in range(1, len_in):
        table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j])
        table_dir[0, j] = 1
    # First column
    for i in range(1, len_db):
        table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0])
        table_dir[i, 0] = 2
    # Remaining
    for i in range(1, len_db):
        for j in range(max(0, i-len_in//2), min(len_in, i+len_in//2)):
            vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([
                table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j],
            ])
            table_dp[i, j] = np.min(vec)
            table_dir[i, j] = np.argmin(vec)
    # Return
    return np.min(table_dp[:, -1])
 
 
def dtw_for_one_song(input_pitch, db_pitch):
    db_pitch -= np.mean(db_pitch)
 
    shift_base = 0
    shift_offset = 2
    dist = dtw(input_pitch + shift_base, db_pitch)
    new_base = 0
    for _ in range(2):
        dist_plus = dtw(input_pitch + shift_base + shift_offset, db_pitch)
        if dist_plus < dist:
            dist = dist_plus
            new_base = -1
        dist_minus = dtw(input_pitch + shift_base - shift_offset, db_pitch)
        if dist_minus < dist:
            dist = dist_minus
            new_base = 1
        shift_base = new_base
        shift_offset /= 2
    return dist
 
 
def compare_to_whole_db(input_pitch, db_pitches):
    input_pitch -= np.mean(input_pitch)
    all_dists = []
    for s_idx, db_pitch in enumerate(db_pitches):
        print('Comparing song {}/{}'.format(s_idx, len(db_pitches)))
        all_dists.append(dtw_for_one_song(input_pitch, db_pitch.astype('float')))
    return np.argsort(all_dists)
 
 
input_pitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14])
db_pitches, db_song_names = read_database('MIR-QBSH/midiFile')
 
sorted_idx = compare_to_whole_db(input_pitch, db_pitches)
for idx in sorted_idx[:10]:
    print(db_song_names[idx])

在上述的程式碼中,因為 DTW 的執行時間顯著的比較久,因此跟之前的範例相比,做了一些或大或小的改變:

與 LS 篇章相同,你可以把上述固定寫死的 input_pitch,換成自己從音檔抽取出來的音高向量,就可以是一個簡單的哼唱選歌系統。如果有需要評估系統的辨識效果,也可以把簡介中提到的 top-n accuracy 或 MRR 實作出來,加到你的程式碼當中。