diff --git a/python/flask_wrapper_analysis_engine.py b/python/flask_wrapper_analysis_engine.py new file mode 100644 index 000000000..9d62ef789 --- /dev/null +++ b/python/flask_wrapper_analysis_engine.py @@ -0,0 +1,86 @@ +from flask import Flask, request, jsonify +from subprocess import Popen, PIPE, STDOUT +import threading +import json +import os + +app = Flask(__name__) + +# Store the process globally +cli_process = None +lock = threading.Lock() + +def start_cli(): + global cli_process + command = os.environ.get('KATAGO_COMMAND') + + if not command: + raise ValueError('Environment variable KATAGO_COMMAND not set') + + with lock: + if cli_process is None: + cli_process = Popen(command, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT, text=True) + wait_for_ready_state(cli_process) + +def wait_for_ready_state(process): + """Reads lines from the CLI program until it outputs the ready state message.""" + try: + while True: + line = process.stdout.readline() + if not line: + break + print(f"Initial CLI Output: {line.strip()}") # Print the initial output for debugging + + if "Started, ready to begin handling requests" in line: + print("CLI program is ready.") + break # Stop reading once the ready state is detected + except Exception as e: + print(f"Error reading initial output: {str(e)}") + +@app.route('/run', methods=['POST']) +def run_cli(): + global cli_process + input_data = request.json + + if not input_data: + return jsonify({'error': 'No input provided'}), 400 + + analyze_turns = input_data.get("analyzeTurns", []) + num_responses_expected = len(analyze_turns) + + with lock: + if cli_process is None: + return jsonify({'error': 'CLI program is not running'}), 400 + + try: + # Send input to the process + input_json = json.dumps(input_data) + cli_process.stdin.write(input_json + '\n') + cli_process.stdin.flush() + + # Read the output from the process until all responses are received + responses = [] + while len(responses) < num_responses_expected: + output_line = cli_process.stdout.readline() + if not output_line: + break + output_line = output_line.strip() + + try: + # Attempt to parse the output as JSON + response_data = json.loads(output_line) + responses.append(response_data) + except json.JSONDecodeError: + # Ignore lines that are not valid JSON + print(f"Non-JSON line ignored: {output_line}") + + if len(responses) != num_responses_expected: + return jsonify({'error': 'Did not receive the expected number of responses from CLI program'}), 500 + + return jsonify(responses) + except Exception as e: + return jsonify({'error': f'An error occurred: {str(e)}'}), 500 + +if __name__ == '__main__': + start_cli() # Start the CLI program when the server starts + app.run(host='0.0.0.0', port=5000) diff --git a/python/flask_wrapper_human_sl_server.py b/python/flask_wrapper_human_sl_server.py new file mode 100644 index 000000000..3685951d5 --- /dev/null +++ b/python/flask_wrapper_human_sl_server.py @@ -0,0 +1,119 @@ +from flask import Flask, request, jsonify +from subprocess import Popen, PIPE +import time +import threading +import json +from flask_cors import CORS, cross_origin +import os + +app = Flask(__name__) +cors = CORS(app) +app.config['CORS_HEADERS'] = 'Content-Type' + +app = Flask(__name__) + +# Store the process globally +cli_process = None +lock = threading.Lock() + +def start_cli(): + global cli_process + global monitor_output_thread + command = "python ./humanslnet_server.py -checkpoint ./b18c384nbt-humanv0.ckpt -device cpu -webserver True" + + with lock: + if cli_process is None: + cli_process = Popen(command, shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True) + + wait_for_ready_state(cli_process) + stderr_thread = threading.Thread(target=monitor_stderr, args=(cli_process,)) + + stderr_thread.daemon = True + + stderr_thread.start() + + +def wait_for_ready_state(process): + """Reads lines from the CLI program until it outputs the ready state message.""" + try: + while True: + line = process.stdout.readline() + if not line: + break + print(f"Initial CLI Output: {line.strip()}") # Print the initial output for debugging + + if "Ready to receive input" in line: + print("CLI program is ready.") + break # Stop reading once the ready state is detected + except Exception as e: + print(f"Error reading initial output: {str(e)}") + +def monitor_stderr(process): + global cli_process + while True: + error_output = process.stderr.readline() + if error_output == '' and process.poll() is not None: + break + if error_output: + print(f"ERROR: {error_output.strip()}") + if "ERROR" in error_output or "Exception" in error_output or "Error" in error_output: + print("Error detected, quitting the process...") + os._exit(1) + time.sleep(0.1) + + +def restart_process(): + global cli_process + with lock: + if cli_process: + cli_process.kill() + cli_process = None + start_cli() + +@app.route('/', methods=['POST']) +@cross_origin() +def run_cli(): + global cli_process + input_data = request.json + + if not input_data: + return jsonify({'error': 'No input provided'}), 400 + + with lock: + if cli_process is None: + start_cli() + return jsonify({'error': 'CLI program was not running, restarted'}), 400 + + try: + # Send input to the process + input_json = json.dumps(input_data) + cli_process.stdin.write(input_json + '\n') + cli_process.stdin.flush() + + # Read the output from the process until all responses are received + num_responses_expected = 1 + responses = [] + while len(responses) < num_responses_expected: + output_line = cli_process.stdout.readline() + if not output_line: + break + output_line = output_line.strip() + + try: + # Attempt to parse the output as JSON + response_data = json.loads(output_line) + responses.append(response_data) + except json.JSONDecodeError: + # Ignore lines that are not valid JSON + print(f"Non-JSON line ignored: {output_line}") + + if len(responses) != num_responses_expected: + return jsonify({'error': 'Did not receive the expected number of responses from CLI program'}), 500 + + return jsonify(responses) + except Exception as e: + return jsonify({'error': f'An error occurred: {str(e)}'}), 500 + +if __name__ == '__main__': + start_cli() # Start the CLI program when the server starts + app.run(host='0.0.0.0', port=5000) diff --git a/python/forever_run_flask_wrapper_sl_server.sh b/python/forever_run_flask_wrapper_sl_server.sh new file mode 100755 index 000000000..36794d53c --- /dev/null +++ b/python/forever_run_flask_wrapper_sl_server.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +while true; do + python flask_wrapper_human_sl_server.py + echo "The server crashed with exit code $?. Respawning.." >&2 + sleep 1 +done diff --git a/python/gamestate.py b/python/gamestate.py index dbfec57d6..8bbe685ee 100644 --- a/python/gamestate.py +++ b/python/gamestate.py @@ -109,9 +109,6 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non features = Features(model.config, model.pos_len) bin_input_data, global_input_data = self.get_input_features(features) - # Currently we don't actually do any symmetries - # symmetry = 0 - # model_outputs = model(apply_symmetry(batch["binaryInputNCHW"],symmetry),batch["globalInputNC"]) input_meta = None if sgfmeta is not None: @@ -278,7 +275,38 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non for name in list(extra_outputs.returned.keys()): if name.endswith(".attention"): extra_outputs.returned[name] = torch.transpose(extra_outputs.returned[name],1,2) - + print({ + "policy0": policy0, + "policy1": policy1, + "moves_and_probs0": moves_and_probs0, + "moves_and_probs1": moves_and_probs1, + "value": value, + "td_value": td_value, + "td_value2": td_value2, + "td_value3": td_value3, + "scoremean": scoremean, + "td_score": td_score, + "scorestdev": scorestdev, + "lead": lead, + "vtime": vtime, + "estv": estv, + "ests": ests, + "ownership": ownership, + "ownership_by_loc": ownership_by_loc, + "scoring": scoring, + "scoring_by_loc": scoring_by_loc, + "futurepos": futurepos, + "futurepos0_by_loc": futurepos0_by_loc, + "futurepos1_by_loc": futurepos1_by_loc, + "seki": seki, + "seki_by_loc": seki_by_loc, + "seki2": seki2, + "seki_by_loc2": seki_by_loc2, + "scorebelief": scorebelief, + "genmove_result": genmove_result, + **{ name:activation[0].numpy() for name, activation in extra_outputs.returned.items() }, + "available_extra_outputs": available_extra_outputs, + }) return { "policy0": policy0, "policy1": policy1, @@ -311,4 +339,52 @@ def get_model_outputs(self, model: "Model", sgfmeta: Optional[SGFMetadata] = Non **{ name:activation[0].numpy() for name, activation in extra_outputs.returned.items() }, "available_extra_outputs": available_extra_outputs, } + + def run_monte_carlo_tree_search(self, model, sgfmeta, max_visits=20): + self.children = [] + model_outputs = self.get_model_outputs(model, sgfmeta) + self.move_probabilities = model_outputs["moves_and_probs0"] + for _ in range(max_visits): + self.do_visit(model, sgfmeta) + self.visits = 0 + + def choose_move(self): + cumulative_probabilities = np.cumsum([self.moves_and_probs0[:][1]]) + random_number = np.random.rand() + for i in range(len(cumulative_probabilities)): + if cumulative_probabilities[i] >= random_number: + return self.moves_and_probs0[i] + + def get_best_move(self): + best_move_index = np.argmax([self.moves_and_probs0[:][1]]) + return self.moves_and_probs0[best_move_index] + + def do_visit(self, model, sgfmeta): + move = move = self.choose_move() + + if not hasattr(self.children, move[0]): + child = self.clone() + child.play(move) + self.children[move[0]] = child + else: + self.children[move[0]].do_visit(model, sgfmeta) + + if self.parent is not None: + self.parent.backpropagate(self) + + def backpropagate(self, child_node): + if(child_node.best_move_win_chance > self.best_move_win_chance): + self.best_move = child_node.move + self.best_move_win_chance = child_node.best_move_win_chance + self.visits += child_node.visits + self.forward_propagate_visits() + + def clone(self): + game_state = GameState(self.board_size, self.rules) + game_state.redo_stack = self.redo_stack[:] + game_state.moves = self.moves[:] + game_state.boards = self.boards[:] + game_state.board = self.board + game_state.parent = self + return game_state diff --git a/python/humanslnet_server.py b/python/humanslnet_server.py index 9d6e351b6..3fa7fc77e 100644 --- a/python/humanslnet_server.py +++ b/python/humanslnet_server.py @@ -14,11 +14,16 @@ def numpy_array_encoder(obj): return float(obj) raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') +def write(output): + sys.stdout.write(json.dumps(output,default=numpy_array_encoder) + "\n") + sys.stdout.flush() + def main(): parser = argparse.ArgumentParser() parser.add_argument('-checkpoint', help='Checkpoint to test', required=True) parser.add_argument('-use-swa', help='Use SWA model', action="store_true", required=False) parser.add_argument('-device', help='Device to use, such as cpu or cuda:0', required=True) + parser.add_argument('-webserver', help='set if is used for the flask wrapper', required=False) args = parser.parse_args() model, swa_model, _ = load_model(args.checkpoint, use_swa=args.use_swa, device=args.device, pos_len=19, verbose=False) @@ -26,10 +31,10 @@ def main(): model = swa_model game_state = None - def write(output): - sys.stdout.write(json.dumps(output,default=numpy_array_encoder) + "\n") - sys.stdout.flush() + if args.webserver: + write("Ready to receive input") + # DEBUGGING # game_state = GameState(board_size=19, rules=GameState.RULES_JAPANESE) # sgfmeta = SGFMetadata() @@ -65,8 +70,6 @@ def write(output): elif data["command"] == "get_model_outputs": sgfmeta = SGFMetadata.of_dict(data["sgfmeta"]) - # features = Features(model.config, model.pos_len) - # foo = game_state.get_input_features(features) outputs = game_state.get_model_outputs(model, sgfmeta=sgfmeta) filtered_outputs = {} for key in outputs: @@ -74,8 +77,19 @@ def write(output): filtered_outputs[key] = outputs[key] write(dict(outputs=filtered_outputs)) + elif data["command"] == "get_best_move": + sgfmeta = SGFMetadata.of_dict(data["sgfmeta"]) + + # Run Monte Carlo Tree Search with a specified number of visits + visits = data.get("visits", 100) # Default to 100 visits if not specified + game_state.run_monte_carlo_tree_search(model, sgfmeta, visits) + + # Write the refined outputs back to the output + write({"best_move": game_state.best_move}) + else: raise ValueError(f"Unknown command: {data['command']}") + if __name__ == "__main__": main() diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 000000000..cc1f11109 --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,5 @@ +matplotlib==3.9.0 +numpy==2.0.1 +torch==2.4.0 +wxPython==4.2.1 +flask_cors==4.0.1