前言

2024 -> 项目代码已经迁移到Gitee上方便各位取用,
测试版本
正式版本

该项目是自动文本摘要论文的一个副产物,就不做过多介绍了,Python代码并不复杂,主要是阅读OpenAI官方的API文档,并理解每个接口参数的含义,以及挑选一个合适的过墙工具,毕竟国内被狠狠制裁了,能顺利跑下来就算胜利。

使用前注意相关文件夹是否已经建立,否则可能报错!

文件夹结构

具体代码

主函数逻辑 chatgpt_rewrite

主函数实现逻辑,API调用方法,使用gpt-3.5-turbo模型:

import os
import openai
import time
import sys

openai.api_key = ""

findings_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\second_for_chatgpt\\second_for_chatgpt_findings.txt"  # findings 路径 (需要重写的内容)
impression_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\second_for_chatgpt\\second_for_chatgpt_impression.txt"  # impression 路径 (需要和findings一一对应)
rewrite_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings.txt"  # chatgpt重写好的文本保存路径
rewrite_check_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings_check.txt"

def read_txt(txt_path):
    txtfile = open(txt_path)
    text = []
    for line in txtfile:
        text.append(line.strip("\n"))
    return text

def chatgpt_completion(model_new="gpt-3.5-turbo", prompt_new="hi", temperature_new=1, top_p_new=1, n_new=1,
                       max_tokens_new=100):
    Chat_Completion = openai.ChatCompletion.create(
        model=model_new,
        messages=[
            {"role": "user", "content": prompt_new}
        ],
        temperature=temperature_new,
        top_p=top_p_new,
        n=n_new,
        max_tokens=max_tokens_new,
        presence_penalty=0,
        frequency_penalty=0
    )
    return Chat_Completion

if __name__ == '__main__':
    findings = read_txt(findings_path)
    impression = read_txt(impression_path)

    if os.path.isfile(rewrite_path):  # 如果原先有生成的文本就先删除
        os.remove(rewrite_path)
    if os.path.isfile(rewrite_check_path):  # 如果原先有生成的文本就先删除
        os.remove(rewrite_check_path)
    for i in range(len(findings)):
        prompt = "give me 3 similar sentences like this:\n" + findings[i]  # 即输入到messages的content里的内容
        completion = chatgpt_completion(prompt_new=prompt, max_tokens_new=400)
        rewrite_finding = ""

        for line in completion.choices[0].message.content.splitlines():
            if line != "":
                sentence = line.replace(")", ".").split(". "[1], 1)[1]
                rewrite_finding = rewrite_finding + sentence + "\n"

        with open(rewrite_path, "a") as f:
            f.write(rewrite_finding)
        with open(rewrite_check_path, "a") as f:
            f.write("-----------第" + str(i + 1) + "个-----------\n")
            f.write("impression:" + impression[i]+" \n\n")
            f.write(prompt+"\n\n")
            f.write(rewrite_finding)

        print("-----------第" + str(i + 1) + "个-----------\n")
        print("impression:" + impression[i]+" \n\n")
        print(prompt+"\n\n")
        print(rewrite_finding)

        time.sleep(30)  # 国内测试10-15s的请求间隔以上可以稳定请求100次以上

数据集处理 txt2jsonl

使用的数据集主要为OPEN-I,以xml格式为主,需处理成jsonl格式,这部分代码可供参考:

import os
import json

rewrite_txt_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings.txt"  # 原始数据集路径
impression_txt_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\second_for_ChatGPT\\second_for_chatgpt_impression.txt"
rewrite_jsonl_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings_all.jsonl"  # 处理后的数据集路径

def read_txt(txt_path):
    txtfile = open(txt_path)
    text = []
    for line in txtfile:
        text.append(line.strip("\n"))
    return text

# 开始
if os.path.isfile(rewrite_jsonl_path):
    os.remove(rewrite_jsonl_path)

rewrite_findings = read_txt(rewrite_txt_path)
impressions = read_txt(impression_txt_path)

count = 0
for rewrite_finding in rewrite_findings:
    rewrite_jsonl = {'source': rewrite_finding, 'target': impressions[int(count/5)]}
    print("第"+str(int(count/5))+"个")
    print(rewrite_jsonl)
    with open(rewrite_jsonl_path, "a") as f:
        json.dump(rewrite_jsonl, f)
        f.write('\n')
    count = count + 1

绘制loss曲线

import matplotlib.pyplot as plt
import os

imgPath = 'loss/img'
nll_loss = []
cl_loss = []
all_loss = []
nll_loss_smp = []
cl_loss_smp = []
all_loss_smp = []
nll_loss_stp = []
cl_loss_stp = []
all_loss_stp = []
nll_loss_txt = open('loss/nll_loss.txt')
cl_loss_txt = open('loss/cl_loss.txt')

for line in nll_loss_txt:
    nll_loss.append(line.strip("\n"))
for line in cl_loss_txt:
    cl_loss.append(line.strip("\n"))

for item in range(0, len(nll_loss) - 1, 100):
    nll_loss_smp.append(float(nll_loss[item]))
    nll_loss_stp.append(item)
for item in range(0, len(cl_loss) - 1, 200):
    cl_loss_smp.append(float(cl_loss[item]))
    cl_loss_stp.append(item)
# print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
# epoch_losses.append(epoch_loss)

# plt.plot(nll_loss_stp,nll_loss_smp,'g-',label=u'Dense_Unet(block layer=5)')
# ‘’g‘’代表“green”,表示画出的曲线是绿色,“-”代表画的曲线是实线,可自行选择,label代表的是图例的名称,一般要在名称前面加一个u,如果名称是中文,会显示不出来,目前还不知道怎么解决。
plt.figure(1)
p1 = plt.plot(nll_loss_stp, nll_loss_smp, 'b-', label=u'nll_loss')
plt.legend()
plt.xlabel(u'iters')
plt.ylabel(u'loss')
plt.title('loss for nll in training')
plt.savefig(os.path.join(imgPath, "nll_loss.png"))

plt.figure(2)
p2 = plt.plot(cl_loss_stp, cl_loss_smp, 'g-', label=u'cl_loss')
plt.legend()
plt.xlabel(u'iters')
plt.ylabel(u'loss')
plt.title('loss for contrastive learning in training')
plt.savefig(os.path.join(imgPath, "cl_loss.png"))

效果如下图所示:

nll-loss
cl-loss