圖像相似度估計 | 結(jié)合三元組損失的暹羅網(wǎng)絡(luò)
在機器學(xué)習(xí)領(lǐng)域,確定圖像之間的相似度在各種應(yīng)用中至關(guān)重要,從檢測重復(fù)項到面部識別。解決這個問題的一個強大方法是使用暹羅網(wǎng)絡(luò)結(jié)合三元組損失函數(shù)。在本文中,我們將探索如何構(gòu)建和訓(xùn)練暹羅網(wǎng)絡(luò)以估計圖像相似度,并通過一個來自GitHub倉庫的實際示例進行說明。
什么是暹羅網(wǎng)絡(luò)?
暹羅網(wǎng)絡(luò)是一種包含兩個或更多相同子網(wǎng)絡(luò)的神經(jīng)網(wǎng)絡(luò)架構(gòu)。這些子網(wǎng)絡(luò)旨在為每個輸入生成特征向量,然后可以比較這些向量以估計相似度。關(guān)鍵思想是使用相同的網(wǎng)絡(luò)處理每個輸入,確保輸出一致且可比較。
這種架構(gòu)特別適合于檢測重復(fù)項、尋找異常和面部識別等任務(wù)。在我們將要探索的實現(xiàn)中,網(wǎng)絡(luò)設(shè)置有三個相同的子網(wǎng)絡(luò)。每個網(wǎng)絡(luò)處理三張圖像中的一張:錨點圖像、正樣本(與錨點相似)和負(fù)樣本(與錨點無關(guān))。
什么是三元組損失?
為了有效地訓(xùn)練暹羅網(wǎng)絡(luò),我們使用三元組損失函數(shù)。這種損失函數(shù)鼓勵網(wǎng)絡(luò)在特征空間中拉近錨點和正樣本的距離,同時將錨點和負(fù)樣本推得更遠(yuǎn)。損失函數(shù)定義如下:
L(A, P, N) = max(‖f(A) — f(P)‖2 — ‖f(A) — f(N)‖2 + margin, 0)
這里,A是錨點圖像,P是正圖像,N是負(fù)圖像。函數(shù)f(x)代表網(wǎng)絡(luò)生成的embedding,而margin是一個小的正值,有助于確保網(wǎng)絡(luò)不會將所有嵌入壓縮到同一點。
設(shè)置暹羅網(wǎng)絡(luò)
在這次實現(xiàn)中,我們首先加載Totally Looks Like數(shù)據(jù)集,其中包含我們用來創(chuàng)建訓(xùn)練網(wǎng)絡(luò)的三元組圖像。
1. 數(shù)據(jù)準(zhǔn)備
使用TensorFlow的tf.data API處理數(shù)據(jù)集以創(chuàng)建圖像三元組。這涉及到設(shè)置一個數(shù)據(jù)管道,其中每個三元組由錨點、正樣本和負(fù)樣本圖像組成。通過調(diào)整圖像大小到目標(biāo)形狀并歸一化像素值來預(yù)處理圖像。
def preprocess_image(filename):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image
def preprocess_triplets(anchor, positive, negative):
return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)
以下是從數(shù)據(jù)集中生成的三元組示例,每行的前兩張圖像相似(錨點和正樣本),第三張不同(負(fù)樣本):
圖1:在數(shù)據(jù)準(zhǔn)備期間生成的三元組。每行的前兩張圖像相似(錨點和正樣本),第三張不同(負(fù)樣本)
2.構(gòu)建 embedding 生成器
我們暹羅網(wǎng)絡(luò)的核心是嵌入生成器,它使用在ImageNet上預(yù)訓(xùn)練的ResNet50模型構(gòu)建。通過凍結(jié)ResNet50中的大部分層的權(quán)重,并且僅微調(diào)最后幾層,我們可以利用遷移學(xué)習(xí)來減少訓(xùn)練時間并提高性能。
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
3.構(gòu)建暹羅網(wǎng)絡(luò)
暹羅網(wǎng)絡(luò)設(shè)置為一次輸入三張圖像(錨點、正樣本和負(fù)樣本)。自定義的DistanceLayer計算錨點-正樣本對和錨點-負(fù)樣本對之間的距離。然后訓(xùn)練模型以最小化相似圖像之間的距離,并最大化不相似圖像之間的距離。
class DistanceLayer(layers.Layer):
def call(self, anchor, positive, negative):
ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
return (ap_distance, an_distance)
anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))
distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)
siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)
4.訓(xùn)練和評估
模型使用自定義訓(xùn)練循環(huán)進行訓(xùn)練,其中計算三元組損失并用于更新網(wǎng)絡(luò)的權(quán)重。仔細(xì)監(jiān)控訓(xùn)練過程,并通過對學(xué)習(xí)到的嵌入進行檢查來評估模型的性能。
class SiameseModel(Model):
def __init__(self, siamese_network, margin=0.5):
super(SiameseModel, self).__init__()
self.siamese_network = siamese_network
self.margin = margin
self.loss_tracker = metrics.Mean(name="loss")
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self._compute_loss(data)
gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.siamese_network.trainable_weights)
)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def _compute_loss(self, data):
ap_distance, an_distance = self.siamese_network(data)
loss = ap_distance - an_distance
loss = tf.maximum(loss + self.margin, 0.0)
return loss
5.檢查結(jié)果
訓(xùn)練完成后,我們可以通過比較錨點-正樣本對和錨點-負(fù)樣本對的嵌入之間的余弦相似度來評估網(wǎng)絡(luò)學(xué)習(xí)分離相似和不相似圖像的能力。
cosine_similarity = metrics.CosineSimilarity()
positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())
negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())
以下是經(jīng)過訓(xùn)練的模型評估的三元組示例。網(wǎng)絡(luò)成功識別出圖像之間的相似性和差異:
圖2:經(jīng)過訓(xùn)練的暹羅網(wǎng)絡(luò)的輸出,其中每行的前兩張圖像被模型識別為相似,第三張為不同
結(jié)論
本文展示了使用三元組損失的暹羅網(wǎng)絡(luò)如何有效地估計圖像相似度。通過使用預(yù)訓(xùn)練的ResNet50模型并微調(diào)其層,我們可以創(chuàng)建一個可以應(yīng)用于需要相似度估計的各種任務(wù)。
完整代碼和解釋,參考:https://github.com/elcaiseri/Siamese-Network