diff --git a/src/hf.ts b/src/hf.ts index 6e67046..8aabeeb 100644 --- a/src/hf.ts +++ b/src/hf.ts @@ -1,4 +1,5 @@ interface ModelMappingItem { + _id?: string; task: string; hfModel: string; providerModel: string; @@ -6,6 +7,7 @@ interface ModelMappingItem { } interface TagFilterMappingItem { + _id?: string; task: string; providerModel: string; status?: 'live' | 'staging'; @@ -122,6 +124,32 @@ class HFInferenceProviderClient { ); } + async getMappingIdByHfModel(hfModel: string): Promise { + const mappings = await this.listMappingItems(); + for (const taskMappings of Object.values(mappings)) { + const mapping = taskMappings[hfModel]; + if (mapping && mapping._id) { + return mapping._id; + } + } + return undefined; + } + + async createHfModelToMappingIdMap(): Promise> { + const mappings = await this.listMappingItems(); + const map = new Map(); + + for (const taskMappings of Object.values(mappings)) { + for (const [hfModel, mapping] of Object.entries(taskMappings)) { + if (mapping._id) { + map.set(hfModel, mapping._id); + } + } + } + + return map; + } + async getMappingsByProvider(provider: string): Promise>> { const url = `${this.baseUrl}/api/partners/${provider}/models`; return this.request>>(url, { diff --git a/src/index.ts b/src/index.ts index eec7857..8be8782 100644 --- a/src/index.ts +++ b/src/index.ts @@ -81,6 +81,8 @@ if (unsupportedModels.length > 0) { // Use only supported models for mapping operations const existingHFModelIds = await hf.listMappingIds(); +// Create a lookup map for efficient mapping ID retrieval during status updates +const hfModelToMappingIdMap = await hf.createHfModelToMappingIdMap(); console.log("\n\nExisting HF model IDs:"); console.log(existingHFModelIds); @@ -102,10 +104,23 @@ if (existingMappings.length > 0) { console.log(`\n\nUpdating statuses for ${existingMappings.length} existing mappings:`); for (const model of existingMappings) { console.log(`${model.hfModel} - ${model.status}`); - await hf.updateMappingItemStatus({ - hfModel: model.hfModel, - status: model.status, - }); + const mappingId = hfModelToMappingIdMap.get(model.hfModel); + if (mappingId) { + try { + await hf.updateMappingItemStatus({ + mappingId: mappingId, + status: model.status, + }); + } catch (error) { + if (error instanceof Error && error.message.includes('does not support task')) { + console.log(`Skipping ${model.hfModel}: ${model.task} task not supported for status updates`); + } else { + throw error; + } + } + } else { + console.error(`Could not find mapping ID for ${model.hfModel}`); + } } } else { console.log("\n\nNo existing mappings to update.");