ML-李宏毅-Lecture 11-Adaptation

李宏毅《机器学习 2022》课程学习笔记。

Preparation

【機器學習 2021】概述領域自適應 (Domain Adaptation)

png

Domain shift: 训练集和测试集的分布不同,解决方法:Domain adaptation

迁移学习:ML Lecture 19: Transfer Learning - YouTube

png

Domain Shift 分为三种情况:

  • Training Data 和 Testing Data 分布不同,我们将 Training Data 所在域称为 Source Domain,Testing Data 所在域称为 Target Domain(本节课只考虑这个情形)。
  • Training Data 和 Testing Data 标签分布不同。
  • Training Data 和 Testing Data 标签不同。
png

Domain Adaption

如果我们对 target domain 有一定认识:Little but labeled(数量少但正确标注)

  • ldea: training a model by source data then fine-tune the model by target data

    从 source data 里预训练,再在 target data 上 fine-tune

  • Challenge: only limited target data, so becareful about overfitting

    由于 target data 数量小,小心过拟合的问题

png

Domain Adaption 的 Basic Idea:设计一个 Feature Extractor 提取 Source Domain 和 Target Domain 的特征,使提取出的特征具有相同分布。

png

将网络分成两个部分:Feature Extactor 和 Label Predictor。

png

利用类似 GAN 的思路,设计一个 Domain Classifier 对 Feature Extractor 提取出的特征作二分类,目标 θd=minθdLd\theta^*_d=\min_{\theta_d}L_d。判断提取出的特征属于 Source Domain 还是 Target Domain。

Label Predictor θp=minθpminL\theta^*_p=\min_{\theta_p}\min L 依旧做类别预测。

Feature Extractor 既要骗过 Domain Classifier,又要提取出有价值的特征。θf=minθfLLd\theta^*_f=\min_{\theta_f}L-L_d

png

最早有关 Domain Adversarial Training 的研究:[1409.7495] Unsupervised Domain Adaptation by Backpropagation (arxiv.org)

png

假设我们当前样本的类别有两类,那么对于有标签的训练集我们可以明显地划分为两类,那么对于没有标签的测试,我们希望它的分布能够和训练集的分布越接近越好,如右图所示

png

那么在这个思路上进行拓展的话,对于我们刚才手写识别的例子,我们输入一张图片得到的是一个向量,其中含有属于每一个分类的概率,那我们希望的是这个测试集的样本离分界线越远越好,那就代表它得到的输出向量要更加集中于某一类的概率,不能够各个分类的可能性都差不多

png

对于 Knowledge of target domain:

png

关于 Domain Generalization 的研究:

Training 域大,Testing 域小:Domain Generalization with Adversarial Feature Learning | IEEE Conference Publication | IEEE Xplore

Training 域小,Testing 域大:[1409.7495] Unsupervised Domain Adaptation by Backpropagation (arxiv.org)

ML Lecture 19-Transfer Learning

png

关于迁移学习 Transfer Learning,分为两种情况:

  • Similar domain, different tasks 相似域,不同任务
  • Different domains, same task 不同域,相同任务
png

迁移学习的应用:Speech Recognition、Image Recognition、Text Analysis

png

类比于研究生与漫画家:

  • 研究生 → 漫画家
  • 导师 → 责编
  • 跑实验 → 画分镜
  • 投稿期刊 → 投稿 jump
png

对于 Transfer Learning,根据 Source Data (not directly related to the task) 和 Target Data 的情况,共有如下策略:

Target Data / Source Datalabelledunlabeled
labelledFine-tuning 微调
Multitask Learning 多任务学习
Self-taught learning 自学习
icml07-selftaughtlearning.pdf (stanford.edu)
unlabeledDomain-adversarial training 域对抗训练
Zero-shot learning 零次学习
Self-taught Clustering 自学聚类算法
icml.dvi (machinelearning.org)
png

Model Fine-tuning

  • Task description

    • Target data: (xt,yt)(x^t,y^t), 数量少
    • Source data: (xs,ys)(x^s,y^s), 数量多
  • Example: (supervised) speaker adaption

    示例:(监督)speaker 自适应

    • Target data: audio data and its transcriptions of specificuser

      目标数据:特定用户的音频数据及其转录

    • Source data: audio data and transcriptions from many speakers

      来源数据:来自许多 speaker 的音频数据和转录

  • ldea: training a model by source data, then fine-tune the model by target data

    ldea:根据源数据训练模型,然后根据目标数据微调模型

    • Challenge: only limited target data, so be careful about overfitting

      挑战:只有有限的目标数据,所以要小心过度拟合

png

Conservative Training:

  • 我们先通过 Source data 去 train 一个 model
  • 然后通过并不是直接把这个 model 当做 pre-trained 的 model,去用少量的 target data 去训练一个新的 model
  • 而是加入一些正则化项,来保证新的 model 和旧的 model 在 input 相同的情况下,得到的 output 尽可能的相近。
  • 为什么要这样做呢?其实原因很简单,如果我们在 train 新的 model 的过程中,并不去加这个正则项(也叫限制项),那么如果我们将 source data 送进新的 model,我们会发现整个 model 彻底坏掉了,他已经不具备原先 model 对于 source data 的表现能力了,这也是在 ML 中经常出现的一个非常重要的问题:灾难性遗忘问题
png

