Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.dicio.skill.skill.SkillInfo
import org.dicio.skill.skill.SkillOutput
import org.dicio.skill.standard.StandardRecognizerData
import org.dicio.skill.standard.StandardRecognizerSkill
import org.dicio.skill.standard.util.MatchHelper
import org.stypox.dicio.sentences.Sentences.Calculator
import org.stypox.dicio.sentences.Sentences.CalculatorOperators
import java.text.DecimalFormat
Expand All @@ -20,7 +21,8 @@ class CalculatorSkill(correspondingSkillInfo: SkillInfo, data: StandardRecognize
operatorSection: StandardRecognizerData<CalculatorOperators>,
text: String
): CalculatorOperators? {
val (score, result) = operatorSection.score(ctx, text)
val helper = MatchHelper(ctx.parserFormatter, text)
val (score, result) = operatorSection.score(helper, text)
Comment on lines +24 to +25
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this makes sense anyway, it was using a helper set up with the wrong text earlier

return if (score.scoreIn01Range() < 0.3) {
null
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.dicio.skill.standard
import org.dicio.skill.context.SkillContext
import org.dicio.skill.skill.Specificity
import org.dicio.skill.standard.construct.Construct
import org.dicio.skill.standard.util.MatchHelper
import org.dicio.skill.standard.util.initialMemToEnd

open class StandardRecognizerData<out T>(
Expand All @@ -11,7 +12,10 @@ open class StandardRecognizerData<out T>(
private val sentencesWithId: List<Pair<String, Construct>>,
) {
fun score(ctx: SkillContext, input: String): Pair<StandardScore, T> {
val helper = ctx.standardMatchHelper!! // surely != null, see its javadoc
return score(ctx.standardMatchHelper!!, input)
}

fun score(helper: MatchHelper, input: String): Pair<StandardScore, T> {
val cumulativeWeight = helper.cumulativeWeight

var bestRes: Pair<String, StandardScore>? = null
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.dicio.skill.standard

import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldNotBe
import org.dicio.skill.skill.Specificity
import org.dicio.skill.standard.construct.WordConstruct
import org.dicio.skill.standard.construct.CompositeConstruct
import org.dicio.skill.standard.util.MatchHelper

/**
* Regression test: StandardRecognizerData.score(MatchHelper, input) must work
* without ctx.standardMatchHelper being set. This is needed when skills call
* score() during generateOutput (e.g. CalculatorSkill matching operators).
*/
class StandardRecognizerDataTest : StringSpec({

"score with explicit MatchHelper does not require ctx" {
// Build a simple recognizer that matches the word "plus"
val construct = CompositeConstruct(
listOf(WordConstruct("plus", false, false, 1.0f))
)
val data = StandardRecognizerData(
specificity = Specificity.HIGH,
converter = { input, sentenceId, _ -> sentenceId },
sentencesWithId = listOf(Pair("plus_sentence", construct)),
)

val helper = MatchHelper(parserFormatter = null, userInput = "plus")
val (score, result) = data.score(helper, "plus")

score shouldNotBe null
result shouldNotBe null
}

"score with explicit MatchHelper returns low score for non-matching input" {
val construct = CompositeConstruct(
listOf(WordConstruct("plus", false, false, 1.0f))
)
val data = StandardRecognizerData(
specificity = Specificity.HIGH,
converter = { input, sentenceId, _ -> sentenceId },
sentencesWithId = listOf(Pair("plus_sentence", construct)),
)

val helper = MatchHelper(parserFormatter = null, userInput = "banana")
val (score, _) = data.score(helper, "banana")

// should score poorly since "banana" doesn't match "plus"
assert(score.scoreIn01Range() < 0.5f)
}
})
Loading