Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ backend:
- appserver/**/*
- statistical-enrichment/*
- statistical-enrichment/**/*
- llm/*
- llm/**/*
- cache-invalidator/*
- cache-invalidator/**/*

Expand Down
3 changes: 3 additions & 0 deletions appserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class Base:
SE_HOST = os.environ.get('SE_HOST', 'statistical-enrichment')
SE_PORT = os.environ.get('SE_PORT', '5010')

LLM_HOST = os.environ.get('LLM_HOST', 'llm')
LLM_PORT = os.environ.get('LLM_PORT', '5000')

NLP_SERVICE_ENDPOINT = os.environ.get(
'NLP_SERVICE_ENDPOINT', 'https://nlp-api.lifelike.bio/v1/predict'
)
Expand Down
23 changes: 23 additions & 0 deletions appserver/neo4japp/blueprints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from neo4japp.schemas.context import ContextRelationshipRequestSchema
from neo4japp.services.chat_gpt import ChatGPT
from neo4japp.services.llm import LLM
from neo4japp.utils.globals import get_current_username

bp = Blueprint('chat-gpt-api', __name__, url_prefix='/explain')
Expand Down Expand Up @@ -50,3 +51,25 @@ def relationship(params):
"result": choice.get('message').get('content').strip(),
"query_params": create_params,
}


@bp.route('/relationship/graph', methods=['POST'])
@use_args(ContextRelationshipRequestSchema)
def relationship_graph(params):
entities = params.get('entities', [])
context = params.get('context')
options = params.get('options', {})
create_params = dict(
model="gpt-3.5-turbo",
query=(
'What is the relationship between '
+ ', '.join(entities)
+ (f', {context}' if context else '')
+ '?'
),
temperature=options.get('temperature', 0),
max_tokens=2000,
user=str(hash(get_current_username())),
timeout=60,
)
return LLM.graph_qa(**create_params)
87 changes: 87 additions & 0 deletions appserver/neo4japp/services/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
from http import HTTPStatus
from typing import cast

import requests
from flask import current_app, Response, g

from neo4japp.exceptions import ServerException
from neo4japp.utils.globals import config
from neo4japp.utils.transaction_id import transaction_id


class LLM:
@classmethod
def request(cls, path, **request_args) -> Response:
host = config.get('LLM_HOST')
port = config.get('LLM_PORT')
host_port = f'{host}:{port}'
url = f'http://{host_port}{path}'
try:
request_args['url'] = url
resp = requests.request(**request_args)
excluded_headers = [
'content-encoding',
'content-length',
'transfer-encoding',
'connection',
]
headers = [
(name, value)
for (name, value) in resp.raw.headers.items()
if name.lower() not in excluded_headers
]
except Exception as e:
raise ServerException(
'An unexpected error occurred while connecting to service.',
fields={
arg: request_args[arg]
for arg in request_args
if arg not in ['headers', 'cookies']
},
) from e

# 500 should contain message from service so we try to include it
if resp.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
try:
decoded_error_message = json.loads(resp.content)['message']
except Exception as e:
# log and proceed so general error can be raised
current_app.logger.error(
f'Could not process 500 error response from forwarded request.',
exc_info=e,
)
else:
raise ServerException(
'Service error',
decoded_error_message,
code=cast(HTTPStatus, resp.status_code),
)

# All errors including failure to parse internal error message
if 400 <= resp.status_code < 600:
raise ServerException(
'Unable to process request',
'An internal error of service occurred.',
code=cast(HTTPStatus, resp.status_code),
)

return Response(resp.content, resp.status_code, headers)

@classmethod
def graph_qa(cls, query, user, **data):
return cls.request(
'/graph',
data=json.dumps(
{
'query': query,
'transaction_id': getattr(g, 'transaction_id'),
'user': user,
'graph': {
'database_type': 'arango',
},
}
),
method='POST',
headers={'Content-Type': 'application/json'},
)
93 changes: 73 additions & 20 deletions client/src/app/drawing-tool/components/prompt/prompt.component.html
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@
<ng-container
*ngIf="{
explain: explain$ | async,
explanation: explanation$ | async
explanation: explanation$ | async,
graph_explanation: graphExplanation$ | async,
graph: graph$ | async
} as state"
>
<button
Expand All @@ -80,30 +82,81 @@
{{ state.explain ? 'Regenerate' : 'Generate' }} explanation
</ng-template>
</button>
<br />
<div class="card mt-2" style="white-space: pre-line" *ngIf="state.explanation?.value">
<div class="card-body p-3 bg-light" style="white-space: pre-wrap">
<span
style="float: right"
class="ml-2 mb-2"
(click)="copyToClipboard(state.explanation.value.result)"
>
<i class="fas fa-copy"></i>
</span>
<span
style="float: right"
class="ml-2 mb-2"
(click)="openInfo(state.explanation.value.query_params)"
>
<i class="fas fa-info-circle"></i>
</span>
{{ state.explanation.value.result }}
<ng-container *ngIf="state.explain">
<div class="card mt-2" style="white-space: pre-line">
<div class="card-body p-3 bg-light" style="white-space: pre-wrap">
<div>
ChatGPT
<span
style="float: right"
class="ml-2 mb-2"
(click)="copyToClipboard(state.explanation.value.result)"
>
<i class="fas fa-copy"></i>
</span>
<span
style="float: right"
class="ml-2"
*ngIf="state.explanation?.value?.query_params"
(click)="openInfo(state.explanation?.value?.query_params)"
>
<i class="fas fa-info-circle"></i>
</span>
</div>
<ng-container *ngIf="state.explanation?.loading; else settled_chat">
<i class="fa fa-spinner fa-spin mr-1"></i>
</ng-container>
<ng-template #settled_chat>
{{ state.explanation.value.result }}
</ng-template>
</div>
</div>
</div>
<div class="card mt-2" style="white-space: pre-line">
<div class="card-body p-3 bg-light" style="white-space: pre-wrap">
<div>
ChatGPT over graph
<span
style="float: right"
class="ml-2 mb-2"
(click)="copyToClipboard(state.graph_explanation.value?.result?.result)"
>
<i class="fas fa-copy"></i>
</span>
<span
style="float: right"
class="ml-2 mb-2"
*ngIf="state.graph_explanation?.value?.result?.query"
(click)="openInfo({ query: state.graph_explanation?.value?.result?.query })"
>
<i class="fas fa-info-circle"></i>
</span>
</div>
<ng-container *ngIf="state.graph_explanation?.loading; else settled_graph">
<i class="fa fa-spinner fa-spin mr-1"></i>
</ng-container>
<ng-template #settled_graph>
{{ state.graph_explanation.value?.result?.result }}
</ng-template>
<app-vis-network
style="height: 200px; resize: vertical; overflow: hidden"
class="mt-1"
*ngIf="state.graph?.nodes?.length"
[data]="state.graph"
[options]="networkConfig"
></app-vis-network>
</div>
</div>
</ng-container>