Layer Transfer

  • 首先还是和 Conservative Training 一样,通过 Source data 去 train 一个 model
  • 然后将该 model 中某些层的 parameters 直接复制进去新的 model 中
  • 对于新 model 中那些没有得到 parameters 的 layer,我们固定其他层的参数,通过 Source data 对那些没有被 transfer 到 parameter 的 layer 进行训练
  • 最后,如果 target data 的数据量比较充足,那么我们就可以在对整个网络进行 fine-tuning 一下,可以进一步提升模型的性能。
png
  • 对于 Speech 任务,通常 copy 最后几层

  • 对于 Image 任务,通常 copy 头几层

png

关于 Layer Transfer 的研究:

png png

Multitask Learning

  • 再来回顾下 fine-tuning 的过程,在做 fine-tuning 的时候,我们更加关注的是 model 在 target domain 上做的好不好,至于在 source domain 上做的到底怎么样,哪怕是将 source data 输入进这个新的 model 中,model 都坏掉了,也不要紧。只要这个新的 model 在 target domain 上做的很出色就够了。
  • 而 multitask learning 和 fine-tuning 的过程就不同了,multitask 是说,不仅要求我们的最终 model 在 target domain 上表现的相当出色,而且在 source domain 上同样也要表现的相当出色。
png

Multitask Learning 在多语言机器翻译上的研究:Multi-Task Learning for Multiple Language Translation (aclanthology.org)

png png

Progressive Neural Networks

  • 这篇网络中的方法还是比较新的,首先对于 task 1 来说,我们 train 出一个 model
  • 然后将 task 1 中的第 i 层的输出直接输入进 task 2 中的第 i 层的输出,作为第 i + 1 层的输入。
  • 然后后面如果有 k 个网络,就会应用到前 k-1 个网络的信息。
png

对于 Source data 和 Target data 处于不同域:mismatch 时

png

Domain-adversarial training: 设计一个 Domain classifier,迫使 feature extractor 提取出相同分布的特征。

png

Domain-adversarial training

  • 第一部分绿色的 feature extractor 其实要做的就是提取出 source data 和 target data 的 feature,然后使得最后在做 classification 的时候,通过这些提取出来的 feature,能够得到一个非常好的精确度。他还有尽可能让这些 mismatch 的 data 混在一起,以至于 domain classifier 不能够正确的判断他们是否混在一起。
  • 第二部分蓝色的 label predictor 做的是,能够尽可能大的输出 classification 的精度
  • 第三部分红色的 domain classifier 做的是,能够尽可能的将从 feature extractor 中提取出来的 feature 进行分开,将其各自归属到其所属的 domain 里。

当然,关于这个网络的 train,讲起来很容易,实际操作起来,肯定会像 GAN 一样,涉及到很多的 trick。

png png png

Zero-shot Learning: Source data 和 Target data 是不同的 tasks.

png png png

在 NLP 领域中比较常见,可以用 word embedding

png png png

有关 Zero-shot Learning 的实验:[1312.5650v3] Zero-Shot Learning by Convex Combination of Semantic Embeddings (arxiv.org)

png png png

Self-taught learning 自学习

  • Learning to extract better representation from the source data (unsupervised approach) 学习从源数据中提取更好的表示(无监督方法)
  • Extracting better representation for target data 为目标数据提取更好的表示

【機器學習 2022】惡搞自督導式學習模型 BERT 的三個故事

png

How versatile are self-supervised models?

自监督模型的通用性如何?

png
  • Cross-lingual 跨语言
  • Cross-discipline 跨学科
  • Pre-training with artificial data 使用人工数据进行预训练

Cross-lingual

png png png png png png png png png png png png png

Cross-discipline

png

用 BERT 做 DNA 分类。

png png png png png png png png png png png png

Pre-training with artificial data 使用人造数据进行预训练 BERT

png png

hw11_domain_adaptation

场景和为什么 Domain Adversarial Training

现在,我们有了已标记的源数据和未标记的目标数据,其中源数据可能与目标数据相关。现在,我们希望仅使用源数据训练模型,并在目标数据上对其进行测试。

如果我们这样做,可能会出现什么问题?在学习了异常检测之后,我们现在知道,如果我们使用从未出现在源数据中的异常数据来测试模型,我们训练的模型很可能会导致性能不佳,因为它不熟悉异常数据。

例如,我们有一个包含 Feature Extractor 和 Classifier 的模型:

当使用源数据训练模型时,特征提取器 将提取有意义的特征,因为它熟悉它的分布。从下图中可以看出,蓝点(即源数据的分布)已经聚集到不同的集群中。因此,Classifier 可以根据这些集群预测标签。

但是,在对目标数据进行测试时,Feature Extractor 将无法提取遵循源特征分布的有意义的特征,这会导致为源域学习的分类器无法应用于目标域。

Nerural 网络的域对抗训练 (DaNN)

基于上述问题,DaNN 方法在源(训练时)和目标(测试时)域之间构建映射,以便为源域学习的分类器在与域之间学习的映射组合时也可以应用于目标域。

在 DaNN 中,作者添加了一个域分类器,这是训练框架中一个深度判别训练的分类器,用于通过特征提取器提取的特征来区分来自不同领域的数据。随着训练的进行,该方法促进了区分源域和目标域的域分类器,以及可以提取对源域上的主要学习任务具有歧视性且对域之间的转换不加区分的特征提取器。

特征提取器的性能可能优于域分类器,因为它的输入是由特征提取器生成的,并且域分类和标签分类的任务并不冲突。

这种方法导致了域不变且位于相同特征分布上的特征的出现。

数据介绍

我们的任务包含源数据:真实照片和目标数据:手绘涂鸦。

