构建 SQL 代理¶
在本教程中,我们将逐步介绍如何构建一个能够回答有关 SQL 数据库问题的代理。
从高层次来看,该代理将:
- 从数据库中获取可用的表
- 确定哪些表与问题相关
- 获取相关表的模式
- 根据问题和模式信息生成查询
- 使用 LLM 双重检查查询以避免常见错误
- 执行查询并返回结果
- 纠正数据库引擎发现的错误,直到查询成功
- 根据结果制定响应
安全提示
构建 SQL 数据库的问答系统需要执行模型生成的 SQL 查询。这样做存在固有风险。请确保您的数据库连接权限始终根据代理的需求尽可能缩小范围。这将减轻(尽管不能完全消除)构建模型驱动系统的风险。
1. 设置¶
首先安装一些依赖项。本教程使用了 langchain-community 中的 SQL 数据库和工具抽象。我们还需要一个 LangChain 聊天模型。
Tip
注册 LangSmith 可以快速发现问题并提高 LangGraph 项目的性能。LangSmith 让您可以使用跟踪数据来调试、测试和监控使用 LangGraph 构建的 LLM 应用程序。
选择 LLM¶
首先我们初始化我们的 LLM。任何支持工具调用的模型都应该可以工作。下面我们使用 OpenAI。
API Reference: init_chat_model
配置数据库¶
我们将为本教程创建一个 SQLite 数据库。SQLite 是一个轻量级数据库,易于设置和使用。我们将加载 chinook
数据库,这是一个代表数字媒体店的示例数据库。
有关该数据库的更多信息,请参见这里。
为了方便起见,我们已将数据库 (Chinook.db
) 托管在一个公共的 GCS 存储桶中。
import requests
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)
if response.status_code == 200:
# 以二进制写入模式打开本地文件
with open("Chinook.db", "wb") as file:
# 将响应内容(文件)写入本地文件
file.write(response.content)
print("文件已下载并保存为 Chinook.db")
else:
print(f"下载文件失败。状态码: {response.status_code}")
我们将使用 langchain_community
包中一个方便的 SQL 数据库包装器来与数据库交互。该包装器提供了一个简单的接口来执行 SQL 查询和获取结果:
API Reference: SQLDatabase
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(f"方言: {db.dialect}")
print(f"可用表: {db.get_usable_table_names()}")
print(f'示例输出: {db.run("SELECT * FROM Artist LIMIT 5;")}')
输出:
方言: sqlite
可用表: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
示例输出: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]
用于数据库交互的工具¶
langchain-community
实现了一些内置工具来与我们的 SQLDatabase
交互,包括列出表、读取表模式以及检查和运行查询的工具:
API Reference: SQLDatabaseToolkit
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
for tool in tools:
print(f"{tool.name}: {tool.description}\n")
输出:
sql_db_query: 此工具的输入是一个详细且正确的 SQL 查询,输出是数据库的结果。如果查询不正确,将返回错误消息。如果返回错误,请重写查询,检查查询,然后重试。如果遇到 Unknown column 'xxxx' in 'field list' 的问题,请使用 sql_db_schema 查询正确的表字段。
sql_db_schema: 此工具的输入是一个逗号分隔的表列表,输出是这些表的模式和示例行。务必先调用 sql_db_list_tables 以确保表确实存在!示例输入: table1, table2, table3
sql_db_list_tables: 输入是一个空字符串,输出是数据库中逗号分隔的表列表。
sql_db_query_checker: 使用此工具在执行查询前双重检查查询是否正确。在使用 sql_db_query 执行查询之前,务必使用此工具!
2. 使用预构建的代理¶
有了这些工具,我们可以在一行代码中初始化一个预构建的代理。为了自定义代理的行为,我们编写一个描述性的系统提示。
API Reference: create_react_agent
from langgraph.prebuilt import create_react_agent
system_prompt = """
您是一个设计用于与 SQL 数据库交互的代理。
给定一个输入问题,创建一个语法正确的 {dialect} 查询来运行,
然后查看查询结果并返回答案。除非用户
指定他们希望获取的特定数量的示例,否则始终将查询限制为最多 {top_k} 个结果。
您可以按相关列对结果进行排序,以返回数据库中最有趣的示例。永远不要查询特定表的所有列,只询问给定问题相关的列。
在执行查询之前,您必须双重检查您的查询。如果在执行查询时遇到错误,请重写查询并重试。
不要对数据库进行任何 DML 语句(INSERT、UPDATE、DELETE、DROP 等)。
开始时,您应该始终查看数据库中的表,以了解可以查询的内容。不要跳过此步骤。
然后您应该查询最相关表的模式。
""".format(
dialect=db.dialect,
top_k=5,
)
agent = create_react_agent(
llm,
tools,
prompt=system_prompt,
)
Note
此系统提示包含许多指令,例如始终在之前或之后运行特定工具。在下一节中,我们将通过图的结构强制执行这些行为,为我们提供更大程度的控制,并允许我们简化提示。
让我们在一个示例查询上运行此代理并观察其行为:
question = "哪种流派的曲目平均长度最长?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
输出:
================================ Human Message =================================
哪种流派的曲目平均长度最长?
================================== Ai Message ==================================
工具调用:
sql_db_list_tables (call_d8lCgywSroCgpVl558nmXKwA)
调用 ID: call_d8lCgywSroCgpVl558nmXKwA
参数:
================================= Tool Message =================================
名称: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
工具调用:
sql_db_schema (call_nNf6IIUcwMYLIkE0l6uWkZHe)
调用 ID: call_nNf6IIUcwMYLIkE0l6uWkZHe
参数:
table_names: Genre, Track
================================= Tool Message =================================
名称: sql_db_schema
CREATE TABLE "Genre" (
"GenreId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("GenreId")
)
/*
3 rows from Genre table:
GenreId Name
1 Rock
2 Jazz
3 Metal
*/
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
3 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
3 Fast As a Shark 3 2 1 F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619 3990994 0.99
*/
================================== Ai Message ==================================
工具调用:
sql_db_query_checker (call_urTRmtiGtTxkwHtscec7Fd2K)
调用 ID: call_urTRmtiGtTxkwHtscec7Fd2K
参数:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.Name
ORDER BY AvgMilliseconds DESC
LIMIT 1;
================================= Tool Message =================================
名称: sql_db_query_checker
\`\`\`sql
SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.Name
ORDER BY AvgMilliseconds DESC
LIMIT 1;
\`\`\`
================================== Ai Message ==================================
工具调用:
sql_db_query (call_RNMqyUEMv0rvy0UxSwrXY2AV)
调用 ID: call_RNMqyUEMv0rvy0UxSwrXY2AV
参数:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.Name
ORDER BY AvgMilliseconds DESC
LIMIT 1;
================================= Tool Message =================================
名称: sql_db_query
[('Sci Fi & Fantasy', 2911783.0384615385)]
================================== Ai Message ==================================
平均曲目长度最长的流派是 "Sci Fi & Fantasy",平均每首曲目时长约为 2,911,783 毫秒(约 48.5 分钟)。
这工作得足够好:代理正确列出了表,获取了模式,编写了查询,检查了查询,并运行它以形成最终响应。
Tip
您可以在 LangSmith trace 中检查上述运行的所有方面,包括采取的步骤、调用的工具、LLM 看到的提示等。
3. 自定义代理¶
预构建的代理让我们可以快速开始,但在每一步代理都可以访问完整的工具集。上面,我们依靠系统提示来约束其行为——例如,我们指示代理始终以“列出表”工具开始,并在执行查询之前始终运行查询检查器工具。
在 LangGraph 中,我们可以通过自定义代理来强制执行更高程度的控制。下面,我们实现了一个简单的 ReAct 代理设置,为特定的工具调用使用专用节点。我们将使用与预构建代理相同的状态。
我们为以下步骤构建专用节点:
- 列出数据库表
- 调用“获取模式”工具
- 生成查询
- 检查查询
将这些步骤放入专用节点让我们可以(1)在需要时强制进行工具调用,以及(2)自定义与每个步骤关联的提示。
API Reference: AIMessage | RunnableConfig | END | START | StateGraph | ToolNode
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_node = ToolNode([get_schema_tool], name="get_schema")
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")
# 示例:创建一个预定的工具调用
def list_tables(state: MessagesState):
tool_call = {
"name": "sql_db_list_tables",
"args": {},
"id": "abc123",
"type": "tool_call",
}
tool_call_message = AIMessage(content="", tool_calls=[tool_call])
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
tool_message = list_tables_tool.invoke(tool_call)
response = AIMessage(f"可用表: {tool_message.content}")
return {"messages": [tool_call_message, tool_message, response]}
# 示例:强制模型创建工具调用
def call_get_schema(state: MessagesState):
# 注意 LangChain 强制所有模型接受 `tool_choice="any"`
# 以及 `tool_choice=<工具名称字符串>`。
llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
generate_query_system_prompt = """
您是一个设计用于与 SQL 数据库交互的代理。
给定一个输入问题,创建一个语法正确的 {dialect} 查询来运行,
然后查看查询结果并返回答案。除非用户
指定他们希望获取的特定数量的示例,否则始终将查询限制为最多 {top_k} 个结果。
您可以按相关列对结果进行排序,以返回数据库中最有趣的示例。永远不要查询特定表的所有列,只询问给定问题相关的列。
不要对数据库进行任何 DML 语句(INSERT、UPDATE、DELETE、DROP 等)。
""".format(
dialect=db.dialect,
top_k=5,
)
def generate_query(state: MessagesState):
system_message = {
"role": "system",
"content": generate_query_system_prompt,
}
# 这里我们不强制工具调用,允许模型在获得解决方案时自然响应。
llm_with_tools = llm.bind_tools([run_query_tool])
response = llm_with_tools.invoke([system_message] + state["messages"])
return {"messages": [response]}
check_query_system_prompt = """
您是一位注重细节的 SQL 专家。
双重检查 {dialect} 查询中的常见错误,包括:
- 在 NULL 值中使用 NOT IN
- 在应该使用 UNION ALL 时使用了 UNION
- 对独占范围使用 BETWEEN
- 谓词中的数据类型不匹配
- 正确引用标识符
- 使用正确数量的函数参数
- 转换为正确的数据类型
- 使用正确的列进行连接
如果存在任何上述错误,请重写查询。如果没有错误,只需重现原始查询。
运行此检查后,您将调用适当的工具来执行查询。
""".format(dialect=db.dialect)
def check_query(state: MessagesState):
system_message = {
"role": "system",
"content": check_query_system_prompt,
}
# 生成一个用于检查的人工用户消息
tool_call = state["messages"][-1].tool_calls[0]
user_message = {"role": "user", "content": tool_call["args"]["query"]}
llm_with_tools = llm.bind_tools([run_query_tool], tool_choice="any")
response = llm_with_tools.invoke([system_message, user_message])
response.id = state["messages"][-1].id
return {"messages": [response]}
最后,我们使用 Graph API 将这些步骤组装成一个工作流。我们在查询生成步骤定义了一个条件边,如果生成了查询,将路由到查询检查器,或者如果没有工具调用(例如 LLM 已交付对查询的响应)则结束。
def should_continue(state: MessagesState) -> Literal[END, "check_query"]:
messages = state["messages"]
last_message = messages[-1]
if not last_message.tool_calls:
return END
else:
return "check_query"
builder = StateGraph(MessagesState)
builder.add_node(list_tables)
builder.add_node(call_get_schema)
builder.add_node(get_schema_node, "get_schema")
builder.add_node(generate_query)
builder.add_node(check_query)
builder.add_node(run_query_node, "run_query")
builder.add_edge(START, "list_tables")
builder.add_edge("list_tables", "call_get_schema")
builder.add_edge("call_get_schema", "get_schema")
builder.add_edge("get_schema", "generate_query")
builder.add_conditional_edges(
"generate_query",
should_continue,
)
builder.add_edge("check_query", "run_query")
builder.add_edge("run_query", "generate_query")
agent = builder.compile()
我们在下面可视化应用程序:
API Reference: CurveStyle | MermaidDrawMethod | NodeStyles
from IPython.display import Image, display
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles
display(Image(agent.get_graph().draw_mermaid_png()))
注意: 当您运行此代码时,它将生成并显示 SQL 代理图的可视化表示,显示不同节点之间的流程(list_tables → call_get_schema → get_schema → generate_query → check_query → run_query)。
我们现在可以像以前一样调用图:
question = "哪种流派的曲目平均长度最长?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
输出:
================================ Human Message =================================
哪种流派的曲目平均长度最长?
================================== Ai Message ==================================
可用表: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
工具调用:
sql_db_schema (call_qxKtYiHgf93AiTDin9ez5wFp)
调用 ID: call_qxKtYiHgf93AiTDin9ez5wFp
参数:
table_names: Genre,Track
================================= Tool Message =================================
名称: sql_db_schema
CREATE TABLE "Genre" (
"GenreId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("GenreId")
)
/*
3 rows from Genre table:
GenreId Name
1 Rock
2 Jazz
3 Metal
*/
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
3 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
3 Fast As a Shark 3 2 1 F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619 3990994 0.99
*/
================================== Ai Message ==================================
工具调用:
sql_db_query (call_RPN3GABMfb6DTaFTLlwnZxVN)
调用 ID: call_RPN3GABMfb6DTaFTLlwnZxVN
参数:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgTrackLength
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.GenreId
ORDER BY AvgTrackLength DESC
LIMIT 1;
================================== Ai Message ==================================
工具调用:
sql_db_query (call_PR4s8ymiF3ZQLaoZADXtdqcl)
调用 ID: call_PR4s8ymiF3ZQLaoZADXtdqcl
参数:
query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgTrackLength
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
GROUP BY Genre.GenreId
ORDER BY AvgTrackLength DESC
LIMIT 1;
================================= Tool Message =================================
名称: sql_db_query
[('Sci Fi & Fantasy', 2911783.0384615385)]
================================== Ai Message ==================================
平均曲目长度最长的流派是 "Sci Fi & Fantasy",平均曲目长度约为 2,911,783 毫秒。
Tip
参见上述运行的 LangSmith trace。
下一步¶
查看本指南,了解如何使用 LangSmith 评估 LangGraph 应用程序,包括像这样的 SQL 代理。