動態時間扭曲(dynamic time warping, DTW)不同於線性伸縮法,主要是允許查詢輸入的節奏變化可能會不太均勻,因此不能用單純的伸縮來調整節奏,而是改用動態規劃(dynamic programming, DP)的方式來處理。本篇將假設各位對 DP 已經有基本的瞭解,例如已經修過演算法的課程,或者知道怎樣解決「最長共同序列(longest common subsequence, LCS)」的問題。以下內容,將主要介紹如何將 DP 運用在哼唱選歌當中。

以下是一個 DTW 的範例,此範例會將查詢輸入與資料庫中的某個音高向量之間進行 DTW 對位,並展示直接對位和 DTW 對位的結果,以及 DTW 運算中的過程:

import numpy as np
import matplotlib.pyplot as plt


def dtw(input_pitch, db_pitch):
	len_in = len(input_pitch)
	len_db = len(db_pitch)
	table_dp = np.zeros((len_db, len_in))
	table_dir = np.zeros((len_db, len_in))
	table_dist = np.zeros((len_db, len_in))
	for i, db in enumerate(db_pitch):
		for j, inp in enumerate(input_pitch):
			table_dist[i, j] = np.abs(db-inp)
	# First element
	table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0])
	# First row
	for j in range(1, len_in):
		table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j])
		table_dir[0, j] = 1
	# First column
	for i in range(1, len_db):
		table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0])
		table_dir[i, 0] = 2
	# Remaining
	for i in range(1, len_db):
		for j in range(1, len_in):
			vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([
				table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j],
			])
			table_dp[i, j] = np.min(vec)
			table_dir[i, j] = np.argmin(vec)
	# Backtrack
	start_i = np.argmin(table_dp[:, -1])
	start_j = len_in - 1
	ret_list = [[start_i, start_j]]
	while start_i >= 0 or start_j >= 0:
		if table_dir[start_i, start_j] == 0:
			start_i -= 1
			start_j -= 1
		elif table_dir[start_i, start_j] == 1:
			start_j -= 1
		elif table_dir[start_i, start_j] == 2:
			start_i -= 1
		ret_list = [[start_i, start_j]] + ret_list
	return ret_list[1:], table_dp, table_dist


inputPitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14])
dbPitch = 1.0 * np.array([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60])

dbPitch -= np.mean(dbPitch)
inputPitch -= np.mean(inputPitch)

alignment, table_dp, table_dist = dtw(inputPitch, dbPitch)

plt.subplot(2, 2, 1)
plt.plot(dbPitch)
plt.plot(inputPitch-10)
for i, (db, inp) in enumerate(zip(dbPitch, inputPitch)):
	if i % 5 == 0:
		plt.plot([i, i], [db, inp-10], 'r--')
plt.yticks([-10, 0], ['Input pitch', 'Target pitch'])
plt.title('Direct Alignment')

plt.subplot(2, 2, 3)
plt.plot(dbPitch)
for i, (idx_db, idx_qry) in enumerate(alignment):
	if i % 5 == 0:
		plt.plot([idx_db, idx_qry], [dbPitch[idx_db], inputPitch[idx_qry]-10], 'r--')
plt.plot(inputPitch-10)
plt.yticks([-10, 0], ['Input pitch', 'Target pitch'])
plt.title('DTW Alignment')

plt.subplot(1, 2, 2)
plt.imshow(table_dp)
plt.plot([a[1] for a in alignment], [a[0] for a in alignment], 'r-')
plt.xlabel('Input pitch Index')
plt.ylabel('Target pitch Index')
plt.colorbar()
plt.title('DTW Path')

plt.show()

在上述的程式碼中:

DTW 在移調上的做法跟 LS 比較不一樣。LS 因為是在移調之前,就很明確的知道查詢輸入跟資料庫中的某片段是相同位置一對一的對應,所以可以直接計算出一個平移方式;而在 DTW 當中,因為在比對之前不知道查詢輸入跟資料庫中的某片段是如何對應,因此通常會在音高平均為 0 附近的某個(例如加減兩個半音)範圍內進行搜尋,來找出能造成最小距離的平移方式。下面的範例,會對查詢輸入以二元搜尋的原則進行三次搜尋,且每次都縮小搜尋範圍,來求取其跟資料庫中某片段的最小 DTW 距離:

