LLAMA3 + groq + tool-use
(llama3 结合 groq + tool-use 微调模型使用)[https://wow.groq.com/introducing-llama-3-groq-tool-use-models/]
安装依赖
sh
# 安装依赖
pip install groq
pip install tavily-python # 使用 tavily 实现联网搜索功能
定义环境变量
python
# 定义 groq 和 tavily_api 的 key
os.environ["GROQ_API_KEY"] = ''
os.environ["TAVILY_API_KEY"] = ''
实际执行
python
# 导入所需的库
from groq import Groq # 导入Groq API客户端
import json # 用于JSON数据处理
import os # 用于环境变量操作
import pprint
# 初始化Groq客户端,使用环境变量中的API密钥
client = Groq(
api_key = os.environ.get("GROQ_API_KEY"),
)
# 定义使用的模型名称
MODEL = 'llama3-groq-70b-8192-tool-use-preview'
def calculate(expression):
"""计算数学表达式"""
try:
# 使用eval函数评估表达式
result = eval(expression)
# 返回JSON格式的结果
return json.dumps({"result": result})
except:
# 如果计算出错,返回错误信息
return json.dumps({"error": "Invalid expression"})
def run_conversation(user_prompt):
# 定义对话的消息列表
messages=[
{
"role": "system",
"content": "你是一个计算器助手。使用计算函数执行数学运算并提供结果."
},
{
"role": "user",
"content": user_prompt,
}
]
# 定义可用的工具(函数)
tools = [
{
"type": "function",
"function": {
"name": "calculate",
"description": "计算数学表达式",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "要评估的数学表达式",
}
},
"required": ["expression"],
},
},
}
]
print('第一次信息输出 \n')
print(messages)
print('\n')
# 发送第一次请求到Groq API
# 作用和目的:
# 初始化对话:将用户的问题发送给 AI 模型。
# 提供工具信息:告诉模型可以使用哪些工具(在这里是 calculate 函数)。
# 获取模型的初步响应:模型可能会直接回答,或者决定使用提供的工具。
# 特点:
# 包含了初始的对话历史(系统提示和用户问题)。
# 提供了 tools 参数,定义了可用的函数。
# 使用 tool_choice="auto",允许模型自主决定是否使用工具。
response = client.chat.completions.create(
model=MODEL,
messages=messages,
tools=tools,
tool_choice="auto",
max_tokens=4096
)
print('\n')
print('输出response \n')
print(response)
print('\n')
# 获取响应消息和工具调用
response_message = response.choices[0].message
print('\n')
print('第一次响应输出 \n')
print(response_message)
print('\n')
tool_calls = response_message.tool_calls
print('输出tool_calls信息: \n')
pprint.pprint(tool_calls)
print('\n')
# 如果有工具调用
if tool_calls:
# 定义可用的函数字典
available_functions = {
"calculate": calculate,
}
# 将响应消息添加到对话历史
messages.append(response_message)
# 处理每个工具调用
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
# 解析函数参数
function_args = json.loads(tool_call.function.arguments)
# 调用函数并获取响应
function_response = function_to_call(
expression=function_args.get("expression")
)
print('\n输出function_response '+function_response +'\n')
# 将函数调用结果添加到对话历史
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
)
print('第二次信息输出 \n')
print(messages)
print('\n')
# 发送第二次请求到Groq API,包含函数调用结果
# 作用和目的:
# 处理工具调用的结果:将计算结果反馈给模型。
# 获取最终响应:让模型基于计算结果生成人类可读的回答。
# 特点:
# 包含了更新后的对话历史,包括第一次响应和工具调用的结果。
# 没有提供 tools 参数,因为此时不需要再次使用工具。
# 目的是获取最终的、格式化的回答。
second_response = client.chat.completions.create(
model=MODEL,
messages=messages
)
# 返回最终响应内容
return second_response.choices[0].message.content
# 定义用户提示
user_prompt = "计算25.6602988 * 4/0.259484 + 5.69560456 -398.11287180等于多少?这个数字有什么特殊意义吗?用中文回答."
# user_prompt = "1+1 等于多少?"
# 运行对话并打印结果
print('第二次响应输出 \n'+run_conversation(user_prompt))
sql 执行工具模拟
python
import sqlite3
import random
from datetime import datetime, timedelta
# 连接到SQLite数据库(如果不存在则创建)
conn = sqlite3.connect('demo_users.db')
cursor = conn.cursor()
# 创建用户表
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
age INTEGER,
email TEXT UNIQUE,
registration_date DATE,
last_login DATETIME
)
''')
# 生成示例数据
names = ["Alice", "Bob", "Charlie", "David", "Eva", "Frank", "Grace", "Henry", "Ivy", "Jack"]
domains = ["gmail.com", "yahoo.com", "hotmail.com", "example.com"]
for i in range(50): # 创建50个用户记录
name = random.choice(names)
age = random.randint(18, 70)
email = f"{name.lower()}{random.randint(1, 100)}@{random.choice(domains)}"
registration_date = datetime.now() - timedelta(days=random.randint(1, 1000))
last_login = registration_date + timedelta(days=random.randint(1, 500))
cursor.execute('''
INSERT INTO users (name, age, email, registration_date, last_login)
VALUES (?, ?, ?, ?, ?)
''', (name, age, email, registration_date.date(), last_login))
# 提交更改并关闭连接
conn.commit()
conn.close()
print("Demo database 'demo_users.db' created successfully with sample data.")
# 函数用于显示表格内容
def display_table_contents():
conn = sqlite3.connect('demo_users.db')
cursor = conn.cursor()
cursor.execute("SELECT * FROM users LIMIT 5")
rows = cursor.fetchall()
print("\nSample data from the users table:")
for row in rows:
print(row)
conn.close()
display_table_contents()
import os
import json
import sqlite3
from groq import Groq
from datetime import datetime, timedelta
# 初始化Groq客户端
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# 数据库连接函数
def get_db_connection():
"""创建并返回到SQLite数据库的连接"""
conn = sqlite3.connect('demo_users.db')
conn.row_factory = sqlite3.Row
return conn
def execute_sql(sql_query):
"""执行SQL查询并返回结果"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute(sql_query)
results = [dict(row) for row in cursor.fetchall()]
return results
except sqlite3.Error as e:
return f"数据库错误: {e}"
finally:
conn.close()
def generate_sql(table_info, conditions, select_fields="*"):
"""
生成SQL查询
:param table_info: 表信息
:param conditions: WHERE子句的条件
:param select_fields: 要选择的字段,默认为所有字段
:return: 生成的SQL查询字符串
"""
return f"SELECT {select_fields} FROM users WHERE {conditions}"
def format_results(results, fields=None):
"""
格式化查询结果
:param results: 查询返回的结果列表
:param fields: 要显示的字段列表,如果为None则显示所有字段
:return: 格式化后的结果字符串
"""
if isinstance(results, str): # 如果结果是错误消息
return results
if not results:
return "没有找到匹配的记录。"
if fields:
formatted = [", ".join(str(row.get(field, "N/A")) for field in fields) for row in results]
else:
formatted = [json.dumps(row, ensure_ascii=False, indent=2) for row in results]
return "\n".join(formatted)
def run_text2sql_conversation(user_prompt):
"""
运行text2sql对话
:param user_prompt: 用户输入的查询
:return: 查询结果
"""
table_info = "users(id INTEGER, name TEXT, age INTEGER, email TEXT, registration_date DATE, last_login DATETIME)"
messages = [
{
"role": "system",
"content": f"你是一个SQL助手。使用generate_sql函数根据用户请求创建SQL查询。可用的表: {table_info}。准确理解用户需求,包括他们想要查询的具体字段。"
},
{
"role": "user",
"content": user_prompt,
}
]
tools = [
{
"type": "function",
"function": {
"name": "generate_sql",
"description": "根据用户请求生成SQL查询",
"parameters": {
"type": "object",
"properties": {
"table_info": {
"type": "string",
"description": "表结构信息",
},
"conditions": {
"type": "string",
"description": "WHERE子句的具体查询条件",
},
"select_fields": {
"type": "string",
"description": "要选择的字段,用逗号分隔",
}
},
"required": ["table_info", "conditions", "select_fields"],
},
},
}
]
try:
response = client.chat.completions.create(
model="llama3-groq-70b-8192-tool-use-preview",
messages=messages,
tools=tools,
tool_choice="auto",
max_tokens=4096
)
assistant_message = response.choices[0].message
if assistant_message.tool_calls:
for tool_call in assistant_message.tool_calls:
if tool_call.function.name == "generate_sql":
function_args = json.loads(tool_call.function.arguments)
sql_query = generate_sql(
function_args["table_info"],
function_args["conditions"],
function_args["select_fields"]
)
results = execute_sql(sql_query)
formatted_results = format_results(results, function_args["select_fields"].split(", ") if function_args["select_fields"] != "*" else None)
return f"生成的SQL查询: {sql_query}\n\n结果:\n{formatted_results}"
return "无法生成SQL查询。请尝试重新表述您的问题。"
except Exception as e:
return f"发生错误: {str(e)}"
# 主程序
if __name__ == "__main__":
print("欢迎使用Text2SQL系统!")
print("您可以用自然语言询问有关用户表的问题。")
print("输入'quit'退出程序。")
while True:
user_input = input("\n请输入您的查询 (或 'quit' 退出): ")
if user_input.lower() == 'quit':
print("谢谢使用,再见!")
break
result = run_text2sql_conversation(user_input)
print("\n" + "="*50)
print(result)
print("="*50)
结合 tavily 查询网络
python
# 导入所需的库
from groq import Groq
from tavily import TavilyClient
import json
import os
import pprint
# 初始化Groq客户端和Tavily客户端,使用环境变量中的API密钥
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
# 定义使用的模型名称
MODEL = 'llama3-groq-70b-8192-tool-use-preview'
def tavily_search(query):
"""执行Tavily搜索"""
try:
response = tavily_client.search(query)
# 返回前5个结果的标题、URL和内容摘要
results = [{
"title": r["title"],
"url": r["url"],
"content": r["content"][:200] + "..." # 限制内容长度
} for r in response["results"][:5]]
return json.dumps({"results": results})
except Exception as e:
return json.dumps({"error": str(e)})
def run_conversation(user_prompt):
messages = [
{
"role": "system",
"content": "你是一个智能助手,能够进行在线搜索以回答问题。使用tavily_search函数来获取最新、最相关的信息。请基于搜索结果提供详细、准确的回答,并在适当的时候引用信息来源。"
},
{
"role": "user",
"content": user_prompt,
}
]
tools = [
{
"type": "function",
"function": {
"name": "tavily_search",
"description": "执行在线搜索查询,获取最新信息",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询",
}
},
"required": ["query"],
},
},
}
]
print('初始消息:\n')
pprint.pprint(messages)
print('\n')
response = client.chat.completions.create(
model=MODEL,
messages=messages,
tools=tools,
tool_choice="auto",
max_tokens=4096
)
print('Groq API 响应:\n')
pprint.pprint(response)
print('\n')
response_message = response.choices[0].message
print('AI 初始响应:\n')
pprint.pprint(response_message)
print('\n')
tool_calls = response_message.tool_calls
if tool_calls:
print('工具调用信息:\n')
pprint.pprint(tool_calls)
print('\n')
messages.append(response_message)
for tool_call in tool_calls:
function_args = json.loads(tool_call.function.arguments)
function_response = tavily_search(**function_args)
print(f'\nTavily 搜索响应:\n{function_response}\n')
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": "tavily_search",
"content": function_response,
}
)
print('更新后的消息:\n')
pprint.pprint(messages)
print('\n')
second_response = client.chat.completions.create(
model=MODEL,
messages=messages
)
return second_response.choices[0].message.content
else:
return response_message.content
# 用户提示示例
user_prompt = "今天有哪些新闻?"
# 运行对话并打印结果
print('最终AI响应:\n' + run_conversation(user_prompt))