diff --git a/package.json b/package.json index 976fc3e6..851d3f88 100644 --- a/package.json +++ b/package.json @@ -192,6 +192,7 @@ "rimraf": "5.0.7", "safe-buffer": "5.2.1", "scanf": "1.2.0", + "sqlite3": "5.1.7", "sequelize": "6.28.2", "sequelize-cli": "6.6.0", "sequelize-typescript": "2.1.5", diff --git a/packages/gpt.gblib/services/ChatServices.ts b/packages/gpt.gblib/services/ChatServices.ts index 0fba8a07..9201c23f 100644 --- a/packages/gpt.gblib/services/ChatServices.ts +++ b/packages/gpt.gblib/services/ChatServices.ts @@ -29,7 +29,7 @@ \*****************************************************************************/ 'use strict'; - +import { PromptTemplate } from '@langchain/core/prompts'; import { WikipediaQueryRun } from '@langchain/community/tools/wikipedia_query_run'; import { HNSWLib } from '@langchain/community/vectorstores/hnswlib'; import { BaseCallbackHandler } from '@langchain/core/callbacks/base'; @@ -48,7 +48,7 @@ import { convertToOpenAITool } from '@langchain/core/utils/function_calling'; import { ChatOpenAI, OpenAI } from '@langchain/openai'; import { SqlDatabaseChain } from 'langchain/chains/sql_db'; import { SqlDatabase } from 'langchain/sql_db'; -import {DataSource } from 'typeorm'; +import { DataSource } from 'typeorm'; import { GBMinInstance } from 'botlib'; import * as Fs from 'fs'; import { jsonSchemaToZod } from 'json-schema-to-zod'; @@ -469,20 +469,18 @@ export class ChatServices { }); } else if (LLMMode === 'sql') { // const con = min[`llmconnection`]; - + // const dialect = con['storageDriver']; // const host = con['storageServer']; // const port = con['storagePort']; // const storageName = con['storageName']; // const username = con['storageUsername']; // const password = con['storagePassword']; - - const dataSource = new DataSource({ - type: 'sqlite', - database: "/home/gbadmin3910/DATA/BotServer/work/frukigbot1.gbai/data.db", - }); - + const dataSource = new DataSource({ + type: 'sqlite', + database: '/home/gbadmin3910/DATA/BotServer/work/frukigbot1.gbai/data.db' + }); // const dataSource = new DataSource({ // type: dialect as any, @@ -491,7 +489,7 @@ export class ChatServices { // database: storageName, // username: username, // password: password, - // synchronize: false, + // synchronize: false, // logging: true, // }); @@ -499,12 +497,78 @@ export class ChatServices { appDataSource: dataSource }); - const chain = new SqlDatabaseChain({ - llm: model, - database: db, + const prompt = + PromptTemplate.fromTemplate(`Based on the provided SQL table schema below, write a SQL query that would answer the user's question. + ------------ + SCHEMA: {schema} + ------------ + QUESTION: {question} + ------------ + SQL QUERY:`); + + /** + * Create a new RunnableSequence where we pipe the output from `db.getTableInfo()` + * and the users question, into the prompt template, and then into the llm. + * We're also applying a stop condition to the llm, so that it stops when it + * sees the `\nSQLResult:` token. + */ + const sqlQueryChain = RunnableSequence.from([ + { + schema: async () => db.getTableInfo(), + question: (input: { question: string }) => input.question + }, + prompt, + model.bind({ stop: ['\nSQLResult:'] }), + new StringOutputParser() + ]); + + /** + * Create the final prompt template which is tasked with getting the natural + * language response to the SQL query. + */ + const finalResponsePrompt = + PromptTemplate.fromTemplate(`Based on the table schema below, question, SQL query, and SQL response, write a natural language response: + ------------ + SCHEMA: {schema} + ------------ + QUESTION: {question} + ------------ + SQL QUERY: {query} + ------------ + SQL RESPONSE: {response} + ------------ + NATURAL LANGUAGE RESPONSE:`); + + /** + * Create a new RunnableSequence where we pipe the output from the previous chain, the users question, + * and the SQL query, into the prompt template, and then into the llm. + * Using the result from the `sqlQueryChain` we can run the SQL query via `db.run(input.query)`. + * + * Lastly we're piping the result of the first chain (the outputted SQL query) so it is + * logged along with the natural language response. + */ + const finalChain = RunnableSequence.from([ + { + question: input => input.question, + query: sqlQueryChain + }, + { + schema: async () => db.getTableInfo(), + question: input => input.question, + query: input => input.query, + response: input => db.run(input.query) + }, + { + result: finalResponsePrompt.pipe(model).pipe(new StringOutputParser()), + // Pipe the query through here unchanged so it gets logged alongside the result. + sql: previousStepResult => previousStepResult.query + } + ]); + + result = await finalChain.invoke({ + question: question }); - result = await chain.run(question); } else if (LLMMode === 'nochain') { result = await (tools.length > 0 ? modelWithTools : model).invoke(` ${systemPrompt}