<app-module-error
class="d-block mt-4"
*ngIf="state.explanation?.error as error"
[error]="error"
></app-module-error>

<app-module-error
class="d-block mt-4"
*ngIf="state.graph_explanation?.error as error"
[error]="error"
></app-module-error>
</ng-container>
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ import {
} from 'app/playground/components/form/drawing-tool-prompt-form/drawing-tool-prompt-form.component';
import { OpenPlaygroundParams } from 'app/playground/components/open-playground/open-playground.component';
import { OpenFileProvider } from 'app/shared/providers/open-file/open-file.provider';
import { ExplainService } from 'app/shared/services/explain.service';
import { DropdownController } from 'app/shared/utils/dropdown.controller.factory';
import { ExplainService, GraphChatGPTResponse } from 'app/shared/services/explain.service';
import { openModal } from 'app/shared/utils/modals';
import { PlaygroundComponent } from 'app/playground/components/playground.component';
import { ChatgptResponseInfoModalComponent } from 'app/shared/components/chatgpt-response-info-modal/chatgpt-response-info-modal.component';
import { ChatGPTResponse } from 'app/enrichment/services/enrichment-visualisation.service';
import { addStatus, PipeStatus } from 'app/shared/pipes/add-status.pipe';
import { ClipboardService } from 'app/shared/services/clipboard.service';
import { AuthenticationService } from 'app/auth/services/authentication.service';
import { annotationTypesMap } from 'app/shared/annotation-styles';

import { MapStoreService, setContext } from '../../services/map-store.service';

Expand All @@ -46,6 +47,44 @@ export class DrawingToolPromptComponent implements OnDestroy, OnChanges {
readonly authService: AuthenticationService
) {}

networkConfig = {
interaction: {
hover: true,
multiselect: true,
selectConnectedEdges: false,
},
physics: {
enabled: true,
solver: 'barnesHut',
},
edges: {
font: {
size: 12,
},
length: 250,
widthConstraint: {
maximum: 90,
},
},
nodes: {
scaling: {
min: 25,
max: 50,
label: {
enabled: true,
min: 12,
max: 72,
maxVisible: 72,
drawThreshold: 5,
},
},
shape: 'box',
widthConstraint: {
maximum: 180,
},
},
};

private readonly destroy$: Subject<void> = new Subject();

@Input() entities!: Iterable<string>;
Expand Down Expand Up @@ -110,6 +149,38 @@ export class DrawingToolPromptComponent implements OnDestroy, OnChanges {
)
);

readonly graphExplanation$: Observable<PipeStatus<GraphChatGPTResponse>> = this.params$.pipe(
switchMap((params) =>
this.explain$.pipe(
map(() => params),
takeUntil(this.destroy$),
switchMap(({ entities, temperature, context }) =>
this.explainService
.relationshipGraph(entities, context, { temperature })
.pipe(addStatus())
),
startWith(undefined)
)
)
);

readonly graph$ = this.graphExplanation$.pipe(
map((explanation) => explanation?.value?.graph),
map((graph) =>
graph
? {
edges: graph.edges,
nodes: graph.nodes.map(({ eid, type, entityType, displayName }) => ({
id: eid,
label: displayName,
color: annotationTypesMap.get((type ?? entityType)?.toLowerCase())?.color,
})),
}
: undefined
),
shareReplay({ bufferSize: 1, refCount: true })
);

readonly playgroundParams$: Observable<OpenPlaygroundParams<DrawingToolPromptFormParams>> =
combineLatest([this.params$, this.contexts$]).pipe(
map(([{ temperature, entities, context }, contexts]) => ({
Expand Down
4 changes: 3 additions & 1 deletion client/src/app/shared/directives/shareddirectives.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
InteractiveInterfaceHasPlaceholderDirective,
ShowPlaceholderDirective,
} from './placeholder.directive';
import { VisNetworkDirective } from './vis-network.directive';

const directives = [
HasPlaceholderDirective,
Expand All @@ -43,6 +44,7 @@ const directives = [
DebounceInputDirective,
LinkWithoutHrefDirective,
LinkWithHrefDirective,
VisNetworkDirective,
FormInputDirective,
AutoFocusDirective,
InnerXMLDirective,
Expand All @@ -67,7 +69,7 @@ const directives = [

@NgModule({
imports: [],
declarations: [...directives],
declarations: [...directives, VisNetworkDirective],
exports: [...directives],
})
export class SharedDirectivesModule {}
Loading
Loading