diff --git a/packages/modeldb-sqlite/src/ModelDB.ts b/packages/modeldb-sqlite/src/ModelDB.ts index 58b6b9d14..73861cb04 100644 --- a/packages/modeldb-sqlite/src/ModelDB.ts +++ b/packages/modeldb-sqlite/src/ModelDB.ts @@ -68,7 +68,10 @@ export class ModelDB extends AbstractModelDB { return api.get(key) as T | null } - public async *iterate = ModelValue>(modelName: string): AsyncIterable { + public async *iterate = ModelValue>( + modelName: string, + query: QueryParams = {}, + ): AsyncIterable { const api = this.#models[modelName] assert(api !== undefined, `model ${modelName} not found`) yield* api.values() as AsyncIterable diff --git a/packages/modeldb-sqlite/src/api.ts b/packages/modeldb-sqlite/src/api.ts index 18a41e15e..e1ed30d7c 100644 --- a/packages/modeldb-sqlite/src/api.ts +++ b/packages/modeldb-sqlite/src/api.ts @@ -215,6 +215,39 @@ export class ModelAPI { } public query(query: QueryParams): ModelValue[] { + const [sql, relations, params] = this.parseQuery(query) + const results = this.db.prepare(sql).all(params) as RecordValue[] + return results.map((record): ModelValue => { + const key = record[this.#primaryKeyName] + assert(typeof key === "string", 'expected typeof primaryKey === "string"') + + const value: ModelValue = {} + for (const [propertyName, propertyValue] of Object.entries(record)) { + const property = this.#properties[propertyName] + if (property.kind === "primary") { + value[propertyName] = decodePrimaryKeyValue(this.model.name, property, propertyValue) + } else if (property.kind === "primitive") { + value[propertyName] = decodePrimitiveValue(this.model.name, property, propertyValue) + } else if (property.kind === "reference") { + value[propertyName] = decodeReferenceValue(this.model.name, property, propertyValue) + } else if (property.kind === "relation") { + throw new Error("internal error") + } else { + signalInvalidType(property) + } + } + + for (const relation of relations) { + value[relation.property] = this.#relations[relation.property].get(key) + } + + return value + }) + } + + private parseQuery( + query: QueryParams, + ): [sql: string, relations: Relation[], params: Record] { // See https://www.sqlite.org/lang_select.html for railroad diagram const sql: string[] = [] @@ -262,33 +295,7 @@ export class ModelAPI { params.limit = query.offset } - const results = this.db.prepare(sql.join(" ")).all(params) as RecordValue[] - return results.map((record): ModelValue => { - const key = record[this.#primaryKeyName] - assert(typeof key === "string", 'expected typeof primaryKey === "string"') - - const value: ModelValue = {} - for (const [propertyName, propertyValue] of Object.entries(record)) { - const property = this.#properties[propertyName] - if (property.kind === "primary") { - value[propertyName] = decodePrimaryKeyValue(this.model.name, property, propertyValue) - } else if (property.kind === "primitive") { - value[propertyName] = decodePrimitiveValue(this.model.name, property, propertyValue) - } else if (property.kind === "reference") { - value[propertyName] = decodeReferenceValue(this.model.name, property, propertyValue) - } else if (property.kind === "relation") { - throw new Error("internal error") - } else { - signalInvalidType(property) - } - } - - for (const relation of relations) { - value[relation.property] = this.#relations[relation.property].get(key) - } - - return value - }) + return [sql.join(" "), relations, params] } private getSelectExpression(