import numpy as np


def dtw(input_pitch, db_pitch):
	len_in = len(input_pitch)
	len_db = len(db_pitch)
	table_dp = np.zeros((len_db, len_in))
	table_dir = np.zeros((len_db, len_in))
	table_dist = np.zeros((len_db, len_in))
	for i, db in enumerate(db_pitch):
		for j, inp in enumerate(input_pitch):
			table_dist[i, j] = np.abs(db-inp)
	# First element
	table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0])
	# First row
	for j in range(1, len_in):
		table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j])
		table_dir[0, j] = 1
	# First column
	for i in range(1, len_db):
		table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0])
		table_dir[i, 0] = 2
	# Remaining
	for i in range(1, len_db):
		for j in range(1, len_in):
			vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([
				table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j],
			])
			table_dp[i, j] = np.min(vec)
			table_dir[i, j] = np.argmin(vec)
	# Return
	return np.min(table_dp[:, -1])


def dtw_for_one_song(input_pitch, db_pitch):
	db_pitch -= np.mean(db_pitch)
	input_pitch -= np.mean(input_pitch)

	shift_base = 0
	shift_offset = 2
	dist = dtw(input_pitch + shift_base, db_pitch)
	new_base = 0
	print(shift_base, 0, dist)
	for _ in range(3):
		dist_plus = dtw(input_pitch + shift_base + shift_offset, db_pitch)
		print(shift_base, shift_offset, dist_plus)
		if dist_plus < dist:
			dist = dist_plus
			new_base = -1
		dist_minus = dtw(input_pitch + shift_base - shift_offset, db_pitch)
		print(shift_base, -shift_offset, dist_minus)
		if dist_minus < dist:
			dist = dist_minus
			new_base = 1
		shift_base = new_base
		shift_offset /= 2
	return dist


input_pitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14])
db_pitch = 1.0 * np.array([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60])

dist_one_song = dtw_for_one_song(input_pitch, db_pitch)
print(dist_one_song)

在上述的程式碼中:

對於一個輸入查詢,要做出系統該回傳給使用者的結果候選清單的方法,則跟 LS 相同,把輸入查詢跟資料庫中的所有歌曲都比對過一遍,並依照它與資料庫中每首歌的最小距離進行遞增排序即可:

import os

import numpy as np
import pretty_midi


def read_database(path, fs=31.25):
	with open(os.path.join(path, 'songList.txt'), 'r') as fin:
		cnt = fin.read().splitlines()
	db_song_names = [' '.join(line.split('\t')[1:3]) for line in cnt]
	midi_files = sorted(os.listdir(path))
	db_pitches = []
	for mf in midi_files:
		if not mf.endswith('.mid'):
			continue
		midi = pretty_midi.PrettyMIDI(os.path.join(path, mf))
		piano_roll = midi.get_piano_roll(fs=fs) # Shape: (semitone, time_step)
		pitches = np.argmax(piano_roll, axis=0)
		db_pitches.append(pitches)
	return db_pitches, db_song_names


