1 | |
2 | type Cohere = { |
3 | apiKey: string; |
4 | }; |
5 | |
6 | * Trains and deploys a fine-tuned model. |
7 | * |
8 | */ |
9 | export async function main( |
10 | auth: Cohere, |
11 | body: { |
12 | id?: string; |
13 | name?: string; |
14 | creator_id?: string; |
15 | organization_id?: string; |
16 | settings?: { |
17 | base_model: { |
18 | name?: string; |
19 | version?: string; |
20 | base_type: |
21 | | "BASE_TYPE_UNSPECIFIED" |
22 | | "BASE_TYPE_GENERATIVE" |
23 | | "BASE_TYPE_CLASSIFICATION" |
24 | | "BASE_TYPE_RERANK" |
25 | | "BASE_TYPE_CHAT"; |
26 | strategy?: |
27 | | "STRATEGY_UNSPECIFIED" |
28 | | "STRATEGY_VANILLA" |
29 | | "STRATEGY_TFEW"; |
30 | }; |
31 | dataset_id: string; |
32 | hyperparameters?: { |
33 | early_stopping_patience?: number; |
34 | early_stopping_threshold?: number; |
35 | train_batch_size?: number; |
36 | train_epochs?: number; |
37 | learning_rate?: number; |
38 | lora_alpha?: number; |
39 | lora_rank?: number; |
40 | lora_target_modules?: |
41 | | "LORA_TARGET_MODULES_UNSPECIFIED" |
42 | | "LORA_TARGET_MODULES_QV" |
43 | | "LORA_TARGET_MODULES_QKVO" |
44 | | "LORA_TARGET_MODULES_QKVO_FFN"; |
45 | }; |
46 | multi_label?: false | true; |
47 | wandb?: { project: string; api_key: string; entity?: string }; |
48 | }; |
49 | status?: |
50 | | "STATUS_UNSPECIFIED" |
51 | | "STATUS_FINETUNING" |
52 | | "STATUS_DEPLOYING_API" |
53 | | "STATUS_READY" |
54 | | "STATUS_FAILED" |
55 | | "STATUS_DELETED" |
56 | | "STATUS_TEMPORARILY_OFFLINE" |
57 | | "STATUS_PAUSED" |
58 | | "STATUS_QUEUED"; |
59 | created_at?: string; |
60 | updated_at?: string; |
61 | completed_at?: string; |
62 | last_used?: string; |
63 | }, |
64 | ) { |
65 | const url = new URL(`https://api.cohere.com/v1/finetuning/finetuned-models`); |
66 |
|
67 | const response = await fetch(url, { |
68 | method: "POST", |
69 | headers: { |
70 | "Content-Type": "application/json", |
71 | Authorization: "Bearer " + auth.apiKey, |
72 | }, |
73 | body: JSON.stringify(body), |
74 | }); |
75 | if (!response.ok) { |
76 | const text = await response.text(); |
77 | throw new Error(`${response.status} ${text}`); |
78 | } |
79 | return await response.json(); |
80 | } |
81 |
|