diff --git a/interface/lang2sql.py b/interface/lang2sql.py index dcff652..95931b6 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -78,14 +78,12 @@ def execute_query( def display_result( *, res: dict, - database: BaseConnector, ) -> None: """ Lang2SQL 실행 결과를 Streamlit 화면에 출력합니다. Args: res (dict): Lang2SQL 실행 결과 딕셔너리. - database (ConnectDB): SQL 쿼리 실행을 위한 데이터베이스 연결 객체. 출력 항목: - 총 토큰 사용량 @@ -240,8 +238,8 @@ def _as_float(value): if not has_query: st.info("QUERY_MAKER 없이 실행되었습니다. 검색된 테이블 정보만 표시합니다.") - if show_table_section: - st.markdown("---") + if show_table_section or show_chart_section: + database = get_db_connector() try: sql_raw = ( res["generated_query"].content @@ -251,23 +249,24 @@ def _as_float(value): if isinstance(sql_raw, str): sql = LLMResponseParser.extract_sql(sql_raw) df = database.run_sql(sql) - st.dataframe(df.head(10) if len(df) > 10 else df) else: st.error("SQL 원본이 문자열이 아닙니다.") except Exception as e: + st.markdown("---") st.error(f"쿼리 실행 중 오류 발생: {e}") + df = None - if show_chart_section: - st.markdown("---") - try: - sql_raw = ( - res["generated_query"].content - if isinstance(res["generated_query"], AIMessage) - else str(res["generated_query"]) - ) - if isinstance(sql_raw, str): - sql = LLMResponseParser.extract_sql(sql_raw) - df = database.run_sql(sql) + if df is not None and show_table_section: + st.markdown("---") + st.markdown("**쿼리 실행 결과:**") + try: + st.dataframe(df.head(10) if len(df) > 10 else df) + except Exception as e: + st.error(f"결과 테이블 생성 중 오류 발생: {e}") + + if df is not None and show_chart_section: + st.markdown("---") + try: st.markdown("**쿼리 결과 시각화:**") try: if len(res["messages"]) > 1: @@ -292,13 +291,9 @@ def _as_float(value): plotly_code=display_code.generate_plotly_code(), df=df ) st.plotly_chart(fig) - else: - st.error("SQL 원본이 문자열이 아닙니다.") - except Exception as e: - st.error(f"차트 생성 중 오류 발생: {e}") - + except Exception as e: + st.error(f"차트 생성 중 오류 발생: {e}") -db = get_db_connector() st.title(TITLE) @@ -401,4 +396,4 @@ def _as_float(value): top_n=user_top_n, device=device, ) - display_result(res=result, database=db) + display_result(res=result)