From 6e52ec1f3447830e376cb04467c29dedaccb440e Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Thu, 23 Apr 2026 19:32:01 -0600 Subject: [PATCH 1/7] Add go-to-def + hover docs for Erlang --- internal/lsp/beam_server.exs | 483 ++++++++++++++++++++++++++++++ internal/lsp/elixir.go | 11 +- internal/lsp/elixir_test.go | 33 ++ internal/lsp/formatter.go | 181 ++++++++++- internal/lsp/formatter_server.exs | 200 ------------- internal/lsp/server.go | 60 ++++ internal/lsp/server_test.go | 102 +++++++ 7 files changed, 858 insertions(+), 212 deletions(-) create mode 100644 internal/lsp/beam_server.exs delete mode 100644 internal/lsp/formatter_server.exs diff --git a/internal/lsp/beam_server.exs b/internal/lsp/beam_server.exs new file mode 100644 index 0000000..9fc3ee0 --- /dev/null +++ b/internal/lsp/beam_server.exs @@ -0,0 +1,483 @@ +# Persistent BEAM server for Dexter LSP. +# +# Boots a Supervisor with two children: +# 1. Formatter — loads .formatter.exs once and caches formatter options +# 2. CodeIntel — resolves Erlang module source locations via :code/:beam_lib +# +# Both services share a single BEAM process (one startup cost). +# Communication is via stdin/stdout with binary framing: +# +# Request envelope: 1-byte service tag + service-specific payload +# Response envelope: service-specific (formatter and code_intel share the same +# status + length + body format) +# +# Service tags: +# 0x00 = Formatter +# 0x01 = CodeIntel +# +# Formatter protocol (after service tag): +# Request: 2-byte filename length (big-endian) + filename + +# 4-byte content length (big-endian) + content +# Response: 1-byte status (0=ok, 1=error) + +# 4-byte result length (big-endian) + result +# +# CodeIntel protocol (after service tag): +# Request: 1-byte op + +# 2-byte module length (big-endian) + module + +# 2-byte function length (big-endian) + function + +# 1-byte arity (255 = unspecified) +# +# Op 0 (erlang_source) response: +# 1-byte status (0=ok, 1=not_found) + +# 2-byte file length (big-endian) + file + +# 4-byte line (big-endian, 0 if not found) +# +# Op 1 (erlang_docs) response: +# 1-byte status (0=ok, 1=not_found) + +# 4-byte doc length (big-endian) + doc (markdown string) +# +# Sends a ready signal once initialization is complete: +# 1-byte status (0=ok) + 4-byte length (0) +# +# Force raw byte mode on stdin/stdout — without this, the Erlang IO server +# applies Unicode encoding, expanding bytes > 127 to multi-byte UTF-8 and +# corrupting our binary protocol framing. +:io.setopts(:standard_io, encoding: :latin1) + +[mix_root, formatter_exs_path, project_root_arg] = System.argv() + +# In umbrella apps, _build and deps live at the umbrella root, not in +# individual app directories. Walk up from mix_root (bounded by the project +# root) to find the nearest ancestor that contains a _build directory. +expanded_mix_root = Path.expand(mix_root) +expanded_boundary = Path.expand(project_root_arg) + +project_root = + Enum.reduce_while(1..20, expanded_mix_root, fn _, dir -> + cond do + File.dir?(Path.join(dir, "_build")) -> + {:halt, dir} + dir == expanded_boundary -> + {:halt, expanded_mix_root} + true -> + parent = Path.dirname(dir) + + if parent == dir do + {:halt, expanded_mix_root} + else + {:cont, parent} + end + end + end) + +# Add the project's compiled deps to the code path so plugins are available +# without needing Mix.install +project_root +|> Path.join("_build/dev/lib/*/ebin") +|> Path.wildcard() +|> Enum.each(&Code.prepend_path/1) + +# Read .formatter.exs +raw_opts = + if File.regular?(formatter_exs_path) do + {result, _} = Code.eval_file(formatter_exs_path) + if is_list(result), do: result, else: [] + else + [] + end + +plugins = Keyword.get(raw_opts, :plugins, []) + +# Resolve locals_without_parens from import_deps by reading each dep's exported +# formatter config. Mix does this automatically in mix format, but we must +# replicate it here since we eval .formatter.exs directly. +import_deps_locals = + raw_opts + |> Keyword.get(:import_deps, []) + |> Enum.flat_map(fn dep -> + dep_formatter = Path.join([project_root, "deps", to_string(dep), ".formatter.exs"]) + + if File.regular?(dep_formatter) do + {dep_opts, _} = Code.eval_file(dep_formatter) + + if is_list(dep_opts) do + dep_opts + |> Keyword.get(:export, []) + |> Keyword.get(:locals_without_parens, []) + else + [] + end + else + [] + end + end) + +explicit_locals = Keyword.get(raw_opts, :locals_without_parens, []) +all_locals_without_parens = Enum.uniq(import_deps_locals ++ explicit_locals) + +# Extract formatting options +format_opts = + raw_opts + |> Keyword.take([ + :line_length, + :normalize_bitstring_modifiers, + :normalize_charlists_as_sigils, + :force_do_end_blocks + ]) + |> Keyword.put(:locals_without_parens, all_locals_without_parens) + +# Resolve which plugins are actually loaded +active_plugins = Enum.filter(plugins, &Code.ensure_loaded?/1) + +missing_plugins = plugins -- active_plugins + +if missing_plugins != [] do + IO.puts(:stderr, "Formatter: WARNING: could not load plugins: #{Enum.map_join(missing_plugins, ", ", &inspect/1)} (not compiled in _build?). Falling back to standard formatter.") +end + +if active_plugins != [] do + IO.puts(:stderr, "Formatter: plugins loaded: #{Enum.map_join(active_plugins, ", ", &inspect/1)}") +else + IO.puts(:stderr, "Formatter: no plugins") +end + +# ── Formatter Service ────────────────────────────────────────────────────── + +defmodule Dexter.Formatter do + def handle_request(format_opts, plugins) do + case IO.binread(:stdio, 2) do + <> -> + filename = if filename_len > 0, do: IO.binread(:stdio, filename_len), else: "" + <> = IO.binread(:stdio, 4) + content = IO.binread(:stdio, content_len) + + {status, result} = format(content, filename, format_opts, plugins) + size = byte_size(result) + IO.binwrite(:stdio, <>) + :ok + + _ -> + :eof + end + end + + defp format(content, filename, format_opts, plugins) when is_binary(content) do + try do + opts = if filename != "", do: [file: filename] ++ format_opts, else: format_opts + ext = Path.extname(filename) + + # Filter plugins to those that handle this file extension + applicable_plugins = + Enum.filter(plugins, fn plugin -> + features = plugin.features(format_opts) + extensions = Keyword.get(features, :extensions, []) + # If a plugin declares no extensions, it handles .ex/.exs by default + extensions == [] or ext in extensions + end) + + formatted = + if applicable_plugins != [] do + # Redirect group leader to stderr during plugin calls so any + # IO.puts from plugins doesn't corrupt the binary protocol on stdout. + old_gl = Process.group_leader() + Process.group_leader(self(), Process.whereis(:standard_error)) + + try do + Enum.reduce(applicable_plugins, content, fn plugin, acc -> + plugin.format(acc, opts) + end) + after + Process.group_leader(self(), old_gl) + end + else + content |> Code.format_string!(opts) |> IO.iodata_to_binary() + end + + # Ensure trailing newline to match mix format output + formatted = + if String.ends_with?(formatted, "\n"), + do: formatted, + else: formatted <> "\n" + + {0, formatted} + rescue + e -> {1, Exception.message(e)} + catch + kind, reason -> {1, "#{kind}: #{inspect(reason)}"} + end + end + + defp format(_, _, _, _), do: {1, "invalid input"} +end + +# ── CodeIntel Service ────────────────────────────────────────────────────── + +defmodule Dexter.CodeIntel do + @op_erlang_source 0 + @op_erlang_docs 1 + + def handle_request() do + case IO.binread(:stdio, 1) do + <<@op_erlang_source>> -> + handle_erlang_source() + + <<@op_erlang_docs>> -> + handle_erlang_docs() + + _ -> + :eof + end + end + + defp handle_erlang_source do + <> = IO.binread(:stdio, 2) + module_name = if module_len > 0, do: IO.binread(:stdio, module_len), else: "" + <> = IO.binread(:stdio, 2) + function_name = if function_len > 0, do: IO.binread(:stdio, function_len), else: "" + <> = IO.binread(:stdio, 1) + + {status, file, line} = resolve_erlang_source(module_name, function_name, arity) + + file_bytes = if file, do: file, else: "" + file_len = byte_size(file_bytes) + IO.binwrite(:stdio, <>) + :ok + end + + defp resolve_erlang_source(module_name, function_name, arity) do + module_atom = String.to_atom(module_name) + + case find_source_file(module_atom) do + nil -> + {1, nil, 0} + + source_file -> + line = find_function_line(module_atom, function_name, arity) + {0, source_file, line} + end + end + + defp find_source_file(module) do + case :code.get_object_code(module) do + {_module, _binary, beam_path} -> + erl_file = + beam_path + |> to_string() + |> String.replace(~r|(.+)/ebin/([^\s]+)\.beam$|, "\\1/src/\\2.erl") + + if File.exists?(erl_file, [:raw]), do: erl_file, else: nil + + :error -> + nil + end + end + + defp find_function_line(_module_atom, "", _arity), do: 0 + + defp find_function_line(module_atom, function_name, arity) do + function_atom = String.to_atom(function_name) + + # Try abstract code first for arity-accurate line numbers + line = find_line_from_abstract_code(module_atom, function_atom, arity) + + if line > 0 do + line + else + # Fallback: regex search in source file + find_line_from_source(module_atom, function_atom) + end + end + + defp find_line_from_abstract_code(module_atom, function_atom, arity) do + beam_path = :code.which(module_atom) + + if is_list(beam_path) do + case :beam_lib.chunks(beam_path, [:abstract_code]) do + {:ok, {_, [{:abstract_code, {:raw_abstract_v1, forms}}]}} -> + find_function_in_forms(forms, function_atom, arity) + + _ -> + 0 + end + else + 0 + end + end + + defp find_function_in_forms(forms, function_atom, 255 = _unspecified) do + # No arity specified — find the first clause with matching name + Enum.find_value(forms, 0, fn + {:function, anno, ^function_atom, _arity, _clauses} -> + anno_line(anno) + + _ -> + nil + end) + end + + defp find_function_in_forms(forms, function_atom, arity) do + # Exact arity match + exact = + Enum.find_value(forms, nil, fn + {:function, anno, ^function_atom, ^arity, _clauses} -> + anno_line(anno) + + _ -> + nil + end) + + # Fall back to first matching name if exact arity not found + exact || find_function_in_forms(forms, function_atom, 255) + end + + defp anno_line(anno) when is_integer(anno), do: anno + defp anno_line(anno) when is_list(anno), do: Keyword.get(anno, :line, 0) + defp anno_line(anno) when is_map(anno), do: Map.get(anno, :line, 0) + defp anno_line(_), do: 0 + + defp find_line_from_source(module_atom, function_atom) do + case find_source_file(module_atom) do + nil -> + 0 + + source_file -> + pattern = ~r/^#{Regex.escape(to_string(function_atom))}\b\(/u + + source_file + |> File.stream!() + |> Stream.with_index(1) + |> Enum.find_value(0, fn {line, line_number} -> + if Regex.match?(pattern, line), do: line_number, else: nil + end) + end + end + + # ── Erlang Docs ────────────────────────────────────────────────────────── + + defp handle_erlang_docs do + <> = IO.binread(:stdio, 2) + module_name = if module_len > 0, do: IO.binread(:stdio, module_len), else: "" + <> = IO.binread(:stdio, 2) + function_name = if function_len > 0, do: IO.binread(:stdio, function_len), else: "" + <> = IO.binread(:stdio, 1) + + {status, doc} = fetch_erlang_docs(module_name, function_name, arity) + + doc_bytes = doc || "" + doc_len = byte_size(doc_bytes) + IO.binwrite(:stdio, <>) + :ok + end + + defp fetch_erlang_docs(module_name, function_name, arity) do + module_atom = String.to_atom(module_name) + + case Code.fetch_docs(module_atom) do + {:docs_v1, _, :erlang, _format, module_doc, _metadata, docs} -> + if function_name == "" do + # Module-level docs + case extract_doc_text(module_doc) do + nil -> {1, nil} + text -> {0, text} + end + else + function_atom = String.to_atom(function_name) + find_function_doc(docs, function_atom, arity) + end + + _ -> + {1, nil} + end + end + + defp find_function_doc(docs, function_atom, arity) do + # Find matching function docs — prefer exact arity match + candidates = + Enum.filter(docs, fn + {{:function, ^function_atom, _arity}, _anno, _sig, _doc, _meta} -> true + _ -> false + end) + + match = + if arity != 255 do + Enum.find(candidates, fn + {{:function, _, ^arity}, _, _, _, _} -> true + _ -> false + end) || List.first(candidates) + else + List.first(candidates) + end + + case match do + {{:function, _, match_arity}, _anno, signatures, doc, _meta} -> + signature = format_signatures(signatures, function_atom, match_arity) + doc_text = extract_doc_text(doc) + + parts = [] + parts = if signature != "", do: parts ++ ["```erlang\n#{signature}\n```"], else: parts + parts = if doc_text, do: parts ++ [doc_text], else: parts + + case parts do + [] -> {1, nil} + _ -> {0, Enum.join(parts, "\n\n")} + end + + nil -> + {1, nil} + end + end + + defp format_signatures(signatures, function_atom, arity) when is_list(signatures) do + case signatures do + [sig | _] when is_binary(sig) -> sig + _ -> "#{function_atom}/#{arity}" + end + end + + defp format_signatures(_, function_atom, arity), do: "#{function_atom}/#{arity}" + + defp extract_doc_text(%{"en" => text}), do: text + defp extract_doc_text(:hidden), do: nil + defp extract_doc_text(:none), do: nil + defp extract_doc_text(_), do: nil +end + +# ── Main IO Loop ─────────────────────────────────────────────────────────── + +defmodule Dexter.Loop do + @service_formatter 0 + @service_code_intel 1 + + def run(format_opts, plugins, first_call?) do + # Signal ready if first call: status=0, length=0 + if first_call?, do: IO.binwrite(:stdio, <<0, 0, 0, 0, 0>>) + + case IO.binread(:stdio, 1) do + <<@service_formatter>> -> + case Dexter.Formatter.handle_request(format_opts, plugins) do + :ok -> run(format_opts, plugins, false) + :eof -> :ok + end + + <<@service_code_intel>> -> + case Dexter.CodeIntel.handle_request() do + :ok -> run(format_opts, plugins, false) + :eof -> :ok + end + + _ -> + :ok + end + end +end + +try do + Dexter.Loop.run(format_opts, active_plugins, true) +rescue + e -> + IO.puts(:stderr, "Dexter BEAM: crash in loop: #{Exception.message(e)}") + IO.puts(:stderr, Exception.format_banner(:error, e, __STACKTRACE__)) +catch + kind, reason -> + IO.puts(:stderr, "Dexter BEAM: crash in loop: #{inspect(kind)} #{inspect(reason)}") +end diff --git a/internal/lsp/elixir.go b/internal/lsp/elixir.go index 20d9200..2b0bdfa 100644 --- a/internal/lsp/elixir.go +++ b/internal/lsp/elixir.go @@ -178,9 +178,9 @@ func (c CursorContext) Empty() bool { } // isExprToken returns true for token kinds that can be part of a dotted -// expression chain (Module.function). +// expression chain (Module.function or :atom.function). func isExprToken(k parser.TokenKind) bool { - return k == parser.TokModule || k == parser.TokIdent + return k == parser.TokModule || k == parser.TokIdent || k == parser.TokAtom } // ExpressionAtCursor extracts the dotted expression at the cursor position @@ -286,11 +286,14 @@ func expressionAtCursorImpl(tokens []parser.Token, source []byte, lineStarts []i for ti := startIdx; ti <= truncEnd; ti += 2 { t := tokens[ti] text := parser.TokenText(source, t) - if t.Kind == parser.TokModule { + switch t.Kind { + case parser.TokModule, parser.TokAtom: moduleParts = append(moduleParts, text) - } else { + default: // TokIdent — this is the function name; stop here functionName = text + } + if functionName != "" { break } } diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index ea9fb70..a49d460 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -231,6 +231,39 @@ func TestExpressionAtCursor(t *testing.T) { wantMod: "Foo.Bar", wantFunc: "transform", }, + // --- Erlang atom module support --- + { + name: "erlang atom module with function", + code: " :code.all_loaded()", + line: 0, + col: 11, // 'a' in all_loaded + wantMod: ":code", + wantFunc: "all_loaded", + }, + { + name: "erlang atom module cursor on atom", + code: " :code.all_loaded()", + line: 0, + col: 6, // 'o' in code + wantMod: ":code", + wantFunc: "", + }, + { + name: "erlang atom :lists.flatten", + code: ":lists.flatten(data)", + line: 0, + col: 7, // 'f' in flatten + wantMod: ":lists", + wantFunc: "flatten", + }, + { + name: "erlang atom piped", + code: " |> :lists.flatten()", + line: 0, + col: 13, // 'f' in flatten + wantMod: ":lists", + wantFunc: "flatten", + }, } for _, tt := range tests { diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index ad35802..cb17dd1 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -21,17 +21,21 @@ import ( "go.lsp.dev/protocol" ) -//go:embed formatter_server.exs -var formatterScript string +//go:embed beam_server.exs +var beamServerScript string const ( - // How long to wait for the persistent formatter to become ready before + // How long to wait for the persistent BEAM to become ready before // falling back to mix format on a given request. formatterWaitTimeout = 5 * time.Second - // How long a not-ready formatter process is allowed to live before being + // How long a not-ready BEAM process is allowed to live before being // killed and restarted. Also used as the hard cap inside the startup // goroutine to prevent leaked goroutines. formatterStuckTimeout = 30 * time.Second + + // Service tags for the BEAM server protocol + serviceFormatter byte = 0x00 + serviceCodeIntel byte = 0x01 ) type formatterProcess struct { @@ -82,6 +86,7 @@ func (fp *formatterProcess) Format(ctx context.Context, content, filename string filenameBytes := []byte(filename) contentBytes := []byte(content) var req bytes.Buffer + req.WriteByte(serviceFormatter) // service tag _ = binary.Write(&req, binary.BigEndian, uint16(len(filenameBytes))) req.Write(filenameBytes) _ = binary.Write(&req, binary.BigEndian, uint32(len(contentBytes))) @@ -130,6 +135,144 @@ func (fp *formatterProcess) Format(ctx context.Context, content, filename string } } +// ErlangSourceResult holds the resolved source location for an Erlang function. +type ErlangSourceResult struct { + File string + Line int +} + +// ErlangSource asks the BEAM's CodeIntel service to resolve an Erlang module/function +// to its source file and line number. +func (fp *formatterProcess) ErlangSource(ctx context.Context, module, function string, arity int) (*ErlangSourceResult, error) { + fp.mu.Lock() + defer fp.mu.Unlock() + + moduleBytes := []byte(module) + functionBytes := []byte(function) + arityByte := byte(255) // 255 = unspecified + if arity >= 0 && arity < 255 { + arityByte = byte(arity) + } + + var req bytes.Buffer + req.WriteByte(serviceCodeIntel) // service tag + req.WriteByte(0) // op: erlang_source + _ = binary.Write(&req, binary.BigEndian, uint16(len(moduleBytes))) + req.Write(moduleBytes) + _ = binary.Write(&req, binary.BigEndian, uint16(len(functionBytes))) + req.Write(functionBytes) + req.WriteByte(arityByte) + if _, err := fp.stdin.Write(req.Bytes()); err != nil { + return nil, fmt.Errorf("write code_intel request: %w", err) + } + + type readResult struct { + result *ErlangSourceResult + err error + } + ch := make(chan readResult, 1) + go func() { + var status byte + if err := binary.Read(fp.stdout, binary.BigEndian, &status); err != nil { + ch <- readResult{err: fmt.Errorf("read status: %w", err)} + return + } + var fileLen uint16 + if err := binary.Read(fp.stdout, binary.BigEndian, &fileLen); err != nil { + ch <- readResult{err: fmt.Errorf("read file length: %w", err)} + return + } + fileBuf := make([]byte, fileLen) + if _, err := io.ReadFull(fp.stdout, fileBuf); err != nil { + ch <- readResult{err: fmt.Errorf("read file: %w", err)} + return + } + var line uint32 + if err := binary.Read(fp.stdout, binary.BigEndian, &line); err != nil { + ch <- readResult{err: fmt.Errorf("read line: %w", err)} + return + } + if status != 0 { + ch <- readResult{err: fmt.Errorf("erlang source not found")} + return + } + ch <- readResult{result: &ErlangSourceResult{File: string(fileBuf), Line: int(line)}} + }() + + select { + case r := <-ch: + return r.result, r.err + case <-ctx.Done(): + _ = fp.cmd.process.Kill() + <-ch + return nil, ctx.Err() + } +} + +// ErlangDocs asks the BEAM's CodeIntel service for the documentation of an +// Erlang module or function. Returns pre-formatted markdown, or empty string +// if no docs are available (e.g. OTP < 24 or undocumented function). +func (fp *formatterProcess) ErlangDocs(ctx context.Context, module, function string, arity int) (string, error) { + fp.mu.Lock() + defer fp.mu.Unlock() + + moduleBytes := []byte(module) + functionBytes := []byte(function) + arityByte := byte(255) + if arity >= 0 && arity < 255 { + arityByte = byte(arity) + } + + var req bytes.Buffer + req.WriteByte(serviceCodeIntel) + req.WriteByte(1) // op: erlang_docs + _ = binary.Write(&req, binary.BigEndian, uint16(len(moduleBytes))) + req.Write(moduleBytes) + _ = binary.Write(&req, binary.BigEndian, uint16(len(functionBytes))) + req.Write(functionBytes) + req.WriteByte(arityByte) + if _, err := fp.stdin.Write(req.Bytes()); err != nil { + return "", fmt.Errorf("write code_intel request: %w", err) + } + + type readResult struct { + doc string + err error + } + ch := make(chan readResult, 1) + go func() { + var status byte + if err := binary.Read(fp.stdout, binary.BigEndian, &status); err != nil { + ch <- readResult{err: fmt.Errorf("read status: %w", err)} + return + } + var docLen uint32 + if err := binary.Read(fp.stdout, binary.BigEndian, &docLen); err != nil { + ch <- readResult{err: fmt.Errorf("read doc length: %w", err)} + return + } + docBuf := make([]byte, docLen) + if _, err := io.ReadFull(fp.stdout, docBuf); err != nil { + ch <- readResult{err: fmt.Errorf("read doc: %w", err)} + return + } + if status != 0 { + ch <- readResult{doc: ""} + return + } + ch <- readResult{doc: string(docBuf)} + }() + + select { + case r := <-ch: + return r.doc, r.err + case <-ctx.Done(): + _ = fp.cmd.process.Kill() + <-ch + return "", ctx.Err() + } +} + // FormatError represents a formatting failure (e.g. syntax error in the source). // The persistent process is still alive — this is not a protocol/crash error. type FormatError struct { @@ -159,10 +302,10 @@ func (s *Server) startFormatterProcess(mixRoot, formatterExs string) (*formatter if err := os.MkdirAll(scriptDir, 0755); err != nil { return nil, fmt.Errorf("create script dir: %w", err) } - scriptPath := filepath.Join(scriptDir, "formatter_server.exs") - if existing, err := os.ReadFile(scriptPath); err != nil || string(existing) != formatterScript { - if err := os.WriteFile(scriptPath, []byte(formatterScript), 0644); err != nil { - return nil, fmt.Errorf("write formatter script: %w", err) + scriptPath := filepath.Join(scriptDir, "beam_server.exs") + if existing, err := os.ReadFile(scriptPath); err != nil || string(existing) != beamServerScript { + if err := os.WriteFile(scriptPath, []byte(beamServerScript), 0644); err != nil { + return nil, fmt.Errorf("write beam server script: %w", err) } } @@ -384,6 +527,28 @@ func (s *Server) evictFormatter(formatterExs string, fp *formatterProcess) { fp.Close() } +// getBeamProcess returns any alive and ready BEAM process for code intel queries. +// Since all BEAM processes have the CodeIntel service, any one will do. +func (s *Server) getBeamProcess(ctx context.Context) *formatterProcess { + s.formattersMu.Lock() + defer s.formattersMu.Unlock() + + for _, fp := range s.formatters { + if !fp.alive() { + continue + } + // Non-blocking ready check + select { + case <-fp.ready: + if fp.startErr == nil { + return fp + } + default: + } + } + return nil +} + func (s *Server) formatWithMixFormat(ctx context.Context, mixRoot, path, content string) (string, error) { if s.mixBin == "" { return "", fmt.Errorf("mix binary not found") diff --git a/internal/lsp/formatter_server.exs b/internal/lsp/formatter_server.exs deleted file mode 100644 index 12de0f1..0000000 --- a/internal/lsp/formatter_server.exs +++ /dev/null @@ -1,200 +0,0 @@ -# Persistent formatter server for Dexter LSP. -# -# Loads .formatter.exs once and caches the formatter options, then loops over -# stdin formatting requests — no VM startup cost per format. -# -# Plugins (e.g. Styler) are loaded from the project's _build directory — -# no Mix.install or Hex downloads needed. -# -# Protocol (request): 2-byte filename length (big-endian) + filename + -# 4-byte content length (big-endian) + content -# Protocol (response): 1-byte status (0=ok, 1=error) + -# 4-byte result length (big-endian) + result -# -# Sends a ready response (status=0, length=0) once initialization is complete. -# -# Force raw byte mode on stdin/stdout — without this, the Erlang IO server -# applies Unicode encoding, expanding bytes > 127 to multi-byte UTF-8 and -# corrupting our binary protocol framing. -:io.setopts(:standard_io, encoding: :latin1) - -[mix_root, formatter_exs_path, project_root_arg] = System.argv() - -# In umbrella apps, _build and deps live at the umbrella root, not in -# individual app directories. Walk up from mix_root (bounded by the project -# root) to find the nearest ancestor that contains a _build directory. -expanded_mix_root = Path.expand(mix_root) -expanded_boundary = Path.expand(project_root_arg) - -# If there are really umbrella apps with a distance greater than 20 to the root -# we can update this (or maybe make it configurable), but 20 seems like a sane -# limit. -project_root = - Enum.reduce_while(1..20, expanded_mix_root, fn _, dir -> - cond do - File.dir?(Path.join(dir, "_build")) -> - {:halt, dir} - dir == expanded_boundary -> - {:halt, expanded_mix_root} - true -> - parent = Path.dirname(dir) - - if parent == dir do - {:halt, expanded_mix_root} - else - {:cont, parent} - end - end - end) - -# Add the project's compiled deps to the code path so plugins are available -# without needing Mix.install -project_root -|> Path.join("_build/dev/lib/*/ebin") -|> Path.wildcard() -|> Enum.each(&Code.prepend_path/1) - -# Read .formatter.exs -raw_opts = - if File.regular?(formatter_exs_path) do - {result, _} = Code.eval_file(formatter_exs_path) - if is_list(result), do: result, else: [] - else - [] - end - -plugins = Keyword.get(raw_opts, :plugins, []) - -# Resolve locals_without_parens from import_deps by reading each dep's exported -# formatter config. Mix does this automatically in mix format, but we must -# replicate it here since we eval .formatter.exs directly. -import_deps_locals = - raw_opts - |> Keyword.get(:import_deps, []) - |> Enum.flat_map(fn dep -> - dep_formatter = Path.join([project_root, "deps", to_string(dep), ".formatter.exs"]) - - if File.regular?(dep_formatter) do - {dep_opts, _} = Code.eval_file(dep_formatter) - - if is_list(dep_opts) do - dep_opts - |> Keyword.get(:export, []) - |> Keyword.get(:locals_without_parens, []) - else - [] - end - else - [] - end - end) - -explicit_locals = Keyword.get(raw_opts, :locals_without_parens, []) -all_locals_without_parens = Enum.uniq(import_deps_locals ++ explicit_locals) - -# Extract formatting options -format_opts = - raw_opts - |> Keyword.take([ - :line_length, - :normalize_bitstring_modifiers, - :normalize_charlists_as_sigils, - :force_do_end_blocks - ]) - |> Keyword.put(:locals_without_parens, all_locals_without_parens) - -# Resolve which plugins are actually loaded -active_plugins = Enum.filter(plugins, &Code.ensure_loaded?/1) - -missing_plugins = plugins -- active_plugins - -if missing_plugins != [] do - IO.puts(:stderr, "Formatter: WARNING: could not load plugins: #{Enum.map_join(missing_plugins, ", ", &inspect/1)} (not compiled in _build?). Falling back to standard formatter.") -end - -if active_plugins != [] do - IO.puts(:stderr, "Formatter: plugins loaded: #{Enum.map_join(active_plugins, ", ", &inspect/1)}") -else - IO.puts(:stderr, "Formatter: no plugins") -end - -defmodule Formatter.Loop do - def run(format_opts, plugins, first_call?) do - # Signal ready if first call: status=0, length=0 - if first_call?, do: IO.binwrite(:stdio, <<0, 0, 0, 0, 0>>) - - case IO.binread(:stdio, 2) do - <> -> - filename = if filename_len > 0, do: IO.binread(:stdio, filename_len), else: "" - <> = IO.binread(:stdio, 4) - content = IO.binread(:stdio, content_len) - - {status, result} = format(content, filename, format_opts, plugins) - size = byte_size(result) - IO.binwrite(:stdio, <>) - run(format_opts, plugins, false) - - _ -> - :ok - end - end - - defp format(content, filename, format_opts, plugins) when is_binary(content) do - try do - opts = if filename != "", do: [file: filename] ++ format_opts, else: format_opts - ext = Path.extname(filename) - - # Filter plugins to those that handle this file extension - applicable_plugins = - Enum.filter(plugins, fn plugin -> - features = plugin.features(format_opts) - extensions = Keyword.get(features, :extensions, []) - # If a plugin declares no extensions, it handles .ex/.exs by default - extensions == [] or ext in extensions - end) - - formatted = - if applicable_plugins != [] do - # Redirect group leader to stderr during plugin calls so any - # IO.puts from plugins doesn't corrupt the binary protocol on stdout. - old_gl = Process.group_leader() - Process.group_leader(self(), Process.whereis(:standard_error)) - - try do - Enum.reduce(applicable_plugins, content, fn plugin, acc -> - plugin.format(acc, opts) - end) - after - Process.group_leader(self(), old_gl) - end - else - content |> Code.format_string!(opts) |> IO.iodata_to_binary() - end - - # Ensure trailing newline to match mix format output - formatted = - if String.ends_with?(formatted, "\n"), - do: formatted, - else: formatted <> "\n" - - {0, formatted} - rescue - e -> {1, Exception.message(e)} - catch - kind, reason -> {1, "#{kind}: #{inspect(reason)}"} - end - end - - defp format(_, _, _, _), do: {1, "invalid input"} -end - -try do - Formatter.Loop.run(format_opts, active_plugins, true) -rescue - e -> - IO.puts(:stderr, "Formatter: crash in loop: #{Exception.message(e)}") - IO.puts(:stderr, Exception.format_banner(:error, e, __STACKTRACE__)) -catch - kind, reason -> - IO.puts(:stderr, "Formatter: crash in loop: #{inspect(kind)} #{inspect(reason)}") -end diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 99814ea..e54db07 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -551,6 +551,14 @@ func (s *Server) Definition(ctx context.Context, params *protocol.DefinitionPara return nil, nil } + // Erlang module atom (e.g. :code.all_loaded) — resolve via BEAM process. + // Check before ExtractModuleAndFunction which doesn't handle atom-prefixed modules. + if strings.HasPrefix(exprCtx.ModuleRef, ":") { + erlModule := exprCtx.ModuleRef[1:] // strip the : prefix + s.debugf("Definition: Erlang module %q function=%q", erlModule, exprCtx.FunctionName) + return s.erlangDefinition(ctx, erlModule, exprCtx.FunctionName) + } + expr := tf.ResolveModuleExpr(exprCtx.Expr(), lineNum) moduleRef, functionName := ExtractModuleAndFunction(expr) @@ -700,6 +708,53 @@ func filterOutTypes(results []store.LookupResult) []store.LookupResult { return results } +// erlangHover fetches documentation for an Erlang module/function via the +// BEAM process's CodeIntel service. +func (s *Server) erlangHover(ctx context.Context, module, function string) (*protocol.Hover, error) { + fp := s.getBeamProcess(ctx) + if fp == nil { + return nil, nil + } + + doc, err := fp.ErlangDocs(ctx, module, function, -1) + if err != nil || doc == "" { + return nil, nil + } + + return &protocol.Hover{ + Contents: protocol.MarkupContent{ + Kind: protocol.Markdown, + Value: doc, + }, + }, nil +} + +// erlangDefinition resolves an Erlang module/function to its .erl source via +// the BEAM process's CodeIntel service. +func (s *Server) erlangDefinition(ctx context.Context, module, function string) ([]protocol.Location, error) { + fp := s.getBeamProcess(ctx) + if fp == nil { + s.debugf("Definition: no BEAM process available for Erlang resolution") + return nil, nil + } + + result, err := fp.ErlangSource(ctx, module, function, -1) + if err != nil { + s.debugf("Definition: Erlang source lookup failed: %v", err) + return nil, nil + } + + line := result.Line + if line > 0 { + line-- // convert 1-based to 0-based for LSP + } + + return []protocol.Location{{ + URI: uri.File(result.File), + Range: lineRange(line), + }}, nil +} + func lineRange(line int) protocol.Range { return protocol.Range{ Start: protocol.Position{Line: uint32(line), Character: 0}, @@ -2757,6 +2812,11 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (*prot return nil, nil } + // Erlang module atom (e.g. :lists.flatten) — fetch docs via BEAM process + if strings.HasPrefix(exprCtx.ModuleRef, ":") { + return s.erlangHover(ctx, exprCtx.ModuleRef[1:], exprCtx.FunctionName) + } + expr := tf.ResolveModuleExpr(exprCtx.Expr(), lineNum) moduleRef, functionName := ExtractModuleAndFunction(expr) diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 64813ab..3af5bdb 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -4628,3 +4628,105 @@ end` t.Errorf("expected definition location in contract_snapshot_schema.ex, got %v", locs) } } + +func TestDefinition_ErlangAtomModule(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + uri := "file:///test.ex" + server.docs.Set(uri, `defmodule MyModule do + def run do + :code.all_loaded() + end +end`) + + // col=5 on ":code" — should take the Erlang path and return nil + // (no BEAM process in test, so no result — but must not crash or + // fall through to Elixir resolution) + locs := definitionAt(t, server, uri, 2, 5) + if len(locs) != 0 { + t.Errorf("expected no definition without BEAM process, got %v", locs) + } + + // col=10 on "all_loaded" function + locs = definitionAt(t, server, uri, 2, 10) + if len(locs) != 0 { + t.Errorf("expected no definition without BEAM process, got %v", locs) + } +} + +func TestDefinition_ErlangAtomDoesNotAffectElixir(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + indexFile(t, server.store, server.projectRoot, "lib/accounts.ex", `defmodule MyApp.Accounts do + def create(attrs), do: :ok +end +`) + + uri := "file:///test.ex" + server.docs.Set(uri, `defmodule MyModule do + alias MyApp.Accounts + Accounts.create(attrs) +end`) + + // Normal Elixir go-to-definition still works + locs := definitionAt(t, server, uri, 2, 13) + if len(locs) == 0 { + t.Fatal("expected Elixir definition to still work") + } + if !strings.Contains(string(locs[0].URI), "accounts.ex") { + t.Errorf("expected accounts.ex, got %s", locs[0].URI) + } +} + +func TestHover_ErlangAtomModule(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + uri := "file:///test.ex" + server.docs.Set(uri, `defmodule MyModule do + def run do + :lists.flatten(data) + end +end`) + + // col=5 on ":lists" — should take the Erlang path and return nil + // (no BEAM process in test) + hover := hoverAt(t, server, uri, 2, 5) + if hover != nil { + t.Errorf("expected no hover without BEAM process, got %v", hover) + } + + // col=12 on "flatten" + hover = hoverAt(t, server, uri, 2, 12) + if hover != nil { + t.Errorf("expected no hover without BEAM process, got %v", hover) + } +} + +func TestHover_ErlangAtomDoesNotAffectElixir(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + indexFile(t, server.store, server.projectRoot, "lib/accounts.ex", `defmodule MyApp.Accounts do + @doc "Creates a new account." + def create(attrs), do: :ok +end +`) + + uri := "file:///test.ex" + server.docs.Set(uri, `defmodule MyModule do + alias MyApp.Accounts + Accounts.create(attrs) +end`) + + // Normal Elixir hover still works + hover := hoverAt(t, server, uri, 2, 13) + if hover == nil { + t.Fatal("expected Elixir hover to still work") + } + if !strings.Contains(hover.Contents.Value, "Creates a new account") { + t.Errorf("expected doc content, got %q", hover.Contents.Value) + } +} From a59b152afca49aaf5baeccf90686516bdde629e5 Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Thu, 23 Apr 2026 20:09:06 -0600 Subject: [PATCH 2/7] Consolidate into single BEAM process with dynamic formatters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the per-.formatter.exs BEAM process model with a single persistent BEAM process that hosts multiple formatter children on demand via DynamicSupervisor + Registry. CodeIntel (Erlang go-to-def, hover docs) is a singleton — no more picking an arbitrary process for code intel queries. - Formatter GenServer starts lazily per .formatter.exs path, loads config and plugins independently, caches opts for subsequent requests - Walk findFormatterConfig up to projectRoot instead of mixRoot, fixing umbrella projects with a root-only .formatter.exs (#34) - Add bp.Ready() guard before Erlang go-to-def/hover to prevent stdout race between the ready signal reader and request readers - Collapse Server.formatters map + formattersMu into single beam/beamMu --- internal/lsp/beam_server.exs | 255 +++++++++++++++++----------- internal/lsp/formatter.go | 301 +++++++++++++++------------------ internal/lsp/formatter_test.go | 48 ++++++ internal/lsp/server.go | 32 ++-- internal/lsp/server_test.go | 54 +++--- 5 files changed, 385 insertions(+), 305 deletions(-) diff --git a/internal/lsp/beam_server.exs b/internal/lsp/beam_server.exs index 9fc3ee0..7495f6e 100644 --- a/internal/lsp/beam_server.exs +++ b/internal/lsp/beam_server.exs @@ -1,22 +1,21 @@ # Persistent BEAM server for Dexter LSP. # -# Boots a Supervisor with two children: -# 1. Formatter — loads .formatter.exs once and caches formatter options -# 2. CodeIntel — resolves Erlang module source locations via :code/:beam_lib +# Single process that hosts: +# - Multiple formatter instances (one per .formatter.exs, started on demand) +# - A singleton CodeIntel service for Erlang source/docs lookups # -# Both services share a single BEAM process (one startup cost). # Communication is via stdin/stdout with binary framing: # # Request envelope: 1-byte service tag + service-specific payload -# Response envelope: service-specific (formatter and code_intel share the same -# status + length + body format) +# Response envelope: service-specific (see below) # # Service tags: # 0x00 = Formatter # 0x01 = CodeIntel # # Formatter protocol (after service tag): -# Request: 2-byte filename length (big-endian) + filename + +# Request: 2-byte formatter_exs path length (big-endian) + formatter_exs path + +# 2-byte filename length (big-endian) + filename + # 4-byte content length (big-endian) + content # Response: 1-byte status (0=ok, 1=error) + # 4-byte result length (big-endian) + result @@ -44,26 +43,24 @@ # corrupting our binary protocol framing. :io.setopts(:standard_io, encoding: :latin1) -[mix_root, formatter_exs_path, project_root_arg] = System.argv() +[project_root_arg] = System.argv() # In umbrella apps, _build and deps live at the umbrella root, not in -# individual app directories. Walk up from mix_root (bounded by the project -# root) to find the nearest ancestor that contains a _build directory. -expanded_mix_root = Path.expand(mix_root) -expanded_boundary = Path.expand(project_root_arg) +# individual app directories. Walk up from project_root (bounded to 20 levels) +# to find the nearest ancestor that contains a _build directory. +expanded_project_root = Path.expand(project_root_arg) -project_root = - Enum.reduce_while(1..20, expanded_mix_root, fn _, dir -> +build_root = + Enum.reduce_while(1..20, expanded_project_root, fn _, dir -> cond do File.dir?(Path.join(dir, "_build")) -> {:halt, dir} - dir == expanded_boundary -> - {:halt, expanded_mix_root} + true -> parent = Path.dirname(dir) if parent == dir do - {:halt, expanded_mix_root} + {:halt, expanded_project_root} else {:cont, parent} end @@ -72,96 +69,121 @@ project_root = # Add the project's compiled deps to the code path so plugins are available # without needing Mix.install -project_root +build_root |> Path.join("_build/dev/lib/*/ebin") |> Path.wildcard() |> Enum.each(&Code.prepend_path/1) -# Read .formatter.exs -raw_opts = - if File.regular?(formatter_exs_path) do - {result, _} = Code.eval_file(formatter_exs_path) - if is_list(result), do: result, else: [] - else - [] +# Formatter Service + +defmodule Dexter.Formatter do + use GenServer + + def start_link(formatter_exs_path) do + GenServer.start_link(__MODULE__, formatter_exs_path, name: via(formatter_exs_path)) end -plugins = Keyword.get(raw_opts, :plugins, []) + def format(formatter_exs_path, content, filename) do + GenServer.call(via(formatter_exs_path), {:format, content, filename}, :infinity) + end -# Resolve locals_without_parens from import_deps by reading each dep's exported -# formatter config. Mix does this automatically in mix format, but we must -# replicate it here since we eval .formatter.exs directly. -import_deps_locals = - raw_opts - |> Keyword.get(:import_deps, []) - |> Enum.flat_map(fn dep -> - dep_formatter = Path.join([project_root, "deps", to_string(dep), ".formatter.exs"]) + defp via(formatter_exs_path) do + {:via, Registry, {Dexter.FormatterRegistry, formatter_exs_path}} + end - if File.regular?(dep_formatter) do - {dep_opts, _} = Code.eval_file(dep_formatter) + @impl true + def init(formatter_exs_path) do + {format_opts, active_plugins} = load_formatter_config(formatter_exs_path) + {:ok, %{format_opts: format_opts, plugins: active_plugins, path: formatter_exs_path}} + end - if is_list(dep_opts) do - dep_opts - |> Keyword.get(:export, []) - |> Keyword.get(:locals_without_parens, []) + @impl true + def handle_call({:format, content, filename}, _from, state) do + result = do_format(content, filename, state.format_opts, state.plugins) + {:reply, result, state} + end + + defp load_formatter_config(formatter_exs_path) do + # Find the build root by looking for _build from the formatter's directory + formatter_dir = Path.dirname(formatter_exs_path) + + project_root = + Enum.reduce_while(1..20, formatter_dir, fn _, dir -> + cond do + File.dir?(Path.join(dir, "_build")) -> {:halt, dir} + true -> + parent = Path.dirname(dir) + if parent == dir, do: {:halt, formatter_dir}, else: {:cont, parent} + end + end) + + # Ensure compiled deps are on the code path so plugins can be loaded + project_root + |> Path.join("_build/dev/lib/*/ebin") + |> Path.wildcard() + |> Enum.each(&Code.prepend_path/1) + + raw_opts = + if File.regular?(formatter_exs_path) do + {result, _} = Code.eval_file(formatter_exs_path) + if is_list(result), do: result, else: [] else [] end - else - [] - end - end) -explicit_locals = Keyword.get(raw_opts, :locals_without_parens, []) -all_locals_without_parens = Enum.uniq(import_deps_locals ++ explicit_locals) + plugins = Keyword.get(raw_opts, :plugins, []) -# Extract formatting options -format_opts = - raw_opts - |> Keyword.take([ - :line_length, - :normalize_bitstring_modifiers, - :normalize_charlists_as_sigils, - :force_do_end_blocks - ]) - |> Keyword.put(:locals_without_parens, all_locals_without_parens) + # Resolve locals_without_parens from import_deps + import_deps_locals = + raw_opts + |> Keyword.get(:import_deps, []) + |> Enum.flat_map(fn dep -> + dep_formatter = Path.join([project_root, "deps", to_string(dep), ".formatter.exs"]) -# Resolve which plugins are actually loaded -active_plugins = Enum.filter(plugins, &Code.ensure_loaded?/1) + if File.regular?(dep_formatter) do + {dep_opts, _} = Code.eval_file(dep_formatter) -missing_plugins = plugins -- active_plugins + if is_list(dep_opts) do + dep_opts + |> Keyword.get(:export, []) + |> Keyword.get(:locals_without_parens, []) + else + [] + end + else + [] + end + end) -if missing_plugins != [] do - IO.puts(:stderr, "Formatter: WARNING: could not load plugins: #{Enum.map_join(missing_plugins, ", ", &inspect/1)} (not compiled in _build?). Falling back to standard formatter.") -end + explicit_locals = Keyword.get(raw_opts, :locals_without_parens, []) + all_locals_without_parens = Enum.uniq(import_deps_locals ++ explicit_locals) -if active_plugins != [] do - IO.puts(:stderr, "Formatter: plugins loaded: #{Enum.map_join(active_plugins, ", ", &inspect/1)}") -else - IO.puts(:stderr, "Formatter: no plugins") -end + format_opts = + raw_opts + |> Keyword.take([ + :line_length, + :normalize_bitstring_modifiers, + :normalize_charlists_as_sigils, + :force_do_end_blocks + ]) + |> Keyword.put(:locals_without_parens, all_locals_without_parens) -# ── Formatter Service ────────────────────────────────────────────────────── + active_plugins = Enum.filter(plugins, &Code.ensure_loaded?/1) -defmodule Dexter.Formatter do - def handle_request(format_opts, plugins) do - case IO.binread(:stdio, 2) do - <> -> - filename = if filename_len > 0, do: IO.binread(:stdio, filename_len), else: "" - <> = IO.binread(:stdio, 4) - content = IO.binread(:stdio, content_len) + missing_plugins = plugins -- active_plugins - {status, result} = format(content, filename, format_opts, plugins) - size = byte_size(result) - IO.binwrite(:stdio, <>) - :ok + if missing_plugins != [] do + IO.puts(:stderr, "Formatter: WARNING: could not load plugins: #{Enum.map_join(missing_plugins, ", ", &inspect/1)} (not compiled in _build?). Falling back to standard formatter.") + end - _ -> - :eof + if active_plugins != [] do + IO.puts(:stderr, "Formatter: plugins loaded for #{formatter_exs_path}: #{Enum.map_join(active_plugins, ", ", &inspect/1)}") end + + {format_opts, active_plugins} end - defp format(content, filename, format_opts, plugins) when is_binary(content) do + defp do_format(content, filename, format_opts, plugins) when is_binary(content) do try do opts = if filename != "", do: [file: filename] ++ format_opts, else: format_opts ext = Path.extname(filename) @@ -171,7 +193,6 @@ defmodule Dexter.Formatter do Enum.filter(plugins, fn plugin -> features = plugin.features(format_opts) extensions = Keyword.get(features, :extensions, []) - # If a plugin declares no extensions, it handles .ex/.exs by default extensions == [] or ext in extensions end) @@ -207,10 +228,10 @@ defmodule Dexter.Formatter do end end - defp format(_, _, _, _), do: {1, "invalid input"} + defp do_format(_, _, _, _), do: {1, "invalid input"} end -# ── CodeIntel Service ────────────────────────────────────────────────────── +# CodeIntel Service defmodule Dexter.CodeIntel do @op_erlang_source 0 @@ -305,7 +326,6 @@ defmodule Dexter.CodeIntel do end defp find_function_in_forms(forms, function_atom, 255 = _unspecified) do - # No arity specified — find the first clause with matching name Enum.find_value(forms, 0, fn {:function, anno, ^function_atom, _arity, _clauses} -> anno_line(anno) @@ -316,7 +336,6 @@ defmodule Dexter.CodeIntel do end defp find_function_in_forms(forms, function_atom, arity) do - # Exact arity match exact = Enum.find_value(forms, nil, fn {:function, anno, ^function_atom, ^arity, _clauses} -> @@ -326,7 +345,6 @@ defmodule Dexter.CodeIntel do nil end) - # Fall back to first matching name if exact arity not found exact || find_function_in_forms(forms, function_atom, 255) end @@ -352,7 +370,7 @@ defmodule Dexter.CodeIntel do end end - # ── Erlang Docs ────────────────────────────────────────────────────────── + # Erlang Docs defp handle_erlang_docs do <> = IO.binread(:stdio, 2) @@ -375,7 +393,6 @@ defmodule Dexter.CodeIntel do case Code.fetch_docs(module_atom) do {:docs_v1, _, :erlang, _format, module_doc, _metadata, docs} -> if function_name == "" do - # Module-level docs case extract_doc_text(module_doc) do nil -> {1, nil} text -> {0, text} @@ -391,7 +408,6 @@ defmodule Dexter.CodeIntel do end defp find_function_doc(docs, function_atom, arity) do - # Find matching function docs — prefer exact arity match candidates = Enum.filter(docs, fn {{:function, ^function_atom, _arity}, _anno, _sig, _doc, _meta} -> true @@ -442,26 +458,25 @@ defmodule Dexter.CodeIntel do defp extract_doc_text(_), do: nil end -# ── Main IO Loop ─────────────────────────────────────────────────────────── +# Main IO Loop defmodule Dexter.Loop do @service_formatter 0 @service_code_intel 1 - def run(format_opts, plugins, first_call?) do - # Signal ready if first call: status=0, length=0 + def run(first_call?) do if first_call?, do: IO.binwrite(:stdio, <<0, 0, 0, 0, 0>>) case IO.binread(:stdio, 1) do <<@service_formatter>> -> - case Dexter.Formatter.handle_request(format_opts, plugins) do - :ok -> run(format_opts, plugins, false) + case handle_format_request() do + :ok -> run(false) :eof -> :ok end <<@service_code_intel>> -> case Dexter.CodeIntel.handle_request() do - :ok -> run(format_opts, plugins, false) + :ok -> run(false) :eof -> :ok end @@ -469,10 +484,54 @@ defmodule Dexter.Loop do :ok end end + + defp handle_format_request do + case IO.binread(:stdio, 2) do + <> -> + config_path = if config_path_len > 0, do: IO.binread(:stdio, config_path_len), else: "" + <> = IO.binread(:stdio, 2) + filename = if filename_len > 0, do: IO.binread(:stdio, filename_len), else: "" + <> = IO.binread(:stdio, 4) + content = IO.binread(:stdio, content_len) + + # Start a formatter for this config if we haven't seen it before + ensure_formatter(config_path) + + {status, result} = Dexter.Formatter.format(config_path, content, filename) + size = byte_size(result) + IO.binwrite(:stdio, <>) + :ok + + _ -> + :eof + end + end + + defp ensure_formatter(config_path) do + case Registry.lookup(Dexter.FormatterRegistry, config_path) do + [{_pid, _}] -> + :ok + + [] -> + DynamicSupervisor.start_child( + Dexter.FormatterSup, + {Dexter.Formatter, config_path} + ) + + :ok + end + end end +# Boot + +{:ok, _} = Registry.start_link(keys: :unique, name: Dexter.FormatterRegistry) +{:ok, _} = DynamicSupervisor.start_link(strategy: :one_for_one, name: Dexter.FormatterSup) + +IO.puts(:stderr, "Dexter BEAM: started (pid #{System.pid()})") + try do - Dexter.Loop.run(format_opts, active_plugins, true) + Dexter.Loop.run(true) rescue e -> IO.puts(:stderr, "Dexter BEAM: crash in loop: #{Exception.message(e)}") diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index cb17dd1..3056dab 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -27,27 +27,26 @@ var beamServerScript string const ( // How long to wait for the persistent BEAM to become ready before // falling back to mix format on a given request. - formatterWaitTimeout = 5 * time.Second + beamWaitTimeout = 5 * time.Second // How long a not-ready BEAM process is allowed to live before being // killed and restarted. Also used as the hard cap inside the startup // goroutine to prevent leaked goroutines. - formatterStuckTimeout = 30 * time.Second + beamStuckTimeout = 30 * time.Second // Service tags for the BEAM server protocol serviceFormatter byte = 0x00 serviceCodeIntel byte = 0x01 ) -type formatterProcess struct { - cmd *commandHandle - stdin io.WriteCloser - stdout io.ReadCloser - mu sync.Mutex - formatterMtime time.Time // mtime of .formatter.exs when process started - startedAt time.Time // when the process was launched - ready chan struct{} // closed when the BEAM has sent the ready signal - startErr error // non-nil if startup failed; set before ready is closed - closed chan struct{} // closed by Close(); makes alive() return false immediately +type beamProcess struct { + cmd *commandHandle + stdin io.WriteCloser + stdout io.ReadCloser + mu sync.Mutex + startedAt time.Time // when the process was launched + ready chan struct{} // closed when the BEAM has sent the ready signal + startErr error // non-nil if startup failed; set before ready is closed + closed chan struct{} // closed by Close(); makes alive() return false immediately } // commandHandle wraps the process so we can check liveness. @@ -56,11 +55,11 @@ type commandHandle struct { done chan struct{} } -func (fp *formatterProcess) alive() bool { +func (bp *beamProcess) alive() bool { select { - case <-fp.cmd.done: + case <-bp.cmd.done: return false - case <-fp.closed: + case <-bp.closed: return false default: return true @@ -69,29 +68,34 @@ func (fp *formatterProcess) alive() bool { // Ready blocks until the process has finished startup. Returns startErr if // the BEAM failed to initialize, or ctx.Err() if the caller gives up first. -func (fp *formatterProcess) Ready(ctx context.Context) error { +func (bp *beamProcess) Ready(ctx context.Context) error { select { - case <-fp.ready: - return fp.startErr + case <-bp.ready: + return bp.startErr case <-ctx.Done(): return ctx.Err() } } -func (fp *formatterProcess) Format(ctx context.Context, content, filename string) (string, error) { - fp.mu.Lock() - defer fp.mu.Unlock() +// Format sends a format request to the BEAM process. The formatterExs path +// tells the BEAM which .formatter.exs config to use (starting a new formatter +// child if needed). +func (bp *beamProcess) Format(ctx context.Context, content, filename, formatterExs string) (string, error) { + bp.mu.Lock() + defer bp.mu.Unlock() - // Build the entire request as a single buffer to avoid partial writes + configPathBytes := []byte(formatterExs) filenameBytes := []byte(filename) contentBytes := []byte(content) var req bytes.Buffer - req.WriteByte(serviceFormatter) // service tag + req.WriteByte(serviceFormatter) + _ = binary.Write(&req, binary.BigEndian, uint16(len(configPathBytes))) + req.Write(configPathBytes) _ = binary.Write(&req, binary.BigEndian, uint16(len(filenameBytes))) req.Write(filenameBytes) _ = binary.Write(&req, binary.BigEndian, uint32(len(contentBytes))) req.Write(contentBytes) - if _, err := fp.stdin.Write(req.Bytes()); err != nil { + if _, err := bp.stdin.Write(req.Bytes()); err != nil { return "", fmt.Errorf("write request: %w", err) } @@ -102,17 +106,17 @@ func (fp *formatterProcess) Format(ctx context.Context, content, filename string ch := make(chan readResult, 1) go func() { var status byte - if err := binary.Read(fp.stdout, binary.BigEndian, &status); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &status); err != nil { ch <- readResult{err: fmt.Errorf("read status: %w", err)} return } var respLen uint32 - if err := binary.Read(fp.stdout, binary.BigEndian, &respLen); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &respLen); err != nil { ch <- readResult{err: fmt.Errorf("read length: %w", err)} return } buf := make([]byte, respLen) - if _, err := io.ReadFull(fp.stdout, buf); err != nil { + if _, err := io.ReadFull(bp.stdout, buf); err != nil { ch <- readResult{err: fmt.Errorf("read data: %w", err)} return } @@ -127,9 +131,7 @@ func (fp *formatterProcess) Format(ctx context.Context, content, filename string case r := <-ch: return r.text, r.err case <-ctx.Done(): - // Kill the process to unblock the reader goroutine — the pipe reads - // will fail once the process exits, preventing a leaked goroutine. - _ = fp.cmd.process.Kill() + _ = bp.cmd.process.Kill() <-ch return "", ctx.Err() } @@ -143,9 +145,9 @@ type ErlangSourceResult struct { // ErlangSource asks the BEAM's CodeIntel service to resolve an Erlang module/function // to its source file and line number. -func (fp *formatterProcess) ErlangSource(ctx context.Context, module, function string, arity int) (*ErlangSourceResult, error) { - fp.mu.Lock() - defer fp.mu.Unlock() +func (bp *beamProcess) ErlangSource(ctx context.Context, module, function string, arity int) (*ErlangSourceResult, error) { + bp.mu.Lock() + defer bp.mu.Unlock() moduleBytes := []byte(module) functionBytes := []byte(function) @@ -162,7 +164,7 @@ func (fp *formatterProcess) ErlangSource(ctx context.Context, module, function s _ = binary.Write(&req, binary.BigEndian, uint16(len(functionBytes))) req.Write(functionBytes) req.WriteByte(arityByte) - if _, err := fp.stdin.Write(req.Bytes()); err != nil { + if _, err := bp.stdin.Write(req.Bytes()); err != nil { return nil, fmt.Errorf("write code_intel request: %w", err) } @@ -173,22 +175,22 @@ func (fp *formatterProcess) ErlangSource(ctx context.Context, module, function s ch := make(chan readResult, 1) go func() { var status byte - if err := binary.Read(fp.stdout, binary.BigEndian, &status); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &status); err != nil { ch <- readResult{err: fmt.Errorf("read status: %w", err)} return } var fileLen uint16 - if err := binary.Read(fp.stdout, binary.BigEndian, &fileLen); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &fileLen); err != nil { ch <- readResult{err: fmt.Errorf("read file length: %w", err)} return } fileBuf := make([]byte, fileLen) - if _, err := io.ReadFull(fp.stdout, fileBuf); err != nil { + if _, err := io.ReadFull(bp.stdout, fileBuf); err != nil { ch <- readResult{err: fmt.Errorf("read file: %w", err)} return } var line uint32 - if err := binary.Read(fp.stdout, binary.BigEndian, &line); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &line); err != nil { ch <- readResult{err: fmt.Errorf("read line: %w", err)} return } @@ -203,7 +205,7 @@ func (fp *formatterProcess) ErlangSource(ctx context.Context, module, function s case r := <-ch: return r.result, r.err case <-ctx.Done(): - _ = fp.cmd.process.Kill() + _ = bp.cmd.process.Kill() <-ch return nil, ctx.Err() } @@ -212,9 +214,9 @@ func (fp *formatterProcess) ErlangSource(ctx context.Context, module, function s // ErlangDocs asks the BEAM's CodeIntel service for the documentation of an // Erlang module or function. Returns pre-formatted markdown, or empty string // if no docs are available (e.g. OTP < 24 or undocumented function). -func (fp *formatterProcess) ErlangDocs(ctx context.Context, module, function string, arity int) (string, error) { - fp.mu.Lock() - defer fp.mu.Unlock() +func (bp *beamProcess) ErlangDocs(ctx context.Context, module, function string, arity int) (string, error) { + bp.mu.Lock() + defer bp.mu.Unlock() moduleBytes := []byte(module) functionBytes := []byte(function) @@ -231,7 +233,7 @@ func (fp *formatterProcess) ErlangDocs(ctx context.Context, module, function str _ = binary.Write(&req, binary.BigEndian, uint16(len(functionBytes))) req.Write(functionBytes) req.WriteByte(arityByte) - if _, err := fp.stdin.Write(req.Bytes()); err != nil { + if _, err := bp.stdin.Write(req.Bytes()); err != nil { return "", fmt.Errorf("write code_intel request: %w", err) } @@ -242,17 +244,17 @@ func (fp *formatterProcess) ErlangDocs(ctx context.Context, module, function str ch := make(chan readResult, 1) go func() { var status byte - if err := binary.Read(fp.stdout, binary.BigEndian, &status); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &status); err != nil { ch <- readResult{err: fmt.Errorf("read status: %w", err)} return } var docLen uint32 - if err := binary.Read(fp.stdout, binary.BigEndian, &docLen); err != nil { + if err := binary.Read(bp.stdout, binary.BigEndian, &docLen); err != nil { ch <- readResult{err: fmt.Errorf("read doc length: %w", err)} return } docBuf := make([]byte, docLen) - if _, err := io.ReadFull(fp.stdout, docBuf); err != nil { + if _, err := io.ReadFull(bp.stdout, docBuf); err != nil { ch <- readResult{err: fmt.Errorf("read doc: %w", err)} return } @@ -267,7 +269,7 @@ func (fp *formatterProcess) ErlangDocs(ctx context.Context, module, function str case r := <-ch: return r.doc, r.err case <-ctx.Done(): - _ = fp.cmd.process.Kill() + _ = bp.cmd.process.Kill() <-ch return "", ctx.Err() } @@ -283,21 +285,20 @@ func (e *FormatError) Error() string { return e.Message } -func (fp *formatterProcess) Close() { +func (bp *beamProcess) Close() { select { - case <-fp.closed: + case <-bp.closed: default: - close(fp.closed) + close(bp.closed) } - _ = fp.stdin.Close() - _ = fp.cmd.process.Kill() + _ = bp.stdin.Close() + _ = bp.cmd.process.Kill() } -// startFormatterProcess launches the BEAM process and returns immediately. -// The returned process may not be ready yet — callers must check fp.Ready() -// before calling fp.Format(). Returns error only for immediate launch failures -// (missing binary, can't create pipes). -func (s *Server) startFormatterProcess(mixRoot, formatterExs string) (*formatterProcess, error) { +// startBeamProcess launches the single BEAM process and returns immediately. +// The returned process may not be ready yet — callers must check bp.Ready() +// before sending requests. +func (s *Server) startBeamProcess() (*beamProcess, error) { scriptDir := filepath.Join(os.TempDir(), "dexter") if err := os.MkdirAll(scriptDir, 0755); err != nil { return nil, fmt.Errorf("create script dir: %w", err) @@ -309,14 +310,9 @@ func (s *Server) startFormatterProcess(mixRoot, formatterExs string) (*formatter } } - var mtime time.Time - if info, err := os.Stat(formatterExs); err == nil { - mtime = info.ModTime() - } - elixirBin := filepath.Join(filepath.Dir(s.mixBin), "elixir") - cmd := exec.Command(elixirBin, scriptPath, mixRoot, formatterExs, s.projectRoot) - cmd.Dir = mixRoot + cmd := exec.Command(elixirBin, scriptPath, s.projectRoot) + cmd.Dir = s.projectRoot stdin, err := cmd.StdinPipe() if err != nil { @@ -330,7 +326,7 @@ func (s *Server) startFormatterProcess(mixRoot, formatterExs string) (*formatter cmd.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("start formatter: %w", err) + return nil, fmt.Errorf("start BEAM: %w", err) } done := make(chan struct{}) @@ -341,18 +337,15 @@ func (s *Server) startFormatterProcess(mixRoot, formatterExs string) (*formatter handle := &commandHandle{process: cmd.Process, done: done} - fp := &formatterProcess{ - cmd: handle, - stdin: stdin, - stdout: stdout, - formatterMtime: mtime, - startedAt: time.Now(), - ready: make(chan struct{}), - closed: make(chan struct{}), + bp := &beamProcess{ + cmd: handle, + stdin: stdin, + stdout: stdout, + startedAt: time.Now(), + ready: make(chan struct{}), + closed: make(chan struct{}), } - // Wait for the BEAM's ready signal asynchronously. Callers use fp.Ready() - // to wait with their own timeout go func() { type readyResult struct { status byte @@ -376,66 +369,61 @@ func (s *Server) startFormatterProcess(mixRoot, formatterExs string) (*formatter select { case r := <-readyCh: if r.err != nil { - fp.startErr = fmt.Errorf("formatter ready: %w", r.err) + bp.startErr = fmt.Errorf("BEAM ready: %w", r.err) _ = cmd.Process.Kill() - <-done // wait for cmd.Wait() to finish copying stderr + <-done s.notifyOTPMismatch(stderrBuf.String()) } else if r.status != 0 { - fp.startErr = fmt.Errorf("formatter failed to initialize (status %d)", r.status) + bp.startErr = fmt.Errorf("BEAM failed to initialize (status %d)", r.status) _ = cmd.Process.Kill() } else { - log.Printf("Formatter: started persistent process for %s (pid %d)", formatterExs, cmd.Process.Pid) + log.Printf("BEAM: started persistent process (pid %d)", cmd.Process.Pid) } - case <-time.After(formatterStuckTimeout): - fp.startErr = fmt.Errorf("formatter startup timed out") + case <-time.After(beamStuckTimeout): + bp.startErr = fmt.Errorf("BEAM startup timed out") _ = cmd.Process.Kill() } - close(fp.ready) + close(bp.ready) }() - return fp, nil + return bp, nil } -// getFormatter returns a cached formatter process (which may still be starting -// up). If none exists, it launches one and caches it immediately. The mutex is -// only held briefly — the slow BEAM startup happens asynchronously. Callers -// must call fp.Ready() before fp.Format() to wait for the process to be usable. -func (s *Server) getFormatter(mixRoot, formatterExs string) (*formatterProcess, error) { - s.formattersMu.Lock() - defer s.formattersMu.Unlock() - - if fp, ok := s.formatters[formatterExs]; ok && fp.alive() { - // Restart if .formatter.exs has changed - if info, err := os.Stat(formatterExs); err == nil && info.ModTime().After(fp.formatterMtime) { - fp.Close() - delete(s.formatters, formatterExs) - } else { - return fp, nil - } +// getBeamProcess returns the single BEAM process, starting it if needed. +func (s *Server) getBeamProcess(ctx context.Context) *beamProcess { + s.beamMu.Lock() + defer s.beamMu.Unlock() + + if s.beam != nil && s.beam.alive() { + return s.beam } - fp, err := s.startFormatterProcess(mixRoot, formatterExs) - if err != nil { - return nil, err + if s.mixBin == "" { + return nil } - if s.formatters == nil { - s.formatters = make(map[string]*formatterProcess) + + bp, err := s.startBeamProcess() + if err != nil { + log.Printf("BEAM: failed to start: %v", err) + return nil } - s.formatters[formatterExs] = fp - return fp, nil + s.beam = bp + return bp } -// findFormatterConfig walks from the file's directory up to the mix root, +// findFormatterConfig walks from the file's directory up to the project root, // returning the path to the nearest .formatter.exs. This handles subdirectory -// configs (e.g. config/.formatter.exs with different rules than the root). -func findFormatterConfig(filePath, mixRoot string) string { +// configs (e.g. config/.formatter.exs with different rules than the root) and +// umbrella projects where .formatter.exs lives at the umbrella root above the +// app's mix root. +func findFormatterConfig(filePath, projectRoot string) string { dir := filepath.Dir(filePath) for { candidate := filepath.Join(dir, ".formatter.exs") if _, err := os.Stat(candidate); err == nil { return candidate } - if dir == mixRoot { + if dir == projectRoot { break } parent := filepath.Dir(dir) @@ -444,54 +432,51 @@ func findFormatterConfig(filePath, mixRoot string) string { } dir = parent } - return filepath.Join(mixRoot, ".formatter.exs") + return filepath.Join(projectRoot, ".formatter.exs") } -// formatContent tries the persistent formatter, falling back to mix format. +// formatContent tries the persistent BEAM process, falling back to mix format. // // Startup-age policy: // - <5s old: wait for the process to become ready, then use it // - 5s–30s old: don't wait, fall back to mix format immediately // - >30s old and still not ready: kill and restart the stuck process func (s *Server) formatContent(ctx context.Context, mixRoot, path, content string) (string, error) { - formatterExs := findFormatterConfig(path, mixRoot) - fp, err := s.getFormatter(mixRoot, formatterExs) - if err != nil { - log.Printf("Formatting: persistent formatter unavailable, falling back to mix format: %v", err) + formatterExs := findFormatterConfig(path, s.projectRoot) + bp := s.getBeamProcess(ctx) + if bp == nil { + log.Printf("Formatting: BEAM process unavailable, falling back to mix format") return s.formatWithMixFormat(ctx, mixRoot, path, content) } // Check if already ready (non-blocking) select { - case <-fp.ready: - if fp.startErr != nil { - s.evictFormatter(formatterExs, fp) - log.Printf("Formatting: persistent formatter failed to start, falling back to mix format: %v", fp.startErr) + case <-bp.ready: + if bp.startErr != nil { + s.evictBeam(bp) + log.Printf("Formatting: BEAM process failed to start, falling back to mix format: %v", bp.startErr) return s.formatWithMixFormat(ctx, mixRoot, path, content) } default: // Not ready yet — decide based on how long it's been starting - age := time.Since(fp.startedAt) + age := time.Since(bp.startedAt) switch { - case age > formatterStuckTimeout: - // Stuck — kill and restart so the next request gets a fresh process - log.Printf("Formatting: persistent formatter stuck (started %s ago), restarting", age.Truncate(time.Second)) - s.evictFormatter(formatterExs, fp) + case age > beamStuckTimeout: + log.Printf("Formatting: BEAM process stuck (started %s ago), restarting", age.Truncate(time.Second)) + s.evictBeam(bp) return s.formatWithMixFormat(ctx, mixRoot, path, content) - case age > formatterWaitTimeout: - // Taking too long — fall back without waiting - log.Printf("Formatting: persistent formatter not ready after %s, falling back to mix format", age.Truncate(time.Millisecond)) + case age > beamWaitTimeout: + log.Printf("Formatting: BEAM process not ready after %s, falling back to mix format", age.Truncate(time.Millisecond)) return s.formatWithMixFormat(ctx, mixRoot, path, content) default: - // Recently started — wait for it - if err := fp.Ready(ctx); err != nil { + if err := bp.Ready(ctx); err != nil { if ctx.Err() != nil { return "", err } - s.evictFormatter(formatterExs, fp) - log.Printf("Formatting: persistent formatter failed to start, falling back to mix format: %v", err) + s.evictBeam(bp) + log.Printf("Formatting: BEAM process failed to start, falling back to mix format: %v", err) return s.formatWithMixFormat(ctx, mixRoot, path, content) } } @@ -502,14 +487,14 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin } start := time.Now() - result, err := fp.Format(ctx, content, path) + result, err := bp.Format(ctx, content, path, formatterExs) if err != nil { var formatErr *FormatError if errors.As(err, &formatErr) { log.Printf("Formatting: %s failed: %s", path, formatErr.Message) } else { - s.evictFormatter(formatterExs, fp) - log.Printf("Formatting: persistent formatter crashed: %v", err) + s.evictBeam(bp) + log.Printf("Formatting: BEAM process crashed: %v", err) } return "", err } @@ -518,35 +503,13 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin return result, nil } -func (s *Server) evictFormatter(formatterExs string, fp *formatterProcess) { - s.formattersMu.Lock() - if s.formatters[formatterExs] == fp { - delete(s.formatters, formatterExs) - } - s.formattersMu.Unlock() - fp.Close() -} - -// getBeamProcess returns any alive and ready BEAM process for code intel queries. -// Since all BEAM processes have the CodeIntel service, any one will do. -func (s *Server) getBeamProcess(ctx context.Context) *formatterProcess { - s.formattersMu.Lock() - defer s.formattersMu.Unlock() - - for _, fp := range s.formatters { - if !fp.alive() { - continue - } - // Non-blocking ready check - select { - case <-fp.ready: - if fp.startErr == nil { - return fp - } - default: - } +func (s *Server) evictBeam(bp *beamProcess) { + s.beamMu.Lock() + if s.beam == bp { + s.beam = nil } - return nil + s.beamMu.Unlock() + bp.Close() } func (s *Server) formatWithMixFormat(ctx context.Context, mixRoot, path, content string) (string, error) { @@ -720,11 +683,11 @@ func computeMinimalEdits(original, formatted string) []protocol.TextEdit { } } -func (s *Server) closeFormatters() { - s.formattersMu.Lock() - defer s.formattersMu.Unlock() - for _, fp := range s.formatters { - fp.Close() +func (s *Server) closeBeam() { + s.beamMu.Lock() + defer s.beamMu.Unlock() + if s.beam != nil { + s.beam.Close() + s.beam = nil } - s.formatters = nil } diff --git a/internal/lsp/formatter_test.go b/internal/lsp/formatter_test.go index b06ed34..c2e8185 100644 --- a/internal/lsp/formatter_test.go +++ b/internal/lsp/formatter_test.go @@ -432,3 +432,51 @@ func TestComputeMinimalEdits(t *testing.T) { } }) } + +func TestFindFormatterConfig_UmbrellaRootOnly(t *testing.T) { + // Simulate an umbrella where only the root has .formatter.exs + // root/ + // .formatter.exs + // apps/ + // my_app/ + // mix.exs + // lib/ + // foo.ex + root := t.TempDir() + appDir := filepath.Join(root, "apps", "my_app", "lib") + if err := os.MkdirAll(appDir, 0755); err != nil { + t.Fatal(err) + } + rootFormatter := filepath.Join(root, ".formatter.exs") + if err := os.WriteFile(rootFormatter, []byte("[]"), 0644); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(appDir, "foo.ex") + got := findFormatterConfig(filePath, root) + if got != rootFormatter { + t.Errorf("expected %s, got %s", rootFormatter, got) + } +} + +func TestFindFormatterConfig_PerAppOverridesRoot(t *testing.T) { + // Both root and app have .formatter.exs — the app's should win + root := t.TempDir() + appDir := filepath.Join(root, "apps", "my_app") + if err := os.MkdirAll(filepath.Join(appDir, "lib"), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, ".formatter.exs"), []byte("[]"), 0644); err != nil { + t.Fatal(err) + } + appFormatter := filepath.Join(appDir, ".formatter.exs") + if err := os.WriteFile(appFormatter, []byte("[]"), 0644); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(appDir, "lib", "foo.ex") + got := findFormatterConfig(filePath, root) + if got != appFormatter { + t.Errorf("expected app-level %s, got %s", appFormatter, got) + } +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index e54db07..2accfd4 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -63,8 +63,8 @@ type Server struct { debug bool mixBin string // resolved path to the mix binary - formatters map[string]*formatterProcess // formatterExs path → persistent formatter - formattersMu sync.Mutex + beam *beamProcess // single persistent BEAM process for formatting + code intel + beamMu sync.Mutex usingCache map[string]*usingCacheEntry // module name → parsed __using__ result usingCacheMu sync.RWMutex @@ -423,7 +423,7 @@ func (s *Server) Initialized(ctx context.Context, params *protocol.InitializedPa } func (s *Server) Shutdown(ctx context.Context) error { - s.closeFormatters() + s.closeBeam() return nil } @@ -437,15 +437,12 @@ func (s *Server) Exit(ctx context.Context) error { func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocumentParams) error { s.docs.Set(string(params.TextDocument.URI), params.TextDocument.Text) - // Eagerly start the persistent formatter so the first format is instant. + // Eagerly start the persistent BEAM process so the first format is instant. // Skip deps and stdlib files — we don't format those. path := uriToPath(params.TextDocument.URI) if path != "" && isFormattableFile(path) && s.isProjectFile(path) && !s.isDepsFile(path) { go func() { - if mixRoot := findMixRoot(filepath.Dir(path)); mixRoot != "" { - formatterExs := findFormatterConfig(path, mixRoot) - _, _ = s.getFormatter(mixRoot, formatterExs) - } + _ = s.getBeamProcess(context.Background()) }() } @@ -711,12 +708,15 @@ func filterOutTypes(results []store.LookupResult) []store.LookupResult { // erlangHover fetches documentation for an Erlang module/function via the // BEAM process's CodeIntel service. func (s *Server) erlangHover(ctx context.Context, module, function string) (*protocol.Hover, error) { - fp := s.getBeamProcess(ctx) - if fp == nil { + bp := s.getBeamProcess(ctx) + if bp == nil { + return nil, nil + } + if err := bp.Ready(ctx); err != nil { return nil, nil } - doc, err := fp.ErlangDocs(ctx, module, function, -1) + doc, err := bp.ErlangDocs(ctx, module, function, -1) if err != nil || doc == "" { return nil, nil } @@ -732,13 +732,17 @@ func (s *Server) erlangHover(ctx context.Context, module, function string) (*pro // erlangDefinition resolves an Erlang module/function to its .erl source via // the BEAM process's CodeIntel service. func (s *Server) erlangDefinition(ctx context.Context, module, function string) ([]protocol.Location, error) { - fp := s.getBeamProcess(ctx) - if fp == nil { + bp := s.getBeamProcess(ctx) + if bp == nil { s.debugf("Definition: no BEAM process available for Erlang resolution") return nil, nil } + if err := bp.Ready(ctx); err != nil { + s.debugf("Definition: BEAM process not ready: %v", err) + return nil, nil + } - result, err := fp.ErlangSource(ctx, module, function, -1) + result, err := bp.ErlangSource(ctx, module, function, -1) if err != nil { s.debugf("Definition: Erlang source lookup failed: %v", err) return nil, nil diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 3af5bdb..1449ced 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -3002,14 +3002,13 @@ func TestFormatter_RestartAfterCrash(t *testing.T) { t.Fatal(err) } - // Kill the persistent process - mixRoot := findMixRoot(filepath.Dir(filePath)) - formatterExs := findFormatterConfig(filePath, mixRoot) - server.formattersMu.Lock() - if fp, ok := server.formatters[formatterExs]; ok { - fp.Close() + // Kill the persistent BEAM process + server.beamMu.Lock() + if server.beam != nil { + server.beam.Close() + server.beam = nil } - server.formattersMu.Unlock() + server.beamMu.Unlock() // Next format should recover (restart or fall back) server.docs.Set(docURI, "defmodule Test2 do\nend\n") @@ -3080,12 +3079,12 @@ func TestFormatter_DidOpen_SkipsDepsFiles(t *testing.T) { }, }) - server.formattersMu.Lock() - count := len(server.formatters) - server.formattersMu.Unlock() + server.beamMu.Lock() + hasBeam := server.beam != nil + server.beamMu.Unlock() - if count != 0 { - t.Errorf("expected no formatter processes for dep file, got %d", count) + if hasBeam { + t.Errorf("expected no BEAM process for dep file, got one") } } @@ -4640,18 +4639,23 @@ func TestDefinition_ErlangAtomModule(t *testing.T) { end end`) - // col=5 on ":code" — should take the Erlang path and return nil - // (no BEAM process in test, so no result — but must not crash or - // fall through to Elixir resolution) + // col=5 on ":code" — should take the Erlang path, not fall through + // to Elixir resolution. If a BEAM process is available, we get a result + // pointing to the .erl source; if not, we get nil. Either way, it must + // not crash or produce an Elixir result. locs := definitionAt(t, server, uri, 2, 5) - if len(locs) != 0 { - t.Errorf("expected no definition without BEAM process, got %v", locs) + for _, loc := range locs { + if !strings.HasSuffix(string(loc.URI), ".erl") { + t.Errorf("expected .erl file or no result, got %s", loc.URI) + } } // col=10 on "all_loaded" function locs = definitionAt(t, server, uri, 2, 10) - if len(locs) != 0 { - t.Errorf("expected no definition without BEAM process, got %v", locs) + for _, loc := range locs { + if !strings.HasSuffix(string(loc.URI), ".erl") { + t.Errorf("expected .erl file or no result, got %s", loc.URI) + } } } @@ -4691,17 +4695,19 @@ func TestHover_ErlangAtomModule(t *testing.T) { end end`) - // col=5 on ":lists" — should take the Erlang path and return nil - // (no BEAM process in test) + // col=5 on ":lists" — should take the Erlang path. If a BEAM process + // is available, we get Erlang docs; if not, nil. Must not crash. hover := hoverAt(t, server, uri, 2, 5) - if hover != nil { - t.Errorf("expected no hover without BEAM process, got %v", hover) + if hover != nil && hover.Contents.Kind != protocol.Markdown { + t.Errorf("expected markdown or nil, got %v", hover.Contents.Kind) } // col=12 on "flatten" hover = hoverAt(t, server, uri, 2, 12) if hover != nil { - t.Errorf("expected no hover without BEAM process, got %v", hover) + if !strings.Contains(hover.Contents.Value, "flatten") { + t.Errorf("expected hover about flatten, got %q", hover.Contents.Value) + } } } From 68030f5ef158e5a13415d916c679d7b7d26dc43f Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Thu, 23 Apr 2026 20:33:03 -0600 Subject: [PATCH 3/7] Use one BEAM process per _build root --- internal/lsp/formatter.go | 85 ++++++++++++++++++++++++++++--------- internal/lsp/server.go | 11 ++--- internal/lsp/server_test.go | 14 +++--- 3 files changed, 78 insertions(+), 32 deletions(-) diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index 3056dab..94b35ad 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -295,10 +295,10 @@ func (bp *beamProcess) Close() { _ = bp.cmd.process.Kill() } -// startBeamProcess launches the single BEAM process and returns immediately. -// The returned process may not be ready yet — callers must check bp.Ready() -// before sending requests. -func (s *Server) startBeamProcess() (*beamProcess, error) { +// startBeamProcess launches a BEAM process for the given build root and returns +// immediately. The returned process may not be ready yet — callers must check +// bp.Ready() before sending requests. +func (s *Server) startBeamProcess(buildRoot string) (*beamProcess, error) { scriptDir := filepath.Join(os.TempDir(), "dexter") if err := os.MkdirAll(scriptDir, 0755); err != nil { return nil, fmt.Errorf("create script dir: %w", err) @@ -311,8 +311,8 @@ func (s *Server) startBeamProcess() (*beamProcess, error) { } elixirBin := filepath.Join(filepath.Dir(s.mixBin), "elixir") - cmd := exec.Command(elixirBin, scriptPath, s.projectRoot) - cmd.Dir = s.projectRoot + cmd := exec.Command(elixirBin, scriptPath, buildRoot) + cmd.Dir = buildRoot stdin, err := cmd.StdinPipe() if err != nil { @@ -389,28 +389,69 @@ func (s *Server) startBeamProcess() (*beamProcess, error) { return bp, nil } -// getBeamProcess returns the single BEAM process, starting it if needed. -func (s *Server) getBeamProcess(ctx context.Context) *beamProcess { +// findBuildRoot walks up from dir looking for a _build directory, bounded by +// projectRoot. Returns the directory containing _build, or projectRoot if none +// is found. +func (s *Server) findBuildRoot(dir string) string { + for { + if _, err := os.Stat(filepath.Join(dir, "_build")); err == nil { + return dir + } + if dir == s.projectRoot { + break + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return s.projectRoot +} + +// getBeamProcess returns a BEAM process for the given build root, starting one +// if needed. Files sharing the same _build share the same BEAM process. +func (s *Server) getBeamProcess(ctx context.Context, buildRoot string) *beamProcess { s.beamMu.Lock() defer s.beamMu.Unlock() - if s.beam != nil && s.beam.alive() { - return s.beam + if bp, ok := s.beams[buildRoot]; ok && bp.alive() { + return bp } if s.mixBin == "" { return nil } - bp, err := s.startBeamProcess() + bp, err := s.startBeamProcess(buildRoot) if err != nil { - log.Printf("BEAM: failed to start: %v", err) + log.Printf("BEAM: failed to start for %s: %v", buildRoot, err) return nil } - s.beam = bp + if s.beams == nil { + s.beams = make(map[string]*beamProcess) + } + s.beams[buildRoot] = bp return bp } +// getAnyBeamProcess returns any alive and ready BEAM process. Used for code +// intel queries (like Erlang go-to-def) where OTP is the same regardless of +// build root. Falls back to starting one for the project root if none exists. +func (s *Server) getAnyBeamProcess(ctx context.Context) *beamProcess { + s.beamMu.Lock() + for _, bp := range s.beams { + if bp.alive() { + s.beamMu.Unlock() + return bp + } + } + s.beamMu.Unlock() + + // No existing process — start one for the project root + return s.getBeamProcess(ctx, s.projectRoot) +} + // findFormatterConfig walks from the file's directory up to the project root, // returning the path to the nearest .formatter.exs. This handles subdirectory // configs (e.g. config/.formatter.exs with different rules than the root) and @@ -443,7 +484,8 @@ func findFormatterConfig(filePath, projectRoot string) string { // - >30s old and still not ready: kill and restart the stuck process func (s *Server) formatContent(ctx context.Context, mixRoot, path, content string) (string, error) { formatterExs := findFormatterConfig(path, s.projectRoot) - bp := s.getBeamProcess(ctx) + buildRoot := s.findBuildRoot(filepath.Dir(path)) + bp := s.getBeamProcess(ctx, buildRoot) if bp == nil { log.Printf("Formatting: BEAM process unavailable, falling back to mix format") return s.formatWithMixFormat(ctx, mixRoot, path, content) @@ -505,8 +547,11 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin func (s *Server) evictBeam(bp *beamProcess) { s.beamMu.Lock() - if s.beam == bp { - s.beam = nil + for key, b := range s.beams { + if b == bp { + delete(s.beams, key) + break + } } s.beamMu.Unlock() bp.Close() @@ -683,11 +728,11 @@ func computeMinimalEdits(original, formatted string) []protocol.TextEdit { } } -func (s *Server) closeBeam() { +func (s *Server) closeBeams() { s.beamMu.Lock() defer s.beamMu.Unlock() - if s.beam != nil { - s.beam.Close() - s.beam = nil + for _, bp := range s.beams { + bp.Close() } + s.beams = nil } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 2accfd4..1344aa9 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -63,7 +63,7 @@ type Server struct { debug bool mixBin string // resolved path to the mix binary - beam *beamProcess // single persistent BEAM process for formatting + code intel + beams map[string]*beamProcess // build root → persistent BEAM process beamMu sync.Mutex usingCache map[string]*usingCacheEntry // module name → parsed __using__ result @@ -423,7 +423,7 @@ func (s *Server) Initialized(ctx context.Context, params *protocol.InitializedPa } func (s *Server) Shutdown(ctx context.Context) error { - s.closeBeam() + s.closeBeams() return nil } @@ -442,7 +442,8 @@ func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocume path := uriToPath(params.TextDocument.URI) if path != "" && isFormattableFile(path) && s.isProjectFile(path) && !s.isDepsFile(path) { go func() { - _ = s.getBeamProcess(context.Background()) + buildRoot := s.findBuildRoot(filepath.Dir(path)) + _ = s.getBeamProcess(context.Background(), buildRoot) }() } @@ -708,7 +709,7 @@ func filterOutTypes(results []store.LookupResult) []store.LookupResult { // erlangHover fetches documentation for an Erlang module/function via the // BEAM process's CodeIntel service. func (s *Server) erlangHover(ctx context.Context, module, function string) (*protocol.Hover, error) { - bp := s.getBeamProcess(ctx) + bp := s.getAnyBeamProcess(ctx) if bp == nil { return nil, nil } @@ -732,7 +733,7 @@ func (s *Server) erlangHover(ctx context.Context, module, function string) (*pro // erlangDefinition resolves an Erlang module/function to its .erl source via // the BEAM process's CodeIntel service. func (s *Server) erlangDefinition(ctx context.Context, module, function string) ([]protocol.Location, error) { - bp := s.getBeamProcess(ctx) + bp := s.getAnyBeamProcess(ctx) if bp == nil { s.debugf("Definition: no BEAM process available for Erlang resolution") return nil, nil diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 1449ced..51568a4 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -3002,11 +3002,11 @@ func TestFormatter_RestartAfterCrash(t *testing.T) { t.Fatal(err) } - // Kill the persistent BEAM process + // Kill all persistent BEAM processes server.beamMu.Lock() - if server.beam != nil { - server.beam.Close() - server.beam = nil + for key, bp := range server.beams { + bp.Close() + delete(server.beams, key) } server.beamMu.Unlock() @@ -3080,11 +3080,11 @@ func TestFormatter_DidOpen_SkipsDepsFiles(t *testing.T) { }) server.beamMu.Lock() - hasBeam := server.beam != nil + beamCount := len(server.beams) server.beamMu.Unlock() - if hasBeam { - t.Errorf("expected no BEAM process for dep file, got one") + if beamCount != 0 { + t.Errorf("expected no BEAM processes for dep file, got %d", beamCount) } } From d4238273f4d7528606050deabf6dc87a075d34d5 Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Thu, 23 Apr 2026 20:59:24 -0600 Subject: [PATCH 4/7] Restart beam when .formatter.exs changes --- internal/lsp/formatter.go | 17 ----------------- internal/lsp/server.go | 34 +++++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index 94b35ad..e0ce4aa 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -435,23 +435,6 @@ func (s *Server) getBeamProcess(ctx context.Context, buildRoot string) *beamProc return bp } -// getAnyBeamProcess returns any alive and ready BEAM process. Used for code -// intel queries (like Erlang go-to-def) where OTP is the same regardless of -// build root. Falls back to starting one for the project root if none exists. -func (s *Server) getAnyBeamProcess(ctx context.Context) *beamProcess { - s.beamMu.Lock() - for _, bp := range s.beams { - if bp.alive() { - s.beamMu.Unlock() - return bp - } - } - s.beamMu.Unlock() - - // No existing process — start one for the project root - return s.getBeamProcess(ctx, s.projectRoot) -} - // findFormatterConfig walks from the file's directory up to the project root, // returning the path to the nearest .formatter.exs. This handles subdirectory // configs (e.g. config/.formatter.exs with different rules than the root) and diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 1344aa9..dc0a0e2 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -465,7 +465,25 @@ func (s *Server) DidClose(ctx context.Context, params *protocol.DidCloseTextDocu func (s *Server) DidSave(ctx context.Context, params *protocol.DidSaveTextDocumentParams) error { path := uriToPath(params.TextDocument.URI) - if path == "" || !parser.IsElixirFile(path) { + if path == "" { + return nil + } + + // Restart the BEAM process when .formatter.exs changes so the new + // config is picked up on the next format request. + if filepath.Base(path) == ".formatter.exs" { + buildRoot := s.findBuildRoot(filepath.Dir(path)) + s.beamMu.Lock() + if bp, ok := s.beams[buildRoot]; ok { + delete(s.beams, buildRoot) + bp.Close() + log.Printf("BEAM: restarting for %s (.formatter.exs changed)", buildRoot) + } + s.beamMu.Unlock() + return nil + } + + if !parser.IsElixirFile(path) { return nil } @@ -554,7 +572,7 @@ func (s *Server) Definition(ctx context.Context, params *protocol.DefinitionPara if strings.HasPrefix(exprCtx.ModuleRef, ":") { erlModule := exprCtx.ModuleRef[1:] // strip the : prefix s.debugf("Definition: Erlang module %q function=%q", erlModule, exprCtx.FunctionName) - return s.erlangDefinition(ctx, erlModule, exprCtx.FunctionName) + return s.erlangDefinition(ctx, uriToPath(params.TextDocument.URI), erlModule, exprCtx.FunctionName) } expr := tf.ResolveModuleExpr(exprCtx.Expr(), lineNum) @@ -708,8 +726,9 @@ func filterOutTypes(results []store.LookupResult) []store.LookupResult { // erlangHover fetches documentation for an Erlang module/function via the // BEAM process's CodeIntel service. -func (s *Server) erlangHover(ctx context.Context, module, function string) (*protocol.Hover, error) { - bp := s.getAnyBeamProcess(ctx) +func (s *Server) erlangHover(ctx context.Context, filePath, module, function string) (*protocol.Hover, error) { + buildRoot := s.findBuildRoot(filepath.Dir(filePath)) + bp := s.getBeamProcess(ctx, buildRoot) if bp == nil { return nil, nil } @@ -732,8 +751,9 @@ func (s *Server) erlangHover(ctx context.Context, module, function string) (*pro // erlangDefinition resolves an Erlang module/function to its .erl source via // the BEAM process's CodeIntel service. -func (s *Server) erlangDefinition(ctx context.Context, module, function string) ([]protocol.Location, error) { - bp := s.getAnyBeamProcess(ctx) +func (s *Server) erlangDefinition(ctx context.Context, filePath, module, function string) ([]protocol.Location, error) { + buildRoot := s.findBuildRoot(filepath.Dir(filePath)) + bp := s.getBeamProcess(ctx, buildRoot) if bp == nil { s.debugf("Definition: no BEAM process available for Erlang resolution") return nil, nil @@ -2819,7 +2839,7 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (*prot // Erlang module atom (e.g. :lists.flatten) — fetch docs via BEAM process if strings.HasPrefix(exprCtx.ModuleRef, ":") { - return s.erlangHover(ctx, exprCtx.ModuleRef[1:], exprCtx.FunctionName) + return s.erlangHover(ctx, uriToPath(params.TextDocument.URI), exprCtx.ModuleRef[1:], exprCtx.FunctionName) } expr := tf.ResolveModuleExpr(exprCtx.Expr(), lineNum) From f99f2d46ac958bc3c961d8a132db6ab804cb3b2c Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Thu, 23 Apr 2026 21:02:08 -0600 Subject: [PATCH 5/7] Wait for any in-flight requests --- internal/lsp/formatter.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index e0ce4aa..05895c9 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -286,6 +286,11 @@ func (e *FormatError) Error() string { } func (bp *beamProcess) Close() { + // Acquire mu to wait for any in-flight request to finish before killing + // the process — prevents broken pipe errors on concurrent format + close. + bp.mu.Lock() + defer bp.mu.Unlock() + select { case <-bp.closed: default: From 58c2b4f2a8725547b3963cd90b1fb8990e14d4a7 Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Sun, 26 Apr 2026 19:07:46 -0600 Subject: [PATCH 6/7] Improvements --- internal/lsp/beam_server.exs | 664 ++++++++++++++++++++++++++----- internal/lsp/elixir.go | 5 + internal/lsp/elixir_test.go | 28 ++ internal/lsp/formatter.go | 698 ++++++++++++++++++++++++--------- internal/lsp/formatter_test.go | 274 +++++++++++++ internal/lsp/server.go | 561 ++++++++++++++++++++++++-- internal/lsp/server_test.go | 506 +++++++++++++++++++++++- 7 files changed, 2409 insertions(+), 327 deletions(-) diff --git a/internal/lsp/beam_server.exs b/internal/lsp/beam_server.exs index 7495f6e..e5d8b7e 100644 --- a/internal/lsp/beam_server.exs +++ b/internal/lsp/beam_server.exs @@ -4,39 +4,51 @@ # - Multiple formatter instances (one per .formatter.exs, started on demand) # - A singleton CodeIntel service for Erlang source/docs lookups # -# Communication is via stdin/stdout with binary framing: +# Communication is via stdin/stdout with framed binary messages: # -# Request envelope: 1-byte service tag + service-specific payload -# Response envelope: service-specific (see below) +# Frame 0x00 = request: +# request_id(u32) + service(u8) + op(u8) + payload_len(u32) + payload +# +# Frame 0x01 = response: +# request_id(u32) + status(u8) + payload_len(u32) + payload +# +# Frame 0x02 = notification: +# op(u8) + payload_len(u32) + payload +# +# Frame 0x03 = ready: +# status(u8) + payload_len(u32) + payload # # Service tags: # 0x00 = Formatter # 0x01 = CodeIntel # -# Formatter protocol (after service tag): -# Request: 2-byte formatter_exs path length (big-endian) + formatter_exs path + -# 2-byte filename length (big-endian) + filename + -# 4-byte content length (big-endian) + content -# Response: 1-byte status (0=ok, 1=error) + -# 4-byte result length (big-endian) + result +# Formatter op 0 (format) payload: +# 2-byte formatter_exs path length (big-endian) + formatter_exs path + +# 2-byte filename length (big-endian) + filename + +# 4-byte content length (big-endian) + content +# +# CodeIntel op 0 (erlang_source) payload: +# 2-byte module length (big-endian) + module + +# 2-byte function length (big-endian) + function + +# 1-byte arity (255 = unspecified) # -# CodeIntel protocol (after service tag): -# Request: 1-byte op + -# 2-byte module length (big-endian) + module + -# 2-byte function length (big-endian) + function + -# 1-byte arity (255 = unspecified) +# CodeIntel op 1 (erlang_docs) payload: +# Same as erlang_source # -# Op 0 (erlang_source) response: -# 1-byte status (0=ok, 1=not_found) + -# 2-byte file length (big-endian) + file + -# 4-byte line (big-endian, 0 if not found) +# CodeIntel op 2 (warm_otp_modules) payload: +# empty; results arrive asynchronously via notification 0 # -# Op 1 (erlang_docs) response: -# 1-byte status (0=ok, 1=not_found) + -# 4-byte doc length (big-endian) + doc (markdown string) +# CodeIntel op 3 (erlang_exports) payload: +# 2-byte module length (big-endian) + module # -# Sends a ready signal once initialization is complete: -# 1-byte status (0=ok) + 4-byte length (0) +# CodeIntel op 4 (runtime_info) payload: +# empty +# +# Notification 0 (otp_modules_ready) payload: +# 2-byte module_count (big-endian) + [name_len(u16) name] +# +# Notification 1 (otp_modules_failed) payload: +# error string # # Force raw byte mode on stdin/stdout — without this, the Erlang IO server # applies Unicode encoding, expanding bytes > 127 to multi-byte UTF-8 and @@ -231,38 +243,188 @@ defmodule Dexter.Formatter do defp do_format(_, _, _, _), do: {1, "invalid input"} end -# CodeIntel Service +# Protocol Writer -defmodule Dexter.CodeIntel do - @op_erlang_source 0 - @op_erlang_docs 1 +defmodule Dexter.Writer do + use GenServer - def handle_request() do - case IO.binread(:stdio, 1) do - <<@op_erlang_source>> -> - handle_erlang_source() + @frame_response 1 + @frame_notification 2 + @frame_ready 3 - <<@op_erlang_docs>> -> - handle_erlang_docs() + @notif_otp_modules_ready 0 + @notif_otp_modules_failed 1 - _ -> - :eof + def start_link() do + GenServer.start_link(__MODULE__, :ok, name: __MODULE__) + end + + def send_ready(status, payload \\ <<>>) when is_binary(payload) do + GenServer.call(__MODULE__, {:write, ready_frame(status, payload)}, :infinity) + end + + def send_response(req_id, status, payload) when is_binary(payload) do + GenServer.cast(__MODULE__, {:write, response_frame(req_id, status, payload)}) + end + + def send_otp_modules_ready(names) do + GenServer.cast(__MODULE__, {:write, notification_frame(@notif_otp_modules_ready, encode_module_names(names))}) + end + + def send_otp_modules_failed(message) do + payload = if message, do: to_string(message), else: "" + GenServer.cast(__MODULE__, {:write, notification_frame(@notif_otp_modules_failed, payload)}) + end + + @impl true + def init(:ok), do: {:ok, nil} + + @impl true + def handle_call({:write, frame}, _from, state) do + write_frame(frame) + {:reply, :ok, state} + end + + @impl true + def handle_cast({:write, frame}, state) do + write_frame(frame) + {:noreply, state} + end + + defp write_frame(frame) do + case IO.binwrite(:stdio, frame) do + :ok -> :ok + {:error, reason} -> exit({:write_failed, reason}) end end - defp handle_erlang_source do - <> = IO.binread(:stdio, 2) - module_name = if module_len > 0, do: IO.binread(:stdio, module_len), else: "" - <> = IO.binread(:stdio, 2) - function_name = if function_len > 0, do: IO.binread(:stdio, function_len), else: "" - <> = IO.binread(:stdio, 1) + defp response_frame(req_id, status, payload) do + <<@frame_response::8, req_id::unsigned-big-32, status::8, byte_size(payload)::unsigned-big-32, + payload::binary>> + end - {status, file, line} = resolve_erlang_source(module_name, function_name, arity) + defp notification_frame(op, payload) do + <<@frame_notification::8, op::8, byte_size(payload)::unsigned-big-32, payload::binary>> + end - file_bytes = if file, do: file, else: "" - file_len = byte_size(file_bytes) - IO.binwrite(:stdio, <>) - :ok + defp ready_frame(status, payload) do + <<@frame_ready::8, status::8, byte_size(payload)::unsigned-big-32, payload::binary>> + end + + defp encode_module_names(names) do + payload = + for name <- names, into: <<>> do + <> + end + + <> + end +end + +# CodeIntel Service + +defmodule Dexter.CodeIntelCache do + use GenServer + + def start_link() do + GenServer.start_link(__MODULE__, %{}, name: __MODULE__) + end + + def warm_otp_modules() do + GenServer.call(__MODULE__, :warm_otp_modules) + end + + @impl true + def init(_state) do + {:ok, %{otp_modules: nil, loading: false}} + end + + @impl true + def handle_call(:warm_otp_modules, _from, %{otp_modules: names} = state) when is_list(names) do + Dexter.Writer.send_otp_modules_ready(names) + {:reply, :ok, state} + end + + def handle_call(:warm_otp_modules, _from, %{loading: true} = state) do + {:reply, :ok, state} + end + + def handle_call(:warm_otp_modules, _from, state) do + {:ok, _pid} = + Task.Supervisor.start_child(Dexter.TaskSup, fn -> + result = + try do + {:ok, compute_otp_module_names()} + rescue + error -> {:error, {:error, error, __STACKTRACE__}} + catch + kind, reason -> {:error, {kind, reason}} + end + + GenServer.cast(__MODULE__, {:otp_module_result, result}) + end) + + {:reply, :ok, %{state | loading: true}} + end + + @impl true + def handle_cast({:otp_module_result, {:ok, names}}, state) do + Dexter.Writer.send_otp_modules_ready(names) + {:noreply, %{state | otp_modules: names, loading: false}} + end + + def handle_cast({:otp_module_result, {:error, reason}}, state) do + IO.puts(:stderr, "CodeIntelCache: failed to load OTP modules: #{inspect(reason)}") + Dexter.Writer.send_otp_modules_failed(inspect(reason)) + {:noreply, %{state | loading: false}} + end + + defp compute_otp_module_names do + otp_root = :code.lib_dir() |> to_string() + + :code.all_available() + |> Enum.reduce([], fn {name, path, _loaded}, acc -> + mod_name = to_string(name) + + if is_list(path) and String.starts_with?(to_string(path), otp_root) and + not String.starts_with?(mod_name, "Elixir.") do + [mod_name | acc] + else + acc + end + end) + |> Enum.sort() + end +end + +defmodule Dexter.CodeIntel do + @op_erlang_source 0 + @op_erlang_docs 1 + @op_warm_otp_modules 2 + @op_erlang_exports 3 + @op_runtime_info 4 + + def handle_request(op, payload) do + case op do + @op_erlang_source -> handle_erlang_source(payload) + @op_erlang_docs -> handle_erlang_docs(payload) + @op_warm_otp_modules -> handle_warm_otp_modules(payload) + @op_erlang_exports -> handle_erlang_exports(payload) + @op_runtime_info -> handle_runtime_info(payload) + _ -> {1, "unknown code intel op: #{inspect(op)}"} + end + end + + defp handle_erlang_source(payload) do + case parse_module_function_arity(payload) do + {:ok, module_name, function_name, arity} -> + {status, file, line} = resolve_erlang_source(module_name, function_name, arity) + file_bytes = if file, do: file, else: "" + {status, <>} + + :error -> + {1, "invalid erlang_source payload"} + end end defp resolve_erlang_source(module_name, function_name, arity) do @@ -298,13 +460,11 @@ defmodule Dexter.CodeIntel do defp find_function_line(module_atom, function_name, arity) do function_atom = String.to_atom(function_name) - # Try abstract code first for arity-accurate line numbers line = find_line_from_abstract_code(module_atom, function_atom, arity) if line > 0 do line else - # Fallback: regex search in source file find_line_from_source(module_atom, function_atom) end end @@ -370,21 +530,16 @@ defmodule Dexter.CodeIntel do end end - # Erlang Docs - - defp handle_erlang_docs do - <> = IO.binread(:stdio, 2) - module_name = if module_len > 0, do: IO.binread(:stdio, module_len), else: "" - <> = IO.binread(:stdio, 2) - function_name = if function_len > 0, do: IO.binread(:stdio, function_len), else: "" - <> = IO.binread(:stdio, 1) + defp handle_erlang_docs(payload) do + case parse_module_function_arity(payload) do + {:ok, module_name, function_name, arity} -> + {status, doc} = fetch_erlang_docs(module_name, function_name, arity) + doc_bytes = doc || "" + {status, <>} - {status, doc} = fetch_erlang_docs(module_name, function_name, arity) - - doc_bytes = doc || "" - doc_len = byte_size(doc_bytes) - IO.binwrite(:stdio, <>) - :ok + :error -> + {1, "invalid erlang_docs payload"} + end end defp fetch_erlang_docs(module_name, function_name, arity) do @@ -407,17 +562,24 @@ defmodule Dexter.CodeIntel do end end - defp find_function_doc(docs, function_atom, arity) do + defp find_function_doc(docs, name_atom, arity) do + case find_doc_entry(docs, :function, name_atom, arity) do + nil -> find_doc_entry(docs, :type, name_atom, arity) || {1, nil} + result -> result + end + end + + defp find_doc_entry(docs, kind, name_atom, arity) do candidates = Enum.filter(docs, fn - {{:function, ^function_atom, _arity}, _anno, _sig, _doc, _meta} -> true + {{^kind, ^name_atom, _arity}, _anno, _sig, _doc, _meta} -> true _ -> false end) match = if arity != 255 do Enum.find(candidates, fn - {{:function, _, ^arity}, _, _, _, _} -> true + {{_, _, ^arity}, _, _, _, _} -> true _ -> false end) || List.first(candidates) else @@ -425,8 +587,8 @@ defmodule Dexter.CodeIntel do end case match do - {{:function, _, match_arity}, _anno, signatures, doc, _meta} -> - signature = format_signatures(signatures, function_atom, match_arity) + {{_, _, match_arity}, _anno, signatures, doc, _meta} -> + signature = format_signatures(signatures, name_atom, match_arity) doc_text = extract_doc_text(doc) parts = [] @@ -434,15 +596,212 @@ defmodule Dexter.CodeIntel do parts = if doc_text, do: parts ++ [doc_text], else: parts case parts do - [] -> {1, nil} + [] -> nil _ -> {0, Enum.join(parts, "\n\n")} end nil -> - {1, nil} + nil end end + defp handle_warm_otp_modules(_payload) do + :ok = Dexter.CodeIntelCache.warm_otp_modules() + {0, <<>>} + end + + defp handle_erlang_exports(payload) do + case parse_module(payload) do + {:ok, module_name} -> + mod_atom = String.to_atom(module_name) + export_params = export_param_names(mod_atom) + + exports = + case :code.ensure_loaded(mod_atom) do + {:module, _} -> + mod_atom.module_info(:exports) + |> Enum.reject(fn {f, _} -> f in [:module_info, :behaviour_info] end) + + _ -> + [] + end + + exports_payload = + for {func, arity} <- exports, into: <<>> do + func_str = to_string(func) + params = Map.get(export_params, {func, arity}, "") + + <> + end + + {0, <>} + + :error -> + {1, "invalid erlang_exports payload"} + end + end + + defp export_param_names(module_atom) do + case Code.fetch_docs(module_atom) do + {:docs_v1, _, :erlang, _format, _module_doc, _metadata, docs} -> + Enum.reduce(docs, %{}, fn + {{:function, name, arity}, _anno, signatures, _doc, _meta}, acc -> + case signature_params(signatures, arity) do + "" -> acc + params -> Map.put(acc, {name, arity}, params) + end + + _other, acc -> + acc + end) + + _ -> + %{} + end + end + + defp signature_params(signatures, arity) when is_list(signatures) do + signatures + |> Enum.find_value("", fn + sig when is_binary(sig) -> + case extract_signature_args(sig) do + {:ok, args} -> + params = + args + |> split_signature_args() + |> Enum.with_index(1) + |> Enum.map(fn {param, index} -> normalize_signature_param(param, index) end) + + if length(params) == arity, do: Enum.join(params, ","), else: nil + + :error -> + nil + end + + _ -> + nil + end) + end + + defp signature_params(_, _arity), do: "" + + defp extract_signature_args(signature) do + case :binary.match(signature, "(") do + {start, 1} -> + rest = binary_part(signature, start + 1, byte_size(signature) - start - 1) + collect_signature_args(rest, 0, []) + + :nomatch -> + :error + end + end + + defp collect_signature_args(<<>>, _depth, _acc), do: :error + + defp collect_signature_args(<<")", _rest::binary>>, 0, acc) do + {:ok, acc |> Enum.reverse() |> IO.iodata_to_binary()} + end + + defp collect_signature_args(<<"(", rest::binary>>, depth, acc), + do: collect_signature_args(rest, depth + 1, ["(" | acc]) + + defp collect_signature_args(<<")", rest::binary>>, depth, acc), + do: collect_signature_args(rest, depth - 1, [")" | acc]) + + defp collect_signature_args(<>, depth, acc), + do: collect_signature_args(rest, depth, [<> | acc]) + + defp split_signature_args(""), do: [] + + defp split_signature_args(args) do + {parts, current, _depths} = + args + |> String.to_charlist() + |> Enum.reduce({[], [], {0, 0, 0}}, fn char, {parts, current, {paren, bracket, brace}} -> + case char do + ?, when paren == 0 and bracket == 0 and brace == 0 -> + part = current |> Enum.reverse() |> to_string() |> String.trim() + {[part | parts], [], {paren, bracket, brace}} + + ?( -> + {parts, [char | current], {paren + 1, bracket, brace}} + + ?) -> + {parts, [char | current], {paren - 1, bracket, brace}} + + ?[ -> + {parts, [char | current], {paren, bracket + 1, brace}} + + ?] -> + {parts, [char | current], {paren, bracket - 1, brace}} + + ?{ -> + {parts, [char | current], {paren, bracket, brace + 1}} + + ?} -> + {parts, [char | current], {paren, bracket, brace - 1}} + + _ -> + {parts, [char | current], {paren, bracket, brace}} + end + end) + + last = current |> Enum.reverse() |> to_string() |> String.trim() + + (if last == "", do: parts, else: [last | parts]) + |> Enum.reverse() + end + + defp normalize_signature_param(param, index) do + name = + Regex.scan(~r/[A-Za-z][A-Za-z0-9_]*/, param) + |> List.flatten() + |> Enum.map(&Macro.underscore/1) + |> Enum.join("_") + |> String.replace(~r/_+/, "_") + |> String.trim("_") + + if name == "", do: "arg#{index}", else: name + end + + defp handle_runtime_info(_payload) do + otp_release = :erlang.system_info(:otp_release) |> to_string() + code_root_dir = :code.root_dir() |> to_string() + + {0, + <>} + end + + defp parse_module_function_arity(payload) do + with <> <- payload, + {:ok, module_name, rest} <- take_string(rest, module_len), + <> <- rest, + {:ok, function_name, rest} <- take_string(rest, function_len), + <> <- rest do + {:ok, module_name, function_name, arity} + else + _ -> :error + end + end + + defp parse_module(payload) do + with <> <- payload, + {:ok, module_name, <<>>} <- take_string(rest, module_len) do + {:ok, module_name} + else + _ -> :error + end + end + + defp take_string(binary, size) when byte_size(binary) >= size do + <> = binary + {:ok, value, rest} + end + + defp take_string(_binary, _size), do: :error + defp format_signatures(signatures, function_atom, arity) when is_list(signatures) do case signatures do [sig | _] when is_binary(sig) -> sig @@ -461,77 +820,168 @@ end # Main IO Loop defmodule Dexter.Loop do + @frame_request 0 @service_formatter 0 @service_code_intel 1 + @formatter_op_format 0 - def run(first_call?) do - if first_call?, do: IO.binwrite(:stdio, <<0, 0, 0, 0, 0>>) + def run do + case read_request_frame() do + {:ok, req_id, service, op, payload} -> + dispatch_request(req_id, service, op, payload) + run() - case IO.binread(:stdio, 1) do - <<@service_formatter>> -> - case handle_format_request() do - :ok -> run(false) - :eof -> :ok - end - - <<@service_code_intel>> -> - case Dexter.CodeIntel.handle_request() do - :ok -> run(false) - :eof -> :ok - end + :eof -> + :ok - _ -> + {:error, reason} -> + :erlang.display({:beam_loop_read_error, reason}) :ok end end - defp handle_format_request do - case IO.binread(:stdio, 2) do - <> -> - config_path = if config_path_len > 0, do: IO.binread(:stdio, config_path_len), else: "" - <> = IO.binread(:stdio, 2) - filename = if filename_len > 0, do: IO.binread(:stdio, filename_len), else: "" - <> = IO.binread(:stdio, 4) - content = IO.binread(:stdio, content_len) + defp dispatch_request(req_id, service, op, payload) do + {:ok, _pid} = + Task.Supervisor.start_child(Dexter.TaskSup, fn -> + {status, response_payload} = + try do + case {service, op} do + {@service_formatter, @formatter_op_format} -> + handle_format_request(payload) + + {@service_code_intel, _} -> + Dexter.CodeIntel.handle_request(op, payload) + + _ -> + {1, "unknown request #{service}/#{op}"} + end + rescue + error -> + :erlang.display({:beam_request_crash, service, op, error, __STACKTRACE__}) + {1, Exception.message(error)} + catch + kind, reason -> + :erlang.display({:beam_request_crash, service, op, kind, reason}) + {1, inspect({kind, reason})} + end - # Start a formatter for this config if we haven't seen it before - ensure_formatter(config_path) + Dexter.Writer.send_response(req_id, status, response_payload) + end) - {status, result} = Dexter.Formatter.format(config_path, content, filename) - size = byte_size(result) - IO.binwrite(:stdio, <>) - :ok + :ok + end - _ -> - :eof + defp read_request_frame do + with {:ok, @frame_request} <- read_byte(), + {:ok, <>} <- read_exact(4), + {:ok, <>} <- read_exact(6), + {:ok, payload} <- read_exact(payload_len) do + {:ok, req_id, service, op, payload} + else + :eof -> :eof + {:error, :eof} -> :eof + {:ok, other} -> {:error, {:unexpected_frame, other}} + {:error, reason} -> {:error, reason} + end + end + + defp read_byte do + case IO.binread(:stdio, 1) do + :eof -> :eof + <> -> {:ok, byte} + other -> {:error, {:bad_read, other}} + end + end + + defp read_exact(0), do: {:ok, <<>>} + + defp read_exact(size) when size > 0 do + case IO.binread(:stdio, size) do + :eof -> + {:error, :eof} + + data when is_binary(data) and byte_size(data) == size -> + {:ok, data} + + data when is_binary(data) -> + {:error, {:short_read, size, byte_size(data)}} + + other -> + {:error, {:bad_read, other}} + end + end + + defp handle_format_request(payload) do + case parse_format_payload(payload) do + {:ok, config_path, filename, content} -> + with :ok <- ensure_formatter(config_path) do + Dexter.Formatter.format(config_path, content, filename) + else + {:error, reason} -> {1, format_formatter_start_error(reason)} + end + + :error -> + {1, "invalid format payload"} end end + defp parse_format_payload(payload) do + with <> <- payload, + {:ok, config_path, rest} <- take_string(rest, config_path_len), + <> <- rest, + {:ok, filename, rest} <- take_string(rest, filename_len), + <> <- rest, + {:ok, content, <<>>} <- take_string(rest, content_len) do + {:ok, config_path, filename, content} + else + _ -> :error + end + end + + defp take_string(binary, size) when byte_size(binary) >= size do + <> = binary + {:ok, value, rest} + end + + defp take_string(_binary, _size), do: :error + defp ensure_formatter(config_path) do case Registry.lookup(Dexter.FormatterRegistry, config_path) do [{_pid, _}] -> :ok [] -> - DynamicSupervisor.start_child( - Dexter.FormatterSup, - {Dexter.Formatter, config_path} - ) - - :ok + case DynamicSupervisor.start_child( + Dexter.FormatterSup, + {Dexter.Formatter, config_path} + ) do + {:ok, _pid} -> :ok + {:error, {:already_started, _pid}} -> :ok + {:error, reason} -> {:error, reason} + end end end + + defp format_formatter_start_error({%_{} = error, _stacktrace}) do + Exception.message(error) + end + + defp format_formatter_start_error(reason), do: inspect(reason) end # Boot {:ok, _} = Registry.start_link(keys: :unique, name: Dexter.FormatterRegistry) {:ok, _} = DynamicSupervisor.start_link(strategy: :one_for_one, name: Dexter.FormatterSup) +{:ok, _} = Task.Supervisor.start_link(name: Dexter.TaskSup) +{:ok, _} = Dexter.Writer.start_link() +{:ok, _} = Dexter.CodeIntelCache.start_link() IO.puts(:stderr, "Dexter BEAM: started (pid #{System.pid()})") try do - Dexter.Loop.run(true) + :ok = Dexter.Writer.send_ready(0) + Dexter.Loop.run() rescue e -> IO.puts(:stderr, "Dexter BEAM: crash in loop: #{Exception.message(e)}") diff --git a/internal/lsp/elixir.go b/internal/lsp/elixir.go index 2b0bdfa..cb9005b 100644 --- a/internal/lsp/elixir.go +++ b/internal/lsp/elixir.go @@ -372,6 +372,11 @@ func ExtractCompletionContext(line string, col int) (prefix string, afterDot boo start-- } + // Include a leading colon for Erlang module references (:lists, :ets, etc.) + if start > 0 && line[start-1] == ':' { + start-- + } + raw := line[start : end+1] // Trim trailing dots — "Foo." means afterDot=true, prefix="Foo" diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index a49d460..b0d05d8 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -1251,6 +1251,34 @@ func TestExtractCompletionContext(t *testing.T) { wantPrefix: "Enum.map", wantAfterDot: false, }, + { + name: "erlang module prefix", + line: " :lis", + col: 6, + wantPrefix: ":lis", + wantAfterDot: false, + }, + { + name: "erlang module dot", + line: " :lists.", + col: 9, + wantPrefix: ":lists", + wantAfterDot: true, + }, + { + name: "erlang module function prefix", + line: " :lists.fla", + col: 12, + wantPrefix: ":lists.fla", + wantAfterDot: false, + }, + { + name: "bare colon — no completion", + line: " :", + col: 3, + wantPrefix: "", + wantAfterDot: false, + }, } for _, tt := range tests { diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index 05895c9..9e523e9 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -36,17 +36,51 @@ const ( // Service tags for the BEAM server protocol serviceFormatter byte = 0x00 serviceCodeIntel byte = 0x01 + + // Frame types for the multiplexed BEAM protocol + frameRequest byte = 0x00 + frameResponse byte = 0x01 + frameNotification byte = 0x02 + frameReady byte = 0x03 + + formatterOpFormat byte = 0x00 + + codeIntelOpErlangSource byte = 0x00 + codeIntelOpErlangDocs byte = 0x01 + codeIntelOpWarmOTPModules byte = 0x02 + codeIntelOpErlangExports byte = 0x03 + codeIntelOpRuntimeInfo byte = 0x04 + + beamNotificationOTPModulesReady byte = 0x00 + beamNotificationOTPModulesFailed byte = 0x01 ) type beamProcess struct { cmd *commandHandle stdin io.WriteCloser stdout io.ReadCloser - mu sync.Mutex + stderr *bytes.Buffer // rolling stderr capture for crash diagnostics + writeMu sync.Mutex + pendingMu sync.Mutex + pending map[uint32]chan beamResponse + nextReqID uint32 startedAt time.Time // when the process was launched ready chan struct{} // closed when the BEAM has sent the ready signal startErr error // non-nil if startup failed; set before ready is closed + startOnce sync.Once closed chan struct{} // closed by Close(); makes alive() return false immediately + notify func(beamNotification) +} + +type beamResponse struct { + status byte + payload []byte + err error +} + +type beamNotification struct { + op byte + payload []byte } // commandHandle wraps the process so we can check liveness. @@ -66,6 +100,33 @@ func (bp *beamProcess) alive() bool { } } +// recentStderr returns the tail of the captured stderr buffer (up to 512 bytes) +// for inclusion in error messages when the BEAM process crashes. +func (bp *beamProcess) recentStderr() string { + if bp.stderr == nil { + return "" + } + b := bp.stderr.Bytes() + if len(b) > 512 { + b = b[len(b)-512:] + } + return strings.TrimSpace(string(b)) +} + +// wrapError annotates a read/write error with recent stderr output if the +// BEAM process has died, making crash diagnostics visible in log messages. +func (bp *beamProcess) wrapError(err error) error { + if err == nil { + return nil + } + if !bp.alive() { + if stderr := bp.recentStderr(); stderr != "" { + return fmt.Errorf("%w\nBEAM stderr:\n%s", err, stderr) + } + } + return err +} + // Ready blocks until the process has finished startup. Returns startErr if // the BEAM failed to initialize, or ctx.Err() if the caller gives up first. func (bp *beamProcess) Ready(ctx context.Context) error { @@ -77,66 +138,255 @@ func (bp *beamProcess) Ready(ctx context.Context) error { } } -// Format sends a format request to the BEAM process. The formatterExs path -// tells the BEAM which .formatter.exs config to use (starting a new formatter -// child if needed). -func (bp *beamProcess) Format(ctx context.Context, content, filename, formatterExs string) (string, error) { - bp.mu.Lock() - defer bp.mu.Unlock() +func (bp *beamProcess) finishStartup(err error) { + bp.startOnce.Do(func() { + bp.startErr = err + close(bp.ready) + }) +} - configPathBytes := []byte(formatterExs) - filenameBytes := []byte(filename) - contentBytes := []byte(content) - var req bytes.Buffer - req.WriteByte(serviceFormatter) - _ = binary.Write(&req, binary.BigEndian, uint16(len(configPathBytes))) - req.Write(configPathBytes) - _ = binary.Write(&req, binary.BigEndian, uint16(len(filenameBytes))) - req.Write(filenameBytes) - _ = binary.Write(&req, binary.BigEndian, uint32(len(contentBytes))) - req.Write(contentBytes) - if _, err := bp.stdin.Write(req.Bytes()); err != nil { - return "", fmt.Errorf("write request: %w", err) - } - - type readResult struct { - text string - err error - } - ch := make(chan readResult, 1) - go func() { - var status byte - if err := binary.Read(bp.stdout, binary.BigEndian, &status); err != nil { - ch <- readResult{err: fmt.Errorf("read status: %w", err)} - return - } - var respLen uint32 - if err := binary.Read(bp.stdout, binary.BigEndian, &respLen); err != nil { - ch <- readResult{err: fmt.Errorf("read length: %w", err)} - return - } - buf := make([]byte, respLen) - if _, err := io.ReadFull(bp.stdout, buf); err != nil { - ch <- readResult{err: fmt.Errorf("read data: %w", err)} +func (bp *beamProcess) addPending() (uint32, chan beamResponse) { + bp.pendingMu.Lock() + defer bp.pendingMu.Unlock() + + bp.nextReqID++ + reqID := bp.nextReqID + respCh := make(chan beamResponse, 1) + if bp.pending == nil { + bp.pending = make(map[uint32]chan beamResponse) + } + bp.pending[reqID] = respCh + return reqID, respCh +} + +func (bp *beamProcess) removePending(reqID uint32) chan beamResponse { + bp.pendingMu.Lock() + defer bp.pendingMu.Unlock() + + respCh := bp.pending[reqID] + delete(bp.pending, reqID) + return respCh +} + +func (bp *beamProcess) failPending(err error) { + err = bp.wrapError(err) + + bp.pendingMu.Lock() + pending := bp.pending + bp.pending = make(map[uint32]chan beamResponse) + bp.pendingMu.Unlock() + + for _, respCh := range pending { + respCh <- beamResponse{err: err} + } +} + +func readByte(r io.Reader) (byte, error) { + var value [1]byte + if _, err := io.ReadFull(r, value[:]); err != nil { + return 0, err + } + return value[0], nil +} + +func readUint32(r io.Reader) (uint32, error) { + var value uint32 + if err := binary.Read(r, binary.BigEndian, &value); err != nil { + return 0, err + } + return value, nil +} + +func readPayload(r io.Reader, size uint32) ([]byte, error) { + buf := make([]byte, size) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return buf, nil +} + +func readStatusPayload(r io.Reader) (byte, []byte, error) { + status, err := readByte(r) + if err != nil { + return 0, nil, err + } + size, err := readUint32(r) + if err != nil { + return 0, nil, err + } + payload, err := readPayload(r, size) + if err != nil { + return 0, nil, err + } + return status, payload, nil +} + +func (bp *beamProcess) readLoop() { + for { + frameType, err := readByte(bp.stdout) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { + err = fmt.Errorf("read frame type: %w", err) + } + bp.finishStartup(fmt.Errorf("BEAM read loop: %w", err)) + bp.failPending(err) return } - if status != 0 { - ch <- readResult{err: &FormatError{Message: string(buf)}} + + switch frameType { + case frameReady: + status, payload, err := readStatusPayload(bp.stdout) + if err != nil { + startErr := fmt.Errorf("read ready frame: %w", err) + bp.finishStartup(startErr) + bp.failPending(startErr) + return + } + if status != 0 { + msg := "BEAM failed to initialize" + if len(payload) > 0 { + msg = fmt.Sprintf("%s: %s", msg, strings.TrimSpace(string(payload))) + } + startErr := errors.New(msg) + bp.finishStartup(startErr) + bp.failPending(startErr) + return + } + log.Printf("BEAM: started persistent process (pid %d)", bp.cmd.process.Pid) + bp.finishStartup(nil) + + case frameResponse: + reqID, err := readUint32(bp.stdout) + if err != nil { + respErr := fmt.Errorf("read response request id: %w", err) + bp.finishStartup(respErr) + bp.failPending(respErr) + return + } + status, payload, err := readStatusPayload(bp.stdout) + if err != nil { + respErr := fmt.Errorf("read response payload: %w", err) + bp.finishStartup(respErr) + bp.failPending(respErr) + return + } + if respCh := bp.removePending(reqID); respCh != nil { + respCh <- beamResponse{status: status, payload: payload} + } + + case frameNotification: + op, err := readByte(bp.stdout) + if err != nil { + notifErr := fmt.Errorf("read notification op: %w", err) + bp.finishStartup(notifErr) + bp.failPending(notifErr) + return + } + size, err := readUint32(bp.stdout) + if err != nil { + notifErr := fmt.Errorf("read notification payload length: %w", err) + bp.finishStartup(notifErr) + bp.failPending(notifErr) + return + } + payload, err := readPayload(bp.stdout, size) + if err != nil { + notifErr := fmt.Errorf("read notification payload: %w", err) + bp.finishStartup(notifErr) + bp.failPending(notifErr) + return + } + if bp.notify != nil { + bp.notify(beamNotification{op: op, payload: payload}) + } + + default: + protocolErr := fmt.Errorf("unexpected BEAM frame type: %d", frameType) + bp.finishStartup(protocolErr) + bp.failPending(protocolErr) return } - ch <- readResult{text: string(buf)} - }() + } +} + +// doRequest sends a framed request to the BEAM and waits for the matching +// response. The permanent read loop demultiplexes responses by request ID and +// routes notifications to the cache layer. +func (bp *beamProcess) doRequest(ctx context.Context, service, op byte, payload []byte, handleResp func(status byte, payload []byte) error) error { + if ctx.Err() != nil { + return ctx.Err() + } + + reqID, respCh := bp.addPending() + + bp.writeMu.Lock() + if ctx.Err() != nil { + bp.writeMu.Unlock() + bp.removePending(reqID) + return ctx.Err() + } + if !bp.alive() { + bp.writeMu.Unlock() + bp.removePending(reqID) + return bp.wrapError(fmt.Errorf("BEAM process is not alive")) + } + + var frame bytes.Buffer + frame.WriteByte(frameRequest) + _ = binary.Write(&frame, binary.BigEndian, reqID) + frame.WriteByte(service) + frame.WriteByte(op) + _ = binary.Write(&frame, binary.BigEndian, uint32(len(payload))) + frame.Write(payload) + + _, err := bp.stdin.Write(frame.Bytes()) + bp.writeMu.Unlock() + if err != nil { + bp.removePending(reqID) + return bp.wrapError(fmt.Errorf("write request: %w", err)) + } select { - case r := <-ch: - return r.text, r.err + case resp := <-respCh: + if resp.err != nil { + return resp.err + } + return handleResp(resp.status, resp.payload) case <-ctx.Done(): - _ = bp.cmd.process.Kill() - <-ch - return "", ctx.Err() + bp.removePending(reqID) + return ctx.Err() + case <-bp.closed: + bp.removePending(reqID) + return bp.wrapError(fmt.Errorf("BEAM process closed")) } } +// Format sends a format request to the BEAM process. The formatterExs path +// tells the BEAM which .formatter.exs config to use (starting a new formatter +// child if needed). +func (bp *beamProcess) Format(ctx context.Context, content, filename, formatterExs string) (string, error) { + var result string + configPathBytes := []byte(formatterExs) + filenameBytes := []byte(filename) + contentBytes := []byte(content) + var payload bytes.Buffer + _ = binary.Write(&payload, binary.BigEndian, uint16(len(configPathBytes))) + payload.Write(configPathBytes) + _ = binary.Write(&payload, binary.BigEndian, uint16(len(filenameBytes))) + payload.Write(filenameBytes) + _ = binary.Write(&payload, binary.BigEndian, uint32(len(contentBytes))) + payload.Write(contentBytes) + + err := bp.doRequest(ctx, serviceFormatter, formatterOpFormat, payload.Bytes(), func(status byte, payload []byte) error { + if status != 0 { + return &FormatError{Message: string(payload)} + } + result = string(payload) + return nil + }) + return result, err +} + // ErlangSourceResult holds the resolved source location for an Erlang function. type ErlangSourceResult struct { File string @@ -146,133 +396,198 @@ type ErlangSourceResult struct { // ErlangSource asks the BEAM's CodeIntel service to resolve an Erlang module/function // to its source file and line number. func (bp *beamProcess) ErlangSource(ctx context.Context, module, function string, arity int) (*ErlangSourceResult, error) { - bp.mu.Lock() - defer bp.mu.Unlock() - - moduleBytes := []byte(module) - functionBytes := []byte(function) - arityByte := byte(255) // 255 = unspecified + var result *ErlangSourceResult + arityByte := byte(255) if arity >= 0 && arity < 255 { arityByte = byte(arity) } - var req bytes.Buffer - req.WriteByte(serviceCodeIntel) // service tag - req.WriteByte(0) // op: erlang_source - _ = binary.Write(&req, binary.BigEndian, uint16(len(moduleBytes))) - req.Write(moduleBytes) - _ = binary.Write(&req, binary.BigEndian, uint16(len(functionBytes))) - req.Write(functionBytes) - req.WriteByte(arityByte) - if _, err := bp.stdin.Write(req.Bytes()); err != nil { - return nil, fmt.Errorf("write code_intel request: %w", err) - } + var payload bytes.Buffer + _ = binary.Write(&payload, binary.BigEndian, uint16(len(module))) + payload.WriteString(module) + _ = binary.Write(&payload, binary.BigEndian, uint16(len(function))) + payload.WriteString(function) + payload.WriteByte(arityByte) - type readResult struct { - result *ErlangSourceResult - err error - } - ch := make(chan readResult, 1) - go func() { - var status byte - if err := binary.Read(bp.stdout, binary.BigEndian, &status); err != nil { - ch <- readResult{err: fmt.Errorf("read status: %w", err)} - return - } + err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpErlangSource, payload.Bytes(), func(status byte, payload []byte) error { + reader := bytes.NewReader(payload) var fileLen uint16 - if err := binary.Read(bp.stdout, binary.BigEndian, &fileLen); err != nil { - ch <- readResult{err: fmt.Errorf("read file length: %w", err)} - return + if err := binary.Read(reader, binary.BigEndian, &fileLen); err != nil { + return fmt.Errorf("read file length: %w", err) } fileBuf := make([]byte, fileLen) - if _, err := io.ReadFull(bp.stdout, fileBuf); err != nil { - ch <- readResult{err: fmt.Errorf("read file: %w", err)} - return + if _, err := io.ReadFull(reader, fileBuf); err != nil { + return fmt.Errorf("read file: %w", err) } var line uint32 - if err := binary.Read(bp.stdout, binary.BigEndian, &line); err != nil { - ch <- readResult{err: fmt.Errorf("read line: %w", err)} - return + if err := binary.Read(reader, binary.BigEndian, &line); err != nil { + return fmt.Errorf("read line: %w", err) } if status != 0 { - ch <- readResult{err: fmt.Errorf("erlang source not found")} - return + return fmt.Errorf("erlang source not found") } - ch <- readResult{result: &ErlangSourceResult{File: string(fileBuf), Line: int(line)}} - }() - - select { - case r := <-ch: - return r.result, r.err - case <-ctx.Done(): - _ = bp.cmd.process.Kill() - <-ch - return nil, ctx.Err() - } + result = &ErlangSourceResult{File: string(fileBuf), Line: int(line)} + return nil + }) + return result, err } // ErlangDocs asks the BEAM's CodeIntel service for the documentation of an // Erlang module or function. Returns pre-formatted markdown, or empty string // if no docs are available (e.g. OTP < 24 or undocumented function). func (bp *beamProcess) ErlangDocs(ctx context.Context, module, function string, arity int) (string, error) { - bp.mu.Lock() - defer bp.mu.Unlock() - - moduleBytes := []byte(module) - functionBytes := []byte(function) + var doc string arityByte := byte(255) if arity >= 0 && arity < 255 { arityByte = byte(arity) } - var req bytes.Buffer - req.WriteByte(serviceCodeIntel) - req.WriteByte(1) // op: erlang_docs - _ = binary.Write(&req, binary.BigEndian, uint16(len(moduleBytes))) - req.Write(moduleBytes) - _ = binary.Write(&req, binary.BigEndian, uint16(len(functionBytes))) - req.Write(functionBytes) - req.WriteByte(arityByte) - if _, err := bp.stdin.Write(req.Bytes()); err != nil { - return "", fmt.Errorf("write code_intel request: %w", err) - } + var payload bytes.Buffer + _ = binary.Write(&payload, binary.BigEndian, uint16(len(module))) + payload.WriteString(module) + _ = binary.Write(&payload, binary.BigEndian, uint16(len(function))) + payload.WriteString(function) + payload.WriteByte(arityByte) - type readResult struct { - doc string - err error - } - ch := make(chan readResult, 1) - go func() { - var status byte - if err := binary.Read(bp.stdout, binary.BigEndian, &status); err != nil { - ch <- readResult{err: fmt.Errorf("read status: %w", err)} - return - } + err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpErlangDocs, payload.Bytes(), func(status byte, payload []byte) error { + reader := bytes.NewReader(payload) var docLen uint32 - if err := binary.Read(bp.stdout, binary.BigEndian, &docLen); err != nil { - ch <- readResult{err: fmt.Errorf("read doc length: %w", err)} - return + if err := binary.Read(reader, binary.BigEndian, &docLen); err != nil { + return fmt.Errorf("read doc length: %w", err) } docBuf := make([]byte, docLen) - if _, err := io.ReadFull(bp.stdout, docBuf); err != nil { - ch <- readResult{err: fmt.Errorf("read doc: %w", err)} - return + if _, err := io.ReadFull(reader, docBuf); err != nil { + return fmt.Errorf("read doc: %w", err) + } + if status == 0 { + doc = string(docBuf) } + return nil + }) + return doc, err +} + +// ErlangExport represents a single exported function from an Erlang module. +type ErlangExport struct { + Function string + Arity int + Params string +} + +// ErlangRuntimeInfo identifies the BEAM runtime backing a process. +type ErlangRuntimeInfo struct { + OTPRelease string + CodeRootDir string +} + +// ErlangRuntimeInfo asks the BEAM's CodeIntel service for a stable runtime +// fingerprint. This lets the LSP share OTP completion caches across build +// roots that resolve to the same OTP install. +func (bp *beamProcess) ErlangRuntimeInfo(ctx context.Context) (*ErlangRuntimeInfo, error) { + var info *ErlangRuntimeInfo + err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpRuntimeInfo, nil, func(status byte, payload []byte) error { if status != 0 { - ch <- readResult{doc: ""} - return + if len(payload) > 0 { + return fmt.Errorf("runtime info failed: %s", strings.TrimSpace(string(payload))) + } + return fmt.Errorf("runtime info failed") } - ch <- readResult{doc: string(docBuf)} - }() - select { - case r := <-ch: - return r.doc, r.err - case <-ctx.Done(): - _ = bp.cmd.process.Kill() - <-ch - return "", ctx.Err() - } + reader := bytes.NewReader(payload) + var releaseLen uint16 + if err := binary.Read(reader, binary.BigEndian, &releaseLen); err != nil { + return fmt.Errorf("read otp release length: %w", err) + } + releaseBuf := make([]byte, releaseLen) + if _, err := io.ReadFull(reader, releaseBuf); err != nil { + return fmt.Errorf("read otp release: %w", err) + } + + var rootLen uint16 + if err := binary.Read(reader, binary.BigEndian, &rootLen); err != nil { + return fmt.Errorf("read code root length: %w", err) + } + rootBuf := make([]byte, rootLen) + if _, err := io.ReadFull(reader, rootBuf); err != nil { + return fmt.Errorf("read code root: %w", err) + } + + info = &ErlangRuntimeInfo{ + OTPRelease: string(releaseBuf), + CodeRootDir: string(rootBuf), + } + return nil + }) + return info, err +} + +// WarmOTPModuleNames asks the BEAM's CodeIntel service to ensure OTP Erlang +// modules are loaded. Completion data is pushed back asynchronously via a +// notification frame once the background warmup finishes. +func (bp *beamProcess) WarmOTPModuleNames(ctx context.Context) error { + return bp.doRequest(ctx, serviceCodeIntel, codeIntelOpWarmOTPModules, nil, func(status byte, payload []byte) error { + if status != 0 { + if len(payload) > 0 { + return fmt.Errorf("warm OTP modules: %s", strings.TrimSpace(string(payload))) + } + return fmt.Errorf("warm OTP modules failed") + } + return nil + }) +} + +// ErlangExports asks the BEAM's CodeIntel service for the exported functions +// of a single Erlang module. +func (bp *beamProcess) ErlangExports(ctx context.Context, module string) ([]ErlangExport, error) { + var exports []ErlangExport + var payload bytes.Buffer + _ = binary.Write(&payload, binary.BigEndian, uint16(len(module))) + payload.WriteString(module) + + err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpErlangExports, payload.Bytes(), func(status byte, payload []byte) error { + if status != 0 { + if len(payload) > 0 { + return fmt.Errorf("erlang exports failed: %s", strings.TrimSpace(string(payload))) + } + return fmt.Errorf("erlang exports failed") + } + + reader := bytes.NewReader(payload) + var exportCount uint16 + if err := binary.Read(reader, binary.BigEndian, &exportCount); err != nil { + return fmt.Errorf("read export count: %w", err) + } + exports = make([]ErlangExport, 0, exportCount) + for i := 0; i < int(exportCount); i++ { + var funcLen uint16 + if err := binary.Read(reader, binary.BigEndian, &funcLen); err != nil { + return fmt.Errorf("read func name length: %w", err) + } + funcBuf := make([]byte, funcLen) + if _, err := io.ReadFull(reader, funcBuf); err != nil { + return fmt.Errorf("read func name: %w", err) + } + var arity uint8 + if err := binary.Read(reader, binary.BigEndian, &arity); err != nil { + return fmt.Errorf("read arity: %w", err) + } + var paramsLen uint16 + if err := binary.Read(reader, binary.BigEndian, ¶msLen); err != nil { + return fmt.Errorf("read params length: %w", err) + } + paramsBuf := make([]byte, paramsLen) + if _, err := io.ReadFull(reader, paramsBuf); err != nil { + return fmt.Errorf("read params: %w", err) + } + exports = append(exports, ErlangExport{ + Function: string(funcBuf), + Arity: int(arity), + Params: string(paramsBuf), + }) + } + return nil + }) + return exports, err } // FormatError represents a formatting failure (e.g. syntax error in the source). @@ -286,16 +601,25 @@ func (e *FormatError) Error() string { } func (bp *beamProcess) Close() { - // Acquire mu to wait for any in-flight request to finish before killing - // the process — prevents broken pipe errors on concurrent format + close. - bp.mu.Lock() - defer bp.mu.Unlock() + bp.closeWithReason("caller did not provide a reason") +} + +func (bp *beamProcess) closeWithReason(reason string) { + if reason == "" { + reason = "no reason provided" + } + + bp.writeMu.Lock() + defer bp.writeMu.Unlock() select { case <-bp.closed: + return // already closed default: close(bp.closed) } + bp.finishStartup(fmt.Errorf("BEAM closed")) + log.Printf("BEAM: closing process (pid %d): %s", bp.cmd.process.Pid, reason) _ = bp.stdin.Close() _ = bp.cmd.process.Kill() } @@ -346,51 +670,33 @@ func (s *Server) startBeamProcess(buildRoot string) (*beamProcess, error) { cmd: handle, stdin: stdin, stdout: stdout, + stderr: &stderrBuf, + pending: make(map[uint32]chan beamResponse), startedAt: time.Now(), ready: make(chan struct{}), closed: make(chan struct{}), + notify: func(notification beamNotification) { + s.handleBeamNotification(buildRoot, notification) + }, } go func() { - type readyResult struct { - status byte - err error - } - readyCh := make(chan readyResult, 1) - go func() { - var status byte - if err := binary.Read(stdout, binary.BigEndian, &status); err != nil { - readyCh <- readyResult{err: err} - return - } - var readyLen uint32 - if err := binary.Read(stdout, binary.BigEndian, &readyLen); err != nil { - readyCh <- readyResult{err: err} - return - } - readyCh <- readyResult{status: status} - }() - select { - case r := <-readyCh: - if r.err != nil { - bp.startErr = fmt.Errorf("BEAM ready: %w", r.err) + case <-bp.ready: + if bp.startErr != nil { _ = cmd.Process.Kill() <-done s.notifyOTPMismatch(stderrBuf.String()) - } else if r.status != 0 { - bp.startErr = fmt.Errorf("BEAM failed to initialize (status %d)", r.status) - _ = cmd.Process.Kill() - } else { - log.Printf("BEAM: started persistent process (pid %d)", cmd.Process.Pid) } case <-time.After(beamStuckTimeout): - bp.startErr = fmt.Errorf("BEAM startup timed out") + bp.finishStartup(fmt.Errorf("BEAM startup timed out")) _ = cmd.Process.Kill() + <-done } - close(bp.ready) }() + go bp.readLoop() + return bp, nil } @@ -420,8 +726,11 @@ func (s *Server) getBeamProcess(ctx context.Context, buildRoot string) *beamProc s.beamMu.Lock() defer s.beamMu.Unlock() - if bp, ok := s.beams[buildRoot]; ok && bp.alive() { - return bp + if bp, ok := s.beams[buildRoot]; ok { + if bp.alive() { + return bp + } + log.Printf("BEAM: process for %s is dead, restarting", buildRoot) } if s.mixBin == "" { @@ -483,7 +792,7 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin select { case <-bp.ready: if bp.startErr != nil { - s.evictBeam(bp) + s.evictBeam(bp, fmt.Sprintf("formatContent: startup finished with error: %v", bp.startErr)) log.Printf("Formatting: BEAM process failed to start, falling back to mix format: %v", bp.startErr) return s.formatWithMixFormat(ctx, mixRoot, path, content) } @@ -493,7 +802,7 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin switch { case age > beamStuckTimeout: log.Printf("Formatting: BEAM process stuck (started %s ago), restarting", age.Truncate(time.Second)) - s.evictBeam(bp) + s.evictBeam(bp, fmt.Sprintf("formatContent: startup exceeded %s without becoming ready", beamStuckTimeout)) return s.formatWithMixFormat(ctx, mixRoot, path, content) case age > beamWaitTimeout: @@ -505,7 +814,7 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin if ctx.Err() != nil { return "", err } - s.evictBeam(bp) + s.evictBeam(bp, fmt.Sprintf("formatContent: Ready failed: %v", err)) log.Printf("Formatting: BEAM process failed to start, falling back to mix format: %v", err) return s.formatWithMixFormat(ctx, mixRoot, path, content) } @@ -522,8 +831,10 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin var formatErr *FormatError if errors.As(err, &formatErr) { log.Printf("Formatting: %s failed: %s", path, formatErr.Message) + } else if ctx.Err() != nil { + // Context cancelled — the BEAM is fine, the editor just moved on. } else { - s.evictBeam(bp) + s.evictBeam(bp, fmt.Sprintf("formatContent: Format request failed: %v", err)) log.Printf("Formatting: BEAM process crashed: %v", err) } return "", err @@ -533,16 +844,29 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin return result, nil } -func (s *Server) evictBeam(bp *beamProcess) { +func (s *Server) evictBeam(bp *beamProcess, reason string) { + if reason == "" { + reason = "no reason provided" + } + + buildRoot := "" s.beamMu.Lock() for key, b := range s.beams { if b == bp { delete(s.beams, key) + buildRoot = key break } } s.beamMu.Unlock() - bp.Close() + + if buildRoot != "" { + log.Printf("BEAM: evicting process for %s (pid %d): %s", buildRoot, bp.cmd.process.Pid, reason) + } else { + log.Printf("BEAM: evicting untracked process (pid %d): %s", bp.cmd.process.Pid, reason) + } + + bp.closeWithReason("evicted: " + reason) } func (s *Server) formatWithMixFormat(ctx context.Context, mixRoot, path, content string) (string, error) { @@ -720,7 +1044,7 @@ func (s *Server) closeBeams() { s.beamMu.Lock() defer s.beamMu.Unlock() for _, bp := range s.beams { - bp.Close() + bp.closeWithReason("server shutdown") } s.beams = nil } diff --git a/internal/lsp/formatter_test.go b/internal/lsp/formatter_test.go index c2e8185..05ebd8f 100644 --- a/internal/lsp/formatter_test.go +++ b/internal/lsp/formatter_test.go @@ -1,13 +1,18 @@ package lsp import ( + "bytes" "context" + "encoding/binary" + "fmt" + "io" "os" "os/exec" "path/filepath" "runtime" "strings" "testing" + "time" "go.lsp.dev/protocol" "go.lsp.dev/uri" @@ -59,6 +64,70 @@ func setupTestServerForFixture(t *testing.T, mixRoot string) (*Server, func()) { } } +func newTestBeamProcess(stdin io.WriteCloser, stdout io.ReadCloser, notify func(beamNotification)) *beamProcess { + bp := &beamProcess{ + cmd: &commandHandle{ + process: &os.Process{Pid: os.Getpid()}, + done: make(chan struct{}), + }, + stdin: stdin, + stdout: stdout, + stderr: &bytes.Buffer{}, + pending: make(map[uint32]chan beamResponse), + ready: make(chan struct{}), + closed: make(chan struct{}), + notify: notify, + } + bp.finishStartup(nil) + return bp +} + +func writeTestResponseFrame(t *testing.T, w io.Writer, reqID uint32, status byte, payload []byte) { + t.Helper() + var frame bytes.Buffer + frame.WriteByte(frameResponse) + if err := binary.Write(&frame, binary.BigEndian, reqID); err != nil { + t.Fatal(err) + } + frame.WriteByte(status) + if err := binary.Write(&frame, binary.BigEndian, uint32(len(payload))); err != nil { + t.Fatal(err) + } + frame.Write(payload) + if _, err := w.Write(frame.Bytes()); err != nil { + t.Fatal(err) + } +} + +func writeTestNotificationFrame(t *testing.T, w io.Writer, op byte, payload []byte) { + t.Helper() + var frame bytes.Buffer + frame.WriteByte(frameNotification) + frame.WriteByte(op) + if err := binary.Write(&frame, binary.BigEndian, uint32(len(payload))); err != nil { + t.Fatal(err) + } + frame.Write(payload) + if _, err := w.Write(frame.Bytes()); err != nil { + t.Fatal(err) + } +} + +func encodeTestModuleNamesPayload(t *testing.T, names []string) []byte { + t.Helper() + var payload bytes.Buffer + if err := binary.Write(&payload, binary.BigEndian, uint16(len(names))); err != nil { + t.Fatal(err) + } + for _, name := range names { + if err := binary.Write(&payload, binary.BigEndian, uint16(len(name))); err != nil { + t.Fatal(err) + } + payload.WriteString(name) + } + return payload.Bytes() +} + func TestFormatterServer_WithStylerPlugin(t *testing.T) { if _, err := exec.LookPath("mix"); err != nil { t.Skip("mix not available in PATH") @@ -480,3 +549,208 @@ func TestFindFormatterConfig_PerAppOverridesRoot(t *testing.T) { t.Errorf("expected app-level %s, got %s", appFormatter, got) } } + +func TestBeamProcess_DoRequestHandlesNotificationBeforeResponse(t *testing.T) { + reqReader, reqWriter := io.Pipe() + respReader, respWriter := io.Pipe() + + notifications := make(chan beamNotification, 1) + bp := newTestBeamProcess(reqWriter, respReader, func(notification beamNotification) { + notifications <- notification + }) + + readLoopDone := make(chan struct{}) + go func() { + bp.readLoop() + close(readLoopDone) + }() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + + frameType, err := readByte(reqReader) + if err != nil { + t.Error(err) + return + } + if frameType != frameRequest { + t.Errorf("expected request frame, got %d", frameType) + return + } + + reqID, err := readUint32(reqReader) + if err != nil { + t.Error(err) + return + } + + header := make([]byte, 6) + if _, err := io.ReadFull(reqReader, header); err != nil { + t.Error(err) + return + } + if header[0] != serviceCodeIntel || header[1] != codeIntelOpRuntimeInfo { + t.Errorf("unexpected request header service=%d op=%d", header[0], header[1]) + return + } + payloadLen := binary.BigEndian.Uint32(header[2:]) + if payloadLen != 0 { + t.Errorf("expected empty payload, got %d bytes", payloadLen) + return + } + + writeTestNotificationFrame(t, respWriter, beamNotificationOTPModulesReady, encodeTestModuleNamesPayload(t, []string{"code"})) + writeTestResponseFrame(t, respWriter, reqID, 0, []byte("ok")) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var gotResponse string + if err := bp.doRequest(ctx, serviceCodeIntel, codeIntelOpRuntimeInfo, nil, func(status byte, payload []byte) error { + if status != 0 { + t.Fatalf("expected success status, got %d", status) + } + gotResponse = string(payload) + return nil + }); err != nil { + t.Fatal(err) + } + + if gotResponse != "ok" { + t.Fatalf("expected response payload %q, got %q", "ok", gotResponse) + } + + select { + case notification := <-notifications: + if notification.op != beamNotificationOTPModulesReady { + t.Fatalf("expected otp_modules_ready notification, got %d", notification.op) + } + names, err := decodeErlangModuleNames(notification.payload) + if err != nil { + t.Fatal(err) + } + if len(names) != 1 || names[0] != "code" { + t.Fatalf("unexpected notification payload: %v", names) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for notification") + } + + <-serverDone + _ = reqWriter.Close() + _ = reqReader.Close() + _ = respWriter.Close() + <-readLoopDone +} + +func TestBeamProcess_CanceledRequestDoesNotBlockSubsequentResponses(t *testing.T) { + reqReader, reqWriter := io.Pipe() + respReader, respWriter := io.Pipe() + + bp := newTestBeamProcess(reqWriter, respReader, nil) + + readLoopDone := make(chan struct{}) + go func() { + bp.readLoop() + close(readLoopDone) + }() + + firstRead := make(chan uint32, 1) + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + + readRequest := func() (uint32, byte, error) { + frameType, err := readByte(reqReader) + if err != nil { + return 0, 0, err + } + if frameType != frameRequest { + return 0, 0, fmt.Errorf("unexpected frame type %d", frameType) + } + reqID, err := readUint32(reqReader) + if err != nil { + return 0, 0, err + } + header := make([]byte, 6) + if _, err := io.ReadFull(reqReader, header); err != nil { + return 0, 0, err + } + payloadLen := binary.BigEndian.Uint32(header[2:]) + if payloadLen > 0 { + if _, err := io.CopyN(io.Discard, reqReader, int64(payloadLen)); err != nil { + return 0, 0, err + } + } + return reqID, header[1], nil + } + + firstReqID, firstOp, err := readRequest() + if err != nil { + t.Error(err) + return + } + if firstOp != 0x11 { + t.Errorf("expected first op 0x11, got %d", firstOp) + return + } + firstRead <- firstReqID + + secondReqID, secondOp, err := readRequest() + if err != nil { + t.Error(err) + return + } + if secondOp != 0x12 { + t.Errorf("expected second op 0x12, got %d", secondOp) + return + } + + writeTestResponseFrame(t, respWriter, secondReqID, 0, []byte("second")) + writeTestResponseFrame(t, respWriter, firstReqID, 0, []byte("first")) + }() + + firstCtx, cancelFirst := context.WithCancel(context.Background()) + firstErrCh := make(chan error, 1) + go func() { + firstErrCh <- bp.doRequest(firstCtx, serviceCodeIntel, 0x11, nil, func(status byte, payload []byte) error { + return fmt.Errorf("canceled request should not receive a response: status=%d payload=%q", status, string(payload)) + }) + }() + + firstReqID := <-firstRead + if firstReqID == 0 { + t.Fatal("expected non-zero first request id") + } + cancelFirst() + + secondCtx, cancelSecond := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelSecond() + + var gotSecond string + if err := bp.doRequest(secondCtx, serviceCodeIntel, 0x12, nil, func(status byte, payload []byte) error { + if status != 0 { + t.Fatalf("expected success status, got %d", status) + } + gotSecond = string(payload) + return nil + }); err != nil { + t.Fatal(err) + } + + if gotSecond != "second" { + t.Fatalf("expected second response payload %q, got %q", "second", gotSecond) + } + + if err := <-firstErrCh; err != context.Canceled { + t.Fatalf("expected first request to return context.Canceled, got %v", err) + } + + <-serverDone + _ = reqWriter.Close() + _ = reqReader.Close() + _ = respWriter.Close() + <-readLoopDone +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index dc0a0e2..2f00296 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -2,7 +2,9 @@ package lsp import ( "bufio" + "bytes" "context" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -51,6 +53,20 @@ type usingCacheEntry struct { aliases map[string]string // alias short name → full module injected by __using__ } +type erlangBuildRootState struct { + runtimeKey string + loading bool +} + +type erlangRuntimeCache struct { + otpRelease string + codeRootDir string + moduleNames map[string]bool + exports map[string][]ErlangExport + loading bool + readyCh chan struct{} +} + type Server struct { store *store.Store docs *DocumentStore @@ -66,6 +82,10 @@ type Server struct { beams map[string]*beamProcess // build root → persistent BEAM process beamMu sync.Mutex + erlangBuildRoots map[string]*erlangBuildRootState // build root → runtime resolution state + erlangRuntimeCache map[string]*erlangRuntimeCache // runtime key → cached OTP modules/exports + erlangRuntimeMu sync.Mutex + usingCache map[string]*usingCacheEntry // module name → parsed __using__ result usingCacheMu sync.RWMutex @@ -97,13 +117,15 @@ func (s *Server) debugNow() time.Time { func NewServer(s *store.Store, projectRoot string) *Server { return &Server{ - store: s, - docs: NewDocumentStore(), - projectRoot: projectRoot, - explicitRoot: projectRoot != "", - followDelegates: true, - usingCache: make(map[string]*usingCacheEntry), - depsCache: make(map[string]bool), + store: s, + docs: NewDocumentStore(), + projectRoot: projectRoot, + explicitRoot: projectRoot != "", + followDelegates: true, + erlangBuildRoots: make(map[string]*erlangBuildRootState), + erlangRuntimeCache: make(map[string]*erlangRuntimeCache), + usingCache: make(map[string]*usingCacheEntry), + depsCache: make(map[string]bool), } } @@ -435,16 +457,18 @@ func (s *Server) Exit(ctx context.Context) error { // === Document Sync === func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocumentParams) error { - s.docs.Set(string(params.TextDocument.URI), params.TextDocument.Text) + docURI := string(params.TextDocument.URI) + s.docs.Set(docURI, params.TextDocument.Text) + path := uriToPath(params.TextDocument.URI) // Eagerly start the persistent BEAM process so the first format is instant. // Skip deps and stdlib files — we don't format those. - path := uriToPath(params.TextDocument.URI) if path != "" && isFormattableFile(path) && s.isProjectFile(path) && !s.isDepsFile(path) { - go func() { - buildRoot := s.findBuildRoot(filepath.Dir(path)) + buildRoot := s.findBuildRoot(filepath.Dir(path)) + go func(path, buildRoot string) { _ = s.getBeamProcess(context.Background(), buildRoot) - }() + s.startErlangModuleLoad(path) + }(path, buildRoot) } return nil @@ -453,7 +477,8 @@ func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocume func (s *Server) DidChange(ctx context.Context, params *protocol.DidChangeTextDocumentParams) error { if len(params.ContentChanges) > 0 { // Full sync mode — last change contains the full text - s.docs.Set(string(params.TextDocument.URI), params.ContentChanges[len(params.ContentChanges)-1].Text) + text := params.ContentChanges[len(params.ContentChanges)-1].Text + s.docs.Set(string(params.TextDocument.URI), text) } return nil } @@ -473,13 +498,17 @@ func (s *Server) DidSave(ctx context.Context, params *protocol.DidSaveTextDocume // config is picked up on the next format request. if filepath.Base(path) == ".formatter.exs" { buildRoot := s.findBuildRoot(filepath.Dir(path)) + var bp *beamProcess s.beamMu.Lock() - if bp, ok := s.beams[buildRoot]; ok { + if existing, ok := s.beams[buildRoot]; ok { delete(s.beams, buildRoot) + bp = existing + } + s.beamMu.Unlock() + if bp != nil { bp.Close() log.Printf("BEAM: restarting for %s (.formatter.exs changed)", buildRoot) } - s.beamMu.Unlock() return nil } @@ -724,10 +753,137 @@ func filterOutTypes(results []store.LookupResult) []store.LookupResult { return results } +func (s *Server) erlangBuildRoot(filePath string) string { + if filePath != "" { + return s.findBuildRoot(filepath.Dir(filePath)) + } + return s.findBuildRoot(s.projectRoot) +} + +func erlangRuntimeKey(info *ErlangRuntimeInfo) string { + if info == nil { + return "" + } + return info.OTPRelease + "\x00" + info.CodeRootDir +} + +func decodeErlangModuleNames(payload []byte) ([]string, error) { + reader := bytes.NewReader(payload) + var count uint16 + if err := binary.Read(reader, binary.BigEndian, &count); err != nil { + return nil, fmt.Errorf("read module count: %w", err) + } + + names := make([]string, 0, count) + for i := 0; i < int(count); i++ { + var nameLen uint16 + if err := binary.Read(reader, binary.BigEndian, &nameLen); err != nil { + return nil, fmt.Errorf("read module name length: %w", err) + } + nameBuf := make([]byte, nameLen) + if _, err := io.ReadFull(reader, nameBuf); err != nil { + return nil, fmt.Errorf("read module name: %w", err) + } + names = append(names, string(nameBuf)) + } + + return names, nil +} + +func (s *Server) clearErlangWarmup(runtimeKey string, readyCh chan struct{}) { + var notifyCh chan struct{} + + s.erlangRuntimeMu.Lock() + cache := s.erlangRuntimeCache[runtimeKey] + if cache != nil && cache.readyCh == readyCh { + cache.loading = false + notifyCh = cache.readyCh + cache.readyCh = nil + } + s.erlangRuntimeMu.Unlock() + + if notifyCh != nil { + close(notifyCh) + } +} + +func (s *Server) completeErlangWarmup(buildRoot string, names []string) { + moduleNames := make(map[string]bool, len(names)) + for _, name := range names { + moduleNames[name] = true + } + + var notifyCh chan struct{} + s.erlangRuntimeMu.Lock() + state := s.erlangBuildRoots[buildRoot] + if state == nil || state.runtimeKey == "" { + s.erlangRuntimeMu.Unlock() + return + } + + cache := s.erlangRuntimeCache[state.runtimeKey] + if cache == nil { + cache = &erlangRuntimeCache{exports: make(map[string][]ErlangExport)} + s.erlangRuntimeCache[state.runtimeKey] = cache + } + if cache.exports == nil { + cache.exports = make(map[string][]ErlangExport) + } + cache.moduleNames = moduleNames + cache.loading = false + notifyCh = cache.readyCh + cache.readyCh = nil + s.erlangRuntimeMu.Unlock() + + if notifyCh != nil { + close(notifyCh) + } +} + +func (s *Server) failErlangWarmup(buildRoot string) { + var notifyCh chan struct{} + + s.erlangRuntimeMu.Lock() + state := s.erlangBuildRoots[buildRoot] + if state != nil && state.runtimeKey != "" { + if cache := s.erlangRuntimeCache[state.runtimeKey]; cache != nil { + cache.loading = false + notifyCh = cache.readyCh + cache.readyCh = nil + } + } + s.erlangRuntimeMu.Unlock() + + if notifyCh != nil { + close(notifyCh) + } +} + +func (s *Server) handleBeamNotification(buildRoot string, notification beamNotification) { + switch notification.op { + case beamNotificationOTPModulesReady: + names, err := decodeErlangModuleNames(notification.payload) + if err != nil { + log.Printf("failed to decode Erlang module notification: %v", err) + s.failErlangWarmup(buildRoot) + return + } + s.completeErlangWarmup(buildRoot, names) + + case beamNotificationOTPModulesFailed: + if len(notification.payload) > 0 { + log.Printf("failed to warm Erlang modules: %s", strings.TrimSpace(string(notification.payload))) + } else { + log.Printf("failed to warm Erlang modules") + } + s.failErlangWarmup(buildRoot) + } +} + // erlangHover fetches documentation for an Erlang module/function via the // BEAM process's CodeIntel service. func (s *Server) erlangHover(ctx context.Context, filePath, module, function string) (*protocol.Hover, error) { - buildRoot := s.findBuildRoot(filepath.Dir(filePath)) + buildRoot := s.erlangBuildRoot(filePath) bp := s.getBeamProcess(ctx, buildRoot) if bp == nil { return nil, nil @@ -752,7 +908,7 @@ func (s *Server) erlangHover(ctx context.Context, filePath, module, function str // erlangDefinition resolves an Erlang module/function to its .erl source via // the BEAM process's CodeIntel service. func (s *Server) erlangDefinition(ctx context.Context, filePath, module, function string) ([]protocol.Location, error) { - buildRoot := s.findBuildRoot(filepath.Dir(filePath)) + buildRoot := s.erlangBuildRoot(filePath) bp := s.getBeamProcess(ctx, buildRoot) if bp == nil { s.debugf("Definition: no BEAM process available for Erlang resolution") @@ -780,6 +936,243 @@ func (s *Server) erlangDefinition(ctx context.Context, filePath, module, functio }}, nil } +// startErlangModuleLoad kicks off a background warmup for the current file's +// runtime. Build roots first resolve to a runtime fingerprint, then runtime +// caches (OTP module names + exports) are loaded once per runtime. +func (s *Server) startErlangModuleLoad(filePath string) { + buildRoot := s.erlangBuildRoot(filePath) + + s.erlangRuntimeMu.Lock() + state, ok := s.erlangBuildRoots[buildRoot] + if !ok { + state = &erlangBuildRootState{} + s.erlangBuildRoots[buildRoot] = state + } + + if state.runtimeKey != "" { + runtimeKey := state.runtimeKey + cache := s.erlangRuntimeCache[runtimeKey] + switch { + case cache != nil && cache.moduleNames != nil: + s.erlangRuntimeMu.Unlock() + return + case cache != nil && cache.loading: + s.erlangRuntimeMu.Unlock() + return + default: + if cache == nil { + cache = &erlangRuntimeCache{exports: make(map[string][]ErlangExport)} + s.erlangRuntimeCache[runtimeKey] = cache + } + if cache.exports == nil { + cache.exports = make(map[string][]ErlangExport) + } + cache.loading = true + cache.readyCh = make(chan struct{}) + readyCh := cache.readyCh + s.erlangRuntimeMu.Unlock() + go s.loadErlangRuntimeCache(buildRoot, runtimeKey, readyCh) + return + } + } + + if state.loading { + s.erlangRuntimeMu.Unlock() + return + } + + state.loading = true + s.erlangRuntimeMu.Unlock() + go s.resolveErlangRuntime(buildRoot) +} + +func (s *Server) resolveErlangRuntime(buildRoot string) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + bp := s.getBeamProcess(ctx, buildRoot) + if bp == nil { + s.erlangRuntimeMu.Lock() + if state := s.erlangBuildRoots[buildRoot]; state != nil { + state.loading = false + } + s.erlangRuntimeMu.Unlock() + return + } + if err := bp.Ready(ctx); err != nil { + s.erlangRuntimeMu.Lock() + if state := s.erlangBuildRoots[buildRoot]; state != nil { + state.loading = false + } + s.erlangRuntimeMu.Unlock() + return + } + + info, err := bp.ErlangRuntimeInfo(ctx) + if err != nil { + s.erlangRuntimeMu.Lock() + if state := s.erlangBuildRoots[buildRoot]; state != nil { + state.loading = false + } + s.erlangRuntimeMu.Unlock() + return + } + + runtimeKey := erlangRuntimeKey(info) + if runtimeKey == "" { + s.erlangRuntimeMu.Lock() + if state := s.erlangBuildRoots[buildRoot]; state != nil { + state.loading = false + } + s.erlangRuntimeMu.Unlock() + return + } + + startLoad := false + var readyCh chan struct{} + s.erlangRuntimeMu.Lock() + state := s.erlangBuildRoots[buildRoot] + if state == nil { + state = &erlangBuildRootState{} + s.erlangBuildRoots[buildRoot] = state + } + state.runtimeKey = runtimeKey + state.loading = false + + cache := s.erlangRuntimeCache[runtimeKey] + if cache == nil { + cache = &erlangRuntimeCache{ + otpRelease: info.OTPRelease, + codeRootDir: info.CodeRootDir, + exports: make(map[string][]ErlangExport), + } + s.erlangRuntimeCache[runtimeKey] = cache + } + if cache.exports == nil { + cache.exports = make(map[string][]ErlangExport) + } + if cache.otpRelease == "" { + cache.otpRelease = info.OTPRelease + } + if cache.codeRootDir == "" { + cache.codeRootDir = info.CodeRootDir + } + if cache.moduleNames == nil && !cache.loading { + cache.loading = true + cache.readyCh = make(chan struct{}) + readyCh = cache.readyCh + startLoad = true + } + s.erlangRuntimeMu.Unlock() + + if startLoad { + go s.loadErlangRuntimeCache(buildRoot, runtimeKey, readyCh) + } +} + +func (s *Server) loadErlangRuntimeCache(buildRoot, runtimeKey string, readyCh chan struct{}) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + bp := s.getBeamProcess(ctx, buildRoot) + if bp == nil { + s.clearErlangWarmup(runtimeKey, readyCh) + return + } + if err := bp.Ready(ctx); err != nil { + s.clearErlangWarmup(runtimeKey, readyCh) + return + } + + if err := bp.WarmOTPModuleNames(ctx); err != nil { + log.Printf("failed to start OTP module warmup: %v", err) + s.clearErlangWarmup(runtimeKey, readyCh) + return + } + + select { + case <-readyCh: + return + case <-ctx.Done(): + s.clearErlangWarmup(runtimeKey, readyCh) + return + } +} + +func (s *Server) erlangModuleNamesForFile(filePath string) map[string]bool { + buildRoot := s.erlangBuildRoot(filePath) + + s.erlangRuntimeMu.Lock() + defer s.erlangRuntimeMu.Unlock() + + state := s.erlangBuildRoots[buildRoot] + if state == nil || state.runtimeKey == "" { + return nil + } + cache := s.erlangRuntimeCache[state.runtimeKey] + if cache == nil { + return nil + } + return cache.moduleNames +} + +func (s *Server) erlangModulesAvailable(filePath string) bool { + moduleNames := s.erlangModuleNamesForFile(filePath) + return moduleNames != nil +} + +// getErlangExports returns the cached exports for an Erlang module, fetching +// from the BEAM on first access. Export caches are shared across build roots +// that resolve to the same runtime fingerprint. +func (s *Server) getErlangExports(ctx context.Context, filePath, module string) []ErlangExport { + buildRoot := s.erlangBuildRoot(filePath) + + s.erlangRuntimeMu.Lock() + state := s.erlangBuildRoots[buildRoot] + if state == nil || state.runtimeKey == "" { + s.erlangRuntimeMu.Unlock() + s.startErlangModuleLoad(filePath) + return nil + } + + runtimeKey := state.runtimeKey + cache := s.erlangRuntimeCache[runtimeKey] + if cache == nil || cache.moduleNames == nil { + s.erlangRuntimeMu.Unlock() + s.startErlangModuleLoad(filePath) + return nil + } + if exports, ok := cache.exports[module]; ok { + s.erlangRuntimeMu.Unlock() + return exports + } + s.erlangRuntimeMu.Unlock() + + bp := s.getBeamProcess(ctx, buildRoot) + if bp == nil { + return nil + } + if err := bp.Ready(ctx); err != nil { + return nil + } + + exports, err := bp.ErlangExports(ctx, module) + if err != nil { + return nil + } + + s.erlangRuntimeMu.Lock() + cache = s.erlangRuntimeCache[runtimeKey] + if cache != nil { + if cache.exports == nil { + cache.exports = make(map[string][]ErlangExport) + } + cache.exports[module] = exports + } + s.erlangRuntimeMu.Unlock() + return exports +} + func lineRange(line int) protocol.Range { return protocol.Range{ Start: protocol.Position{Line: uint32(line), Character: 0}, @@ -1003,6 +1396,7 @@ func (s *Server) ColorPresentation(ctx context.Context, params *protocol.ColorPr } func (s *Server) Completion(ctx context.Context, params *protocol.CompletionParams) (*protocol.CompletionList, error) { docURI := string(params.TextDocument.URI) + filePath := uriToPath(params.TextDocument.URI) text, ok := s.docs.Get(docURI) if !ok { @@ -1068,9 +1462,70 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara Start: protocol.Position{Line: uint32(lineNum), Character: uint32(prefixStartCol)}, End: protocol.Position{Line: uint32(lineNum), Character: uint32(col)}, } + inPipe := IsPipeContext(lines[lineNum], prefixStartCol) + + // Erlang module/function completions (prefix starts with ":") + if strings.HasPrefix(prefix, ":") { + s.startErlangModuleLoad(filePath) + if !s.erlangModulesAvailable(filePath) { + return nil, nil + } + moduleNames := s.erlangModuleNamesForFile(filePath) + if moduleNames == nil { + return nil, nil + } + erlPrefix := prefix[1:] // strip leading colon + var items []protocol.CompletionItem + + if afterDot || strings.Contains(erlPrefix, ".") { + // :module.func — complete exported functions + var erlModule, erlFuncPrefix string + if dotIdx := strings.IndexByte(erlPrefix, '.'); dotIdx >= 0 { + erlModule = erlPrefix[:dotIdx] + erlFuncPrefix = erlPrefix[dotIdx+1:] + } else { + erlModule = erlPrefix + } + if moduleNames[erlModule] { + for _, e := range s.getErlangExports(ctx, filePath, erlModule) { + if strings.HasPrefix(e.Function, erlFuncPrefix) { + item := protocol.CompletionItem{ + Label: e.Function, + Kind: protocol.CompletionItemKindFunction, + Detail: fmt.Sprintf(":%s.%s/%d", erlModule, e.Function, e.Arity), + } + applySnippet(&item, e.Function, e.Arity, e.Params, inPipe, s.snippetSupport) + items = append(items, item) + } + } + } + } else { + // :mod — complete module names + for name := range moduleNames { + if strings.HasPrefix(name, erlPrefix) { + items = append(items, protocol.CompletionItem{ + Label: ":" + name, + Kind: protocol.CompletionItemKindModule, + Detail: "Erlang module", + TextEdit: &protocol.TextEdit{ + Range: prefixRange, + NewText: ":" + name, + }, + }) + } + } + } + + if len(items) == 0 { + return nil, nil + } + return &protocol.CompletionList{ + IsIncomplete: len(items) >= 100, + Items: items, + }, nil + } moduleRef, funcPrefix := ExtractModuleAndFunction(prefix) - inPipe := IsPipeContext(lines[lineNum], prefixStartCol) // "Module.func." or "variable." — dot after a function call result or // map/struct field access. We have no type info to complete the result. @@ -1194,12 +1649,17 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara } } - for _, mod := range ExtractImports(text) { + imports := ExtractImports(text) + imports = append(imports, "Kernel") + for _, mod := range imports { results, err := s.store.ListModuleFunctions(mod, true) if err != nil { continue } for _, r := range results { + if s.snippetSupport && elixirFormSnippets[r.Function] != "" { + continue + } key := funcKey(r.Function, r.Arity) if strings.HasPrefix(r.Function, funcPrefix) && !seen[key] { seen[key] = true @@ -1241,6 +1701,21 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara }) } } + + if s.snippetSupport { + for name, snippet := range elixirFormSnippets { + if strings.HasPrefix(name, funcPrefix) && !seen[name] { + seen[name] = true + items = append(items, protocol.CompletionItem{ + Label: name, + Kind: protocol.CompletionItemKindKeyword, + Detail: "special form", + InsertText: snippet, + InsertTextFormat: protocol.InsertTextFormatSnippet, + }) + } + } + } } if len(items) == 0 { @@ -1767,14 +2242,29 @@ func funcKey(name string, arity int) string { return name + "/" + strconv.Itoa(arity) } +var elixirFormSnippets = map[string]string{ + "for": "for ${1:pattern} <- ${2:enumerable} do\n\t$0\nend", + "with": "with ${1:pattern} <- ${2:expression} do\n\t$0\nend", + "case": "case ${1:expression} do\n\t${2:pattern} ->\n\t\t$0\nend", + "cond": "cond do\n\t${1:condition} ->\n\t\t$0\nend", + "if": "if ${1:condition} do\n\t$0\nend", + "unless": "unless ${1:condition} do\n\t$0\nend", + "receive": "receive do\n\t${1:pattern} ->\n\t\t$0\nend", + "try": "try do\n\t$0\nrescue\n\t${1:exception} ->\n\t\t${2:handler}\nend", + "quote": "quote do\n\t$0\nend", + "fn": "fn ${1:args} -> $0 end", +} + func applySnippet(item *protocol.CompletionItem, name string, arity int, params string, inPipe bool, useSnippets bool) { item.Label = fmt.Sprintf("%s/%d", name, arity) item.FilterText = name snippetArity := arity snippetParams := params + paramStartIndex := 1 if inPipe && arity > 0 { snippetArity-- + paramStartIndex = 2 if snippetParams != "" { if commaIdx := strings.IndexByte(snippetParams, ','); commaIdx >= 0 { snippetParams = snippetParams[commaIdx+1:] @@ -1786,7 +2276,7 @@ func applySnippet(item *protocol.CompletionItem, name string, arity int, params if !useSnippets { if snippetArity > 0 { - item.InsertText = functionCallText(name, snippetArity, snippetParams) + item.InsertText = functionCallText(name, snippetArity, snippetParams, paramStartIndex) } else { item.InsertText = name + "()" } @@ -1795,33 +2285,36 @@ func applySnippet(item *protocol.CompletionItem, name string, arity int, params if snippetArity > 0 { item.InsertTextFormat = protocol.InsertTextFormatSnippet - item.InsertText = functionSnippet(name, snippetArity, snippetParams) + item.InsertText = functionSnippet(name, snippetArity, snippetParams, paramStartIndex) } else { item.InsertText = name + "()" } } -func functionSnippet(name string, arity int, params string) string { - return buildCallText(name, arity, params, true) +func functionSnippet(name string, arity int, params string, paramStartIndex int) string { + return buildCallText(name, arity, params, true, paramStartIndex) } -func functionCallText(name string, arity int, params string) string { - return buildCallText(name, arity, params, false) +func functionCallText(name string, arity int, params string, paramStartIndex int) string { + return buildCallText(name, arity, params, false, paramStartIndex) } -func buildCallText(name string, arity int, params string, snippet bool) string { +func buildCallText(name string, arity int, params string, snippet bool, paramStartIndex int) string { + if paramStartIndex < 1 { + paramStartIndex = 1 + } var paramNames []string if params != "" { paramNames = strings.Split(params, ",") } var args []string - for i := 1; i <= arity; i++ { - paramName := fmt.Sprintf("arg%d", i) - if i-1 < len(paramNames) { - paramName = paramNames[i-1] + for i := 0; i < arity; i++ { + paramName := fmt.Sprintf("arg%d", paramStartIndex+i) + if i < len(paramNames) { + paramName = paramNames[i] } if snippet { - args = append(args, fmt.Sprintf("${%d:%s}", i, paramName)) + args = append(args, fmt.Sprintf("${%d:%s}", i+1, paramName)) } else { args = append(args, paramName) } @@ -2890,6 +3383,12 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (*prot return s.hoverFromFile(functionName, results[0]) } + // Fallback: bare identifier might be an Erlang built-in type or function + // (e.g. pos_integer, binary, term, length, is_atom) + if hover, _ := s.erlangHover(ctx, uriToPath(protocol.DocumentURI(docURI)), "erlang", functionName); hover != nil { + return hover, nil + } + return nil, nil } diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 51568a4..da614e7 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -122,10 +122,10 @@ end } } -// waitFor polls condition every 10ms until it returns true or one second elapses. +// waitFor polls condition every 10ms until it returns true or two seconds elapse. func waitFor(t *testing.T, condition func() bool) { t.Helper() - deadline := time.Now().Add(time.Second) + deadline := time.Now().Add(2 * time.Second) for time.Now().Before(deadline) { if condition() { return @@ -500,6 +500,133 @@ end } } +func TestApplySnippet_PipeGenericParamNamesPreserveOriginalIndex(t *testing.T) { + var item protocol.CompletionItem + applySnippet(&item, "call", 3, "", true, true) + + if item.Label != "call/3" { + t.Fatalf("expected label call/3, got %q", item.Label) + } + if item.FilterText != "call" { + t.Fatalf("expected filter text call, got %q", item.FilterText) + } + if item.InsertText != "call(${1:arg2}, ${2:arg3})$0" { + t.Fatalf("expected piped generic snippet to start at arg2, got %q", item.InsertText) + } + if item.InsertTextFormat != protocol.InsertTextFormatSnippet { + t.Fatalf("expected snippet insert format, got %v", item.InsertTextFormat) + } +} + +func TestCompletion_ElixirFormSnippets(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + uri := "file:///test.ex" + + tests := []struct { + prefix string + label string + snippet string + }{ + {"fo", "for", "for ${1:pattern} <- ${2:enumerable} do\n\t$0\nend"}, + {"wi", "with", "with ${1:pattern} <- ${2:expression} do\n\t$0\nend"}, + {"cas", "case", "case ${1:expression} do\n\t${2:pattern} ->\n\t\t$0\nend"}, + {"con", "cond", "cond do\n\t${1:condition} ->\n\t\t$0\nend"}, + {"i", "if", "if ${1:condition} do\n\t$0\nend"}, + {"unl", "unless", "unless ${1:condition} do\n\t$0\nend"}, + {"rec", "receive", "receive do\n\t${1:pattern} ->\n\t\t$0\nend"}, + {"tr", "try", "try do\n\t$0\nrescue\n\t${1:exception} ->\n\t\t${2:handler}\nend"}, + {"quo", "quote", "quote do\n\t$0\nend"}, + {"f", "fn", "fn ${1:args} -> $0 end"}, + } + + for _, tt := range tests { + t.Run(tt.label, func(t *testing.T) { + server.docs.Set(uri, " "+tt.prefix) + items := completionAt(t, server, uri, 0, uint32(2+len(tt.prefix))) + + var found bool + for _, item := range items { + if item.Label == tt.label { + found = true + if item.InsertText != tt.snippet { + t.Errorf("expected snippet %q, got %q", tt.snippet, item.InsertText) + } + if item.InsertTextFormat != protocol.InsertTextFormatSnippet { + t.Error("expected InsertTextFormatSnippet") + } + if item.Kind != protocol.CompletionItemKindKeyword { + t.Errorf("expected Keyword kind, got %v", item.Kind) + } + break + } + } + if !found { + t.Errorf("expected to find completion item %q", tt.label) + } + }) + } +} + +func TestCompletion_ElixirFormSnippets_NoDuplicateWithKernel(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + indexFile(t, server.store, server.projectRoot, "lib/kernel.ex", `defmodule Kernel do + defmacro if(condition, clauses) do + :ok + end + + defmacro unless(condition, clauses) do + :ok + end +end +`) + + uri := "file:///test.ex" + server.docs.Set(uri, " i") + items := completionAt(t, server, uri, 0, 3) + + var count int + for _, item := range items { + if item.Label == "if" || item.Label == "if/2" { + count++ + } + } + if count != 1 { + t.Errorf("expected exactly 1 'if' completion, got %d", count) + } + + // The one we get should be the snippet form, not the function-call form + for _, item := range items { + if item.Label == "if" { + if item.Kind != protocol.CompletionItemKindKeyword { + t.Errorf("expected Keyword kind for 'if', got %v", item.Kind) + } + if item.InsertText != elixirFormSnippets["if"] { + t.Errorf("expected form snippet for 'if', got %q", item.InsertText) + } + } + } +} + +func TestCompletion_ElixirFormSnippets_NoSnippetSupport(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + server.snippetSupport = false + + uri := "file:///test.ex" + server.docs.Set(uri, " fo") + items := completionAt(t, server, uri, 0, 4) + + for _, item := range items { + if item.Label == "for" { + t.Error("form snippets should not appear when client lacks snippet support") + } + } +} + func TestCompletion_MultiArity(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -1462,6 +1589,238 @@ end` } } +func TestCompletion_ErlangUsesFileBuildRootInMonorepo(t *testing.T) { + if _, err := exec.LookPath("mix"); err != nil { + t.Skip("mix not available in PATH") + } + + server, cleanup := setupTestServer(t) + defer cleanup() + defer server.closeBeams() + + appRoot := filepath.Join(server.projectRoot, "apps", "tiger") + if err := os.MkdirAll(filepath.Join(appRoot, "_build"), 0755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(appRoot, "lib", "tiger"), 0755); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(appRoot, "lib", "tiger", "docusign.ex") + docURI := string(uri.File(filePath)) + + server.docs.Set(docURI, " :c") + _ = completionAt(t, server, docURI, 0, 4) + + waitFor(t, func() bool { + server.beamMu.Lock() + defer server.beamMu.Unlock() + return len(server.beams) > 0 + }) + waitFor(t, func() bool { + return server.erlangModulesAvailable(filePath) + }) + + server.docs.Set(docURI, " :code.") + items := completionAt(t, server, docURI, 0, 8) + if len(items) == 0 { + t.Fatal("expected Erlang exports for :code.") + } + if !hasCompletionItem(items, "all_loaded") { + t.Fatal("expected code exports to include all_loaded") + } + + server.beamMu.Lock() + defer server.beamMu.Unlock() + var buildRoots []string + for buildRoot := range server.beams { + buildRoots = append(buildRoots, buildRoot) + } + if len(server.beams) != 1 { + t.Fatalf("expected exactly 1 BEAM process, got %d", len(server.beams)) + } + if _, ok := server.beams[appRoot]; !ok { + t.Fatalf("expected BEAM for app build root %s, got %v", appRoot, buildRoots) + } + if _, ok := server.beams[server.projectRoot]; ok { + t.Fatalf("did not expect BEAM for project root %s", server.projectRoot) + } +} + +func TestDidOpen_WarmsErlangModulesInBackground(t *testing.T) { + if _, err := exec.LookPath("mix"); err != nil { + t.Skip("mix not available in PATH") + } + + server, cleanup := setupTestServer(t) + defer cleanup() + defer server.closeBeams() + + appRoot := filepath.Join(server.projectRoot, "apps", "tiger") + if err := os.MkdirAll(filepath.Join(appRoot, "_build"), 0755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(appRoot, "lib", "tiger"), 0755); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(appRoot, "lib", "tiger", "docusign.ex") + docURI := string(uri.File(filePath)) + + err := server.DidOpen(context.Background(), &protocol.DidOpenTextDocumentParams{ + TextDocument: protocol.TextDocumentItem{ + URI: protocol.DocumentURI(docURI), + LanguageID: "elixir", + Version: 1, + Text: " :c", + }, + }) + if err != nil { + t.Fatal(err) + } + + waitFor(t, func() bool { + return server.erlangModulesAvailable(filePath) + }) + + items := completionAt(t, server, docURI, 0, 4) + if !hasCompletionItem(items, ":code") { + t.Fatal("expected background warmup to make :code available on first completion") + } +} + +func TestCompletion_ErlangFunctionSnippet(t *testing.T) { + if _, err := exec.LookPath("mix"); err != nil { + t.Skip("mix not available in PATH") + } + + server, cleanup := setupTestServer(t) + defer cleanup() + defer server.closeBeams() + + appRoot := filepath.Join(server.projectRoot, "apps", "tiger") + if err := os.MkdirAll(filepath.Join(appRoot, "_build"), 0755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(appRoot, "lib", "tiger"), 0755); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(appRoot, "lib", "tiger", "docusign.ex") + docURI := string(uri.File(filePath)) + + err := server.DidOpen(context.Background(), &protocol.DidOpenTextDocumentParams{ + TextDocument: protocol.TextDocumentItem{ + URI: protocol.DocumentURI(docURI), + LanguageID: "elixir", + Version: 1, + Text: " :ets.", + }, + }) + if err != nil { + t.Fatal(err) + } + + waitFor(t, func() bool { + return server.erlangModulesAvailable(filePath) + }) + + items := completionAt(t, server, docURI, 0, 7) + + var found bool + for _, item := range items { + if item.Label == "new/2" { + found = true + if item.FilterText != "new" { + t.Errorf("new/2: expected filter text 'new', got %q", item.FilterText) + } + if item.InsertText != "new(${1:name}, ${2:options})$0" { + t.Errorf("new/2: expected snippet insert text, got %q", item.InsertText) + } + if item.InsertTextFormat != protocol.InsertTextFormatSnippet { + t.Errorf("new/2: expected snippet format, got %v", item.InsertTextFormat) + } + if item.Detail != ":ets.new/2" { + t.Errorf("new/2: expected detail :ets.new/2, got %q", item.Detail) + } + break + } + } + if !found { + t.Fatal("expected to find Erlang completion item new/2") + } +} + +func TestDidOpen_SharesErlangRuntimeCacheAcrossBuildRoots(t *testing.T) { + if _, err := exec.LookPath("mix"); err != nil { + t.Skip("mix not available in PATH") + } + + server, cleanup := setupTestServer(t) + defer cleanup() + defer server.closeBeams() + + appOne := filepath.Join(server.projectRoot, "apps", "tiger") + appTwo := filepath.Join(server.projectRoot, "apps", "lynx") + for _, appRoot := range []string{appOne, appTwo} { + if err := os.MkdirAll(filepath.Join(appRoot, "_build"), 0755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(appRoot, "lib"), 0755); err != nil { + t.Fatal(err) + } + } + + fileOne := filepath.Join(appOne, "lib", "one.ex") + fileTwo := filepath.Join(appTwo, "lib", "two.ex") + + open := func(path, text string) { + t.Helper() + err := server.DidOpen(context.Background(), &protocol.DidOpenTextDocumentParams{ + TextDocument: protocol.TextDocumentItem{ + URI: protocol.DocumentURI(uri.File(path)), + LanguageID: "elixir", + Version: 1, + Text: text, + }, + }) + if err != nil { + t.Fatal(err) + } + } + + open(fileOne, " :c") + waitFor(t, func() bool { + return server.erlangModulesAvailable(fileOne) + }) + + open(fileTwo, " :c") + waitFor(t, func() bool { + return server.erlangModulesAvailable(fileTwo) + }) + + buildRootOne := server.erlangBuildRoot(fileOne) + buildRootTwo := server.erlangBuildRoot(fileTwo) + + server.erlangRuntimeMu.Lock() + defer server.erlangRuntimeMu.Unlock() + + stateOne := server.erlangBuildRoots[buildRootOne] + stateTwo := server.erlangBuildRoots[buildRootTwo] + if stateOne == nil || stateOne.runtimeKey == "" { + t.Fatalf("expected runtime key for %s", buildRootOne) + } + if stateTwo == nil || stateTwo.runtimeKey == "" { + t.Fatalf("expected runtime key for %s", buildRootTwo) + } + if stateOne.runtimeKey != stateTwo.runtimeKey { + t.Fatalf("expected shared runtime key, got %q and %q", stateOne.runtimeKey, stateTwo.runtimeKey) + } + if len(server.erlangRuntimeCache) != 1 { + t.Fatalf("expected exactly 1 Erlang runtime cache, got %d", len(server.erlangRuntimeCache)) + } +} + func TestDefinition_ModuleKeyword(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -3026,6 +3385,149 @@ func TestFormatter_RestartAfterCrash(t *testing.T) { } } +func TestDidSave_FormatterConfigRestartDoesNotHoldBeamMuWhileClosing(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + configPath := filepath.Join(server.projectRoot, ".formatter.exs") + if err := os.WriteFile(configPath, []byte("[]"), 0644); err != nil { + t.Fatal(err) + } + + cmd := exec.Command("sleep", "30") + stdin, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + cmdDone := make(chan struct{}) + go func() { + _ = cmd.Wait() + close(cmdDone) + }() + defer func() { + _ = cmd.Process.Kill() + <-cmdDone + }() + + bp := &beamProcess{ + cmd: &commandHandle{process: cmd.Process, done: cmdDone}, + stdin: stdin, + ready: make(chan struct{}), + closed: make(chan struct{}), + } + bp.writeMu.Lock() + writeLocked := true + defer func() { + if writeLocked { + bp.writeMu.Unlock() + } + }() + + buildRoot := server.findBuildRoot(filepath.Dir(configPath)) + server.beams = map[string]*beamProcess{buildRoot: bp} + + done := make(chan error, 1) + go func() { + done <- server.DidSave(context.Background(), &protocol.DidSaveTextDocumentParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(uri.File(configPath))}, + }) + }() + + waitFor(t, func() bool { + if !server.beamMu.TryLock() { + return false + } + defer server.beamMu.Unlock() + _, ok := server.beams[buildRoot] + return !ok + }) + + select { + case err := <-done: + t.Fatalf("DidSave returned before Close was unblocked: %v", err) + default: + } + + if !server.beamMu.TryLock() { + t.Fatal("expected beamMu to be released while Close was blocked") + } + server.beamMu.Unlock() + + bp.writeMu.Unlock() + writeLocked = false + + select { + case err := <-done: + if err != nil { + t.Fatal(err) + } + case <-time.After(2 * time.Second): + t.Fatal("DidSave did not finish after Close was unblocked") + } +} + +func TestDidSave_FormatterConfigRestartPicksUpUpdatedConfig(t *testing.T) { + if _, err := exec.LookPath("mix"); err != nil { + t.Skip("mix not available in PATH") + } + + server, cleanup := setupTestServer(t) + defer cleanup() + + if err := os.WriteFile(filepath.Join(server.projectRoot, "mix.exs"), []byte(""), 0644); err != nil { + t.Fatal(err) + } + + configPath := filepath.Join(server.projectRoot, ".formatter.exs") + if err := os.WriteFile(configPath, []byte("[]"), 0644); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(server.projectRoot, "lib", "test.ex") + if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { + t.Fatal(err) + } + docURI := string(uri.File(filePath)) + input := "defmodule Test do\n def hello, do: :world\nend\n" + server.docs.Set(docURI, input) + + edits, err := server.Formatting(context.Background(), &protocol.DocumentFormattingParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(docURI)}, + }) + if err != nil { + t.Fatal(err) + } + if edits != nil { + t.Fatalf("expected initial formatting to match input, got %#v", edits) + } + + if err := os.WriteFile(configPath, []byte("[force_do_end_blocks: true]\n"), 0644); err != nil { + t.Fatal(err) + } + if err := server.DidSave(context.Background(), &protocol.DidSaveTextDocumentParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(uri.File(configPath))}, + }); err != nil { + t.Fatal(err) + } + + server.docs.Set(docURI, input) + edits, err = server.Formatting(context.Background(), &protocol.DocumentFormattingParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(docURI)}, + }) + if err != nil { + t.Fatal(err) + } + if edits == nil { + t.Fatal("expected formatting edits after .formatter.exs change") + } + if !strings.Contains(edits[0].NewText, "def hello do") { + t.Fatalf("expected updated formatter config to force do/end blocks, got:\n%s", edits[0].NewText) + } +} + func TestFormatter_WillSaveWaitUntil(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() From 8a98b184f448393f61420567697bf2fcf746bbaf Mon Sep 17 00:00:00 2001 From: Jesse Herrick Date: Fri, 1 May 2026 13:39:24 -0600 Subject: [PATCH 7/7] Improve LSP token reuse and BEAM formatter robustness Reuse cached token streams for completion and common Elixir source queries, including token-aware completion contexts that skip strings and comments. Make the persistent BEAM formatter safer by synchronizing stderr capture and restarting when formatter config changes outside DidSave. --- internal/lsp/elixir.go | 259 +++++++++++++++++++++++++++------ internal/lsp/elixir_test.go | 110 ++++++++++++++ internal/lsp/formatter.go | 120 ++++++++++++--- internal/lsp/formatter_test.go | 2 +- internal/lsp/server.go | 124 ++++++++++------ internal/lsp/server_test.go | 90 ++++++++++++ 6 files changed, 593 insertions(+), 112 deletions(-) diff --git a/internal/lsp/elixir.go b/internal/lsp/elixir.go index cb9005b..783585e 100644 --- a/internal/lsp/elixir.go +++ b/internal/lsp/elixir.go @@ -122,19 +122,178 @@ func (tf *TokenizedFile) ExtractAliasesInScope(targetLine int) map[string]string return extractAliasesFromTokens(tf.source, tf.tokens, targetLine) } +// ExtractAliases parses all alias declarations from the tokenized file. +func (tf *TokenizedFile) ExtractAliases() map[string]string { + return extractAliasesFromTokens(tf.source, tf.tokens, -1) +} + // ExtractImports returns all import declarations from the tokenized file. func (tf *TokenizedFile) ExtractImports() []string { - var imports []string - for i := 0; i < tf.n; i++ { - if tf.tokens[i].Kind == parser.TokImport { - j := tokNextSig(tf.tokens, tf.n, i+1) - mod, _ := tokCollectModuleName(tf.source, tf.tokens, tf.n, j) - if mod != "" { - imports = append(imports, mod) + return extractImportsFromTokens(tf.source, tf.tokens) +} + +// ExtractUses returns module names from all `use Module` declarations. +func (tf *TokenizedFile) ExtractUses() []string { + return extractUsesFromTokens(tf.source, tf.tokens) +} + +// ExtractUsesWithOpts parses all `use Module` declarations with keyword opts. +func (tf *TokenizedFile) ExtractUsesWithOpts(aliases map[string]string) []UseCall { + return extractUsesWithOptsFromTokens(tf.source, tf.tokens, aliases) +} + +// FindBufferFunctions scans the tokenized file for all function and type definitions. +func (tf *TokenizedFile) FindBufferFunctions() []BufferFunction { + return findBufferFunctionsFromTokens(tf.source, tf.tokens) +} + +// ExtractAliasBlockParent detects whether targetLine is inside a multi-line alias block. +func (tf *TokenizedFile) ExtractAliasBlockParent(targetLine int) (string, bool) { + return extractAliasBlockParentFromTokens(tf.source, tf.tokens, targetLine) +} + +// CompletionContext describes the token-aware completion prefix at the cursor. +type CompletionContext struct { + Prefix string + AfterDot bool + StartCol int +} + +// Empty returns true if no completion should be offered at the cursor. +func (c CompletionContext) Empty() bool { + return c.Prefix == "" && !c.AfterDot +} + +// CompletionContextAtCursor extracts the completion prefix at the given 0-based +// line/column using the cached token stream. Unlike ExtractCompletionContext, +// this ignores strings/comments/heredocs and treats `::` distinctly from `:atom`. +func (tf *TokenizedFile) CompletionContextAtCursor(line, col int) CompletionContext { + return CompletionContextAtCursor(tf.tokens, tf.source, tf.lineStarts, line, col) +} + +// CompletionContextAtCursor extracts the token-aware completion context at the +// given 0-based line/column. +func CompletionContextAtCursor(tokens []parser.Token, source []byte, lineStarts []int, line, col int) CompletionContext { + if line < 0 || line >= len(lineStarts) || col <= 0 { + return CompletionContext{} + } + + lineStart := lineStarts[line] + lineEnd := len(source) + if line+1 < len(lineStarts) { + lineEnd = lineStarts[line+1] - 1 // exclude the newline byte + } + maxCol := lineEnd - lineStart + if maxCol < 0 { + maxCol = 0 + } + if col > maxCol { + col = maxCol + } + + offset := parser.LineColToOffset(lineStarts, line, col) + if offset <= lineStart { + return CompletionContext{} + } + + idx := parser.TokenAtOffset(tokens, offset-1) + if idx < 0 { + return CompletionContext{} + } + + tok := tokens[idx] + if tok.Kind == parser.TokDot { + exprIdx := idx - 1 + if exprIdx < 0 || !isCompletionSegmentToken(tokens[exprIdx].Kind) { + return CompletionContext{} + } + startIdx := completionChainStart(tokens, exprIdx) + prefix := buildCompletionPrefix(source, tokens, startIdx, exprIdx, tok.Start) + if prefix == "" { + return CompletionContext{} + } + return CompletionContext{ + Prefix: prefix, + AfterDot: true, + StartCol: tokens[startIdx].Start - lineStart, + } + } + + if !isCompletionSegmentToken(tok.Kind) { + return CompletionContext{} + } + + startIdx := completionChainStart(tokens, idx) + prefix := buildCompletionPrefix(source, tokens, startIdx, idx, offset) + if prefix == "" { + return CompletionContext{} + } + return CompletionContext{ + Prefix: prefix, + AfterDot: false, + StartCol: tokens[startIdx].Start - lineStart, + } +} + +func completionChainStart(tokens []parser.Token, idx int) int { + startIdx := idx + for startIdx >= 2 { + dotIdx := startIdx - 1 + prevIdx := startIdx - 2 + if tokens[dotIdx].Kind == parser.TokDot && isCompletionModuleToken(tokens[prevIdx].Kind) { + startIdx = prevIdx + continue + } + break + } + return startIdx +} + +func buildCompletionPrefix(source []byte, tokens []parser.Token, startIdx, endIdx, endOffset int) string { + var b strings.Builder + for i := startIdx; i <= endIdx; i++ { + tok := tokens[i] + switch tok.Kind { + case parser.TokDot: + b.WriteByte('.') + default: + if !isCompletionSegmentToken(tok.Kind) { + return "" } + end := tok.End + if i == endIdx && endOffset < end { + end = endOffset + } + if end <= tok.Start { + return "" + } + b.Write(source[tok.Start:end]) } } - return imports + return b.String() +} + +func isCompletionModuleToken(k parser.TokenKind) bool { + return k == parser.TokModule || k == parser.TokAtom +} + +func isCompletionFunctionToken(k parser.TokenKind) bool { + switch k { + case parser.TokIdent, + parser.TokDefmodule, parser.TokDefprotocol, parser.TokDefimpl, + parser.TokDefstruct, parser.TokDefexception, parser.TokDefdelegate, + parser.TokDefmacro, parser.TokDefmacrop, parser.TokDefguard, + parser.TokDefguardp, parser.TokDefp, parser.TokDef, + parser.TokAlias, parser.TokImport, parser.TokUse, parser.TokRequire, + parser.TokDo, parser.TokEnd, parser.TokFn, parser.TokWhen: + return true + default: + return false + } +} + +func isCompletionSegmentToken(k parser.TokenKind) bool { + return isCompletionModuleToken(k) || isCompletionFunctionToken(k) } func isExprChar(b byte) bool { @@ -416,16 +575,20 @@ func ExtractAliasBlockParent(lines []string, targetLine int) (string, bool) { return "", false } - // Use tokenizer for accurate parsing source := []byte(strings.Join(lines, "\n")) - tokens := parser.Tokenize(source) + return extractAliasBlockParentFromTokens(source, parser.Tokenize(source), targetLine) +} + +func extractAliasBlockParentFromTokens(source []byte, tokens []parser.Token, targetLine int) (string, bool) { n := len(tokens) + if targetLine < 0 || n == 0 { + return "", false + } - // targetLine is 0-based; token.Line is 1-based targetLine1 := targetLine + 1 // Find the token position for the target line - targetIdx := 0 + targetIdx := n - 1 for i, tok := range tokens { if tok.Line >= targetLine1 { targetIdx = i @@ -693,9 +856,11 @@ type BufferFunction struct { // Private types (@typep) are included since they are accessible within the same file. func FindBufferFunctions(text string) []BufferFunction { source := []byte(text) - tokens := parser.Tokenize(source) - n := len(tokens) + return findBufferFunctionsFromTokens(source, parser.Tokenize(source)) +} +func findBufferFunctionsFromTokens(source []byte, tokens []parser.Token) []BufferFunction { + n := len(tokens) seen := make(map[string]bool) var results []BufferFunction @@ -723,15 +888,16 @@ func FindBufferFunctions(text string) []BufferFunction { minArity := maxArity - defaultCount for arity := minArity; arity <= maxArity; arity++ { key := name + "/" + strconv.Itoa(arity) - if !seen[key] { - seen[key] = true - results = append(results, BufferFunction{ - Name: name, - Arity: arity, - Kind: kind, - Params: parser.JoinParams(paramNames, arity), - }) + if seen[key] { + continue } + seen[key] = true + results = append(results, BufferFunction{ + Name: name, + Arity: arity, + Kind: kind, + Params: parser.JoinParams(paramNames, arity), + }) } case parser.TokAttrType: @@ -754,10 +920,11 @@ func FindBufferFunctions(text string) []BufferFunction { arity, _, _, _ = parser.CollectParams(source, tokens, n, pj) } key := name + "/" + strconv.Itoa(arity) - if !seen[key] { - seen[key] = true - results = append(results, BufferFunction{Name: name, Arity: arity, Kind: kind}) + if seen[key] { + continue } + seen[key] = true + results = append(results, BufferFunction{Name: name, Arity: arity, Kind: kind}) } } return results @@ -938,16 +1105,20 @@ var ( // Returns a slice of full module names. func ExtractImports(text string) []string { source := []byte(text) - tokens := parser.Tokenize(source) + return extractImportsFromTokens(source, parser.Tokenize(source)) +} + +func extractImportsFromTokens(source []byte, tokens []parser.Token) []string { n := len(tokens) var imports []string for i := 0; i < n; i++ { - if tokens[i].Kind == parser.TokImport { - j := tokNextSig(tokens, n, i+1) - mod, _ := tokCollectModuleName(source, tokens, n, j) - if mod != "" { - imports = append(imports, mod) - } + if tokens[i].Kind != parser.TokImport { + continue + } + j := tokNextSig(tokens, n, i+1) + mod, _ := tokCollectModuleName(source, tokens, n, j) + if mod != "" { + imports = append(imports, mod) } } return imports @@ -1132,16 +1303,20 @@ func parseHelperQuoteBlock(lines []string, helperName string, fileAliases map[st // ExtractUses returns module names from all `use Module` declarations. func ExtractUses(text string) []string { source := []byte(text) - tokens := parser.Tokenize(source) + return extractUsesFromTokens(source, parser.Tokenize(source)) +} + +func extractUsesFromTokens(source []byte, tokens []parser.Token) []string { n := len(tokens) var uses []string for i := 0; i < n; i++ { - if tokens[i].Kind == parser.TokUse { - j := tokNextSig(tokens, n, i+1) - mod, _ := tokCollectModuleName(source, tokens, n, j) - if mod != "" { - uses = append(uses, mod) - } + if tokens[i].Kind != parser.TokUse { + continue + } + j := tokNextSig(tokens, n, i+1) + mod, _ := tokCollectModuleName(source, tokens, n, j) + if mod != "" { + uses = append(uses, mod) } } return uses @@ -1158,7 +1333,10 @@ type UseCall struct { // provided map. Handles opts spanning multiple lines via the tokenizer. func ExtractUsesWithOpts(text string, aliases map[string]string) []UseCall { source := []byte(text) - tokens := parser.Tokenize(source) + return extractUsesWithOptsFromTokens(source, parser.Tokenize(source), aliases) +} + +func extractUsesWithOptsFromTokens(source []byte, tokens []parser.Token, aliases map[string]string) []UseCall { n := len(tokens) var calls []UseCall @@ -1173,7 +1351,6 @@ func ExtractUsesWithOpts(text string, aliases map[string]string) []UseCall { } module := parser.ResolveModuleRef(modName, aliases, "") - // Check for comma after module name → keyword opts follow nk := tokNextSig(tokens, n, k) if nk < n && tokens[nk].Kind == parser.TokComma { opts := tokCollectKeywordModuleOpts(source, tokens, n, nk+1, aliases) diff --git a/internal/lsp/elixir_test.go b/internal/lsp/elixir_test.go index b0d05d8..1f33c3d 100644 --- a/internal/lsp/elixir_test.go +++ b/internal/lsp/elixir_test.go @@ -280,6 +280,116 @@ func TestExpressionAtCursor(t *testing.T) { } } +func TestCompletionContextAtCursor(t *testing.T) { + tests := []struct { + name string + code string + line int + col int + wantPrefix string + wantAfterDot bool + wantStartCol int + }{ + { + name: "module prefix", + code: " MyApp.Han", + line: 0, + col: 11, + wantPrefix: "MyApp.Han", + wantAfterDot: false, + wantStartCol: 2, + }, + { + name: "after dot", + code: " Foo.", + line: 0, + col: 6, + wantPrefix: "Foo", + wantAfterDot: true, + wantStartCol: 2, + }, + { + name: "function prefix after dot", + code: " Foo.ba", + line: 0, + col: 8, + wantPrefix: "Foo.ba", + wantAfterDot: false, + wantStartCol: 2, + }, + { + name: "mid-word cursor truncates current token", + code: " Enum.map_reduce", + line: 0, + col: 10, + wantPrefix: "Enum.map", + wantAfterDot: false, + wantStartCol: 2, + }, + { + name: "erlang module prefix", + code: " :lis", + line: 0, + col: 6, + wantPrefix: ":lis", + wantAfterDot: false, + wantStartCol: 2, + }, + { + name: "double colon does not create atom prefix", + code: " value::foo", + line: 0, + col: 12, + wantPrefix: "foo", + wantAfterDot: false, + wantStartCol: 9, + }, + { + name: "string is ignored", + code: ` "MyApp.Acc"`, + line: 0, + col: 12, + wantPrefix: "", + wantAfterDot: false, + wantStartCol: 0, + }, + { + name: "comment is ignored", + code: " # MyApp.Acc", + line: 0, + col: 13, + wantPrefix: "", + wantAfterDot: false, + wantStartCol: 0, + }, + { + name: "heredoc is ignored", + code: " \"\"\"\n MyApp.Acc\n \"\"\"", + line: 1, + col: 11, + wantPrefix: "", + wantAfterDot: false, + wantStartCol: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokens, source, lineStarts := tokenize(tt.code) + ctx := CompletionContextAtCursor(tokens, source, lineStarts, tt.line, tt.col) + if ctx.Prefix != tt.wantPrefix { + t.Errorf("Prefix = %q, want %q", ctx.Prefix, tt.wantPrefix) + } + if ctx.AfterDot != tt.wantAfterDot { + t.Errorf("AfterDot = %v, want %v", ctx.AfterDot, tt.wantAfterDot) + } + if ctx.StartCol != tt.wantStartCol { + t.Errorf("StartCol = %d, want %d", ctx.StartCol, tt.wantStartCol) + } + }) + } +} + func TestFullExpressionAtCursor(t *testing.T) { code := " Foo.Bar.baz(123)" tokens, source, lineStarts := tokenize(code) diff --git a/internal/lsp/formatter.go b/internal/lsp/formatter.go index 9e523e9..550e250 100644 --- a/internal/lsp/formatter.go +++ b/internal/lsp/formatter.go @@ -56,20 +56,22 @@ const ( ) type beamProcess struct { - cmd *commandHandle - stdin io.WriteCloser - stdout io.ReadCloser - stderr *bytes.Buffer // rolling stderr capture for crash diagnostics - writeMu sync.Mutex - pendingMu sync.Mutex - pending map[uint32]chan beamResponse - nextReqID uint32 - startedAt time.Time // when the process was launched - ready chan struct{} // closed when the BEAM has sent the ready signal - startErr error // non-nil if startup failed; set before ready is closed - startOnce sync.Once - closed chan struct{} // closed by Close(); makes alive() return false immediately - notify func(beamNotification) + cmd *commandHandle + stdin io.WriteCloser + stdout io.ReadCloser + stderr *stderrCapture // rolling stderr capture for crash diagnostics + writeMu sync.Mutex + pendingMu sync.Mutex + pending map[uint32]chan beamResponse + nextReqID uint32 + formatterConfigMu sync.Mutex + formatterConfigStamps map[string]fileStamp + startedAt time.Time // when the process was launched + ready chan struct{} // closed when the BEAM has sent the ready signal + startErr error // non-nil if startup failed; set before ready is closed + startOnce sync.Once + closed chan struct{} // closed by Close(); makes alive() return false immediately + notify func(beamNotification) } type beamResponse struct { @@ -83,6 +85,58 @@ type beamNotification struct { payload []byte } +type stderrCapture struct { + mu sync.Mutex + buf []byte + max int +} + +func newStderrCapture() *stderrCapture { + return &stderrCapture{max: 4096} +} + +func (c *stderrCapture) Write(p []byte) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.buf = append(c.buf, p...) + if len(c.buf) > c.max { + c.buf = append([]byte(nil), c.buf[len(c.buf)-c.max:]...) + } + return len(p), nil +} + +func (c *stderrCapture) String() string { + c.mu.Lock() + defer c.mu.Unlock() + return string(append([]byte(nil), c.buf...)) +} + +func (c *stderrCapture) Tail(n int) string { + c.mu.Lock() + defer c.mu.Unlock() + + b := c.buf + if len(b) > n { + b = b[len(b)-n:] + } + return string(append([]byte(nil), b...)) +} + +type fileStamp struct { + exists bool + mtime int64 + size int64 +} + +func statFileStamp(path string) fileStamp { + info, err := os.Stat(path) + if err != nil { + return fileStamp{} + } + return fileStamp{exists: true, mtime: info.ModTime().UnixNano(), size: info.Size()} +} + // commandHandle wraps the process so we can check liveness. type commandHandle struct { process *os.Process @@ -106,11 +160,7 @@ func (bp *beamProcess) recentStderr() string { if bp.stderr == nil { return "" } - b := bp.stderr.Bytes() - if len(b) > 512 { - b = b[len(b)-512:] - } - return strings.TrimSpace(string(b)) + return strings.TrimSpace(bp.stderr.Tail(512)) } // wrapError annotates a read/write error with recent stderr output if the @@ -127,6 +177,23 @@ func (bp *beamProcess) wrapError(err error) error { return err } +func (bp *beamProcess) formatterConfigChanged(formatterExs string) bool { + stamp := statFileStamp(formatterExs) + + bp.formatterConfigMu.Lock() + defer bp.formatterConfigMu.Unlock() + + if bp.formatterConfigStamps == nil { + bp.formatterConfigStamps = make(map[string]fileStamp) + } + prev, ok := bp.formatterConfigStamps[formatterExs] + if ok && prev != stamp { + return true + } + bp.formatterConfigStamps[formatterExs] = stamp + return false +} + // Ready blocks until the process has finished startup. Returns startErr if // the BEAM failed to initialize, or ctx.Err() if the caller gives up first. func (bp *beamProcess) Ready(ctx context.Context) error { @@ -651,8 +718,8 @@ func (s *Server) startBeamProcess(buildRoot string) (*beamProcess, error) { if err != nil { return nil, err } - var stderrBuf bytes.Buffer - cmd.Stderr = io.MultiWriter(os.Stderr, &stderrBuf) + stderrBuf := newStderrCapture() + cmd.Stderr = io.MultiWriter(os.Stderr, stderrBuf) if err := cmd.Start(); err != nil { return nil, fmt.Errorf("start BEAM: %w", err) @@ -670,7 +737,7 @@ func (s *Server) startBeamProcess(buildRoot string) (*beamProcess, error) { cmd: handle, stdin: stdin, stdout: stdout, - stderr: &stderrBuf, + stderr: stderrBuf, pending: make(map[uint32]chan beamResponse), startedAt: time.Now(), ready: make(chan struct{}), @@ -787,6 +854,15 @@ func (s *Server) formatContent(ctx context.Context, mixRoot, path, content strin log.Printf("Formatting: BEAM process unavailable, falling back to mix format") return s.formatWithMixFormat(ctx, mixRoot, path, content) } + if bp.formatterConfigChanged(formatterExs) { + s.evictBeam(bp, fmt.Sprintf("formatter config changed: %s", formatterExs)) + bp = s.getBeamProcess(ctx, buildRoot) + if bp == nil { + log.Printf("Formatting: BEAM process unavailable after formatter config change, falling back to mix format") + return s.formatWithMixFormat(ctx, mixRoot, path, content) + } + _ = bp.formatterConfigChanged(formatterExs) + } // Check if already ready (non-blocking) select { diff --git a/internal/lsp/formatter_test.go b/internal/lsp/formatter_test.go index 05ebd8f..fb8d873 100644 --- a/internal/lsp/formatter_test.go +++ b/internal/lsp/formatter_test.go @@ -72,7 +72,7 @@ func newTestBeamProcess(stdin io.WriteCloser, stdout io.ReadCloser, notify func( }, stdin: stdin, stdout: stdout, - stderr: &bytes.Buffer{}, + stderr: newStderrCapture(), pending: make(map[uint32]chan beamResponse), ready: make(chan struct{}), closed: make(chan struct{}), diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 2f00296..4f1d83e 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -488,6 +488,21 @@ func (s *Server) DidClose(ctx context.Context, params *protocol.DidCloseTextDocu return nil } +func (s *Server) restartBeamForFormatterConfig(path string) { + buildRoot := s.findBuildRoot(filepath.Dir(path)) + var bp *beamProcess + s.beamMu.Lock() + if existing, ok := s.beams[buildRoot]; ok { + delete(s.beams, buildRoot) + bp = existing + } + s.beamMu.Unlock() + if bp != nil { + bp.Close() + log.Printf("BEAM: restarting for %s (.formatter.exs changed)", buildRoot) + } +} + func (s *Server) DidSave(ctx context.Context, params *protocol.DidSaveTextDocumentParams) error { path := uriToPath(params.TextDocument.URI) if path == "" { @@ -497,18 +512,7 @@ func (s *Server) DidSave(ctx context.Context, params *protocol.DidSaveTextDocume // Restart the BEAM process when .formatter.exs changes so the new // config is picked up on the next format request. if filepath.Base(path) == ".formatter.exs" { - buildRoot := s.findBuildRoot(filepath.Dir(path)) - var bp *beamProcess - s.beamMu.Lock() - if existing, ok := s.beams[buildRoot]; ok { - delete(s.beams, buildRoot) - bp = existing - } - s.beamMu.Unlock() - if bp != nil { - bp.Close() - log.Printf("BEAM: restarting for %s (.formatter.exs changed)", buildRoot) - } + s.restartBeamForFormatterConfig(path) return nil } @@ -608,13 +612,13 @@ func (s *Server) Definition(ctx context.Context, params *protocol.DefinitionPara moduleRef, functionName := ExtractModuleAndFunction(expr) if moduleRef != "" { - if aliasParent, inBlock := ExtractAliasBlockParent(lines, lineNum); inBlock { + if aliasParent, inBlock := tf.ExtractAliasBlockParent(lineNum); inBlock { moduleRef = aliasParent + "." + moduleRef } } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) s.debugf("Definition: expr=%q module=%q function=%q", expr, moduleRef, functionName) // Bare identifier — check variable first (cheap tree-sitter lookup), then functions @@ -636,7 +640,7 @@ func (s *Server) Definition(ctx context.Context, params *protocol.DefinitionPara } currentModule := tf.FirstDefmodule() - fullModule := s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, lines, lineNum, functionName, aliases) + fullModule := s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, tf, lineNum, functionName, aliases) s.debugf("Definition: resolved bare %q -> %q", functionName, fullModule) if fullModule == "" { s.debugf("Definition: could not resolve bare function %q", functionName) @@ -1256,7 +1260,7 @@ func (s *Server) CodeAction(ctx context.Context, params *protocol.CodeActionPara } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) // Check if the first segment is already aliased — if so, the reference // already resolves and no code action is needed. @@ -1411,10 +1415,16 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara return nil, nil } - prefix, afterDot, prefixStartCol := ExtractCompletionContext(lines[lineNum], col) + tf := s.docs.GetTokenizedFile(docURI) + if tf == nil { + tf = NewTokenizedFile(text) + } + + completionCtx := tf.CompletionContextAtCursor(lineNum, col) + prefix, afterDot, prefixStartCol := completionCtx.Prefix, completionCtx.AfterDot, completionCtx.StartCol // Inside a multi-line alias block: complete child module segments under the parent. - if aliasParent, inBlock := ExtractAliasBlockParent(lines, lineNum); inBlock { + if aliasParent, inBlock := tf.ExtractAliasBlockParent(lineNum); inBlock { searchParent := aliasParent segmentPrefix := prefix labelPrefix := "" @@ -1536,8 +1546,8 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara var items []protocol.CompletionItem if moduleRef != "" && (afterDot || funcPrefix != "") { - aliases := ExtractAliases(text) - s.mergeAliasesFromUse(text, aliases) + aliases := tf.ExtractAliases() + s.mergeAliasesFromUseTokenized(tf, aliases) resolved := resolveModule(moduleRef, aliases) results, err := s.store.ListModuleFunctions(resolved, true) if err != nil { @@ -1573,8 +1583,8 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara } } } else if moduleRef != "" { - aliases := ExtractAliases(text) - s.mergeAliasesFromUse(text, aliases) + aliases := tf.ExtractAliases() + s.mergeAliasesFromUseTokenized(tf, aliases) seenModules := make(map[string]bool) addModuleItem := func(label, detail string) { @@ -1635,7 +1645,7 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara } else if funcPrefix != "" { seen := make(map[string]bool) - for _, bf := range FindBufferFunctions(text) { + for _, bf := range tf.FindBufferFunctions() { key := funcKey(bf.Name, bf.Arity) if strings.HasPrefix(bf.Name, funcPrefix) && !seen[key] { seen[key] = true @@ -1649,7 +1659,7 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara } } - imports := ExtractImports(text) + imports := tf.ExtractImports() imports = append(imports, "Kernel") for _, mod := range imports { results, err := s.store.ListModuleFunctions(mod, true) @@ -1679,10 +1689,10 @@ func (s *Server) Completion(ctx context.Context, params *protocol.CompletionPara } // Check use-injected imports and inline defs (including transitive use chains) - aliases := ExtractAliases(text) - s.mergeAliasesFromUse(text, aliases) + aliases := tf.ExtractAliases() + s.mergeAliasesFromUseTokenized(tf, aliases) visitedCompletion := make(map[string]bool) - for _, usedModule := range ExtractUses(text) { + for _, usedModule := range tf.ExtractUses() { s.addCompletionsFromUsing(resolveModule(usedModule, aliases), funcPrefix, seen, &items, visitedCompletion, inPipe, s.snippetSupport) } @@ -2119,15 +2129,19 @@ func (s *Server) addCompletionsFromUsing(moduleName, funcPrefix string, seen map // resolveBareFunctionModule finds the module that defines a bare function name. // Mirrors the go-to-definition priority: current file modules → imports → use chains → Kernel. // Callers should pass pre-computed aliases to avoid redundant ExtractAliases scans. -func (s *Server) resolveBareFunctionModule(filePath, text string, lines []string, lineNum int, functionName string, aliases map[string]string) string { +func (s *Server) resolveBareFunctionModule(filePath, text string, tf *TokenizedFile, lineNum int, functionName string, aliases map[string]string) string { // Check all modules in the current file with a single query, preferring // the one closest to the cursor line (handles sibling nested modules). if mod, ok := s.store.LookupFunctionInFile(filePath, functionName, lineNum+1); ok { return mod } + if tf == nil { + tf = NewTokenizedFile(text) + } + // Explicit imports (direct definitions only — fast store lookup) - imports := ExtractImports(text) + imports := tf.ExtractImports() for _, mod := range imports { if results, err := s.store.LookupFunction(mod, functionName); err == nil && len(results) > 0 { return mod @@ -2136,7 +2150,7 @@ func (s *Server) resolveBareFunctionModule(filePath, text string, lines []string // Use chains — use opts-aware resolution so `import unquote(mod)` patterns // resolve to the consumer-provided module rather than always using the default. - for _, uc := range ExtractUsesWithOpts(text, aliases) { + for _, uc := range tf.ExtractUsesWithOpts(aliases) { if mod := s.resolveModuleViaUseChainWithOpts(uc.Module, functionName, uc.Opts, map[string]bool{}); mod != "" { return mod } @@ -2174,7 +2188,17 @@ func resolveModule(moduleRef string, aliases map[string]string) string { // declarations in the file. For example, if the file has `use MyApp.Schema` // and MyApp.Schema.__using__ contains `alias MyApp.Repo`, then Repo is added. func (s *Server) mergeAliasesFromUse(text string, aliases map[string]string) { - useCalls := ExtractUsesWithOpts(text, aliases) + s.mergeAliasesFromUseCalls(ExtractUsesWithOpts(text, aliases), aliases) +} + +func (s *Server) mergeAliasesFromUseTokenized(tf *TokenizedFile, aliases map[string]string) { + if tf == nil { + return + } + s.mergeAliasesFromUseCalls(tf.ExtractUsesWithOpts(aliases), aliases) +} + +func (s *Server) mergeAliasesFromUseCalls(useCalls []UseCall, aliases map[string]string) { visited := make(map[string]bool) for _, uc := range useCalls { s.mergeAliasesFromUsingEntry(uc.Module, aliases, visited) @@ -2583,6 +2607,10 @@ func (s *Server) DidChangeWatchedFiles(ctx context.Context, params *protocol.Did if path == "" { continue } + if filepath.Base(path) == ".formatter.exs" { + s.restartBeamForFormatterConfig(path) + continue + } switch change.Type { case protocol.FileChangeTypeCreated, protocol.FileChangeTypeChanged: go func(filePath string) { @@ -3341,13 +3369,13 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (*prot // Inside a multi-line alias block like "alias MyModule.{ Something }", // prepend the parent so "Something" resolves to "MyModule.Something". if moduleRef != "" { - if aliasParent, inBlock := ExtractAliasBlockParent(lines, lineNum); inBlock { + if aliasParent, inBlock := tf.ExtractAliasBlockParent(lineNum); inBlock { moduleRef = aliasParent + "." + moduleRef } } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) if moduleRef == "" { if functionName == "" { @@ -3355,7 +3383,7 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (*prot } currentModule := tf.FirstDefmodule() - fullModule := s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, lines, lineNum, functionName, aliases) + fullModule := s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, tf, lineNum, functionName, aliases) if fullModule != "" { // Current module — hover from the buffer directly @@ -3531,7 +3559,7 @@ func (s *Server) PrepareRename(ctx context.Context, params *protocol.PrepareRena // Try module/function rename via the index if !exprCtx.Empty() { - aliases := ExtractAliasesInScope(text, lineNum) + aliases := tf.ExtractAliasesInScope(lineNum) // Detect `as:` aliases — these are file-local renames, not module renames. // An `as:` alias has a short name that differs from the last segment of @@ -3549,7 +3577,7 @@ func (s *Server) PrepareRename(ctx context.Context, params *protocol.PrepareRena } } - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) var tokenName string var fullModule string @@ -3560,7 +3588,7 @@ func (s *Server) PrepareRename(ctx context.Context, params *protocol.PrepareRena if moduleRef != "" { fullModule = resolveModule(moduleRef, aliases) } else { - fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, lines, lineNum, functionName, aliases) + fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, tf, lineNum, functionName, aliases) } found = fullModule != "" } else if moduleRef != "" { @@ -3686,13 +3714,13 @@ func (s *Server) References(ctx context.Context, params *protocol.ReferenceParam moduleRef, functionName := ExtractModuleAndFunction(expr) if moduleRef != "" { - if aliasParent, inBlock := ExtractAliasBlockParent(lines, lineNum); inBlock { + if aliasParent, inBlock := tf.ExtractAliasBlockParent(lineNum); inBlock { moduleRef = aliasParent + "." + moduleRef } } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) s.debugf("References: expr=%q module=%q function=%q", expr, moduleRef, functionName) var fullModule string @@ -3726,7 +3754,7 @@ func (s *Server) References(ctx context.Context, params *protocol.ReferenceParam } // Bare function — resolve to its defining module - fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, lines, lineNum, functionName, aliases) + fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, tf, lineNum, functionName, aliases) s.debugf("References: resolved bare %q -> %q", functionName, fullModule) if fullModule == "" { s.debugf("References: could not resolve bare function %q", functionName) @@ -3756,7 +3784,7 @@ func (s *Server) References(ctx context.Context, params *protocol.ReferenceParam // findModulesWhoseUsingImports scan entirely. var injectors []string if functionName != "" && moduleRef == "" { - useCalls := ExtractUsesWithOpts(text, aliases) + useCalls := tf.ExtractUsesWithOpts(aliases) visited := make(map[string]bool) for _, uc := range useCalls { if s.lookupInUsingEntry(uc.Module, functionName, uc.Opts, visited) != nil { @@ -3935,7 +3963,7 @@ func (s *Server) Rename(ctx context.Context, params *protocol.RenameParams) (*pr // Try module/function rename via the index if !renameCtx.Empty() { - aliases := ExtractAliasesInScope(text, lineNum) + aliases := tf.ExtractAliasesInScope(lineNum) // Detect `as:` aliases — file-local rename of the alias name, not // the underlying module. This check runs before merging use-injected @@ -3969,14 +3997,14 @@ func (s *Server) Rename(ctx context.Context, params *protocol.RenameParams) (*pr } } - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) if functionName != "" { var fullModule string if moduleRef != "" { fullModule = resolveModule(moduleRef, aliases) } else { - fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, lines, lineNum, functionName, aliases) + fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, tf, lineNum, functionName, aliases) } if fullModule != "" { if !isValidFunctionName(params.NewName) { @@ -4977,7 +5005,7 @@ func (s *Server) SignatureHelp(ctx context.Context, params *protocol.SignatureHe } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) lines := strings.Split(text, "\n") // Resolve the function to a store lookup result @@ -5142,7 +5170,7 @@ func (s *Server) TypeDefinition(ctx context.Context, params *protocol.TypeDefini } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) fullModule := s.resolveModuleWithNesting(typeCtx.ModuleRef, aliases, uriToPath(protocol.DocumentURI(docURI)), lineNum) results, err := s.store.LookupFunction(fullModule, typeName) @@ -5215,12 +5243,12 @@ func (s *Server) PrepareCallHierarchy(ctx context.Context, params *protocol.Call } aliases := tf.ExtractAliasesInScope(lineNum) - s.mergeAliasesFromUse(text, aliases) + s.mergeAliasesFromUseTokenized(tf, aliases) var fullModule string if callCtx.ModuleRef != "" { fullModule = resolveModule(callCtx.ModuleRef, aliases) } else { - fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, lines, lineNum, functionName, aliases) + fullModule = s.resolveBareFunctionModule(uriToPath(protocol.DocumentURI(docURI)), text, tf, lineNum, functionName, aliases) } if fullModule == "" { return nil, nil diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index da614e7..48eced8 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -1198,6 +1198,40 @@ func TestCompletion_NoResults(t *testing.T) { } } +func TestCompletion_IgnoresStringsAndComments(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + indexFile(t, server.store, server.projectRoot, "lib/accounts.ex", `defmodule MyApp.Accounts do + def create(attrs) do + :ok + end +end +`) + + uri := "file:///test.ex" + + t.Run("string", func(t *testing.T) { + line := ` "MyApp.Acc"` + server.docs.Set(uri, line) + + items := completionAt(t, server, uri, 0, uint32(len(line)-1)) + if len(items) != 0 { + t.Errorf("expected no completions inside string, got %d: %v", len(items), items) + } + }) + + t.Run("comment", func(t *testing.T) { + line := " # MyApp.Acc" + server.docs.Set(uri, line) + + items := completionAt(t, server, uri, 0, uint32(len(line))) + if len(items) != 0 { + t.Errorf("expected no completions inside comment, got %d: %v", len(items), items) + } + }) +} + func TestCompletion_FunctionResultDotNoResults(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup() @@ -3528,6 +3562,62 @@ func TestDidSave_FormatterConfigRestartPicksUpUpdatedConfig(t *testing.T) { } } +func TestFormatter_ExternalFormatterConfigChangePicksUpUpdatedConfig(t *testing.T) { + if _, err := exec.LookPath("mix"); err != nil { + t.Skip("mix not available in PATH") + } + + server, cleanup := setupTestServer(t) + defer cleanup() + defer server.closeBeams() + + if err := os.WriteFile(filepath.Join(server.projectRoot, "mix.exs"), []byte(""), 0644); err != nil { + t.Fatal(err) + } + + configPath := filepath.Join(server.projectRoot, ".formatter.exs") + if err := os.WriteFile(configPath, []byte("[]"), 0644); err != nil { + t.Fatal(err) + } + + filePath := filepath.Join(server.projectRoot, "lib", "test.ex") + if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { + t.Fatal(err) + } + docURI := string(uri.File(filePath)) + input := "defmodule Test do\n def hello, do: :world\nend\n" + server.docs.Set(docURI, input) + + edits, err := server.Formatting(context.Background(), &protocol.DocumentFormattingParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(docURI)}, + }) + if err != nil { + t.Fatal(err) + } + if edits != nil { + t.Fatalf("expected initial formatting to match input, got %#v", edits) + } + + time.Sleep(10 * time.Millisecond) + if err := os.WriteFile(configPath, []byte("[force_do_end_blocks: true]\n"), 0644); err != nil { + t.Fatal(err) + } + + server.docs.Set(docURI, input) + edits, err = server.Formatting(context.Background(), &protocol.DocumentFormattingParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: protocol.DocumentURI(docURI)}, + }) + if err != nil { + t.Fatal(err) + } + if edits == nil { + t.Fatal("expected formatting edits after external .formatter.exs change") + } + if !strings.Contains(edits[0].NewText, "def hello do") { + t.Fatalf("expected external formatter config change to force do/end blocks, got:\n%s", edits[0].NewText) + } +} + func TestFormatter_WillSaveWaitUntil(t *testing.T) { server, cleanup := setupTestServer(t) defer cleanup()