From 635f3ec0745868a4e737bf0cdc889e66deeaad8d Mon Sep 17 00:00:00 2001 From: Neutrinowo <86736305+JarvisG495@users.noreply.github.com> Date: Thu, 1 Jun 2023 15:40:26 -0400 Subject: [PATCH 1/3] add individual feature importance for QA temporally use a dropdown to select token --- libs/core-ui/src/index.ts | 1 + .../lib/Interfaces/ITokenDropdownContext.ts | 17 ++ .../TextExplanationView.tsx | 165 +++++++++++++++++- libs/interpret/src/index.ts | 1 + .../TokenImportance/TokenImportance.styles.ts | 23 +++ .../TokenImportance/TokenImportance.tsx | 95 ++++++++++ 6 files changed, 295 insertions(+), 7 deletions(-) create mode 100644 libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts create mode 100644 libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts create mode 100644 libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx diff --git a/libs/core-ui/src/index.ts b/libs/core-ui/src/index.ts index c67e37bb25..c9f63fc47a 100644 --- a/libs/core-ui/src/index.ts +++ b/libs/core-ui/src/index.ts @@ -95,6 +95,7 @@ export * from "./lib/Interfaces/IModelExplanationData"; export * from "./lib/Interfaces/IConfusionMatrixData"; export * from "./lib/Interfaces/IVisionModelExplanationData"; export * from "./lib/Interfaces/IWeightedDropdownContext"; +export * from "./lib/Interfaces/ITokenDropdownContext"; export * from "./lib/Interfaces/IFilter"; export * from "./lib/Interfaces/IPreBuiltFilter"; export * from "./lib/Interfaces/ICohort"; diff --git a/libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts b/libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts new file mode 100644 index 0000000000..79ae0d7702 --- /dev/null +++ b/libs/core-ui/src/lib/Interfaces/ITokenDropdownContext.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IComboBox, IComboBoxOption } from "@fluentui/react"; +import React from "react"; + +export type TokenOption = number + + +export interface ITokenDropdownContext { + options: IComboBoxOption[]; + selectedKey: TokenOption; + onSelection: ( + event: React.FormEvent, + item?: IComboBoxOption + ) => void; +} diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx index 31bf384112..74a790b45f 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx @@ -10,8 +10,9 @@ import { Stack, Text } from "@fluentui/react"; -import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; -import { ClassImportanceWeights } from "@responsible-ai/interpret"; +//import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; +import { WeightVectors } from "@responsible-ai/core-ui"; +import { TokenImportance } from "@responsible-ai/interpret"; import { localization } from "@responsible-ai/localization"; import React from "react"; @@ -30,7 +31,11 @@ export interface ITextExplanationViewState { maxK: number; topK: number; radio: string; + qaRadio?: string; importances: number[]; + singleTokenImportances: number[]; + selectedToken: number; + tokenIndexes: number[]; text: string[]; } @@ -43,6 +48,15 @@ const options: IChoiceGroupOption[] = [ { key: RadioKeys.Neg, text: localization.InterpretText.View.negButton } ]; +const qaOptions: IChoiceGroupOption[] = [ + /* + * Creates the choices for the QA prediction radio button(local testing) + * TODO: move keys to CommonUtils, and text to localization.InterpretText.View.* + */ + { key: "starting", text: "STARTING POSITION" }, + { key: "ending", text: "ENDING POSITION" }, +]; + const componentStackTokens: IStackTokens = { childrenGap: "m", padding: "m" @@ -59,17 +73,23 @@ export class TextExplanationView extends React.PureComponent< * Initializes the text view with its state */ super(props); - const weightVector = this.props.selectedWeightVector; - const importances = this.computeImportancesForWeightVector( + + const importances = this.computeImportancesForAllTokens( this.props.dataSummary.localExplanations, - weightVector ); + + const selectedToken = 0; // default to the first token + const singleTokenImportances = this.getImportanceForSingleToken(selectedToken); const maxK = this.calculateMaxKImportances(importances); const topK = this.calculateTopKImportances(importances); this.state = { - importances, + importances: importances, + singleTokenImportances: singleTokenImportances, + selectedToken: selectedToken, + tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index), maxK, radio: RadioKeys.All, + qaRadio: "starting", text: this.props.dataSummary.text, topK }; @@ -81,7 +101,17 @@ export class TextExplanationView extends React.PureComponent< this.props.dataSummary.localExplanations !== prevProps.dataSummary.localExplanations ) { - this.updateImportances(this.props.selectedWeightVector); + // TODO: this should be conditionally disabled (when we're explaining QA) + //this.updateImportances(this.props.selectedWeightVector); + + this.setState({ //update token dropdown + tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index), + selectedToken: 0 + }, + () => { + this.updateTokenImportances(); + this.updateSingleTokenImportances(); + }) } } @@ -103,6 +133,8 @@ export class TextExplanationView extends React.PureComponent< + + {/* TODO: this item should be conditionally disabled (when we're explaining QA) {localization.InterpretText.View.label + @@ -113,9 +145,22 @@ export class TextExplanationView extends React.PureComponent< )} + */} + + {/* TODO: this should be conditionally enabled (when we're explaining QA)*/} + + + + + + + {/* TODO: this should be conditionally disabled (when we're explaining QA) + */} + + + + + {this.props.selectedWeightVector !== WeightVectors.AbsAvg && ( + + + + + + + + ); } + /* TODO: move the func back when conditional rendering is ready (commented for passing build) private onWeightVectorChange = (weightOption: WeightVectorOption): void => { this.updateImportances(weightOption); this.props.onWeightChange(weightOption); }; + */ + + private onSelectedTokenChange = (newIndex: number): void => { + + this.setState( + { selectedToken: newIndex }, + () => { + this.updateSingleTokenImportances(); + }); + }; + /* TODO: move the func back when conditional rendering is ready (commented for passing build) private updateImportances(weightOption: WeightVectorOption): void { const importances = this.computeImportancesForWeightVector( this.props.dataSummary.localExplanations, weightOption ); + const topK = this.calculateTopKImportances(importances); const maxK = this.calculateMaxKImportances(importances); this.setState({ @@ -188,6 +280,28 @@ export class TextExplanationView extends React.PureComponent< topK }); } + */ + + // for QA + private updateTokenImportances(): void { + + const importances = this.computeImportancesForAllTokens( + this.props.dataSummary.localExplanations, + ); + const topK = this.calculateTopKImportances(importances); + const maxK = this.calculateMaxKImportances(importances); + this.setState({ + importances, + maxK, + topK, + text: this.props.dataSummary.text + }); + } + + private updateSingleTokenImportances(): void { + const singleTokenImportances = this.getImportanceForSingleToken(this.state.selectedToken); + this.setState({singleTokenImportances: singleTokenImportances}); + } private calculateTopKImportances(importances: number[]): number { return Math.min( @@ -203,6 +317,7 @@ export class TextExplanationView extends React.PureComponent< ); } + /* TODO: move the func back when conditional rendering is ready (commented for passing build) private computeImportancesForWeightVector( importances: number[][], weightVector: WeightVectorOption @@ -221,6 +336,28 @@ export class TextExplanationView extends React.PureComponent< (perClassImportances) => perClassImportances[weightVector as number] ); } + */ + + private computeImportancesForAllTokens( + importances: number[][] + ): number[] { + /* + * sum the tokens importance + * TODO: add base values + */ + const sumImportances = importances.map((row) => + row.reduce((a, b): number => { + return (a + b); + }, 0) + ); + return sumImportances; + } + + private getImportanceForSingleToken( + index: number + ): number[] { + return this.props.dataSummary.localExplanations.map(row => row[index]); + } private setTopK = (newNumber: number): void => { /* @@ -240,4 +377,18 @@ export class TextExplanationView extends React.PureComponent< this.setState({ radio: item.key }); } }; + + // TODO: add logic for switching explanation data + private switchQAprediction = ( + _event?: React.FormEvent, + item?: IChoiceGroupOption + ): void => { + /* + * switch to the target predictions(starting or ending) + */ + if (item?.key !== undefined) { + this.setState({ qaRadio: item.key }); + } + }; + } diff --git a/libs/interpret/src/index.ts b/libs/interpret/src/index.ts index 0ee6695c72..f30eda7c5b 100644 --- a/libs/interpret/src/index.ts +++ b/libs/interpret/src/index.ts @@ -6,6 +6,7 @@ export * from "./lib/MLIDashboard/NewExplanationDashboard"; export * from "./lib/MLIDashboard/Interfaces/IExplanationDashboardProps"; export * from "./lib/MLIDashboard/Interfaces/IStringsParam"; export * from "./lib/MLIDashboard/Controls/ClassImportanceWeights/ClassImportanceWeights"; +export * from "./lib/MLIDashboard/Controls/TokenImportance/TokenImportance"; export * from "./lib/MLIDashboard/Controls/GlobalExplanationTab/GlobalExplanationTab"; export * from "./lib/MLIDashboard/Controls/GlobalExplanationTab/IGlobalSeries"; export * from "./lib/MLIDashboard/Controls/ModelPerformanceTab/ModelPerformanceTab"; diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts new file mode 100644 index 0000000000..f10de55dcc --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.styles.ts @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { IStyle, mergeStyleSets, IProcessedStyleSet } from "@fluentui/react"; + +export interface ILabelWithCalloutStyles { + tokenLabel: IStyle; + tokenLabelText: IStyle; +} + +export const tokenImportanceStyles: () => IProcessedStyleSet = + () => { + return mergeStyleSets({ + tokenLabel: { + display: "inline-flex", + paddingTop: "10px" + }, + tokenLabelText: { + fontWeight: "600", + paddingTop: "5px" + } + }); + }; diff --git a/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx new file mode 100644 index 0000000000..f8e6683cac --- /dev/null +++ b/libs/interpret/src/lib/MLIDashboard/Controls/TokenImportance/TokenImportance.tsx @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { + Dropdown, + IDropdownOption, + Text + } from "@fluentui/react"; + import { TokenOption } from "@responsible-ai/core-ui"; + import React from "react"; + + import { tokenImportanceStyles } from "./TokenImportance.styles"; + + export interface ITokenImportanceProps { + onTokenChange: (option: TokenOption) => void; + selectedToken: TokenOption; + tokenOptions: TokenOption[]; + tokenLabels: any; + disabled?: boolean; + } + interface ITokenImportanceState { + tokenOptionsMap: { key: TokenOption; text: any }[]; + } + + export class TokenImportance extends React.Component< + ITokenImportanceProps, + ITokenImportanceState + > { + public constructor(props: ITokenImportanceProps) { + super(props); + + this.state = { + tokenOptionsMap: this.props.tokenOptions?.map((option) => { + return { + key: option, + text: this.props.tokenLabels[option] + }; + }) + }; + } + + public componentDidUpdate(prevProps: ITokenImportanceProps): void { + if (this.props.tokenOptions !== prevProps.tokenOptions || + this.props.tokenLabels !== prevProps.tokenLabels + ) { + this.setState({ + tokenOptionsMap: this.props.tokenOptions?.map((option) => { + return { + key: option, + text: this.props.tokenLabels[option] + }; + }) + }) + } + } + + public render(): React.ReactNode { + const tokenNames = tokenImportanceStyles(); + return ( +
+
+ + {"Selected Token"} + + +
+ {this.state.tokenOptionsMap && ( + + )} +
+ ); + } + + private setTokenOption = ( + _event: React.FormEvent, + item?: IDropdownOption + ): void => { + if (item?.key === undefined) { + return; + } + + const newIndex = item.key as TokenOption; + this.props.onTokenChange(newIndex); + }; + } + From 99dbafef9365d09ac8feb8adb06676040460cc52 Mon Sep 17 00:00:00 2001 From: Neutrinowo <86736305+JarvisG495@users.noreply.github.com> Date: Thu, 1 Jun 2023 16:49:19 -0400 Subject: [PATCH 2/3] make QA UI compatible with text classification temporally using hardcode isQA flag as the prop for identifying qa from an upper level is TBD --- .../TextExplanationDashboard/CommonUtils.ts | 5 ++ .../TextExplanationView.tsx | 88 ++++++++++--------- 2 files changed, 50 insertions(+), 43 deletions(-) diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts index f5d705c2d7..03ca37e2de 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/CommonUtils.ts @@ -10,6 +10,11 @@ export enum RadioKeys { Neg = "neg" } +export enum QAExplanationType { + Start = "start", + End = "end" +} + export class Utils { public static argsort(toSort: number[]): number[] { /* diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx index 74a790b45f..8a693ac7b0 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx @@ -10,13 +10,12 @@ import { Stack, Text } from "@fluentui/react"; -//import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; -import { WeightVectors } from "@responsible-ai/core-ui"; -import { TokenImportance } from "@responsible-ai/interpret"; +import { WeightVectorOption, WeightVectors } from "@responsible-ai/core-ui"; +import { ClassImportanceWeights, TokenImportance } from "@responsible-ai/interpret"; import { localization } from "@responsible-ai/localization"; import React from "react"; -import { RadioKeys, Utils } from "../../CommonUtils"; +import { RadioKeys, QAExplanationType, Utils } from "../../CommonUtils"; import { ITextExplanationViewProps } from "../../Interfaces/IExplanationViewProps"; import { BarChart } from "../BarChart/BarChart"; import { TextFeatureLegend } from "../TextFeatureLegend/TextFeatureLegend"; @@ -37,6 +36,7 @@ export interface ITextExplanationViewState { selectedToken: number; tokenIndexes: number[]; text: string[]; + isQA: boolean // temporal flag for identifying QA } const options: IChoiceGroupOption[] = [ @@ -51,10 +51,10 @@ const options: IChoiceGroupOption[] = [ const qaOptions: IChoiceGroupOption[] = [ /* * Creates the choices for the QA prediction radio button(local testing) - * TODO: move keys to CommonUtils, and text to localization.InterpretText.View.* + * TODO: move text under localization.InterpretText.View */ - { key: "starting", text: "STARTING POSITION" }, - { key: "ending", text: "ENDING POSITION" }, + { key: QAExplanationType.Start, text: "STARTING POSITION" }, + { key: QAExplanationType.End, text: "ENDING POSITION" }, ]; const componentStackTokens: IStackTokens = { @@ -74,6 +74,8 @@ export class TextExplanationView extends React.PureComponent< */ super(props); + const isQA = false; // FIXME: temporally hardcode the flag, should use prop instead + const importances = this.computeImportancesForAllTokens( this.props.dataSummary.localExplanations, ); @@ -91,7 +93,8 @@ export class TextExplanationView extends React.PureComponent< radio: RadioKeys.All, qaRadio: "starting", text: this.props.dataSummary.text, - topK + topK, + isQA }; } @@ -101,17 +104,18 @@ export class TextExplanationView extends React.PureComponent< this.props.dataSummary.localExplanations !== prevProps.dataSummary.localExplanations ) { - // TODO: this should be conditionally disabled (when we're explaining QA) - //this.updateImportances(this.props.selectedWeightVector); - - this.setState({ //update token dropdown - tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index), - selectedToken: 0 - }, - () => { - this.updateTokenImportances(); - this.updateSingleTokenImportances(); - }) + if (this.state.isQA) { + this.setState({ //update token dropdown + tokenIndexes: Array.from(this.props.dataSummary.text, (_, index) => index), + selectedToken: 0 + }, + () => { + this.updateTokenImportances(); + this.updateSingleTokenImportances(); + }) + } else { + this.updateImportances(this.props.selectedWeightVector); + } } } @@ -134,7 +138,7 @@ export class TextExplanationView extends React.PureComponent< - {/* TODO: this item should be conditionally disabled (when we're explaining QA) + { !this.state.isQA && ( // classfication {localization.InterpretText.View.label + @@ -145,9 +149,9 @@ export class TextExplanationView extends React.PureComponent< )} - */} + )} - {/* TODO: this should be conditionally enabled (when we're explaining QA)*/} + { this.state.isQA && ( // select starting/ending for QA + )} @@ -172,25 +177,26 @@ export class TextExplanationView extends React.PureComponent< /> - {/* TODO: this should be conditionally disabled (when we're explaining QA) - - - - */} - - + { this.state.isQA? + ( + ) + : + ( + + ) + } {this.props.selectedWeightVector !== WeightVectors.AbsAvg && ( @@ -220,9 +226,8 @@ export class TextExplanationView extends React.PureComponent< radio={this.state.radio} /> - - + { this.state.isQA && ( + )} + @@ -248,12 +255,10 @@ export class TextExplanationView extends React.PureComponent< ); } - /* TODO: move the func back when conditional rendering is ready (commented for passing build) private onWeightVectorChange = (weightOption: WeightVectorOption): void => { this.updateImportances(weightOption); this.props.onWeightChange(weightOption); }; - */ private onSelectedTokenChange = (newIndex: number): void => { @@ -264,7 +269,6 @@ export class TextExplanationView extends React.PureComponent< }); }; - /* TODO: move the func back when conditional rendering is ready (commented for passing build) private updateImportances(weightOption: WeightVectorOption): void { const importances = this.computeImportancesForWeightVector( this.props.dataSummary.localExplanations, @@ -280,7 +284,7 @@ export class TextExplanationView extends React.PureComponent< topK }); } - */ + // for QA private updateTokenImportances(): void { @@ -317,7 +321,6 @@ export class TextExplanationView extends React.PureComponent< ); } - /* TODO: move the func back when conditional rendering is ready (commented for passing build) private computeImportancesForWeightVector( importances: number[][], weightVector: WeightVectorOption @@ -336,7 +339,6 @@ export class TextExplanationView extends React.PureComponent< (perClassImportances) => perClassImportances[weightVector as number] ); } - */ private computeImportancesForAllTokens( importances: number[][] @@ -378,13 +380,13 @@ export class TextExplanationView extends React.PureComponent< } }; - // TODO: add logic for switching explanation data private switchQAprediction = ( _event?: React.FormEvent, item?: IChoiceGroupOption ): void => { /* * switch to the target predictions(starting or ending) + * TODO: add logic for switching explanation data */ if (item?.key !== undefined) { this.setState({ qaRadio: item.key }); From ee601721ca05c6d8b88bb894006fa367cf7691a5 Mon Sep 17 00:00:00 2001 From: Neutrinowo <86736305+JarvisG495@users.noreply.github.com> Date: Mon, 5 Jun 2023 14:12:15 -0400 Subject: [PATCH 3/3] fix left textbox data fix left textbox importance data and add description for qa interpret --- .../TextExplanationView.tsx | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx index 8a693ac7b0..70023e1a6e 100644 --- a/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx +++ b/libs/interpret-text/src/lib/TextExplanationDashboard/Control/TextExplanationView/TextExplanationView.tsx @@ -121,10 +121,20 @@ export class TextExplanationView extends React.PureComponent< public render(): React.ReactNode { const classNames = textExplanationDashboardStyles(); + const qaDescription = 'The left text box and the bar chart display the predictions of the model.' + + 'The right textbox shows the feature importance associated with a selected token. Positive feature ' + + 'importances represent the extent that the words were important towards marking the selected token' + + 'as the starting/ending position of the answer.' + return ( - {localization.InterpretText.View.legendText} + { + this.state.isQA? + {qaDescription} : + {localization.InterpretText.View.legendText} + } + @@ -345,13 +355,13 @@ export class TextExplanationView extends React.PureComponent< ): number[] { /* * sum the tokens importance - * TODO: add base values + * TODO: add base values? */ - const sumImportances = importances.map((row) => - row.reduce((a, b): number => { - return (a + b); - }, 0) + + const sumImportances = importances[0].map((_, index) => + importances.reduce((sum, row) => sum + row[index], 0) ); + return sumImportances; }