|
1 | 1 | import { aiCommand } from './decorators'; |
2 | | -import { AiProvider } from './providers/ai-provider'; |
| 2 | +import { AiProvider, EmptyAiProvider } from './providers/ai-provider'; |
| 3 | +import { getAtlasAiProvider } from './providers/atlas/atlas-ai-provider'; |
3 | 4 | import { getDocsAiProvider } from './providers/docs/docs-ai-provider'; |
| 5 | +import { |
| 6 | + getAiSdkProvider, |
| 7 | + models, |
| 8 | +} from './providers/generic/ai-sdk-provider'; |
| 9 | +import { Config, ConfigSchema } from './config'; |
| 10 | +import { CliContext, wrapAllFunctions, formatHelpCommands } from './helpers'; |
4 | 11 |
|
5 | 12 | class AI { |
6 | | - constructor( |
7 | | - private readonly cliContext: any, |
8 | | - private readonly ai: AiProvider, |
9 | | - ) { |
10 | | - const methods = Object.getOwnPropertyNames( |
11 | | - Object.getPrototypeOf(this), |
12 | | - ).filter((name) => { |
13 | | - const descriptor = Object.getOwnPropertyDescriptor( |
14 | | - Object.getPrototypeOf(this), |
15 | | - name, |
16 | | - ); |
17 | | - return ( |
18 | | - descriptor && |
19 | | - typeof descriptor.value === 'function' && |
20 | | - name !== 'constructor' |
21 | | - ); |
22 | | - }); |
| 13 | + private readonly replConfig: { |
| 14 | + set: (key: string, value: any) => Promise<void>; |
| 15 | + get: <T>(key: string) => Promise<T>; |
| 16 | + }; |
23 | 17 |
|
24 | | - // for all methods, wrap them with the wrapFunction method |
25 | | - for (const methodName of methods) { |
26 | | - const method = (this as any)[methodName]; |
27 | | - if (typeof method === 'function' && method.isDirectShellCommand) { |
28 | | - this.wrapFunction(methodName, method.bind(this)); |
29 | | - } |
30 | | - } |
| 18 | + private ai: AiProvider; |
| 19 | + public config: Config; |
| 20 | + |
| 21 | + constructor(private readonly cliContext: CliContext) { |
31 | 22 | const instanceState = this.cliContext.db._mongo._instanceState; |
32 | | - instanceState.registerPlugin(this); |
33 | 23 |
|
34 | | - this.wrapFunction(undefined, this.help.bind(this)); |
| 24 | + this.replConfig = { |
| 25 | + set: (key, value) => |
| 26 | + instanceState.evaluationListener.setConfig(`snippet_ai_${key}`, value), |
| 27 | + get: (key) => |
| 28 | + instanceState.evaluationListener.getConfig(`snippet_ai_${key}`), |
| 29 | + }; |
| 30 | + |
| 31 | + this.config = new Config(this.replConfig); |
| 32 | + |
| 33 | + // Set up provider change listener |
| 34 | + this.config.on('change', (event) => { |
| 35 | + switch (event.key) { |
| 36 | + case 'provider': |
| 37 | + this.ai = this.getProvider(event.value as ConfigSchema['provider']); |
| 38 | + break; |
| 39 | + case 'model': |
| 40 | + if (Object.keys(models).includes(event.value)) { |
| 41 | + this.ai = getAiSdkProvider( |
| 42 | + models[this.config.get('provider') as keyof typeof models]( |
| 43 | + event.value, |
| 44 | + ), |
| 45 | + this.cliContext, |
| 46 | + ); |
| 47 | + } else { |
| 48 | + throw new Error(`Invalid model: ${event.value}`); |
| 49 | + } |
| 50 | + break; |
| 51 | + default: |
| 52 | + break; |
| 53 | + } |
| 54 | + }); |
| 55 | + |
| 56 | + this.ai = this.getProvider(process.env.MONGOSH_AI_PROVIDER as ConfigSchema['provider'] | undefined); |
| 57 | + wrapAllFunctions(this.cliContext, this); |
| 58 | + |
| 59 | + this.setupConfig(); |
35 | 60 | } |
36 | 61 |
|
37 | | - private wrapFunction(name: string | undefined, fn: Function) { |
38 | | - const wrapperFn = (...args: string[]) => { |
39 | | - return Object.assign(fn(...args), { |
40 | | - [Symbol.for('@@mongosh.syntheticPromise')]: true, |
41 | | - }); |
42 | | - }; |
43 | | - wrapperFn.isDirectShellCommand = true; |
44 | | - wrapperFn.returnsPromise = true; |
| 62 | + async setupConfig() { |
| 63 | + await this.config.setup(); |
45 | 64 |
|
46 | | - const instanceState = this.cliContext.db._mongo._instanceState; |
| 65 | + this.ai = this.getProvider(this.config.get('provider')); |
| 66 | + } |
| 67 | + |
| 68 | + private getProvider(provider: ConfigSchema['provider'] | undefined): AiProvider { |
| 69 | + switch (provider) { |
| 70 | + case 'docs': |
| 71 | + return getDocsAiProvider(this.cliContext); |
| 72 | + case 'atlas': |
| 73 | + return getAtlasAiProvider(this.cliContext); |
| 74 | + case 'openai': |
| 75 | + case 'mistral': |
| 76 | + case 'ollama': |
| 77 | + const model = this.config.get('model'); |
| 78 | + return getAiSdkProvider( |
| 79 | + models[provider](model === 'default' ? undefined : model), |
| 80 | + this.cliContext, |
| 81 | + ); |
| 82 | + default: |
| 83 | + return new EmptyAiProvider(this.cliContext); |
| 84 | + } |
| 85 | + } |
47 | 86 |
|
48 | | - instanceState.shellApi[name ? `ai.${name}` : 'ai'] = instanceState.context[ |
49 | | - name ? `ai.${name}` : 'ai' |
50 | | - ] = wrapperFn; |
| 87 | + @aiCommand |
| 88 | + async command(prompt: string) { |
| 89 | + await this.ai.command(prompt); |
51 | 90 | } |
52 | 91 |
|
53 | 92 | @aiCommand |
54 | | - async query(code: string) { |
55 | | - return await this.ai.query(code); |
| 93 | + async query(prompt: string) { |
| 94 | + await this.ai.query(prompt); |
56 | 95 | } |
57 | 96 |
|
58 | 97 | @aiCommand |
59 | | - async ask(code: string) { |
60 | | - return await this.ai.ask(code); |
| 98 | + async ask(prompt: string) { |
| 99 | + await this.ai.ask(prompt); |
61 | 100 | } |
62 | 101 |
|
63 | 102 | @aiCommand |
64 | | - async aggregate(code: string) { |
65 | | - return await this.ai.aggregate(code); |
| 103 | + async aggregate(prompt: string) { |
| 104 | + await this.ai.aggregate(prompt); |
66 | 105 | } |
67 | 106 |
|
68 | 107 | @aiCommand |
69 | 108 | async help(...args: string[]) { |
| 109 | + const commands = [ |
| 110 | + { cmd: 'ai.ask', desc: 'ask questions', example: 'ai.ask how do I run queries in mongosh?' }, |
| 111 | + { cmd: 'ai.command', desc: 'generate any mongosh command', example: 'ai.command create a new database' }, |
| 112 | + { cmd: 'ai.query', desc: 'generate a MongoDB query', example: 'ai.query find documents where name = "Ada"' }, |
| 113 | + { cmd: 'ai.aggregate', desc: 'generate a MongoDB aggregation', example: 'ai.aggregate find documents where name = "Ada"' }, |
| 114 | + { cmd: 'ai.config', desc: 'configure the AI commands', example: 'ai.config.set("provider", "ollama")' } |
| 115 | + ]; |
| 116 | + |
| 117 | + this.ai.respond( |
| 118 | + formatHelpCommands( |
| 119 | + commands, |
| 120 | + this.config.get('provider'), |
| 121 | + this.config.get('model') |
| 122 | + ) |
| 123 | + ); |
| 124 | + } |
| 125 | + |
| 126 | + [Symbol.for('nodejs.util.inspect.custom')]() { |
70 | 127 | this.ai.help(); |
| 128 | + return ''; |
71 | 129 | } |
72 | 130 | } |
73 | 131 |
|
74 | 132 | module.exports = (globalThis: any) => { |
75 | | - globalThis.ai = new AI(globalThis, getDocsAiProvider(globalThis)); |
| 133 | + globalThis.ai = new AI(globalThis); |
76 | 134 | }; |
0 commit comments