我们将使用照片和标签训练模型,并尝试预测手绘涂鸦的标签是什么。

数据可以在这里下载。下面的代码用于数据下载和可视化。

注意:源数据和目标数据都是平衡数据,您可以使用此信息。

python
# Download dataset
!wget "https://github.com/redxouls/ml2020spring-hw11-dataset/releases/download/v1.0.0/real_or_drawing.zip" -O real_or_drawing.zip
 
# Download from mirrored dataset link
# !wget "https://github.com/redxouls/ml2020spring-hw11-dataset/releases/download/v1.0.1/real_or_drawing.zip" -O real_or_drawing.zip
# !wget "https://github.com/redxouls/ml2020spring-hw11-dataset/releases/download/v1.0.2/real_or_drawing.zip" -O real_or_drawing.zip
 
# Unzip the files
!unzip real_or_drawing.zip
Streaming output truncated to the last 5000 lines.
  inflating: real_or_drawing/train_data/0/106.bmp  
  inflating: real_or_drawing/train_data/0/107.bmp  
  inflating: real_or_drawing/train_data/0/108.bmp  
  inflating: real_or_drawing/train_data/0/109.bmp  
  inflating: real_or_drawing/train_data/0/11.bmp  
...
python
import matplotlib.pyplot as plt
 
def no_axis_show(img, title='', cmap=None):
    # imshow, and set the interpolation mode to be "nearest"。
    fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
    # do not show the axes in the images.
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)
 
titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i+1)
    fig = no_axis_show(plt.imread(f'real_or_drawing/train_data/{i}/{500*i}.bmp'), title=titles[i])
png
python
plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i+1)
    fig = no_axis_show(plt.imread(f'real_or_drawing/test_data/0/' + str(i).rjust(5, '0') + '.bmp'))
png

特殊领域知识

我们在涂鸦时,通常只画轮廓,因此我们可以对源数据进行边缘检测处理,使其与目标数据更加相似。

Canny 边缘检测

Canny Edge Detection 的实现如下。 此处不会详细描述该算法。如果您有兴趣,请参考 wiki 或这里

我们只需要两个参数即可使用 CV2 实现 Canny Edge Detection:low_thresholdhigh_threshold

python
cv2.Canny(image, low_threshold, high_threshold)

简单地说,当边值超过 high_threshold 时,我们将其确定为边。如果边值仅高于 low_threshold,我们将确定它是否为边。

让我们在源数据上实现它。

python
import cv2
import matplotlib.pyplot as plt
titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
 
original_img = plt.imread(f'real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')
 
gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')
 
gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')
 
canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')
 
canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')
 
canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
png

数据处理

数据适用于 。您可以使用 创建数据集。图像增广的详细信息请参考以下代码中的注释。torchvision.ImageFolder torchvision.ImageFolder

python
import numpy as np
import torch  # 张量操作
import torch.nn as nn  # 神经网络层
import torch.nn.functional as F
from torch.autograd import Function  # 自动微分
 
import torch.optim as optim  # 优化器
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
 
# source_transform 使用 Canny 算法进行边缘检测,然后应用了一些数据增强操作,如翻转和旋转。
source_transform = transforms.Compose([
    # Turn RGB to grayscale. (Bacause Canny do not support RGB images.)
    transforms.Grayscale(),  # 转换为灰度图
    # cv2 do not support skimage.Image, so we transform it to np.array, 
    # and then adopt cv2.Canny algorithm.
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),  # 应用 Canny 边缘检测
    # Transform np.array back to the skimage.Image.
    transforms.ToPILImage(),  # 转换为 PIL 图像对象
    # 50% Horizontal Flip. (For Augmentation)
    transforms.RandomHorizontalFlip(),  # 50% 概率水平翻转
    # Rotate +- 15 degrees. (For Augmentation), and filled with zero 
    # if there's empty pixel after rotation.
    transforms.RandomRotation(15, fill=(0,)),  # 旋转 ±15 度,空白像素填充为 0
    # Transform to tensor for model inputs.
    transforms.ToTensor(),  # 转换为 PyTorch 张量
])
 
# target_transform 则不包含边缘检测,而是对图像大小进行了调整(从 28x28 调整到 32x32),以便与训练数据一致。
target_transform = transforms.Compose([
    # Turn RGB to grayscale.
    transforms.Grayscale(),
    # Resize: size of source data is 32x32, thus we need to 
    #  enlarge the size of target data from 28x28 to 32x32。
    transforms.Resize((32, 32)),  # 调整大小到 32x32
    # 50% Horizontal Flip. (For Augmentation)
    transforms.RandomHorizontalFlip(),
    # Rotate +- 15 degrees. (For Augmentation), and filled with zero 
    # if there's empty pixel after rotation.
    transforms.RandomRotation(15, fill=(0,)),
    # Transform to tensor for model inputs.
    transforms.ToTensor(),
])
 
