diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index c67e37bb25..c9f63fc47a 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -95,6 +95,7 @@ export * from "./lib/Interfaces/IModelExplanationData"; export * from "./lib/Interfaces/IConfusionMatrixData"; export * from "./lib/Interfaces/IVisionModelExplanationData"; export * from "./lib/Interfaces/IWeightedDropdownContext"; +export * from "./lib/Interfaces/ITokenDropdownContext"; export * from "./lib/Interfaces/IFilter"; export * from "./lib/Interfaces/IPreBuiltFilter"; export * from "./lib/Interfaces/ICohort"; diff --git a/libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts b/libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts new file mode 100644 index 0000000000..79ae0d7702 --- /dev/null +++ b/libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IComboBox, IComboBoxOption } from "@fluentui/react"; +import React from "react"; + +export type TokenOption = number + + +export interface ITokenDropdownContext { + options: IComboBoxOption[]; + selectedKey: TokenOption; + onSelection: ( + event: React.FormEvent, + item?: IComboBoxOption + ) => void; +} diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts index f5d705c2d7..03ca37e2de 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts @@ -10,6 +10,11 @@ export enum RadioKeys { Neg = "neg" } +export enum QAExplanationType { + Start = "start", + End = "end" +} + export class Utils { public static argsort(toSort: number[]): number[] { /* diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx index 31bf384112..70023e1a6e 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx @@ -11,11 +11,11 @@ import { Text } from "@fluentui/react"; import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; -import { ClassImportanceWeights } from "@responsible-ai/interpret"; +import { ClassImportanceWeights, TokenImportance } from "@responsible-ai/interpret"; import { localization } from "@responsible-ai/localization"; import React from "react"; -import { RadioKeys, Utils } from "../../CommonUtils"; +import { RadioKeys, QAExplanationType, Utils } from "../../CommonUtils"; import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps"; import { BarChart } from "../BarChart/BarChart"; import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend"; @@ -30,8 +30,13 @@ export interface ITextExplanationViewState { maxK: number; topK: number; radio: string; + qaRadio?: string; importances: number[]; + singleTokenImportances: number[]; + selectedToken: number; + tokenIndexes: number[]; text: string[]; + isQA: boolean // temporal flag for identifying QA } const options: IChoiceGroupOption[] = [ @@ -43,6 +48,15 @@ const options: IChoiceGroupOption[] = [ { key: RadioKeys.Neg, text: localization.InterpretText.View.negButton } ]; +const qaOptions: IChoiceGroupOption[] = [ + /* + * Creates the choices for the QA prediction radio button(local testing) + * TODO: move text under localization.InterpretText.View + */ + { key: QAExplanationType.Start, text: "STARTING POSITION" }, + { key: QAExplanationType.End, text: "ENDING POSITION" }, +]; + const componentStackTokens: IStackTokens = { childrenGap: "m", padding: "m" @@ -59,19 +73,28 @@ export class TextExplanationView extends React.PureComponent< * Initializes the text view with its state */ super(props); - const weightVector = this.props.selectedWeightVector; - const importances = this.computeImportancesForWeightVector( + + const isQA = false; // FIXME: temporally hardcode the flag, should use prop instead + + const importances = this.computeImportancesForAllTokens( this.props.dataSummary.localExplanations, - weightVector ); + + const selectedToken = 0; // default to the first token + const singleTokenImportances = this.getImportanceForSingleToken(selectedToken); const maxK = this.calculateMaxKImportances(importances); const topK = this.calculateTopKImportances(importances); this.state = { - importances, + importances: importances, + singleTokenImportances: singleTokenImportances, + selectedToken: selectedToken, + tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index), maxK, radio: RadioKeys.All, + qaRadio: "starting", text: this.props.dataSummary.text, - topK + topK, + isQA }; } @@ -81,16 +104,37 @@ export class TextExplanationView extends React.PureComponent< this.props.dataSummary.localExplanations !== prevProps.dataSummary.localExplanations ) { - this.updateImportances(this.props.selectedWeightVector); + if (this.state.isQA) { + this.setState({ //update token dropdown + tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index), + selectedToken: 0 + }, + () => { + this.updateTokenImportances(); + this.updateSingleTokenImportances(); + }) + } else { + this.updateImportances(this.props.selectedWeightVector); + } } } public render(): React.ReactNode { const classNames = textExplanationDashboardStyles(); + const qaDescription = 'The left text box and the bar chart display the predictions of the model.' + + 'The right textbox shows the feature importance associated with a selected token. Positive feature ' + + 'importances represent the extent that the words were important towards marking the selected token' + + 'as the starting/ending position of the answer.' + return ( - {localization.InterpretText.View.legendText} + { + this.state.isQA? + {qaDescription} : + {localization.InterpretText.View.legendText} + } + @@ -103,6 +147,8 @@ export class TextExplanationView extends React.PureComponent< + + { !this.state.isQA && ( // classfication {localization.InterpretText.View.label + @@ -113,9 +159,23 @@ export class TextExplanationView extends React.PureComponent< )} + )} + + { this.state.isQA && ( // select starting/ending for QA + + + + )} + + - + + { this.state.isQA? + ( + + ) + : + ( + ) + } + {this.props.selectedWeightVector !== WeightVectors.AbsAvg && ( + + + { this.state.isQA && ( + + + + )} + + ); } @@ -174,11 +270,21 @@ export class TextExplanationView extends React.PureComponent< this.props.onWeightChange(weightOption); }; + private onSelectedTokenChange = (newIndex: number): void => { + + this.setState( + { selectedToken: newIndex }, + () => { + this.updateSingleTokenImportances(); + }); + }; + private updateImportances(weightOption: WeightVectorOption): void { const importances = this.computeImportancesForWeightVector( this.props.dataSummary.localExplanations, weightOption ); + const topK = this.calculateTopKImportances(importances); const maxK = this.calculateMaxKImportances(importances); this.setState({ @@ -189,6 +295,28 @@ export class TextExplanationView extends React.PureComponent< }); } + + // for QA + private updateTokenImportances(): void { + + const importances = this.computeImportancesForAllTokens( + this.props.dataSummary.localExplanations, + ); + const topK = this.calculateTopKImportances(importances); + const maxK = this.calculateMaxKImportances(importances); + this.setState({ + importances, + maxK, + topK, + text: this.props.dataSummary.text + }); + } + + private updateSingleTokenImportances(): void { + const singleTokenImportances = this.getImportanceForSingleToken(this.state.selectedToken); + this.setState({singleTokenImportances: singleTokenImportances}); + } + private calculateTopKImportances(importances: number[]): number { return Math.min( MaxImportantWords, @@ -222,6 +350,27 @@ export class TextExplanationView extends React.PureComponent< ); } + private computeImportancesForAllTokens( + importances: number[][] + ): number[] { + /* + * sum the tokens importance + * TODO: add base values? + */ + + const sumImportances = importances[0].map((_, index) => + importances.reduce((sum, row) => sum + row[index], 0) + ); + + return sumImportances; + } + + private getImportanceForSingleToken( + index: number + ): number[] { + return this.props.dataSummary.localExplanations.map(row => row[index]); + } + private setTopK = (newNumber: number): void => { /* * Changes the state of K @@ -240,4 +389,18 @@ export class TextExplanationView extends React.PureComponent< this.setState({ radio: item.key }); } }; + + private switchQAprediction = ( + _event?: React.FormEvent, + item?: IChoiceGroupOption + ): void => { + /* + * switch to the target predictions(starting or ending) + * TODO: add logic for switching explanation data + */ + if (item?.key !== undefined) { + this.setState({ qaRadio: item.key }); + } + }; + } diff --git a/libs/interpret/src/index.ts b/libs/interpret/src/index.ts index 0ee6695c72..f30eda7c5b 100644 --- a/libs/interpret/src/index.ts +++ b/libs/interpret/src/index.ts @@ -6,6 +6,7 @@ export * from "./lib/MLIDashboard/NewExplanationDashboard"; export * from "./lib/MLIDashboard/Interfaces/IExplanationDashboardProps"; export * from "./lib/MLIDashboard/Interfaces/IStringsParam"; export * from "./lib/MLIDashboard/Controls/ClassImportanceWeights/ClassImportanceWeights"; +export * from "./lib/MLIDashboard/Controls/TokenImportance/TokenImportance"; export * from "./lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab"; export * from "./lib/MLIDashboard/Controls/GlobalExplanationTab/IGlobalSeries"; export * from "./lib/MLIDashboard/Controls/ModelPerformanceTab/ModelPerformanceTab"; diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts new file mode 100644 index 0000000000..f10de55dcc --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IStyle, mergeStyleSets, IProcessedStyleSet } from "@fluentui/react"; + +export interface ILabelWithCalloutStyles { + tokenLabel: IStyle; + tokenLabelText: IStyle; +} + +export const tokenImportanceStyles: () => IProcessedStyleSet = + () => { + return mergeStyleSets({ + tokenLabel: { + display: "inline-flex", + paddingTop: "10px" + }, + tokenLabelText: { + fontWeight: "600", + paddingTop: "5px" + } + }); + }; diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx new file mode 100644 index 0000000000..f8e6683cac --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { + Dropdown, + IDropdownOption, + Text + } from "@fluentui/react"; + import { TokenOption } from "@responsible-ai/core-ui"; + import React from "react"; + + import { tokenImportanceStyles } from "./TokenImportance.styles"; + + export interface ITokenImportanceProps { + onTokenChange: (option: TokenOption) => void; + selectedToken: TokenOption; + tokenOptions: TokenOption[]; + tokenLabels: any; + disabled?: boolean; + } + interface ITokenImportanceState { + tokenOptionsMap: { key: TokenOption; text: any }[]; + } + + export class TokenImportance extends React.Component< + ITokenImportanceProps, + ITokenImportanceState + > { + public constructor(props: ITokenImportanceProps) { + super(props); + + this.state = { + tokenOptionsMap: this.props.tokenOptions?.map((option) => { + return { + key: option, + text: this.props.tokenLabels[option] + }; + }) + }; + } + + public componentDidUpdate(prevProps: ITokenImportanceProps): void { + if (this.props.tokenOptions !== prevProps.tokenOptions || + this.props.tokenLabels !== prevProps.tokenLabels + ) { + this.setState({ + tokenOptionsMap: this.props.tokenOptions?.map((option) => { + return { + key: option, + text: this.props.tokenLabels[option] + }; + }) + }) + } + } + + public render(): React.ReactNode { + const tokenNames = tokenImportanceStyles(); + return ( +
+
+ + {"Selected Token"} + + +
+ {this.state.tokenOptionsMap && ( + + )} +
+ ); + } + + private setTokenOption = ( + _event: React.FormEvent, + item?: IDropdownOption + ): void => { + if (item?.key === undefined) { + return; + } + + const newIndex = item.key as TokenOption; + this.props.onTokenChange(newIndex); + }; + } +