自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

圖像相似度估計 | 結(jié)合三元組損失的暹羅網(wǎng)絡(luò)

人工智能 機器學(xué)習(xí)
在本文中我們將探索如何構(gòu)建和訓(xùn)練暹羅網(wǎng)絡(luò)以估計圖像相似度,并通過一個來自GitHub倉庫的實際示例進行說明。

在機器學(xué)習(xí)領(lǐng)域,確定圖像之間的相似度在各種應(yīng)用中至關(guān)重要,從檢測重復(fù)項到面部識別。解決這個問題的一個強大方法是使用暹羅網(wǎng)絡(luò)結(jié)合三元組損失函數(shù)。在本文中,我們將探索如何構(gòu)建和訓(xùn)練暹羅網(wǎng)絡(luò)以估計圖像相似度,并通過一個來自GitHub倉庫的實際示例進行說明。

什么是暹羅網(wǎng)絡(luò)?

暹羅網(wǎng)絡(luò)是一種包含兩個或更多相同子網(wǎng)絡(luò)的神經(jīng)網(wǎng)絡(luò)架構(gòu)。這些子網(wǎng)絡(luò)旨在為每個輸入生成特征向量,然后可以比較這些向量以估計相似度。關(guān)鍵思想是使用相同的網(wǎng)絡(luò)處理每個輸入,確保輸出一致且可比較。

這種架構(gòu)特別適合于檢測重復(fù)項、尋找異常和面部識別等任務(wù)。在我們將要探索的實現(xiàn)中,網(wǎng)絡(luò)設(shè)置有三個相同的子網(wǎng)絡(luò)。每個網(wǎng)絡(luò)處理三張圖像中的一張:錨點圖像、正樣本(與錨點相似)和負(fù)樣本(與錨點無關(guān))。

什么是三元組損失?

為了有效地訓(xùn)練暹羅網(wǎng)絡(luò),我們使用三元組損失函數(shù)。這種損失函數(shù)鼓勵網(wǎng)絡(luò)在特征空間中拉近錨點和正樣本的距離,同時將錨點和負(fù)樣本推得更遠(yuǎn)。損失函數(shù)定義如下:

L(A, P, N) = max(‖f(A) — f(P)‖2 — ‖f(A) — f(N)‖2 + margin, 0)

這里,A是錨點圖像,P是正圖像,N是負(fù)圖像。函數(shù)f(x)代表網(wǎng)絡(luò)生成的embedding,而margin是一個小的正值,有助于確保網(wǎng)絡(luò)不會將所有嵌入壓縮到同一點。

設(shè)置暹羅網(wǎng)絡(luò)

在這次實現(xiàn)中,我們首先加載Totally Looks Like數(shù)據(jù)集,其中包含我們用來創(chuàng)建訓(xùn)練網(wǎng)絡(luò)的三元組圖像。

1. 數(shù)據(jù)準(zhǔn)備

使用TensorFlow的tf.data API處理數(shù)據(jù)集以創(chuàng)建圖像三元組。這涉及到設(shè)置一個數(shù)據(jù)管道,其中每個三元組由錨點、正樣本和負(fù)樣本圖像組成。通過調(diào)整圖像大小到目標(biāo)形狀并歸一化像素值來預(yù)處理圖像。

def preprocess_image(filename):
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, target_shape)
    return image

def preprocess_triplets(anchor, positive, negative):
    return (
        preprocess_image(anchor),
        preprocess_image(positive),
        preprocess_image(negative),
    )

以下是從數(shù)據(jù)集中生成的三元組示例,每行的前兩張圖像相似(錨點和正樣本),第三張不同(負(fù)樣本):

圖1:在數(shù)據(jù)準(zhǔn)備期間生成的三元組。每行的前兩張圖像相似(錨點和正樣本),第三張不同(負(fù)樣本)

2.構(gòu)建 embedding 生成器

我們暹羅網(wǎng)絡(luò)的核心是嵌入生成器,它使用在ImageNet上預(yù)訓(xùn)練的ResNet50模型構(gòu)建。通過凍結(jié)ResNet50中的大部分層的權(quán)重,并且僅微調(diào)最后幾層,我們可以利用遷移學(xué)習(xí)來減少訓(xùn)練時間并提高性能。

