为何要生成SQL语句

在实际业务开发过程中,面对越来越多的用户在使用一些工具的时候,他们认为界面上展示出来的数据太过于单一了,无论是数据维度还是数据过滤方式,都被限制在了几个功能按钮上。数据聚合跟嵌套查询的想法是大多数人都有的,但是奈何操作者可能并不是编程行业的人,并且就算是也根本不了解当前自己使用的产品有什么数据表可以提供怎样的查询方式,这就限制了一个产品在数据展示维度上的欠缺。因此需要一款程序用于优化这个痛点。根据自然语言文本生成SQL语句的想法应运而生。

方案选择

知识库

生成SQL语句的时候需要大模型理解当前产品的各种特性与表设计,知识库无疑是最好的解决方案,将当前产品的文档灌入到知识库当中,根据用户的问题查询对应的知识库,再将知识库的结果作为参考,结合问题一起让大模型返回结果。但是这里有很多问题,知识库作为参考只能提高模型对当前产品的理解能力,虽然生成sql语句在产品的大致逻辑方向上是正确的,但是却不能够明显提升SQL语句的生成能力,这导致SQL语句出现了大量的隐藏错误,使用体验非常糟糕。

提示词

提示词在GLM提示词当中有做简单讲解,这里就不再赘述了,提示词可以引导大模型按照特定的身份,朝着特定的方向生成高质量的回答,但是提示词本身是有限的,只能够引导并不能够教会模型怎么做或者告诉模型自己正在基于怎样的产品在做什么,并且提示词复杂之后,提示词的逻辑会严重影响到生成的质量,所以提示词的提示能力是有限的。

模型训练

模型训练可以对模型进行微调,基于我们自己的数据集进行模型训练,将模型在某一个方向上的回答能力进行显著的拉升。但是这里有更多的问题,没有足够的数据集模型对文本的理解能力就会不够,但是很少有人能够手持那么大量的优质数据集,如果只是微调,即使保证了模型微调数据集的质量,模型微调的结果也很难把控,模型可能会调得不错,也可能变成一个傻子,更有可能变成偏科生,某一方面理解能力强,其他方面就是纯纯智障,模型的训练结果在一般公司里太不可控了。

混合

最后是混合方案,就是结合之前的所有方案,知识库+提示词+模型训练。从知识库当中查询相关知识,将相关知识结合提示词灌入简单训练过的模型当中,在共同作用下,经过实际测试,有非常好的使用体验。

Vanna框架

Vanna是一个基于知识库与大模型生成SQL语句的python库,也是采用混合方案,将文档、sql语句、dll表作为知识库,配合提示词模板构建历史对话,引导大模型生成SQL语句。由于SQL语句的生成有知识库、提示词、对话历史的引导,这能让大模型生成的SQL变得更准确,并且SQL的生成并不会对数据库有太多挑剔,无论是sqlite、ch数据库、mysql数据库还是mongodb数据库,生成的效果都非常好。框架本身支持对模型对话记录的收集,正确生成的结果将再次被加入到知识库当中,因此一旦启用,随着使用次数的增多,生成SQL会越来越准确。
演示代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vanna.remote import VannaDefault
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.flask import VannaFlaskApp
from clickhouse_driver import Client
from openai import OpenAI
import pandas as pd
import json
import os
import pymysql
client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8000/v1/",
default_headers = {"x-foo": "true"}
)
class MyVanna(ChromaDB_VectorStore,OpenAI_Chat):
def __init__(self,client=None,config=None):
ChromaDB_VectorStore.__init__(self,config=config)
OpenAI_Chat.__init__(self,client=client,config=config)
vn = MyVanna(client=client,config={"model": "chatglm3-6b"}) #chatglm2-6b
vn.max_tokens = 800
vn.temperature = 0.5

client= Client(database='default', # 数据库的名称为 第一个参数
user='default', # 数据库 用户名
password='admin', # 数据库 密码
host='127.0.0.1', port=9000)

def run_sql(sql: str) -> pd.DataFrame:
result = client.execute(sql)
df = pd.DataFrame(result) # type: ignore
df.to_csv('temp.csv', encoding='utf-8', index=False)
df = pd.read_csv('temp.csv', encoding='utf-8')
return df

# 将函数设置到vn.run_sql中
vn.run_sql = run_sql # type: ignore
vn.run_sql_is_set = True

vn.train(ddl="""
CREATE TABLE netlink_5.DnsLog
(
`uniqueId` UInt64 COMMENT '',
)
COMMENT '日志表'
ENGINE = MergeTree
PARTITION BY toStartOfDay(toDateTime(time))
ORDER BY time
SETTINGS index_granularity = 8192
""")
# res = vn.ask('查找链路2.dns日志表,计算过去一个月内每个工作日每个服务端IP的请求次数标准差,找出标准差最大的服务端IP。',print_results=False)
VannaFlaskApp(vn,allow_llm_to_see_data=True).run(host="0.0.0.0", port=9999)

源码分析

Vanna的知识库是基于ChromaDB进行开发的,源码支持对dll、doc、sql的训练,将知识库源码进行剥离之后得到一个简单的知识库管理类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import json
from typing import List,Union,Optional

import chromadb
import pandas as pd
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import hashlib
import uuid
from chromadb.api.types import IDs,OneOrMany,Document


default_ef = embedding_functions.DefaultEmbeddingFunction()

