Skip to content
On this page

llama-index text-to-sql

mermaid
graph LR
    A[用户输入自然语言查询] --> B[大型语言模型 LLM]
    B --> C{理解查询意图}
    C --> D[结合数据库结构信息]
    D --> E[生成SQL查询]
    E --> F[执行SQL查询]
    F --> G[获取查询结果]
    G --> H[LLM解释结果]
    H --> I[生成自然语言回答]
    I --> J[向用户展示结果]

    classDef default fill:#e6e6fa,stroke:#4b0082,stroke-width:2px,color:#4b0082;
    classDef llm fill:#4b0082,stroke:#4b0082,stroke-width:2px,color:#e6e6fa;
    classDef process fill:#4169e1,stroke:#4b0082,stroke-width:2px,color:#e6e6fa;
    classDef decision fill:#6a5acd,stroke:#4b0082,stroke-width:2px,color:#e6e6fa;

    class A,J default;
    class B,H llm;
    class C decision;
    class D,E,F,G,I process;
python
%pip install llama-index-llms-openrouter

!pip install llama-index

import os
from llama_index.llms.openrouter import OpenRouter
from llama_index.core.llms import ChatMessage

os.environ["OPENROUTER_API_KEY"] = "sk-or-v1-"

from IPython.display import Markdown, display

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    select,
    insert,
)

# 创建内存数据库
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# 创建产品库存表

table_name = "product_inventory"
product_inventory_table = Table(
    table_name,
    metadata_obj,
    Column("product_id", String(16), primary_key=True),
    Column("product_name", String(50), nullable=False),
    Column("quantity", Integer),
    Column("unit_price", Float),
    Column("category", String(20)),
)
metadata_obj.create_all(engine)


from llama_index.core import SQLDatabase

# 初始化LLM
llm = OpenRouter(
    api_key=os.environ.get("OPENROUTER_API_KEY"),
    max_tokens=4096,
    context_window=131072, 
    model="qwen/qwen-2.5-72b-instruct",
)

# 创建SQL数据库对象
sql_database = SQLDatabase(engine, include_tables=["product_inventory"])

# 创建SQL数据库对象
sql_database = SQLDatabase(engine, include_tables=["product_inventory"])

# 插入50条示例数据请自行模拟
rows = [
]

for row in rows:
    stmt = insert(product_inventory_table).values(**row)
    with engine.begin() as connection:
        connection.execute(stmt)


# 查看当前表内容
stmt = select(
    product_inventory_table.c.product_id,
    product_inventory_table.c.product_name,
    product_inventory_table.c.quantity,
    product_inventory_table.c.unit_price,
    product_inventory_table.c.category,
).select_from(product_inventory_table)


#演示如何执行原始 SQL 查询,该查询可直接在表上执行
from sqlalchemy import text

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    for row in results:
        print(row)

from llama_index.core.query_engine import NLSQLTableQueryEngine

# 创建查询引擎
query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["product_inventory"], llm=llm
)

# 示例查询
query_str = "哪个类别的产品种类最多?"
response = query_engine.query(query_str)

display(Markdown(f"<b>{response}</b>"))

Released under the MIT License.