使用 LLaMA-Factory 实现对大模型函数调用功能

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学。

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

合集:

《大模型面试宝典》(2024版) 正式发布!


大模型函数调用(function calling)功能,能让大模型调用成千上万的工具API,赋予大模型更多的外部知识,使得大模型能力变得更加强大。

本文将会介绍如何使用 LLaMa-Factory 这个大模型微调框架,对 Qwen1.5-4B 模型进行微调,实现 function calling 功能,使得大模型具有工具调用能力。

如何微调大模型的function calling能力?

首先,训练数据集是关键,我们在这里使用Glaive AI生成的工具调用数据集,也可以在HuggingFace找到function calling相关的数据集,该数据集包含用户(human)、模型(gpt)、工具调用(function_call)和工具调用结果(observation)四种不同角色,以及工具列表(tools)字段。

同时,我们还选择了alpaca_gpt4_en、alpaca_gpt4_zh 和 oaast_sft_zh这三种数据集,以增强大模型的通用对话能力。

其中一条样本为:

{
  "conversations": [
    {
      "from": "human",
      "value": "I saw a dress that I liked. It was originally priced at $200 but it's on sale for 20% off. Can you tell me how much it will cost after the discount?"
    },
    {
      "from": "function_call",
      "value": "{\"name\": \"calculate_discount\", \"arguments\": {\"original_price\": 200, \"discount_percentage\": 20}}"
    },
    {
      "from": "observation",
      "value": "{\"discounted_price\": 160}"
    },
    {
      "from": "gpt",
      "value": "The dress will cost you $160 after the 20% discount."
    }
  ],
  "tools": "[{\"name\": \"calculate_discount\", \"description\": \"Calculate the discounted price\", \"parameters\": {\"type\": \"object\", \"properties\": {\"original_price\": {\"type\": \"number\", \"description\": \"The original price of the item\"}, \"discount_percentage\": {\"type\": \"number\", \"description\": \"The percentage of discount\"}}, \"required\": [\"original_price\", \"discount_percentage\"]}}]"
}

其加工成对话样本后的格式如下:

<|im_start|>system
You are a helpful assistant.You have access to the following tools:
> Tool Name: calculate_discount
Tool Description: Calculate the discounted price
Tool Args:
  - original_price (number, required): The original price of the item
  - discount_percentage (number, required): The percentage of discount

