Skip to content

Commit f3906ce

Browse files
committed
fix: Update initialize_forum_qa logic
1 parent 048638b commit f3906ce

File tree

1 file changed

+61
-31
lines changed

1 file changed

+61
-31
lines changed

sources/gc-qa-rag-etl/etlapp/ved/initialize_forum_qa.py

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@ class GroupObject:
4343

4444
@dataclass
4545
class ForumObject:
46-
"""Represents a forum post with metadata and groups of Q&A pairs."""
46+
"""Represents a tutorial post with groups of Q&A pairs."""
4747

48-
summary: str
49-
possible_qa: List[QAObject]
48+
groups: List[GroupObject]
5049

5150

5251
def transform_sparse(embedding: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
@@ -92,6 +91,49 @@ def extract_object(text: str) -> ForumObject:
9291
return ForumObject(summary="", possible_qa=[])
9392

9493

94+
def extract_object(text: str) -> ForumObject:
95+
"""Extract and parse tutorial object from JSON text."""
96+
try:
97+
data = json.loads(text)
98+
groups = []
99+
for group in data.get("Groups", []):
100+
qa_objects = []
101+
for qa in group.get("PossibleQA", []):
102+
qa_objects.append(
103+
QAObject(
104+
question=qa.get("Question", ""),
105+
answer=qa.get("Answer", ""),
106+
question_embedding=EmbeddingData(
107+
embedding=qa.get("QuestionEmbedding", {}).get(
108+
"embedding", []
109+
),
110+
sparse_embedding=qa.get("QuestionEmbedding", {}).get(
111+
"sparse_embedding", []
112+
),
113+
)
114+
if "QuestionEmbedding" in qa
115+
else None,
116+
answer_embedding=EmbeddingData(
117+
embedding=qa.get("AnswerEmbedding", {}).get(
118+
"embedding", []
119+
),
120+
sparse_embedding=qa.get("AnswerEmbedding", {}).get(
121+
"sparse_embedding", []
122+
),
123+
)
124+
if "AnswerEmbedding" in qa
125+
else None,
126+
)
127+
)
128+
groups.append(
129+
GroupObject(summary=group.get("Summary", ""), possible_qa=qa_objects)
130+
)
131+
return ForumObject(groups=groups)
132+
except json.JSONDecodeError:
133+
logger.error("Failed to parse JSON, returning empty tutorial object")
134+
return ForumObject(groups=[])
135+
136+
95137
def create_point(qa: QAObject, metadata: Dict[str, Any]) -> Optional[PointStruct]:
96138
"""Create a point structure for vector storage from a Q&A pair."""
97139
if not qa.question_embedding or not qa.answer_embedding:
@@ -115,19 +157,19 @@ def create_point(qa: QAObject, metadata: Dict[str, Any]) -> Optional[PointStruct
115157

116158

117159
def process_forum_object(
118-
forum: ForumObject, file_index: str, question_index: int, metadata: Dict[str, Any]
160+
group: GroupObject, file_index: str, question_index: int, metadata: Dict[str, Any]
119161
) -> List[PointStruct]:
120-
"""Process a forum object and create points for vector storage."""
162+
"""Process a group object and create points for vector storage."""
121163
points = []
122164

123-
for qa in forum.possible_qa:
165+
for qa in group.possible_qa:
124166
point = create_point(
125167
qa=qa,
126168
metadata={
127169
**metadata,
128170
"file_index": file_index,
129171
"question_index": question_index,
130-
"summary": forum.summary,
172+
"summary": group.summary,
131173
},
132174
)
133175
if point:
@@ -151,7 +193,9 @@ def start_initialize_forum_qa(context: EtlRagContext) -> None:
151193

152194
thread_list = json.loads(read_text_from_file(forum_file_path))['threads']
153195
thread_dict = {
154-
f"{thread['tid']}_{thread['postDate']}": thread for thread in thread_list
196+
f"{thread['tid']}_{thread['postDate']}": thread
197+
for thread in thread_list
198+
if thread['postDate'] >= 1609459200
155199
}
156200

157201
for file_index in thread_dict:
@@ -163,13 +207,6 @@ def start_initialize_forum_qa(context: EtlRagContext) -> None:
163207
logger.warning(f"File does not exist: {file_path}, skipping")
164208
continue
165209

166-
# Check if post date is within specified range (after 2021.01.01 12:00:00 AM)
167-
if thread_dict[file_index]["postDate"] < 1609459200:
168-
logger.info(
169-
f"Post date is outside specified range for file: {file_path}, skipping"
170-
)
171-
continue
172-
173210
content = read_text_from_file(file_path)
174211
forum = extract_object(content)
175212

@@ -181,20 +218,13 @@ def start_initialize_forum_qa(context: EtlRagContext) -> None:
181218
"date": thread_dict[file_index]["postDate"],
182219
}
183220

184-
points = process_forum_object(
185-
forum=forum, file_index=file_index, question_index=0, metadata=metadata
186-
)
221+
for group_index, group in enumerate(forum.groups):
222+
points = process_forum_object(
223+
group=group,
224+
file_index=file_index,
225+
question_index=group_index,
226+
metadata=metadata,
227+
)
187228

188-
if points:
189-
# Retry up to 3 times
190-
max_retries = 3
191-
for attempt in range(max_retries):
192-
try:
193-
client.insert_to_collection(collection_name, points)
194-
break
195-
except Exception as e:
196-
if attempt == max_retries - 1:
197-
logger.error(
198-
f"Failed to insert points after {max_retries} attempts: {str(e)}"
199-
)
200-
time.sleep(1)
229+
if points:
230+
client.insert_to_collection(collection_name, points)

0 commit comments

Comments
 (0)