Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions python/flask_wrapper_analysis_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from flask import Flask, request, jsonify
from subprocess import Popen, PIPE, STDOUT
import threading
import json
import os

app = Flask(__name__)

# Store the process globally
cli_process = None
lock = threading.Lock()

def start_cli():
global cli_process
command = os.environ.get('KATAGO_COMMAND')

if not command:
raise ValueError('Environment variable KATAGO_COMMAND not set')

with lock:
if cli_process is None:
cli_process = Popen(command, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT, text=True)
wait_for_ready_state(cli_process)

def wait_for_ready_state(process):
"""Reads lines from the CLI program until it outputs the ready state message."""
try:
while True:
line = process.stdout.readline()
if not line:
break
print(f"Initial CLI Output: {line.strip()}") # Print the initial output for debugging

if "Started, ready to begin handling requests" in line:
print("CLI program is ready.")
break # Stop reading once the ready state is detected
except Exception as e:
print(f"Error reading initial output: {str(e)}")

@app.route('/run', methods=['POST'])
def run_cli():
global cli_process
input_data = request.json

if not input_data:
return jsonify({'error': 'No input provided'}), 400

analyze_turns = input_data.get("analyzeTurns", [])
num_responses_expected = len(analyze_turns)

with lock:
if cli_process is None:
return jsonify({'error': 'CLI program is not running'}), 400

try:
# Send input to the process
input_json = json.dumps(input_data)
cli_process.stdin.write(input_json + '\n')
cli_process.stdin.flush()

# Read the output from the process until all responses are received
responses = []
while len(responses) < num_responses_expected:
output_line = cli_process.stdout.readline()
if not output_line:
break
output_line = output_line.strip()

try:
# Attempt to parse the output as JSON
response_data = json.loads(output_line)
responses.append(response_data)
except json.JSONDecodeError:
# Ignore lines that are not valid JSON
print(f"Non-JSON line ignored: {output_line}")

if len(responses) != num_responses_expected:
return jsonify({'error': 'Did not receive the expected number of responses from CLI program'}), 500

return jsonify(responses)
except Exception as e:
return jsonify({'error': f'An error occurred: {str(e)}'}), 500

if __name__ == '__main__':
start_cli() # Start the CLI program when the server starts
app.run(host='0.0.0.0', port=5000)
119 changes: 119 additions & 0 deletions python/flask_wrapper_human_sl_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from flask import Flask, request, jsonify
from subprocess import Popen, PIPE
import time
import threading
import json
from flask_cors import CORS, cross_origin
import os

app = Flask(__name__)
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'

app = Flask(__name__)

# Store the process globally
cli_process = None
lock = threading.Lock()

def start_cli():
global cli_process
global monitor_output_thread
command = "python ./humanslnet_server.py -checkpoint ./b18c384nbt-humanv0.ckpt -device cpu -webserver True"

with lock:
if cli_process is None:
cli_process = Popen(command, shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True)

wait_for_ready_state(cli_process)
stderr_thread = threading.Thread(target=monitor_stderr, args=(cli_process,))

stderr_thread.daemon = True

stderr_thread.start()


def wait_for_ready_state(process):
"""Reads lines from the CLI program until it outputs the ready state message."""
try:
while True:
line = process.stdout.readline()
if not line:
break
print(f"Initial CLI Output: {line.strip()}") # Print the initial output for debugging

if "Ready to receive input" in line:
print("CLI program is ready.")
break # Stop reading once the ready state is detected
except Exception as e:
print(f"Error reading initial output: {str(e)}")

def monitor_stderr(process):
global cli_process
while True:
error_output = process.stderr.readline()
if error_output == '' and process.poll() is not None:
break
if error_output:
print(f"ERROR: {error_output.strip()}")
if "ERROR" in error_output or "Exception" in error_output or "Error" in error_output:
print("Error detected, quitting the process...")
os._exit(1)
time.sleep(0.1)


def restart_process():
global cli_process
with lock:
if cli_process:
cli_process.kill()
cli_process = None
start_cli()

@app.route('/', methods=['POST'])
@cross_origin()
def run_cli():
global cli_process
input_data = request.json

if not input_data:
return jsonify({'error': 'No input provided'}), 400

with lock:
if cli_process is None:
start_cli()
return jsonify({'error': 'CLI program was not running, restarted'}), 400

try:
# Send input to the process
input_json = json.dumps(input_data)
cli_process.stdin.write(input_json + '\n')
cli_process.stdin.flush()

# Read the output from the process until all responses are received
num_responses_expected = 1
responses = []
while len(responses) < num_responses_expected:
output_line = cli_process.stdout.readline()
if not output_line:
break
output_line = output_line.strip()

try:
# Attempt to parse the output as JSON
response_data = json.loads(output_line)
responses.append(response_data)
except json.JSONDecodeError:
# Ignore lines that are not valid JSON
print(f"Non-JSON line ignored: {output_line}")

if len(responses) != num_responses_expected:
return jsonify({'error': 'Did not receive the expected number of responses from CLI program'}), 500

return jsonify(responses)
except Exception as e:
return jsonify({'error': f'An error occurred: {str(e)}'}), 500

if __name__ == '__main__':
start_cli() # Start the CLI program when the server starts
app.run(host='0.0.0.0', port=5000)
7 changes: 7 additions & 0 deletions python/forever_run_flask_wrapper_sl_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

while true; do
python flask_wrapper_human_sl_server.py
echo "The server crashed with exit code $?. Respawning.." >&2
sleep 1
done
84 changes: 80 additions & 4 deletions python/gamestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non
features = Features(model.config, model.pos_len)