Use the following format if using a tool:
Action: tool name (one of [calculate_discount]).```
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```).```
<|im_end|>
<|im_start|>user
I saw a dress that I liked. It was originally priced at $200 but it's on sale for 20% off. Can you tell me how much it will cost after the discount?<|im_end|>
<|im_start|>assistant
Action: calculate_discount
Action Input: {"original_price": 200, "discount_percentage": 20}<|im_end|>
<|im_start|>user
{"discounted_price": 160}<|im_end|>
<|im_start|>assistant
The dress will cost you $160 after the 20% discount.<|im_end|>

微调的基座模型选择Qwen1.5-4B,每个数据集最大样本量为50000,训练2轮,训练命令如下:

python src/train_bash.py     \
--stage sft     \
--do_train True     \
--model_name_or_path /models/Qwen1.5-4B     \
--finetuning_type lora     \
--template qwen     \
--dataset_dir data     \
--dataset glaive_toolcall,alpaca_gpt4_en,alpaca_gpt4_zh,oaast_sft_zh     \
--cutoff_len 1024     \
--learning_rate 5e-05     \
--num_train_epochs 2.0     \
--max_samples 50000     \
--per_device_train_batch_size 2     \
--gradient_accumulation_steps 4     \
--lr_scheduler_type cosine     \
--max_grad_norm 1.0     \
--logging_steps 100     \
--save_steps 1000     \
--warmup_steps 0     \
--optim adamw_torch     \
--report_to none     \
--output_dir saves/Qwen1.5-4B/lora/train_2024-04-20-15-30-29     \
--fp16 True     \
--lora_rank 8     \
--lora_alpha 16     \
--lora_dropout 0.1     \
--lora_target all     \
--plot_loss True

在笔者的GPU上大约训练了14个小时(同时还在运行其它任务)。训练完后,将lora部分的参数与原始模型进行合并,形成新的训练后的模型(Qwen1.5-4B-agent),此时,新模型已经具有了function calling的调用能力。

测试微调后的大模型的function calling

我们来测试下训练后的大模型的function calling的能力。模型服务的部署命令如下:

python -m llmtuner.api.app --model_name_or_path /models/Qwen1.5-4B-agent --template qwen

笔者找了三个API工具来进行测试,它们的作用分别为生活垃圾分类,动漫信息查询,歌曲信息查询,API具体的入参、出参可以参考网址为:https://apifox.com/apidoc/shared-faff130e-7aa3-42da-9f93-574b16c8acda。

测试脚本如下:

# -*- coding: utf-8 -*-
# @place: Pudong, Shanghai
# @file: llama_factory_agent_test.py
import os
import json
from openai import OpenAI
from typing import Sequence
import requests

os.environ["OPENAI_BASE_URL"] = "http://localhost:50079/v1"
os.environ["OPENAI_API_KEY"] = "0"


def get_rubbish_category(keyword):
    url = f"https://api.timelessq.com/garbage?keyword={keyword}"
    response = requests.request("GET", url)
    output_str_list = []
    for item in response.json()['data']:
        output_str_list.append(f"{item['name']}: {item['categroy']}")
    return '\n'.join(output_str_list)


def get_song_information(keyword):
    url = f"https://api.timelessq.com/music/tencent/search?keyword={keyword}"
    response = requests.request("GET", url)
    song_infor = response.json()['data']['list'][0]
    singer = '' if not song_infor['singer'] else song_infor['singer'][0]['name']
    return f"歌曲: {keyword}\n歌手: {singer}\n时长: {song_infor['interval']}秒\n专辑名称: {song_infor['albumname']}"


def get_cartoon_information(title):
    url = f"https://api.timelessq.com/bangumi?title={title}"
    response = requests.request("GET", url)
    data = response.json()['data'][0]
    return f"标题: {data['title']}\n类型:{data['type']}\n语言:{data['lang']}\n出品方:{data['officialSite']}\n上映时间:{data['begin']}\n完结事件:{data['end']}"


tool_map = {"get_rubbish_category": get_rubbish_category,
            "get_song_information": get_song_information,
            "get_cartoon_information": get_cartoon_information}


if __name__ == "__main__":
    client = OpenAI()
    tools = [
        {
            "type": "function",
            "function": {
                            "name": "get_rubbish_category",
                            "description": "适用于生活垃圾分类时,判断物品属于哪种类型的垃圾?",
                            "parameters": {
                                "type": "object",
                                "properties": {
                                    "keyword": {
                                        "type": "string",
                                        "description": "物品名称,用于垃圾分类",
                                    },
                                },
                                "required": ["keyword"],
                            }
                        }
        },
        {
            "type": "function",
            "function": {
                            "name": "get_cartoon_information",
                            "description": "根据用户提供的动漫标题,查询该动漫的相关信息。",
                            "parameters": {
                                "type": "object",
                                "properties": {
                                    "title": {
                                        "type": "string",
                                        "description": "动漫",
                                    },
                                },
                                "required": ["title"],
                            }
                        }
        },
        {
            "type": "function",
            "function": {
                            "name": "get_song_information",
                            "description": "根据用户提供的歌曲名称,查询歌曲相关信息,包括歌手、时长、专辑名称等。",
                            "parameters": {
                                "type": "object",
                                "properties": {
                                    "keyword": {
                                        "type": "string",
                                        "description": "歌曲名称",
                                    },
                                },
                                "required": ["keyword"],
                            }
                        }
        }
    ]

    messages = []
    messages.append({"role": "system", "content": "你是一个有用的小助手,请调用下面的工具来回答用户的问题,参考工具输出进行回答。"})
    # messages.append({"role": "user", "content": "鸡蛋壳属于哪种类型的垃圾?"})
    # messages.append({"role": "user", "content": "爱在西元前是谁唱的,来自哪张专辑?"})
    messages.append({"role": "user", "content": "动漫《棋魂》是哪个国家的,什么时候上映的?"})
    result = client.chat.completions.create(messages=messages, model="Qwen1.5-4B-agent", tools=tools)
    tool_call = result.choices[0].message.tool_calls[0].function
    print(tool_call)
    name, arguments = tool_call.name, json.loads(tool_call.arguments)
    messages.append({"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)})
    tool_result = tool_map[name](**arguments)
    messages.append({"role": "tool", "content": "工具输出结果为: " + tool_result})
    for msg in messages:
        print('--->', msg)
    result = client.chat.completions.create(messages=messages, model="Qwen1.5-4B-agent")
    print("Answer: ", result.choices[0].message.content)

测试结果如下:

  • 问题: 鸡蛋壳属于哪种类型的垃圾?

输出:

---> {'role': 'system', 'content': '你是一个有用的小助手,请调用下面的工具来回答用户的问题,参考工具输出进行回答。'}
---> {'role': 'user', 'content': '鸡蛋壳属于哪种类型的垃圾?'}
---> {'role': 'function', 'content': '{"name": "get_rubbish_category", "argument": {"keyword": "鸡蛋壳"}}'}
---> {'role': 'tool', 'content': '工具输出结果为: 熟鸡蛋壳: 湿垃圾\n生鸡蛋壳: 湿垃圾\n鸡蛋壳: 湿垃圾\n包裹着鸡蛋壳的餐巾纸: 干垃圾'}
Answer:  鸡蛋壳属于湿垃圾。
  • 问题: 爱在西元前是谁唱的,来自哪张专辑?

输出:

---> {'role': 'system', 'content': '你是一个有用的小助手,请调用下面的工具来回答用户的问题,参考工具输出进行回答。'}
---> {'role': 'user', 'content': '爱在西元前是谁唱的,来自哪张专辑?'}
---> {'role': 'function', 'content': '{"name": "get_song_information", "argument": {"keyword": "爱在西元前"}}'}
---> {'role': 'tool', 'content': '工具输出结果为: 歌曲: 爱在西元前\n歌手: 周杰伦\n时长: 234秒\n专辑名称: 范特西'}
Answer:  歌曲《爱在西元前》的演唱者是周杰伦,来自专辑《范特西》。
  • 问题: 动漫《棋魂》是哪个国家的,什么时候上映的?

输出:

---> {'role': 'system', 'content': '你是一个有用的小助手,请调用下面的工具来回答用户的问题,参考工具输出进行回答。'}
---> {'role': 'user', 'content': '动漫《棋魂》是哪个国家的,什么时候上映的?'}
---> {'role': 'function', 'content': '{"name": "get_cartoon_information", "argument": {"title": "棋魂"}}'}
---> {'role': 'tool', 'content': '工具输出结果为: 标题: ヒカルの碁\n类型:tv\n语言:ja\n出品方:http://www.tv-tokyo.co.jp/anime/hikaru/\n上映时间:2001-10-10T10:27:00.000Z\n完结事件:2003-03-26T10:55:00.000Z'}
Answer:  动漫《棋魂》是日本的,它于2001年10月10日上映。

总结

OpenAI模型的function calling能力无疑是让人惊讶的,但自己实现大模型的function calling能力也是值得开心的。

本文重点介绍了如何使用 LLaMa-Factory 微调框架来自己实现 function calling 能力,并在测试中验证了大模型的工具调用能力。

相关推荐

  1. 使用 LLaMA-Factory 实现模型函数调用功能

    2024-06-09 07:44:01       3 阅读
  2. 使用 LLaMA Factory 微调 Llama-3 中文对话模型

    2024-06-09 07:44:01       12 阅读

最近更新

  1. 利用STM32F103驱动舵机的指南(使用HAL库)

    2024-06-09 07:44:01       0 阅读
  2. web前端defer:深度解析与实用指南

    2024-06-09 07:44:01       0 阅读
  3. c语言如何进行文件错误检查

    2024-06-09 07:44:01       0 阅读
  4. C语言从头学18——类型的自动转换

    2024-06-09 07:44:01       0 阅读
  5. OpenCV之cv::Scalar

    2024-06-09 07:44:01       0 阅读
  6. python显示神经网络训练时的1batch数据

    2024-06-09 07:44:01       0 阅读
  7. 运维监控系统

    2024-06-09 07:44:01       0 阅读
  8. ArrayList<Integer>()转为int[]的几种方式

    2024-06-09 07:44:01       0 阅读

热门阅读

  1. 二叉树----7-3 列出叶结点

    2024-06-09 07:44:01       3 阅读
  2. bat指令踩坑记录

    2024-06-09 07:44:01       3 阅读
  3. Web Dart前端:探索、挑战与未来展望

    2024-06-09 07:44:01       3 阅读
  4. 计算机视觉中的low-level与 high-level任务

    2024-06-09 07:44:01       5 阅读
  5. python记录之字符串

    2024-06-09 07:44:01       4 阅读
  6. Playwright 这个强大的自动化测试工具

    2024-06-09 07:44:01       3 阅读
  7. 安装 hbase(伪分布式)

    2024-06-09 07:44:01       3 阅读
  8. 密码学基本概念

    2024-06-09 07:44:01       3 阅读
  9. Python为项目中添加上彩色日志

    2024-06-09 07:44:01       4 阅读
  10. perl use HTTP::Server::Simple 轻量级 http server

    2024-06-09 07:44:01       2 阅读
  11. 面试 Redis 八股文十问十答第二期

    2024-06-09 07:44:01       2 阅读
  12. ASP.NET Core 中使用基本消息的 RabbitMQ 消费者

    2024-06-09 07:44:01       3 阅读