# ImageFolder 类从指定路径加载图像文件夹中的数据,并将 source_transform 和 target_transform 应用于图像数据。
source_dataset = ImageFolder('real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('real_or_drawing/test_data', transform=target_transform)
 
# 这里的 DataLoader 为训练和测试集创建批量加载器,将数据分批次加载到模型中。
# batch_size=32 表示每个批次包含 32 张图像,shuffle=True 用于随机打乱训练集的顺序,有助于减少过拟合。
# test_dataloader 用于加载测试数据,shuffle=False 表示测试集的顺序不会被打乱。
source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

模型

Feature Extractor:经典 VGG 式架构

标签预测器/域分类器:线性模型。

python
class FeatureExtractor(nn.Module):
 
    def __init__(self):
        super(FeatureExtractor, self).__init__()
 
        # FeatureExtractor 是一个卷积神经网络,用于从输入图像中提取高维特征。该模块包含 5 个卷积层块,每个块包括:
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),  # 一个卷积层(Conv2d)用于提取空间特征;
            nn.BatchNorm2d(64),  # 批归一化层(BatchNorm2d),用于稳定训练过程;
            nn.ReLU(),  # ReLU 激活函数,使网络具有非线性;
            nn.MaxPool2d(2),  # 最大池化层(MaxPool2d),用于下采样和减少特征图的尺寸。
 
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
 
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
 
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
 
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
    def forward(self, x):
        x = self.conv(x).squeeze()
        return x
 
class LabelPredictor(nn.Module):
 
    def __init__(self):
        super(LabelPredictor, self).__init__()
 
        self.layer = nn.Sequential(
            nn.Linear(512, 512),  # 第一和第二层是 512 维输入和 512 维输出的全连接层,分别通过 ReLU 激活;
            nn.ReLU(),
 
            nn.Linear(512, 512),
            nn.ReLU(),
 
            nn.Linear(512, 10),  # 最后一层是一个全连接层,输出大小为 10,用于分类。
        )
 
    def forward(self, h):
        c = self.layer(h)
        return c
 
class DomainClassifier(nn.Module):
 
    def __init__(self):
        super(DomainClassifier, self).__init__()
 
        self.layer = nn.Sequential(
            nn.Linear(512, 512),  # 每个隐藏层之后都有批归一化(BatchNorm1d)和 ReLU 激活。
            nn.BatchNorm1d(512),
            nn.ReLU(),
 
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
 
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
 
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
 
            nn.Linear(512, 1),  # 最后一层是一个单神经元输出层(Linear(512, 1)),用于预测域标签。
        )
 
    def forward(self, h):
        y = self.layer(h)
        return y

预处理

在这里,我们使用 Adam 作为我们的优化器。

python
# 初始化模型并将其移动到 GPU
feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()
 
# 定义损失函数
class_criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,用于分类任务。它会在 LabelPredictor 输出的类别概率分布和真实标签之间计算损失。
domain_criterion = nn.BCEWithLogitsLoss()  # 二元交叉熵损失(带 Logits,即自动处理 sigmoid 函数),用于二分类任务。这里用于计算 DomainClassifier 的域分类损失(判断样本属于哪个域),即通过二值标签来判断样本来自哪个数据分布。
 
# 每个优化器使用 Adam 优化算法(optim.Adam),适合处理深度学习任务中的大量参数和梯度不稳定的问题。
optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())

开始训练

DaNN 实施

在原始论文中,使用了 Gradient Reversal Layer。 Feature Extractor、Label Predictor 和 Domain Classifier 都同时进行训练。在这段代码中,我们首先训练 Domain Classifier,然后训练我们的 Feature Extractor(与 GAN 中的 Generator 和 Discriminator 训练过程的概念相同)。

提醒

  • 控制域对抗性损失的 Lambda 在原始论文中是自适应的。您可以参考原著。此处 lambda 设置为 0.1。
  • 我们没有用于目标数据的标签,您只能通过将结果上传到 kaggle 来评估您的模型:)