def dtw(input_pitch, db_pitch):
	len_in = len(input_pitch)
	len_db = len(db_pitch)
	table_dp = np.zeros((len_db, len_in)) + np.inf
	table_dir = np.zeros((len_db, len_in))
	table_dist = np.zeros((len_db, len_in))
	for i, db in enumerate(db_pitch):
		for j, inp in enumerate(input_pitch):
			table_dist[i, j] = np.abs(db-inp)
	# First element
	table_dp[0, 0] = np.abs(db_pitch[0] - input_pitch[0])
	# First row
	for j in range(1, len_in):
		table_dp[0, j] = table_dp[0, j-1] + np.abs(db_pitch[0] - input_pitch[j])
		table_dir[0, j] = 1
	# First column
	for i in range(1, len_db):
		table_dp[i, 0] = table_dp[i-1, 0] + np.abs(db_pitch[i] - input_pitch[0])
		table_dir[i, 0] = 2
	# Remaining
	for i in range(1, len_db):
		for j in range(max(0, i-len_in//2), min(len_in, i+len_in//2)):
			vec = np.abs(db_pitch[i] - input_pitch[j]) + np.array([
				table_dp[i-1, j-1], table_dp[i, j-1], table_dp[i-1, j],
			])
			table_dp[i, j] = np.min(vec)
			table_dir[i, j] = np.argmin(vec)
	# Return
	return np.min(table_dp[:, -1])


def dtw_for_one_song(input_pitch, db_pitch):
	db_pitch -= np.mean(db_pitch)

	shift_base = 0
	shift_offset = 2
	dist = dtw(input_pitch + shift_base, db_pitch)
	new_base = 0
	for _ in range(2):
		dist_plus = dtw(input_pitch + shift_base + shift_offset, db_pitch)
		if dist_plus < dist:
			dist = dist_plus
			new_base = -1
		dist_minus = dtw(input_pitch + shift_base - shift_offset, db_pitch)
		if dist_minus < dist:
			dist = dist_minus
			new_base = 1
		shift_base = new_base
		shift_offset /= 2
	return dist


def compare_to_whole_db(input_pitch, db_pitches):
	input_pitch -= np.mean(input_pitch)
	all_dists = []
	for s_idx, db_pitch in enumerate(db_pitches):
		print('Comparing song {}/{}'.format(s_idx, len(db_pitches)))
		all_dists.append(dtw_for_one_song(input_pitch, db_pitch.astype('float')))
	return np.argsort(all_dists)


input_pitch = np.array([48.04, 49.36, 50.13, 50.62, 51.11, 51.49, 51.49, 50.82, 50.15, 49.89, 50.67, 51.06, 49.9, 49.52, 49.52, 51.14, 51.14, 51.52, 51.27, 51.18, 52.02, 51.49, 51.33, 51.14, 51.14, 51.14, 51.14, 51.14, 50.55, 50.47, 50.64, 50.39, 48.35, 51.35, 51.49, 51.27, 50.89, 51.49, 51.49, 51.49, 55.78, 55.14, 54.92, 55.35, 55.35, 55.35, 55.35, 55.35, 54.01, 58.46, 59.63, 59.76, 59.76, 58.06, 57.99, 58.68, 58.68, 57.94, 55.06, 55.37, 55.79, 55.79, 54.73, 56.14, 55.35, 55.35, 55.04, 54.51, 53.29, 50.16, 50.9, 51.14, 51.07, 50.81, 50.54, 51.25, 51.49, 51.49, 51.49, 52.08, 52.56, 52.56, 52.2, 52.0, 53.23, 53.31, 52.98, 53.06, 53.0, 52.45, 52.77, 53.21, 52.14, 52.79, 53.18, 52.87, 52.79, 52.89, 52.56, 52.6, 53.31, 53.31, 53.13, 52.93, 52.77, 52.9, 52.93, 52.93, 52.93, 52.93, 53.31, 53.31, 52.19, 52.19, 53.88, 52.93, 52.93, 52.93, 51.14, 51.14, 48.94, 49.52, 49.84, 49.24, 49.48, 48.85, 48.33, 48.33, 48.33, 49.62, 52.58, 53.31, 53.03, 52.93, 53.19, 52.93, 52.69, 52.39, 52.2, 49.26, 50.01, 49.83, 49.08, 48.62, 49.16, 49.84, 49.84, 47.09, 46.68, 46.25, 45.66, 45.85, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 46.16, 50.84, 51.49, 51.14])
db_pitches, db_song_names = read_database('MIR-QBSH/midiFile')

sorted_idx = compare_to_whole_db(input_pitch, db_pitches)
for idx in sorted_idx[:10]:
	print(db_song_names[idx])

在上述的程式碼中,因為 DTW 的執行時間顯著的比較久,因此跟之前的範例相比,做了一些或大或小的改變:

與 LS 篇章相同,你可以把上述固定寫死的 input_pitch,換成自己從音檔抽取出來的音高向量,就可以是一個簡單的哼唱選歌系統。如果有需要評估系統的辨識效果,也可以把簡介中提到的 top-n accuracy 或 MRR 實作出來,加到你的程式碼當中。