diff --git a/app/src/main/kotlin/org/stypox/dicio/skills/calculator/CalculatorSkill.kt b/app/src/main/kotlin/org/stypox/dicio/skills/calculator/CalculatorSkill.kt index c2927d1c..d0d4f12c 100644 --- a/app/src/main/kotlin/org/stypox/dicio/skills/calculator/CalculatorSkill.kt +++ b/app/src/main/kotlin/org/stypox/dicio/skills/calculator/CalculatorSkill.kt @@ -7,6 +7,7 @@ import org.dicio.skill.skill.SkillInfo import org.dicio.skill.skill.SkillOutput import org.dicio.skill.standard.StandardRecognizerData import org.dicio.skill.standard.StandardRecognizerSkill +import org.dicio.skill.standard.util.MatchHelper import org.stypox.dicio.sentences.Sentences.Calculator import org.stypox.dicio.sentences.Sentences.CalculatorOperators import java.text.DecimalFormat @@ -20,7 +21,8 @@ class CalculatorSkill(correspondingSkillInfo: SkillInfo, data: StandardRecognize operatorSection: StandardRecognizerData, text: String ): CalculatorOperators? { - val (score, result) = operatorSection.score(ctx, text) + val helper = MatchHelper(ctx.parserFormatter, text) + val (score, result) = operatorSection.score(helper, text) return if (score.scoreIn01Range() < 0.3) { null } else { diff --git a/skill/src/main/java/org/dicio/skill/standard/StandardRecognizerData.kt b/skill/src/main/java/org/dicio/skill/standard/StandardRecognizerData.kt index 1bce4076..7c5cb36a 100644 --- a/skill/src/main/java/org/dicio/skill/standard/StandardRecognizerData.kt +++ b/skill/src/main/java/org/dicio/skill/standard/StandardRecognizerData.kt @@ -3,6 +3,7 @@ package org.dicio.skill.standard import org.dicio.skill.context.SkillContext import org.dicio.skill.skill.Specificity import org.dicio.skill.standard.construct.Construct +import org.dicio.skill.standard.util.MatchHelper import org.dicio.skill.standard.util.initialMemToEnd open class StandardRecognizerData( @@ -11,7 +12,10 @@ open class StandardRecognizerData( private val sentencesWithId: List>, ) { fun score(ctx: SkillContext, input: String): Pair { - val helper = ctx.standardMatchHelper!! // surely != null, see its javadoc + return score(ctx.standardMatchHelper!!, input) + } + + fun score(helper: MatchHelper, input: String): Pair { val cumulativeWeight = helper.cumulativeWeight var bestRes: Pair? = null diff --git a/skill/src/test/java/org/dicio/skill/standard/StandardRecognizerDataTest.kt b/skill/src/test/java/org/dicio/skill/standard/StandardRecognizerDataTest.kt new file mode 100644 index 00000000..a82c6011 --- /dev/null +++ b/skill/src/test/java/org/dicio/skill/standard/StandardRecognizerDataTest.kt @@ -0,0 +1,51 @@ +package org.dicio.skill.standard + +import io.kotest.core.spec.style.StringSpec +import io.kotest.matchers.shouldNotBe +import org.dicio.skill.skill.Specificity +import org.dicio.skill.standard.construct.WordConstruct +import org.dicio.skill.standard.construct.CompositeConstruct +import org.dicio.skill.standard.util.MatchHelper + +/** + * Regression test: StandardRecognizerData.score(MatchHelper, input) must work + * without ctx.standardMatchHelper being set. This is needed when skills call + * score() during generateOutput (e.g. CalculatorSkill matching operators). + */ +class StandardRecognizerDataTest : StringSpec({ + + "score with explicit MatchHelper does not require ctx" { + // Build a simple recognizer that matches the word "plus" + val construct = CompositeConstruct( + listOf(WordConstruct("plus", false, false, 1.0f)) + ) + val data = StandardRecognizerData( + specificity = Specificity.HIGH, + converter = { input, sentenceId, _ -> sentenceId }, + sentencesWithId = listOf(Pair("plus_sentence", construct)), + ) + + val helper = MatchHelper(parserFormatter = null, userInput = "plus") + val (score, result) = data.score(helper, "plus") + + score shouldNotBe null + result shouldNotBe null + } + + "score with explicit MatchHelper returns low score for non-matching input" { + val construct = CompositeConstruct( + listOf(WordConstruct("plus", false, false, 1.0f)) + ) + val data = StandardRecognizerData( + specificity = Specificity.HIGH, + converter = { input, sentenceId, _ -> sentenceId }, + sentencesWithId = listOf(Pair("plus_sentence", construct)), + ) + + val helper = MatchHelper(parserFormatter = null, userInput = "banana") + val (score, _) = data.score(helper, "banana") + + // should score poorly since "banana" doesn't match "plus" + assert(score.scoreIn01Range() < 0.5f) + } +})