python
def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data 的 dataloader
        target_dataloader: target data 的 dataloader
        lamb: control the balance of domain adaptatoin and classification. 控制域适配与分类之间的平衡
    '''
 
    # D loss: Domain Classifier 的 loss
    # F loss: Feature Extrator & Label Predictor 的 loss
    # running_D_loss 用于累计域分类器的损失;
    # running_F_loss 用于累计特征提取器和标签分类器的损失;
    running_D_loss, running_F_loss = 0.0, 0.0
    # total_hit 和 total_num 用于计算源域的分类准确率。
    total_hit, total_num = 0.0, 0.0
 
    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):
    # 这个循环同时迭代源域和目标域的数据批次。每次迭代中,source_data 和 source_label 为源域的图像和标签,target_data 为目标域的图像。
        source_data = source_data.cuda()
        source_label = source_label.cuda()
        target_data = target_data.cuda()
        
        # Mixed the source data and target data, or it'll mislead the running params of batch_norm. (runnning mean/var of soucre and target data are different.)
        # 混合源数据和目标数据,否则会误导 batch_norm 的运行参数。(源和目标数据的运行均值/方差不同。
        mixed_data = torch.cat([source_data, target_data], dim=0)
        # 为混合数据生成域标签 domain_label,源域数据的标签设为 1,目标域数据的标签设为 0
        domain_label = torch.zeros([source_data.shape[0] + target_data.shape[0], 1]).cuda()
        # set domain label of source data to be 1.
        domain_label[:source_data.shape[0]] = 1
 
        # Step 1 : train domain classifier
        # 第一步:训练域分类器
        # 提取合并数据的特征 feature,feature.detach() 防止梯度反向传播到 feature_extractor。
        feature = feature_extractor(mixed_data)
        # We don't need to train feature extractor in step 1.
        # Thus we detach the feature neuron to avoid backpropgation.
        domain_logits = domain_classifier(feature.detach())
        # 使用域标签 domain_label 和预测的 domain_logits 计算域分类损失 loss。
        loss = domain_criterion(domain_logits, domain_label)
        running_D_loss += loss.item()
        loss.backward()
        optimizer_D.step()
 
        # Step 2 : train feature extractor and label classifier
        # 第二步:训练特征提取器和标签分类器
        # class_logits 是标签分类器的预测结果,用于源域的分类。
        class_logits = label_predictor(feature[:source_data.shape[0]])
        # domain_logits 用于域分类。
        domain_logits = domain_classifier(feature)
        # loss = cross entropy of classification - lamb * domain binary cross entropy.
        #  The reason why using subtraction is similar to generator loss in disciminator of GAN
        # 损失函数包括源域的分类损失 class_criterion(class_logits, source_label) 和域适配损失 domain_criterion(domain_logits, domain_label),两者相减以达到类似对抗训练的效果。
        loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)
        running_F_loss += loss.item()
        loss.backward()
        # 更新 feature_extractor 和 label_predictor 的参数。
        # 每次训练批次后清零梯度,并计算源域的分类准确率。
        optimizer_F.step()
        optimizer_C.step()
 
        optimizer_D.zero_grad()
        optimizer_F.zero_grad()
        optimizer_C.zero_grad()
 
        total_hit += torch.sum(torch.argmax(class_logits, dim=1) == source_label).item()
        total_num += source_data.shape[0]
        print(i, end='\r')
 
    return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num
 
# train 200 epochs
# 在 200 个 epoch 中循环,每次 epoch 后保存模型参数并输出训练的损失和准确率。
for epoch in range(200):
  
    train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)
 
    torch.save(feature_extractor.state_dict(), f'extractor_model.bin')
    torch.save(label_predictor.state_dict(), f'predictor_model.bin')
 
    print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))
epoch   0: train D loss: 0.6715, train F loss: 1.8669, acc 0.2928
epoch   1: train D loss: 0.6264, train F loss: 1.5707, acc 0.4166
epoch   2: train D loss: 0.5412, train F loss: 1.4445, acc 0.4794
epoch   3: train D loss: 0.5390, train F loss: 1.3692, acc 0.4992
epoch   4: train D loss: 0.5540, train F loss: 1.3243, acc 0.5140
epoch   5: train D loss: 0.5439, train F loss: 1.2459, acc 0.5480
epoch   6: train D loss: 0.5538, train F loss: 1.2264, acc 0.5482
epoch   7: train D loss: 0.5369, train F loss: 1.1544, acc 0.5800
epoch   8: train D loss: 0.5194, train F loss: 1.1397, acc 0.5838
epoch   9: train D loss: 0.5368, train F loss: 1.0921, acc 0.5950
epoch  10: train D loss: 0.5298, train F loss: 1.0657, acc 0.6070
epoch  11: train D loss: 0.5146, train F loss: 1.0287, acc 0.6186
epoch  12: train D loss: 0.5331, train F loss: 0.9963, acc 0.6338
epoch  13: train D loss: 0.5301, train F loss: 0.9842, acc 0.6412
epoch  14: train D loss: 0.5383, train F loss: 0.9447, acc 0.6488
epoch  15: train D loss: 0.5252, train F loss: 0.9263, acc 0.6560
epoch  16: train D loss: 0.5268, train F loss: 0.8820, acc 0.6748
epoch  17: train D loss: 0.5110, train F loss: 0.8503, acc 0.6848
epoch  18: train D loss: 0.4955, train F loss: 0.8061, acc 0.7070
epoch  19: train D loss: 0.5145, train F loss: 0.7806, acc 0.7096
epoch  20: train D loss: 0.4760, train F loss: 0.7562, acc 0.7194
epoch  21: train D loss: 0.4721, train F loss: 0.7087, acc 0.7350
epoch  22: train D loss: 0.4876, train F loss: 0.6906, acc 0.7458
epoch  23: train D loss: 0.4821, train F loss: 0.6563, acc 0.7580
epoch  24: train D loss: 0.4547, train F loss: 0.6063, acc 0.7780
epoch  25: train D loss: 0.4642, train F loss: 0.6035, acc 0.7788
epoch  26: train D loss: 0.4758, train F loss: 0.5768, acc 0.7826
epoch  27: train D loss: 0.4539, train F loss: 0.5465, acc 0.7956
epoch  28: train D loss: 0.4447, train F loss: 0.4864, acc 0.8144
epoch  29: train D loss: 0.4610, train F loss: 0.5191, acc 0.8064
epoch  30: train D loss: 0.4341, train F loss: 0.4504, acc 0.8372
epoch  31: train D loss: 0.4363, train F loss: 0.4291, acc 0.8380
epoch  32: train D loss: 0.4493, train F loss: 0.4082, acc 0.8508
epoch  33: train D loss: 0.4308, train F loss: 0.3958, acc 0.8506
epoch  34: train D loss: 0.4318, train F loss: 0.3513, acc 0.8658
epoch  35: train D loss: 0.4356, train F loss: 0.3378, acc 0.8708
epoch  36: train D loss: 0.3975, train F loss: 0.3467, acc 0.8684
epoch  37: train D loss: 0.4213, train F loss: 0.3099, acc 0.8794
epoch  38: train D loss: 0.3939, train F loss: 0.2874, acc 0.8900
epoch  39: train D loss: 0.4279, train F loss: 0.3113, acc 0.8826
epoch  40: train D loss: 0.4045, train F loss: 0.2726, acc 0.8916
epoch  41: train D loss: 0.4015, train F loss: 0.2682, acc 0.8974
epoch  42: train D loss: 0.3976, train F loss: 0.2458, acc 0.9062
epoch  43: train D loss: 0.4092, train F loss: 0.2502, acc 0.9026
epoch  44: train D loss: 0.3983, train F loss: 0.2196, acc 0.9120
epoch  45: train D loss: 0.3920, train F loss: 0.2242, acc 0.9158
epoch  46: train D loss: 0.4072, train F loss: 0.2050, acc 0.9168
epoch  47: train D loss: 0.3964, train F loss: 0.1852, acc 0.9272
epoch  48: train D loss: 0.4001, train F loss: 0.2130, acc 0.9172
epoch  49: train D loss: 0.3910, train F loss: 0.1914, acc 0.9248
epoch  50: train D loss: 0.3924, train F loss: 0.1978, acc 0.9228
epoch  51: train D loss: 0.3916, train F loss: 0.1758, acc 0.9262
epoch  52: train D loss: 0.3843, train F loss: 0.1651, acc 0.9314
epoch  53: train D loss: 0.3681, train F loss: 0.1555, acc 0.9352
epoch  54: train D loss: 0.3960, train F loss: 0.1557, acc 0.9320
epoch  55: train D loss: 0.3765, train F loss: 0.1543, acc 0.9356
epoch  56: train D loss: 0.3789, train F loss: 0.1420, acc 0.9406
epoch  57: train D loss: 0.3878, train F loss: 0.1423, acc 0.9418
epoch  58: train D loss: 0.3799, train F loss: 0.1477, acc 0.9396
epoch  59: train D loss: 0.3710, train F loss: 0.1316, acc 0.9450
epoch  60: train D loss: 0.3815, train F loss: 0.1294, acc 0.9456
epoch  61: train D loss: 0.3789, train F loss: 0.1300, acc 0.9466
epoch  62: train D loss: 0.3912, train F loss: 0.1273, acc 0.9472
epoch  63: train D loss: 0.4002, train F loss: 0.1206, acc 0.9492
epoch  64: train D loss: 0.3895, train F loss: 0.1332, acc 0.9432
epoch  65: train D loss: 0.3853, train F loss: 0.1152, acc 0.9518
epoch  66: train D loss: 0.3878, train F loss: 0.1420, acc 0.9424
epoch  67: train D loss: 0.3823, train F loss: 0.1158, acc 0.9478
epoch  68: train D loss: 0.3798, train F loss: 0.1131, acc 0.9514
epoch  69: train D loss: 0.3736, train F loss: 0.1022, acc 0.9508
epoch  70: train D loss: 0.3749, train F loss: 0.1215, acc 0.9498
epoch  71: train D loss: 0.3752, train F loss: 0.0972, acc 0.9572
epoch  72: train D loss: 0.3745, train F loss: 0.1077, acc 0.9558
epoch  73: train D loss: 0.3694, train F loss: 0.1041, acc 0.9562
epoch  74: train D loss: 0.3717, train F loss: 0.0976, acc 0.9534
epoch  75: train D loss: 0.3718, train F loss: 0.1092, acc 0.9552
epoch  76: train D loss: 0.3717, train F loss: 0.0744, acc 0.9648
epoch  77: train D loss: 0.3794, train F loss: 0.0861, acc 0.9590
epoch  78: train D loss: 0.3652, train F loss: 0.1077, acc 0.9586
epoch  79: train D loss: 0.3774, train F loss: 0.0617, acc 0.9674
epoch  80: train D loss: 0.3712, train F loss: 0.0974, acc 0.9582
epoch  81: train D loss: 0.3725, train F loss: 0.1011, acc 0.9546
epoch  82: train D loss: 0.3812, train F loss: 0.0931, acc 0.9596
epoch  83: train D loss: 0.3720, train F loss: 0.0634, acc 0.9668
epoch  84: train D loss: 0.3752, train F loss: 0.0738, acc 0.9666
epoch  85: train D loss: 0.3851, train F loss: 0.1143, acc 0.9536
epoch  86: train D loss: 0.3821, train F loss: 0.0813, acc 0.9618
epoch  87: train D loss: 0.3911, train F loss: 0.0735, acc 0.9648
epoch  88: train D loss: 0.3837, train F loss: 0.0832, acc 0.9604
epoch  89: train D loss: 0.3884, train F loss: 0.0757, acc 0.9624
epoch  90: train D loss: 0.3728, train F loss: 0.0761, acc 0.9640
epoch  91: train D loss: 0.3969, train F loss: 0.0718, acc 0.9632
epoch  92: train D loss: 0.3646, train F loss: 0.0668, acc 0.9632
epoch  93: train D loss: 0.3808, train F loss: 0.0756, acc 0.9662
epoch  94: train D loss: 0.3650, train F loss: 0.0818, acc 0.9628
epoch  95: train D loss: 0.3781, train F loss: 0.0610, acc 0.9682
epoch  96: train D loss: 0.3837, train F loss: 0.0587, acc 0.9684
epoch  97: train D loss: 0.3809, train F loss: 0.0591, acc 0.9680
epoch  98: train D loss: 0.3714, train F loss: 0.0626, acc 0.9670
epoch  99: train D loss: 0.3909, train F loss: 0.0753, acc 0.9632
epoch 100: train D loss: 0.3641, train F loss: 0.0607, acc 0.9696
epoch 101: train D loss: 0.3730, train F loss: 0.0853, acc 0.9612
epoch 102: train D loss: 0.3746, train F loss: 0.0511, acc 0.9706
epoch 103: train D loss: 0.3831, train F loss: 0.0493, acc 0.9700
epoch 104: train D loss: 0.3882, train F loss: 0.0751, acc 0.9622
epoch 105: train D loss: 0.3777, train F loss: 0.0508, acc 0.9726
epoch 106: train D loss: 0.3702, train F loss: 0.0462, acc 0.9732
epoch 107: train D loss: 0.3694, train F loss: 0.0542, acc 0.9734
epoch 108: train D loss: 0.3700, train F loss: 0.0520, acc 0.9712
epoch 109: train D loss: 0.3596, train F loss: 0.0439, acc 0.9738
epoch 110: train D loss: 0.3681, train F loss: 0.0544, acc 0.9688
epoch 111: train D loss: 0.3840, train F loss: 0.0592, acc 0.9674
epoch 112: train D loss: 0.3770, train F loss: 0.0624, acc 0.9682
epoch 113: train D loss: 0.3644, train F loss: 0.0531, acc 0.9720
epoch 114: train D loss: 0.3787, train F loss: 0.0566, acc 0.9712
epoch 115: train D loss: 0.3720, train F loss: 0.0429, acc 0.9746
epoch 116: train D loss: 0.3768, train F loss: 0.0489, acc 0.9732
epoch 117: train D loss: 0.3765, train F loss: 0.0412, acc 0.9748
epoch 118: train D loss: 0.3820, train F loss: 0.0450, acc 0.9724
epoch 119: train D loss: 0.3735, train F loss: 0.0386, acc 0.9768
epoch 120: train D loss: 0.3774, train F loss: 0.0436, acc 0.9736
epoch 121: train D loss: 0.3816, train F loss: 0.0491, acc 0.9708
epoch 122: train D loss: 0.3717, train F loss: 0.0587, acc 0.9686
epoch 123: train D loss: 0.3802, train F loss: 0.0538, acc 0.9714
epoch 124: train D loss: 0.3878, train F loss: 0.0432, acc 0.9762
epoch 125: train D loss: 0.3785, train F loss: 0.0453, acc 0.9746
epoch 126: train D loss: 0.3749, train F loss: 0.0423, acc 0.9774
epoch 127: train D loss: 0.3925, train F loss: 0.0328, acc 0.9766
epoch 128: train D loss: 0.3874, train F loss: 0.0546, acc 0.9682
epoch 129: train D loss: 0.3843, train F loss: 0.0482, acc 0.9712
epoch 130: train D loss: 0.3698, train F loss: 0.0500, acc 0.9736
epoch 131: train D loss: 0.3752, train F loss: 0.0368, acc 0.9762
epoch 132: train D loss: 0.3818, train F loss: 0.0303, acc 0.9784
epoch 133: train D loss: 0.3838, train F loss: 0.0490, acc 0.9722
epoch 134: train D loss: 0.3744, train F loss: 0.0332, acc 0.9792
epoch 135: train D loss: 0.3743, train F loss: 0.0311, acc 0.9786
epoch 136: train D loss: 0.3838, train F loss: 0.0419, acc 0.9728
epoch 137: train D loss: 0.3951, train F loss: 0.0352, acc 0.9760
epoch 138: train D loss: 0.3878, train F loss: 0.0439, acc 0.9732
epoch 139: train D loss: 0.3879, train F loss: 0.0419, acc 0.9736
epoch 140: train D loss: 0.3871, train F loss: 0.0355, acc 0.9758
epoch 141: train D loss: 0.3819, train F loss: 0.0392, acc 0.9746
epoch 142: train D loss: 0.3905, train F loss: 0.0578, acc 0.9722
epoch 143: train D loss: 0.3816, train F loss: 0.0350, acc 0.9758
epoch 144: train D loss: 0.3899, train F loss: 0.0175, acc 0.9822
epoch 145: train D loss: 0.4025, train F loss: 0.0469, acc 0.9748
epoch 146: train D loss: 0.3715, train F loss: 0.0345, acc 0.9748
epoch 147: train D loss: 0.3841, train F loss: 0.0375, acc 0.9744
epoch 148: train D loss: 0.3833, train F loss: 0.0310, acc 0.9802
epoch 149: train D loss: 0.3805, train F loss: 0.0263, acc 0.9764
epoch 150: train D loss: 0.3763, train F loss: 0.0352, acc 0.9760
epoch 151: train D loss: 0.3861, train F loss: 0.0330, acc 0.9778
epoch 152: train D loss: 0.3844, train F loss: 0.0340, acc 0.9764
epoch 153: train D loss: 0.3902, train F loss: 0.0311, acc 0.9764
epoch 154: train D loss: 0.3782, train F loss: 0.0387, acc 0.9760
epoch 155: train D loss: 0.3950, train F loss: 0.0180, acc 0.9808
epoch 156: train D loss: 0.4017, train F loss: 0.0205, acc 0.9808
epoch 157: train D loss: 0.3952, train F loss: 0.0484, acc 0.9734
epoch 158: train D loss: 0.3885, train F loss: 0.0346, acc 0.9776
epoch 159: train D loss: 0.3916, train F loss: 0.0202, acc 0.9812
epoch 160: train D loss: 0.3980, train F loss: 0.0306, acc 0.9774
epoch 161: train D loss: 0.3897, train F loss: 0.0306, acc 0.9800
epoch 162: train D loss: 0.3909, train F loss: 0.0164, acc 0.9816
epoch 163: train D loss: 0.3911, train F loss: 0.0273, acc 0.9806
epoch 164: train D loss: 0.3737, train F loss: 0.0133, acc 0.9830
epoch 165: train D loss: 0.4064, train F loss: 0.0520, acc 0.9706
epoch 166: train D loss: 0.3951, train F loss: 0.0242, acc 0.9810
epoch 167: train D loss: 0.3865, train F loss: 0.0287, acc 0.9810
epoch 168: train D loss: 0.3921, train F loss: 0.0141, acc 0.9814
epoch 169: train D loss: 0.3862, train F loss: 0.0130, acc 0.9836
epoch 170: train D loss: 0.4018, train F loss: 0.0273, acc 0.9764
epoch 171: train D loss: 0.4053, train F loss: 0.0254, acc 0.9774
epoch 172: train D loss: 0.4040, train F loss: 0.0169, acc 0.9810
epoch 173: train D loss: 0.3935, train F loss: 0.0463, acc 0.9734
epoch 174: train D loss: 0.3991, train F loss: 0.0199, acc 0.9804
epoch 175: train D loss: 0.3919, train F loss: 0.0275, acc 0.9800
epoch 176: train D loss: 0.4021, train F loss: 0.0315, acc 0.9780
epoch 177: train D loss: 0.3856, train F loss: 0.0289, acc 0.9796
epoch 178: train D loss: 0.3880, train F loss: 0.0171, acc 0.9812
epoch 179: train D loss: 0.3874, train F loss: 0.0200, acc 0.9824
epoch 180: train D loss: 0.3974, train F loss: 0.0243, acc 0.9826
epoch 181: train D loss: 0.3981, train F loss: 0.0191, acc 0.9812
epoch 182: train D loss: 0.4048, train F loss: 0.0159, acc 0.9822
epoch 183: train D loss: 0.3929, train F loss: 0.0212, acc 0.9796
epoch 184: train D loss: 0.3944, train F loss: 0.0130, acc 0.9822
epoch 185: train D loss: 0.3895, train F loss: 0.0402, acc 0.9752
epoch 186: train D loss: 0.3849, train F loss: 0.0136, acc 0.9826
epoch 187: train D loss: 0.3791, train F loss: 0.0222, acc 0.9814
epoch 188: train D loss: 0.3990, train F loss: 0.0190, acc 0.9812
epoch 189: train D loss: 0.3964, train F loss: 0.0317, acc 0.9794
epoch 190: train D loss: 0.3935, train F loss: 0.0385, acc 0.9788
epoch 191: train D loss: 0.3914, train F loss: 0.0218, acc 0.9812
epoch 192: train D loss: 0.3764, train F loss: 0.0212, acc 0.9822
epoch 193: train D loss: 0.3782, train F loss: 0.0193, acc 0.9836
epoch 194: train D loss: 0.3787, train F loss: 0.0111, acc 0.9832
epoch 195: train D loss: 0.4000, train F loss: 0.0239, acc 0.9808
epoch 196: train D loss: 0.3830, train F loss: 0.0201, acc 0.9836
epoch 197: train D loss: 0.4085, train F loss: 0.0230, acc 0.9802
epoch 198: train D loss: 0.3908, train F loss: 0.0197, acc 0.9802
epoch 199: train D loss: 0.3981, train F loss: 0.0170, acc 0.9820

绘制图表:

python
import re
import matplotlib.pyplot as plt
 
# 你的字符串数据
data = """
epoch   0: train D loss: 0.6715, train F loss: 1.8669, acc 0.2928
epoch   1: train D loss: 0.6264, train F loss: 1.5707, acc 0.4166
epoch   2: train D loss: 0.5412, train F loss: 1.4445, acc 0.4794
epoch   3: train D loss: 0.5390, train F loss: 1.3692, acc 0.4992
...
"""
 
# 使用正则表达式提取数据
pattern = r"epoch\s+(\d+): train D loss: ([\d.]+), train F loss: ([\d.]+), acc ([\d.]+)"
matches = re.findall(pattern, data)
 
# 转换数据格式
epochs = []
d_losses = []
f_losses = []
accuracies = []
 
for match in matches:
    epoch, d_loss, f_loss, acc = map(float, match)
    epochs.append(int(epoch))
    d_losses.append(d_loss)
    f_losses.append(f_loss)
    accuracies.append(acc)
 
# 绘制图表
plt.figure(figsize=(10, 6))
 
# 绘制 D loss 和 F loss
plt.subplot(2, 1, 1)
plt.plot(epochs, d_losses, label='D Loss', color='blue')
plt.plot(epochs, f_losses, label='F Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train Losses Over Epochs')
plt.legend()
 
# 绘制 Accuracy
plt.subplot(2, 1, 2)
plt.plot(epochs, accuracies, label='Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Over Epochs')
plt.legend()
 
plt.tight_layout()
plt.show()
png

推理

我们使用 pandas 生成 csv 文件。

顺便说一句,训练 200 个 epoch 的模型的性能可能不稳定。您可以训练更多 epoch 以获得更稳定的性能。

python
# 初始化和设置模型为评估模式
result = []
label_predictor.eval()
feature_extractor.eval()
 
# 循环遍历测试数据并进行预测
for i, (test_data, _) in enumerate(test_dataloader):
    test_data = test_data.cuda()
 
    class_logits = label_predictor(feature_extractor(test_data))
 
    x = torch.argmax(class_logits, dim=1).cpu().detach().numpy()
    result.append(x)
 
import pandas as pd
result = np.concatenate(result)
 
# Generate your submission
# 合并预测结果
df = pd.DataFrame({'id': np.arange(0,len(result)), 'label': result})
df.to_csv('DaNN_submission.csv',index=False)

可视化

我们使用 t-SNE 图来观察提取特征的分布。

python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import manifold