new(basic.gblib): New batch features.

This commit is contained in:
Rodrigo Rodriguez 2024-08-13 21:12:58 -03:00
parent 139b28337e
commit 3b03cf4bcd

View file

@ -69,7 +69,7 @@ import {
SQL_SQLITE_PROMPT, SQL_SQLITE_PROMPT,
SQL_MSSQL_PROMPT, SQL_MSSQL_PROMPT,
SQL_MYSQL_PROMPT SQL_MYSQL_PROMPT
} from "langchain/chains/sql_db"; } from 'langchain/chains/sql_db';
export interface CustomOutputParserFields {} export interface CustomOutputParserFields {}
export type ExpectedOutput = any; export type ExpectedOutput = any;
@ -525,9 +525,9 @@ export class ChatServices {
const sqlQueryChain = RunnableSequence.from([ const sqlQueryChain = RunnableSequence.from([
{ {
schema: async () => db.getTableInfo(), schema: async () => db.getTableInfo(),
question: (input: { question: string }) => input.question , question: (input: { question: string }) => input.question,
top_k: ()=>10, top_k: () => 10,
table_info: ()=>'any' table_info: () => 'any'
}, },
prompt, prompt,
model, model,
@ -538,9 +538,9 @@ export class ChatServices {
* Create the final prompt template which is tasked with getting the natural * Create the final prompt template which is tasked with getting the natural
* language response to the SQL query. * language response to the SQL query.
*/ */
const finalResponsePrompt = SQL_SQLITE_PROMPT; const finalResponsePrompt =
PromptTemplate.fromTemplate(`Based on the table schema below, question, SQL query, and SQL response, write a natural language response: PromptTemplate.fromTemplate(`Based on the table schema below, question, SQL query, and SQL response, write a natural language response:
Optimize answers for KPI people. Optimize answers for KPI people. ${systemPrompt}
------------ ------------
SCHEMA: {schema} SCHEMA: {schema}
------------ ------------
@ -563,14 +563,15 @@ Optimize answers for KPI people.
const finalChain = RunnableSequence.from([ const finalChain = RunnableSequence.from([
{ {
input: input => input.question, input: input => input.question,
query: sqlQueryChain, query: sqlQueryChain
}, },
{ {
schema: async () => db.getTableInfo(), schema: async () => db.getTableInfo(),
input: input => input.question, input: input => input.question,
query: input => input.query, query: input => input.query,
response: input => db.run(input.query), response: input => db.run(input.query),
top_k: ()=>10, table_info: ()=>'any' top_k: () => 10,
table_info: () => 'any'
}, },
{ {
result: finalResponsePrompt.pipe(model).pipe(new StringOutputParser()), result: finalResponsePrompt.pipe(model).pipe(new StringOutputParser()),
@ -582,7 +583,6 @@ Optimize answers for KPI people.
result = await finalChain.invoke({ result = await finalChain.invoke({
question: question question: question
}); });
} else if (LLMMode === 'nochain') { } else if (LLMMode === 'nochain') {
result = await (tools.length > 0 ? modelWithTools : model).invoke(` result = await (tools.length > 0 ? modelWithTools : model).invoke(`
${systemPrompt} ${systemPrompt}