Skip to content

Commit 6978bd5

Browse files
committed
rename to set_prior_value
1 parent 8ef30a5 commit 6978bd5

File tree

6 files changed

+27
-25
lines changed

6 files changed

+27
-25
lines changed

visual-tree-search-app/components/ControlPanelMCTS.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ interface SearchParams {
1010
goal: string;
1111
maxDepth: number;
1212
iterations: number;
13-
prior_value: boolean;
13+
set_prior_value: boolean;
1414
}
1515

1616
interface ControlPanelProps {
@@ -139,12 +139,12 @@ const ControlPanelMCTS: React.FC<ControlPanelProps> = ({
139139
<div className="mt-4">
140140
<div className="flex items-center space-x-2">
141141
<Checkbox
142-
id="prior_value"
143-
checked={searchParams.prior_value}
144-
onCheckedChange={(checked) => handleParamChange('prior_value', checked === true)}
142+
id="set_prior_value"
143+
checked={searchParams.set_prior_value}
144+
onCheckedChange={(checked) => handleParamChange('set_prior_value', checked === true)}
145145
/>
146146
<Label
147-
htmlFor="prior_value"
147+
htmlFor="set_prior_value"
148148
className="text-slate-700 dark:text-slate-300 font-medium cursor-pointer"
149149
>
150150
Use Prior Value

visual-tree-search-app/pages/MCTSAgent.tsx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ interface SearchParams {
1010
goal: string;
1111
maxDepth: number;
1212
iterations: number;
13-
prior_value: boolean;
13+
set_prior_value: boolean;
1414
}
1515

1616
interface Message {
@@ -32,7 +32,7 @@ const MCTSAgent = () => {
3232
startingUrl: 'http://xwebarena.pathonai.org:7770/',
3333
goal: 'search running shoes, click on the first result',
3434
maxDepth: 3,
35-
prior_value: false,
35+
set_prior_value: false,
3636
iterations: 1
3737
});
3838

@@ -97,7 +97,7 @@ const MCTSAgent = () => {
9797
goal: searchParams.goal,
9898
search_algorithm: "mcts",
9999
iterations: searchParams.iterations,
100-
prior_value: searchParams.prior_value,
100+
set_prior_value: searchParams.set_prior_value,
101101
max_depth: searchParams.maxDepth
102102
};
103103
console.log(request);
@@ -145,7 +145,7 @@ const MCTSAgent = () => {
145145
goal: searchParams.goal,
146146
search_algorithm: "mcts",
147147
iterations: searchParams.iterations,
148-
prior_value: searchParams.prior_value,
148+
set_prior_value: searchParams.set_prior_value,
149149
max_depth: searchParams.maxDepth
150150
};
151151

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,10 +486,11 @@ async def node_evaluation(self, node: LATSNode) -> None:
486486
## TODO: check the logic of updating value/ reward, is the input value?
487487
def backpropagate(self, node: LATSNode, value: float) -> None:
488488
while node:
489-
node.visits += 1
490-
# Calculate running average: newAvg = oldAvg + (value - oldAvg) / newCount
491-
node.value += (value - node.value) / node.visits
492-
node = node.parent
489+
if node.depth != 0:
490+
node.visits += 1
491+
# Calculate running average: newAvg = oldAvg + (value - oldAvg) / newCount
492+
node.value += (value - node.value) / node.visits
493+
node = node.parent
493494

494495
# shared
495496
async def simulation(self, node: LATSNode, websocket=None) -> tuple[float, LATSNode]:

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
291291

292292

293293
# optional: prior value
294-
if self.config.prior_value:
294+
if self.config.set_prior_value:
295295
await self.websocket_step_start(step=2, step_name="node_children_evaluation", websocket=websocket)
296296
await self.node_children_evaluation(selected_node)
297297
tree_data = self._get_tree_data()
@@ -348,14 +348,15 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
348348
print(f"{GREEN}Step 5: Backpropagation{RESET}")
349349
await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
350350
for node in path:
351-
old_value = node.value
352-
node.visits += 1
353-
node.value += (score - node.value) / node.visits
354-
# consiste with lats backpropagation
355-
#node.value = (node.value * (node.visits - 1) + score) / node.visits
356-
print(f"Node {node.action}:")
357-
print(f" Visits: {node.visits}")
358-
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
351+
if node != self.root_node:
352+
old_value = node.value
353+
node.visits += 1
354+
node.value += (score - node.value) / node.visits
355+
# consiste with lats backpropagation
356+
#node.value = (node.value * (node.visits - 1) + score) / node.visits
357+
print(f"Node {node.action}:")
358+
print(f" Visits: {node.visits}")
359+
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
359360
# add websocket information, just use websocket here
360361
# if websocket:
361362
# await websocket.send_json({

visual-tree-search-backend/app/api/lwats/core_async/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class AgentConfig:
3030

3131
# for MCTS
3232
reflection_score: float = 0.75
33-
prior_value: bool = False
33+
set_prior_value: bool = False
3434

3535
# Features
3636
features: List[str] = field(default_factory=lambda: ['axtree'])

visual-tree-search-backend/app/api/routes/tree_search_websocket.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
7979
storage_state = message.get("storage_state", "app/api/shopping.json")
8080
iterations = message.get("iterations", 1) # Extract iterations parameter
8181
num_simulations=message.get("num_simulations", 1)
82-
prior_value = message.get("prior_value", False)
82+
set_prior_value = message.get("set_prior_value", False)
8383

8484
# Send status update
8585
await websocket.send_json({
@@ -97,7 +97,7 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
9797
headless=False,
9898
iterations=iterations,
9999
num_simulations=num_simulations,
100-
prior_value=prior_value
100+
set_prior_value=set_prior_value
101101
)
102102
print(config)
103103

0 commit comments

Comments
 (0)