bin_input_data, global_input_data = self.get_input_features(features)
# Currently we don't actually do any symmetries
# symmetry = 0
# model_outputs = model(apply_symmetry(batch["binaryInputNCHW"],symmetry),batch["globalInputNC"])

input_meta = None
if sgfmeta is not None:
Expand Down Expand Up @@ -278,7 +275,38 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non
for name in list(extra_outputs.returned.keys()):
if name.endswith(".attention"):
extra_outputs.returned[name] = torch.transpose(extra_outputs.returned[name],1,2)

print({
"policy0": policy0,
"policy1": policy1,
"moves_and_probs0": moves_and_probs0,
"moves_and_probs1": moves_and_probs1,
"value": value,
"td_value": td_value,
"td_value2": td_value2,
"td_value3": td_value3,
"scoremean": scoremean,
"td_score": td_score,
"scorestdev": scorestdev,
"lead": lead,
"vtime": vtime,
"estv": estv,
"ests": ests,
"ownership": ownership,
"ownership_by_loc": ownership_by_loc,
"scoring": scoring,
"scoring_by_loc": scoring_by_loc,
"futurepos": futurepos,
"futurepos0_by_loc": futurepos0_by_loc,
"futurepos1_by_loc": futurepos1_by_loc,
"seki": seki,
"seki_by_loc": seki_by_loc,
"seki2": seki2,
"seki_by_loc2": seki_by_loc2,
"scorebelief": scorebelief,
"genmove_result": genmove_result,
**{ name:activation[0].numpy() for name, activation in extra_outputs.returned.items() },
"available_extra_outputs": available_extra_outputs,
})
return {
"policy0": policy0,
"policy1": policy1,
Expand Down Expand Up @@ -311,4 +339,52 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non
**{ name:activation[0].numpy() for name, activation in extra_outputs.returned.items() },
"available_extra_outputs": available_extra_outputs,
}

def run_monte_carlo_tree_search(self, model, sgfmeta, max_visits=20):
self.children = []
model_outputs = self.get_model_outputs(model, sgfmeta)
self.move_probabilities = model_outputs["moves_and_probs0"]
for _ in range(max_visits):
self.do_visit(model, sgfmeta)
self.visits = 0

def choose_move(self):
cumulative_probabilities = np.cumsum([self.moves_and_probs0[:][1]])
random_number = np.random.rand()
for i in range(len(cumulative_probabilities)):
if cumulative_probabilities[i] >= random_number:
return self.moves_and_probs0[i]

def get_best_move(self):
best_move_index = np.argmax([self.moves_and_probs0[:][1]])
return self.moves_and_probs0[best_move_index]

def do_visit(self, model, sgfmeta):
move = move = self.choose_move()

if not hasattr(self.children, move[0]):
child = self.clone()
child.play(move)
self.children[move[0]] = child
else:
self.children[move[0]].do_visit(model, sgfmeta)

if self.parent is not None:
self.parent.backpropagate(self)

def backpropagate(self, child_node):
if(child_node.best_move_win_chance > self.best_move_win_chance):
self.best_move = child_node.move
self.best_move_win_chance = child_node.best_move_win_chance
self.visits += child_node.visits
self.forward_propagate_visits()

def clone(self):
game_state = GameState(self.board_size, self.rules)
game_state.redo_stack = self.redo_stack[:]
game_state.moves = self.moves[:]
game_state.boards = self.boards[:]
game_state.board = self.board
game_state.parent = self
return game_state

24 changes: 19 additions & 5 deletions python/humanslnet_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@ def numpy_array_encoder(obj):
return float(obj)
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')

def write(output):
sys.stdout.write(json.dumps(output,default=numpy_array_encoder) + "\n")
sys.stdout.flush()

def main():
parser = argparse.ArgumentParser()
parser.add_argument('-checkpoint', help='Checkpoint to test', required=True)
parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False)
parser.add_argument('-device', help='Device to use, such as cpu or cuda:0', required=True)
parser.add_argument('-webserver', help='set if is used for the flask wrapper', required=False)
args = parser.parse_args()

model, swa_model, _ = load_model(args.checkpoint, use_swa=args.use_swa, device=args.device, pos_len=19, verbose=False)
if swa_model is not None:
model = swa_model
game_state = None

def write(output):
sys.stdout.write(json.dumps(output,default=numpy_array_encoder) + "\n")
sys.stdout.flush()
if args.webserver:
write("Ready to receive input")


# DEBUGGING
# game_state = GameState(board_size=19, rules=GameState.RULES_JAPANESE)
# sgfmeta = SGFMetadata()
Expand Down Expand Up @@ -65,17 +70,26 @@ def write(output):

elif data["command"] == "get_model_outputs":
sgfmeta = SGFMetadata.of_dict(data["sgfmeta"])
# features = Features(model.config, model.pos_len)
# foo = game_state.get_input_features(features)
outputs = game_state.get_model_outputs(model, sgfmeta=sgfmeta)
filtered_outputs = {}
for key in outputs:
if key in ["moves_and_probs0", "value", "lead", "scorestdev"]:
filtered_outputs[key] = outputs[key]
write(dict(outputs=filtered_outputs))

elif data["command"] == "get_best_move":
sgfmeta = SGFMetadata.of_dict(data["sgfmeta"])

# Run Monte Carlo Tree Search with a specified number of visits
visits = data.get("visits", 100) # Default to 100 visits if not specified
game_state.run_monte_carlo_tree_search(model, sgfmeta, visits)

# Write the refined outputs back to the output
write({"best_move": game_state.best_move})

else:
raise ValueError(f"Unknown command: {data['command']}")


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
matplotlib==3.9.0
numpy==2.0.1
torch==2.4.0
wxPython==4.2.1
flask_cors==4.0.1