Connect LLM to SQL database with LangChain SQLChain
How to Tutorial for using LangChain SQLChain
Last updated
How to Tutorial for using LangChain SQLChain
Last updated
!pip install -r requirements.txtimport psycopg2
import json
import psycopg2
import json
# Load the JSON data from the file
with open("./data/my_database.json", "r") as f:
data = json.load(f)
# Connect to the PostgreSQL database
conn = psycopg2.connect(
host="localhost",
database="postgres",
user="user",
password="password",
port="5432"
)
# Create a cursor object
cur = conn.cursor()
# Create the table
cur.execute("""
CREATE TABLE IF NOT EXISTS person (
id SERIAL PRIMARY KEY,
name VARCHAR(100),
address TEXT,
email VARCHAR(100),
phone_number VARCHAR(30),
birthdate DATE,
job VARCHAR(100)
)
""")
# Insert the data into the table
for person_id, person_data in data["_default"].items():
name = person_data.get("name", "")
address = person_data.get("address", "")
email = person_data.get("email", "")
phone_number = person_data.get("phone_number", "")
birthdate = person_data.get("birthdate", None)
job = person_data.get("job", "")
cur.execute(
"INSERT INTO person (name, address, email, phone_number, birthdate, job) VALUES (%s, %s, %s, %s, %s, %s)",
(name, address, email, phone_number, birthdate, job)
)
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
table_names = [row[0] for row in cur.fetchall()]
# Print the table names
print("Table names:")
for table_name in table_names:
print(table_name)
# Commit the changes and close the connection
conn.commit()
from dotenv import load_dotenv
env_content = """
GROQ_API_KEY=your groq api key without ''
"""
# Write the content to a file named .env
with open('.env', 'w') as f:
f.write(env_content)
# Load environment variables
load_dotenv(".env")Set OpenAI API environment variablesimport os
# Set OpenAI API environment variables (if needed)
os.environ["OPENAI_API_BASE"] = "https://api.groq.com/openai/v1"
os.environ["OPENAI_MODEL_NAME"] = "llama3-8b-8192"
os.environ["OPENAI_API_KEY"] =os.getenv("GROQ_API_KEY")from langchain_community.utilities import SQLDatabase
username="postgres"
password="password" # plain (unescaped) text
host="localhost" # Hostname or IP address of the PostgreSQL server
port="5432" # Port number
mydatabase="postgres"
pg_uri = f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{mydatabase}"
db = SQLDatabase.from_uri(pg_uri)import os
from langchain.prompts.prompt import PromptTemplate
OPENAI_API_KEY = os.getenv("GROQ_API_KEY")
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
openai_api_key=OPENAI_API_KEY,
model_name="llama3-8b-8192"
)
from langchain_community.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTSsql_chain = SQLDatabaseChain.from_llm(llm, db)
PROMPT = """
Given an input question, first create a syntactically correct postgresql query to run,
then look at the results of the query and return the answer.
The question: {question}
"""decider_chain = LLMChain(llm=llm, prompt=DECIDER_PROMPT, output_key="table_names")
class CustomSQLDatabaseSequentialChain(SQLDatabaseSequentialChain):
def run(self, question):
result = super().run(question)
print(f"Query result: {result}")
return result
db_chain = CustomSQLDatabaseSequentialChain(llm=llm, database=db, verbose=True, top_k=3,
decider_chain=decider_chain, sql_chain=sql_chain, prompt=PROMPT)
question = "Whom the person has a name starting with 'S'?"
db_chain.run(question)from sqlalchemy import create_engine, text
# Assuming db_chain is your CustomSQLDatabaseSequentialChain instance
query = db_chain.run(question)
# Extract the SQL query from the combined string
sql_query = query.split('SQLQuery:')[1].strip()
# Establish a connection to your database
engine = create_engine(pg_uri)
# Execute the SQL query
with engine.connect() as connection:
result = connection.execute(text(sql_query))
# Print the results
for row in result:
print(row)