def deterministic_uuid(content: Union[str, bytes]) -> str:
"""
根据输入内容的 SHA-256 哈希生成确定性的 UUID。
:param content: `deterministic_uuid` 函数中的 `content` 参数可以是字符串或字节数据类型。
:type content: Union[str, bytes]
:return: 以字符串形式返回 UUID(通用唯一标识符)。
"""
if isinstance(content, str):
content_bytes = content.encode("utf-8")
elif isinstance(content, bytes):
content_bytes = content
else:
raise ValueError(f"Content type {type(content)} not supported !")
hash_object = hashlib.sha256(content_bytes)
hash_hex = hash_object.hexdigest()
namespace = uuid.UUID("00000000-0000-0000-0000-000000000000")
content_uuid = str(uuid.uuid5(namespace, hash_hex))
return content_uuid



# 管理带有嵌入的文档集合以及根据给定问题查询相关文档的方法。
class ChromaDB_VectorStore():
def __init__(self, config=None):
if config is None:
config = {}

path = config.get("path", "CHROMADB") # 向量数据库保存地址
self.embedding_function = config.get("embedding_function", default_ef) # 向量化方法
curr_client = config.get("client", "persistent") # 客户端类型,允许直接传入一个客户端
collection_metadata = config.get("collection_metadata", None) # metadata
self.n_results = config.get("n_result", 10) # 最大查询返回数量
self.collection_name = config.get("collection","documentation") # 当前数据库的名称
if curr_client == "persistent":
self.chroma_client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
elif curr_client == "in-memory":
self.chroma_client = chromadb.EphemeralClient(
settings=Settings(anonymized_telemetry=False)
)
elif isinstance(curr_client, chromadb.api.client.Client): # type: ignore
self.chroma_client = curr_client# 允许直接提供客户端
else:
raise ValueError(f"在配置中设置了不支持的客户端: {curr_client}")

self.collection = self.chroma_client.get_or_create_collection(
name=self.collection_name,
embedding_function=self.embedding_function,
metadata=collection_metadata,
)

def generate_embedding(self, data: str, **kwargs) -> List[float]:
"""
接受字符串输入,使用指定函数生成嵌入,并将嵌入作为浮点数列表返回。

:param data: `generate_embedding` 函数中的 `data` 参数是一个字符串,表示要生成嵌入的输入数据。
:type data: str
:return: 如果 `embedding` 列表的长度为 1,则该函数将返回列表的第一个元素。否则,它将返回整个 `embedding` 列表。
"""
embedding = self.embedding_function([data])
if len(embedding) == 1:
return embedding[0]
return embedding

def add_documentation(self, documentation: str, **kwargs) -> str:
"""
将文档及其相应的嵌入添加到集合中并返回生成的 ID。

:param documentation: `add_documentation` 方法接受一个 `documentation` 参数,该参数是包含要添加的文档的字符串。此方法还接受其他关键字参数
(**kwargs) 以提高灵活性。
:type documentation: str
:return: 返回添加到集合中的文档的 `id`。
"""
id = deterministic_uuid(documentation)
self.collection.add(
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
return id

def get_training_data(self, **kwargs) -> pd.DataFrame:
collection_data = self.collection.get()
df = pd.DataFrame()
if collection_data is not None:
documents = [doc for doc in collection_data["documents"]]
ids = collection_data["ids"]
df_doc = pd.DataFrame({"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents]}
)
df_doc["training_data_type"] = self.collection_name
df = pd.concat([df, df_doc])
return df

def remove_training_data(self, id: IDs, **kwargs) -> bool:
"""
根据提供的 ID 从集合中删除训练数据。
:param id: `remove_training_data` 方法中的 `id` 参数为 `IDs` 类型。它用于指定需要从集合中移除的训练数据的标识符。
:type id: IDs
:return: 返回一个布尔值。如果成功删除了指定 ID 的训练数据,则返回 `True`;如果删除过程中出现错误,则返回 `False`。
"""
try:
self.collection.delete(ids=id)
return True
except:
return False

def remove_collection(self) -> bool:
"""
尝试删除一个集合并在必要时重新创建它,如果成功则返回 True,否则返回 False。
:return: `remove_collection` 方法返回一个布尔值 - 如果集合被成功删除并重新创建,则返回 `True`;如果在此过程中发生异常,则返回 `False`。
"""
try:
self.chroma_client.delete_collection(name=self.collection_name)
self.collection = self.chroma_client.get_or_create_collection(
name=self.collection_name, embedding_function=self.embedding_function
)
return True
except:
return False

@staticmethod
def _extract_documents(query_results,collection_name) -> list: # type: ignore
"""
该函数从查询结果中提取文档,处理文档嵌套在列表中的情况。

:param query_results:
您提供的代码片段似乎是一个名为“_extract_documents”的函数,它以“query_results”作为参数,并应该返回从查询结果中提取的文档列表。但是,代码片段中有几个问题需要解决:
:return: 变量“document”,但 return 语句似乎有拼写错误。应该是“return documents”,而不是“return document”。
"""
if query_results is None:
return []

if "documents" in query_results:
documents = query_results["documents"]
if len(documents) == 1 and isinstance(documents[0], list):
try:
documents = [json.loads(doc) for doc in documents[0]]
except Exception as e:
return documents[0]
return documents

def get_related(self, question: Optional[OneOrMany[Document]], **kwargs) -> list:
"""
使用给定的问题查询集合并返回相关文档的列表。

:param question: `get_related` 方法中的 `question` 参数属于 `Optional[OneOrMany[Document]]` 类型。这意味着它可以接受单个
`Document` 对象或 `Document` 对象集合,也可以为 `None`。
:type question: Optional[OneOrMany[Document]]
:return: 返回文件清单。
"""
res = self.collection.query(query_texts=question,n_results=self.n_results)
return ChromaDB_VectorStore._extract_documents(res,self.collection_name)

if __name__ == "__main__":
config = {"path":"",
"embedding_function":"",
"client":object,
"collection_metadata":None,
"n_result":10,
"collection":"doc"}