Skip to content

HuggingFace Datasets

简介

HuggingFace Datasets 提供统一的数据集加载、处理、格式转换接口,是 LLM 微调数据准备的标准工具。

bash
pip install datasets

加载数据集

python
from datasets import load_dataset, Dataset, DatasetDict

# 从 HuggingFace Hub 加载
dataset = load_dataset("financial_phrasebank", "sentences_allagree")

# 从本地文件加载
dataset = load_dataset("json", data_files="finance_qa.jsonl")
dataset = load_dataset("csv", data_files={"train": "train.csv", "test": "test.csv"})

# 从 Python 列表创建
data = [
    {"instruction": "解释P/E比率", "output": "P/E比率是..."},
    {"instruction": "什么是不良贷款", "output": "不良贷款是..."},
]
dataset = Dataset.from_list(data)

数据处理

python
# map 转换
def format_for_training(example):
    return {
        "text": f"### 问题:{example['instruction']}\n### 回答:{example['output']}"
    }

processed = dataset.map(format_for_training, remove_columns=["instruction", "output"])

# 过滤
filtered = dataset.filter(lambda x: len(x["output"]) > 50)

# 分割
split = dataset.train_test_split(test_size=0.1, seed=42)
train_data = split["train"]
test_data = split["test"]

# Tokenize
def tokenize(examples, tokenizer, max_length=2048):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=max_length,
        padding="max_length"
    )

tokenized = processed.map(
    lambda x: tokenize(x, tokenizer),
    batched=True,
    remove_columns=["text"]
)

保存与加载

python
# 保存到磁盘
dataset.save_to_disk("./finance_dataset")

# 加载
from datasets import load_from_disk
dataset = load_from_disk("./finance_dataset")

# 推送到 HuggingFace Hub
dataset.push_to_hub("your-username/finance-qa-dataset")

数据增强

python
from openai import OpenAI

client = OpenAI(api_key="sk-xxx", base_url="...")

def augment_qa_pair(question: str) -> dict:
    """使用 LLM 生成更多训练数据"""
    response = client.chat.completions.create(
        model="qwen-turbo",
        messages=[{
            "role": "user",
            "content": f"用专业金融语言回答:{question}"
        }]
    )
    return {"instruction": question, "output": response.choices[0].message.content}

# 批量生成
questions = ["什么是资本充足率?", "如何计算净息差?", "什么是系统性风险?"]
augmented_data = [augment_qa_pair(q) for q in questions]
new_dataset = Dataset.from_list(augmented_data)

本站内容由 褚成志 整理编写,仅供学习参考