Skip to content

pcl_loss | Prototypical Verbalizer for Prompt-based Few-shot Tuning | Prototypical Contrastive Learning #322

@Charon-HN

Description

@Charon-HN

在读了贵作Prototypical Verbalizer for Prompt-based Few-shot Tuning 是个很好的文章,但是我比较好奇代码是怎么实现的,然后我需要定位到程序的运行入口在哪里,我找了半天,感觉是在项目的experiments/cli.py这个[文件]中。而这个文件是需要读取配置文件的,那么原型学习的配置文件是在experiments/classification_proto_verbalizer.yaml中。

最终我在openprompt/prompts/prototypical_verbalizer.py文件中找到了这个pcl_loss的计算,结合论文来看,是由两个loss组成的,我根据自己的理解看了看代码,出现了一些困惑,下面是pcl_loss的计算

 def pcl_loss(self, v_ins):
        # instance-prototype loss

        sim_mat = torch.exp(self.sim(v_ins, self.proto))
        num = sim_mat.shape[1]
        loss = 0.
        for i in range(num):
            pos_score = torch.diag(sim_mat[:,i,:])
            neg_score = (sim_mat[:,i,:].sum(1) - pos_score)
            loss += - torch.log(pos_score / (pos_score + neg_score)).sum()
        loss = loss / (num * self.num_classes * self.num_classes)

        # instance-instance loss

        loss_ins = 0.
        for i in range(v_ins.shape[0]):
            sim_instance = torch.exp(self.sim(v_ins, v_ins[i]))
            pos_ins = sim_instance[i]
            neg_ins = (sim_instance.sum(0) - pos_ins).sum(0)
            loss_ins += - torch.log(pos_ins / (pos_ins + neg_ins)).sum()
        loss_ins = loss_ins / (num * self.num_classes * num * self.num_classes)
        loss = loss + loss_ins

        return loss

对于 instance-prototype loss的计算我有一些困惑,我在代码中加了一些说明:

 # instance-prototype loss
sim_mat = torch.exp(self.sim(v_ins, self.proto)) # 维度应该为(num_classes, batch_size, num_classes)
num = sim_mat.shape[1] # 获取batch_size
loss = 0.
for i in range(num):
    pos_score = torch.diag(sim_mat[:,i,:]) #
    neg_score = (sim_mat[:,i,:].sum(1) - pos_score)
    loss += - torch.log(pos_score / (pos_score + neg_score)).sum()
loss = loss / (num * self.num_classes * self.num_classes)

困惑

  1. 对于这个loss的实现的感觉这不就变成了loss += - torch.log(pos_score / (sim_mat[:,i,:].sum(1))).sum()了吗。
  2. pos_score = torch.diag(sim_mat[:,i,:])表示每个样本的在每个类别上的损失。但是从公式来看$\mathbf{v}_i^n$表示表征$\mathbf{v}_i$所属的类别为$n$,应该只需要考虑样本i所属的类别就可以了呀?
  3. 论文中的exp体现在了什么地方呢?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions