前言
该项目是自动文本摘要论文的一个副产物,就不做过多介绍了,Python代码并不复杂,主要是阅读OpenAI官方的API文档,并理解每个接口参数的含义,以及挑选一个合适的过墙工具,毕竟国内被狠狠制裁了,能顺利跑下来就算胜利。
使用前注意相关文件夹是否已经建立,否则可能报错!
具体代码
主函数逻辑 chatgpt_rewrite
主函数实现逻辑,API调用方法,使用gpt-3.5-turbo模型:
import os
import openai
import time
import sys
openai.api_key = ""
findings_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\second_for_chatgpt\\second_for_chatgpt_findings.txt" # findings 路径 (需要重写的内容)
impression_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\second_for_chatgpt\\second_for_chatgpt_impression.txt" # impression 路径 (需要和findings一一对应)
rewrite_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings.txt" # chatgpt重写好的文本保存路径
rewrite_check_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings_check.txt"
def read_txt(txt_path):
txtfile = open(txt_path)
text = []
for line in txtfile:
text.append(line.strip("\n"))
return text
def chatgpt_completion(model_new="gpt-3.5-turbo", prompt_new="hi", temperature_new=1, top_p_new=1, n_new=1,
max_tokens_new=100):
Chat_Completion = openai.ChatCompletion.create(
model=model_new,
messages=[
{"role": "user", "content": prompt_new}
],
temperature=temperature_new,
top_p=top_p_new,
n=n_new,
max_tokens=max_tokens_new,
presence_penalty=0,
frequency_penalty=0
)
return Chat_Completion
if __name__ == '__main__':
findings = read_txt(findings_path)
impression = read_txt(impression_path)
if os.path.isfile(rewrite_path): # 如果原先有生成的文本就先删除
os.remove(rewrite_path)
if os.path.isfile(rewrite_check_path): # 如果原先有生成的文本就先删除
os.remove(rewrite_check_path)
for i in range(len(findings)):
prompt = "give me 3 similar sentences like this:\n" + findings[i] # 即输入到messages的content里的内容
completion = chatgpt_completion(prompt_new=prompt, max_tokens_new=400)
rewrite_finding = ""
for line in completion.choices[0].message.content.splitlines():
if line != "":
sentence = line.replace(")", ".").split(". "[1], 1)[1]
rewrite_finding = rewrite_finding + sentence + "\n"
with open(rewrite_path, "a") as f:
f.write(rewrite_finding)
with open(rewrite_check_path, "a") as f:
f.write("-----------第" + str(i + 1) + "个-----------\n")
f.write("impression:" + impression[i]+" \n\n")
f.write(prompt+"\n\n")
f.write(rewrite_finding)
print("-----------第" + str(i + 1) + "个-----------\n")
print("impression:" + impression[i]+" \n\n")
print(prompt+"\n\n")
print(rewrite_finding)
time.sleep(30) # 国内测试10-15s的请求间隔以上可以稳定请求100次以上
数据集处理 txt2jsonl
使用的数据集主要为OPEN-I,以xml格式为主,需处理成jsonl格式,这部分代码可供参考:
import os
import json
rewrite_txt_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings.txt" # 原始数据集路径
impression_txt_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\second_for_ChatGPT\\second_for_chatgpt_impression.txt"
rewrite_jsonl_path = "G:\\A\\Desktop\\CoNT Work\\ChatGPT_Aug-main\\rewrite\\rewrite_findings_all.jsonl" # 处理后的数据集路径
def read_txt(txt_path):
txtfile = open(txt_path)
text = []
for line in txtfile:
text.append(line.strip("\n"))
return text
# 开始
if os.path.isfile(rewrite_jsonl_path):
os.remove(rewrite_jsonl_path)
rewrite_findings = read_txt(rewrite_txt_path)
impressions = read_txt(impression_txt_path)
count = 0
for rewrite_finding in rewrite_findings:
rewrite_jsonl = {'source': rewrite_finding, 'target': impressions[int(count/5)]}
print("第"+str(int(count/5))+"个")
print(rewrite_jsonl)
with open(rewrite_jsonl_path, "a") as f:
json.dump(rewrite_jsonl, f)
f.write('\n')
count = count + 1
绘制loss曲线
import matplotlib.pyplot as plt
import os
imgPath = 'loss/img'
nll_loss = []
cl_loss = []
all_loss = []
nll_loss_smp = []
cl_loss_smp = []
all_loss_smp = []
nll_loss_stp = []
cl_loss_stp = []
all_loss_stp = []
nll_loss_txt = open('loss/nll_loss.txt')
cl_loss_txt = open('loss/cl_loss.txt')
for line in nll_loss_txt:
nll_loss.append(line.strip("\n"))
for line in cl_loss_txt:
cl_loss.append(line.strip("\n"))
for item in range(0, len(nll_loss) - 1, 100):
nll_loss_smp.append(float(nll_loss[item]))
nll_loss_stp.append(item)
for item in range(0, len(cl_loss) - 1, 200):
cl_loss_smp.append(float(cl_loss[item]))
cl_loss_stp.append(item)
# print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
# epoch_losses.append(epoch_loss)
# plt.plot(nll_loss_stp,nll_loss_smp,'g-',label=u'Dense_Unet(block layer=5)')
# ‘’g‘’代表“green”,表示画出的曲线是绿色,“-”代表画的曲线是实线,可自行选择,label代表的是图例的名称,一般要在名称前面加一个u,如果名称是中文,会显示不出来,目前还不知道怎么解决。
plt.figure(1)
p1 = plt.plot(nll_loss_stp, nll_loss_smp, 'b-', label=u'nll_loss')
plt.legend()
plt.xlabel(u'iters')
plt.ylabel(u'loss')
plt.title('loss for nll in training')
plt.savefig(os.path.join(imgPath, "nll_loss.png"))
plt.figure(2)
p2 = plt.plot(cl_loss_stp, cl_loss_smp, 'g-', label=u'cl_loss')
plt.legend()
plt.xlabel(u'iters')
plt.ylabel(u'loss')
plt.title('loss for contrastive learning in training')
plt.savefig(os.path.join(imgPath, "cl_loss.png"))
效果如下图所示:
Comments NOTHING