base_cnn = resnet.ResNet50(
    weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

3.構(gòu)建暹羅網(wǎng)絡(luò)

暹羅網(wǎng)絡(luò)設(shè)置為一次輸入三張圖像(錨點、正樣本和負(fù)樣本)。自定義的DistanceLayer計算錨點-正樣本對和錨點-負(fù)樣本對之間的距離。然后訓(xùn)練模型以最小化相似圖像之間的距離,并最大化不相似圖像之間的距離。

class DistanceLayer(layers.Layer):
    def call(self, anchor, positive, negative):
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
        return (ap_distance, an_distance)

anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))

distances = DistanceLayer()(
    embedding(resnet.preprocess_input(anchor_input)),
    embedding(resnet.preprocess_input(positive_input)),
    embedding(resnet.preprocess_input(negative_input)),
)

siamese_network = Model(
    inputs=[anchor_input, positive_input, negative_input], outputs=distances
)

4.訓(xùn)練和評估

模型使用自定義訓(xùn)練循環(huán)進行訓(xùn)練,其中計算三元組損失并用于更新網(wǎng)絡(luò)的權(quán)重。仔細(xì)監(jiān)控訓(xùn)練過程,并通過對學(xué)習(xí)到的嵌入進行檢查來評估模型的性能。

class SiameseModel(Model):
    def __init__(self, siamese_network, margin=0.5):
        super(SiameseModel, self).__init__()
        self.siamese_network = siamese_network
        self.margin = margin
        self.loss_tracker = metrics.Mean(name="loss")

    def train_step(self, data):
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)
        gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
        self.optimizer.apply_gradients(
            zip(gradients, self.siamese_network.trainable_weights)
        )
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def _compute_loss(self, data):
        ap_distance, an_distance = self.siamese_network(data)
        loss = ap_distance - an_distance
        loss = tf.maximum(loss + self.margin, 0.0)
        return loss

5.檢查結(jié)果

訓(xùn)練完成后,我們可以通過比較錨點-正樣本對和錨點-負(fù)樣本對的嵌入之間的余弦相似度來評估網(wǎng)絡(luò)學(xué)習(xí)分離相似和不相似圖像的能力。

cosine_similarity = metrics.CosineSimilarity()

positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())

negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())

以下是經(jīng)過訓(xùn)練的模型評估的三元組示例。網(wǎng)絡(luò)成功識別出圖像之間的相似性和差異:

圖2:經(jīng)過訓(xùn)練的暹羅網(wǎng)絡(luò)的輸出,其中每行的前兩張圖像被模型識別為相似,第三張為不同

結(jié)論

本文展示了使用三元組損失的暹羅網(wǎng)絡(luò)如何有效地估計圖像相似度。通過使用預(yù)訓(xùn)練的ResNet50模型并微調(diào)其層,我們可以創(chuàng)建一個可以應(yīng)用于需要相似度估計的各種任務(wù)。

完整代碼和解釋,參考:https://github.com/elcaiseri/Siamese-Network

責(zé)任編輯:趙寧寧 來源: 小白玩轉(zhuǎn)Python
相關(guān)推薦

2023-12-10 15:15:18

開源模型工具

2024-06-24 13:06:04

2024-04-02 10:05:28

Siamese神經(jīng)網(wǎng)絡(luò)人工智能

2025-04-28 09:28:14

2020-10-14 10:18:05

Python三元表達式代碼

2023-11-21 16:06:04

計算機視覺人工智能

2021-10-19 10:09:21

三角形個數(shù)數(shù)組

2023-11-30 08:30:12

Python三元表達

2023-09-06 09:40:29

2025-05-06 09:41:06

2015-10-15 10:27:12

文本相似度判定

2018-10-08 08:00:00

前端ReactJavaScript

2023-09-07 08:05:32

三元表達式自動

2009-08-19 17:26:28

C# 操作符

2025-03-11 11:40:00

三元運算符代碼JavaScript

2016-01-12 17:21:54

金稅工程曙光

2017-07-28 11:31:20

交通三要素高德平臺

2009-05-19 09:57:16

次貸危機運維管理摩卡軟件

2025-01-14 13:51:44

2013-01-10 15:21:09

三元食品辦公自動化IBM
點贊
收藏

51CTO技術(shù)棧公眾號