Skip to content

Commit 6eaae35

Browse files
authored
Merge pull request #502 from randomir/cli-tweaks
CLI tweaks
2 parents cc91a4d + 1995a11 commit 6eaae35

File tree

3 files changed

+198
-120
lines changed

3 files changed

+198
-120
lines changed

dwave/cloud/cli.py

Lines changed: 172 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import subprocess
2020
import pkg_resources
2121

22-
from functools import partial
22+
from collections.abc import Sequence
23+
from functools import partial, wraps
2324
from timeit import default_timer as timer
2425

2526
from typing import Dict, List
@@ -315,40 +316,32 @@ def _config_create(config_file, profile, ask_full=False):
315316
click.echo("Configuration saved.")
316317

317318

318-
def _ping(config_file, profile, endpoint, region, client_type, solver_def,
319-
sampling_params, request_timeout, polling_timeout, output):
320-
"""Helper method for the ping command that uses `output()` for info output
321-
and raises `CLIError()` on handled errors.
322-
323-
This function is invariant to output format and/or error signaling mechanism.
319+
def _get_client_solver(config, output=None):
320+
"""Helper function to return an instantiated client, and solver, validating
321+
parameters in the process, while wrapping errors in `CLIError` and using
322+
`output` writer as a centralized printer.
324323
"""
325-
params = {}
326-
if sampling_params is not None:
327-
try:
328-
params = json.loads(sampling_params)
329-
assert isinstance(params, dict)
330-
except:
331-
raise CLIError("sampling parameters required as JSON-encoded "
332-
"map of param names to values", code=99)
324+
if output is None:
325+
output = click.echo
333326

334-
config = dict(config_file=config_file, profile=profile,
335-
endpoint=endpoint, region=region,
336-
client=client_type, solver=solver_def)
337-
if request_timeout is not None:
338-
config.update(request_timeout=request_timeout)
339-
if polling_timeout is not None:
340-
config.update(polling_timeout=polling_timeout)
327+
# get client
341328
try:
342329
client = Client.from_config(**config)
343330
except Exception as e:
344331
raise CLIError("Invalid configuration: {}".format(e), code=1)
332+
333+
config_file = config.get('config_file')
345334
if config_file:
346335
output("Using configuration file: {config_file}", config_file=config_file)
336+
337+
profile = config.get('profile')
347338
if profile:
348339
output("Using profile: {profile}", profile=profile)
340+
349341
output("Using endpoint: {endpoint}", endpoint=client.endpoint)
342+
output("Using region: {region}", region=client.region)
350343

351-
t0 = timer()
344+
# get solver
352345
try:
353346
solver = client.get_solver()
354347
except SolverAuthenticationError:
@@ -372,38 +365,71 @@ def _ping(config_file, profile, endpoint, region, client_type, solver_def,
372365
except Exception as e:
373366
raise CLIError("Unexpected error while fetching solver: {!r}".format(e), 5)
374367

375-
if hasattr(solver, 'nodes'):
376-
# structured solver: use the first existing node
377-
problem = ({min(solver.nodes): 0}, {})
378-
else:
379-
# unstructured solver doesn't constrain problem graph
380-
problem = ({0: 1}, {})
381-
382-
t1 = timer()
383368
output("Using solver: {solver_id}", solver_id=solver.id)
384369

370+
return (client, solver)
371+
372+
373+
def _sample(solver, problem, params, output):
374+
"""Blocking sample call with error handling and using custom printer."""
375+
385376
try:
386-
future = solver.sample_ising(*problem, **params)
387-
timing = future.timing
377+
response = solver.sample_ising(*problem, **params)
378+
problem_id = response.wait_id()
379+
output("Submitted problem ID: {problem_id}", problem_id=problem_id)
380+
response.wait()
388381
except RequestTimeout:
389382
raise CLIError("API connection timed out.", 8)
390383
except PollingTimeout:
391384
raise CLIError("Polling timeout exceeded.", 9)
392385
except Exception as e:
393386
raise CLIError("Sampling error: {!r}".format(e), 10)
394-
output("Submitted problem ID: {problem_id}", problem_id=future.id)
395387

396-
t2 = timer()
397-
output("\nWall clock time:")
398-
output(" * Solver definition fetch: {wallclock_solver_definition:.3f} ms", wallclock_solver_definition=(t1-t0)*1000.0)
399-
output(" * Problem submit and results fetch: {wallclock_sampling:.3f} ms", wallclock_sampling=(t2-t1)*1000.0)
400-
output(" * Total: {wallclock_total:.3f} ms", wallclock_total=(t2-t0)*1000.0)
401-
if timing:
402-
output("\nQPU timing:")
403-
for component, duration in sorted(timing.items()):
404-
output(" * %(name)s = {%(name)s} us" % {"name": component}, **{component: duration})
405-
else:
406-
output("\nQPU timing data not available.")
388+
return response
389+
390+
391+
def standardized_output(fn):
392+
"""Decorator that captures `CLIError`s from `fn` and formats output.
393+
394+
The decorated function (cli command) receives `output()` for info output
395+
and should raise `CLIError()` (for handled errors) to output error messages.
396+
397+
The function itself can be invariant to output format and/or error signaling
398+
mechanism.
399+
"""
400+
401+
@wraps(fn)
402+
def wrapped(*args, **kwargs):
403+
# text/json output taken from callee args
404+
json_output = kwargs.get('json_output', False)
405+
406+
now = utcnow()
407+
info = dict(datetime=now.isoformat(), timestamp=datetime_to_timestamp(now), code=0)
408+
409+
def output(fmt, maxlen=None, **params):
410+
info.update(params)
411+
if not json_output:
412+
msg = fmt.format(**params)
413+
if maxlen is not None:
414+
msg = strtrunc(msg, maxlen)
415+
click.echo(msg)
416+
417+
def flush():
418+
if json_output:
419+
click.echo(json.dumps(info))
420+
421+
try:
422+
fn(*args, output=output, **kwargs)
423+
except CLIError as error:
424+
output("Error: {error} (code: {code})", error=str(error), code=error.code)
425+
sys.exit(error.code)
426+
except Exception as error:
427+
output("Unhandled error: {error}", error=str(error))
428+
sys.exit(127)
429+
finally:
430+
flush()
431+
432+
return wrapped
407433

408434

409435
@cli.command()
@@ -415,35 +441,59 @@ def _ping(config_file, profile, endpoint, region, client_type, solver_def,
415441
help='Connection and read timeouts (in seconds) for all API requests')
416442
@click.option('--polling-timeout', default=None, type=float,
417443
help='Problem polling timeout in seconds (time-to-solution timeout)')
444+
@click.option('--label', default='dwave ping', type=str, help='Problem label')
418445
@click.option('--json', 'json_output', default=False, is_flag=True,
419446
help='JSON output')
420-
def ping(config_file, profile, endpoint, region, client_type, solver_def,
421-
sampling_params, json_output, request_timeout, polling_timeout):
447+
@standardized_output
448+
def ping(*, config_file, profile, endpoint, region, client_type, solver_def,
449+
sampling_params, request_timeout, polling_timeout, label, json_output,
450+
output):
422451
"""Ping the QPU by submitting a single-qubit problem."""
423452

424-
now = utcnow()
425-
info = dict(datetime=now.isoformat(), timestamp=datetime_to_timestamp(now), code=0)
453+
# parse params (TODO: move to click validator)
454+
params = {}
455+
if sampling_params is not None:
456+
try:
457+
params = json.loads(sampling_params)
458+
assert isinstance(params, dict)
459+
except:
460+
raise CLIError("sampling parameters required as JSON-encoded "
461+
"map of param names to values", code=99)
426462

427-
def output(fmt, **kwargs):
428-
info.update(kwargs)
429-
if not json_output:
430-
click.echo(fmt.format(**kwargs))
463+
if label:
464+
params.update(label=label)
431465

432-
def flush():
433-
if json_output:
434-
click.echo(json.dumps(info))
466+
config = dict(
467+
config_file=config_file, profile=profile,
468+
endpoint=endpoint, region=region,
469+
client=client_type, solver=solver_def,
470+
request_timeout=request_timeout, polling_timeout=polling_timeout)
435471

436-
try:
437-
_ping(config_file, profile, endpoint, region, client_type, solver_def,
438-
sampling_params, request_timeout, polling_timeout, output)
439-
except CLIError as error:
440-
output("Error: {error} (code: {code})", error=str(error), code=error.code)
441-
sys.exit(error.code)
442-
except Exception as error:
443-
output("Unhandled error: {error}", error=str(error))
444-
sys.exit(127)
445-
finally:
446-
flush()
472+
t0 = timer()
473+
client, solver = _get_client_solver(config, output)
474+
475+
# generate problem
476+
if hasattr(solver, 'nodes'):
477+
# structured solver: use the first existing node
478+
problem = ({min(solver.nodes): 0}, {})
479+
else:
480+
# unstructured solver doesn't constrain problem graph
481+
problem = ({0: 1}, {})
482+
483+
t1 = timer()
484+
response = _sample(solver, problem, params, output)
485+
486+
t2 = timer()
487+
output("\nWall clock time:")
488+
output(" * Solver definition fetch: {wallclock_solver_definition:.3f} ms", wallclock_solver_definition=(t1-t0)*1000.0)
489+
output(" * Problem submit and results fetch: {wallclock_sampling:.3f} ms", wallclock_sampling=(t2-t1)*1000.0)
490+
output(" * Total: {wallclock_total:.3f} ms", wallclock_total=(t2-t0)*1000.0)
491+
if response.timing:
492+
output("\nQPU timing:")
493+
for component, duration in sorted(response.timing.items()):
494+
output(" * %(name)s = {%(name)s} us" % {"name": component}, **{component: duration})
495+
else:
496+
output("\nQPU timing data not available.")
447497

448498

449499
@cli.command()
@@ -508,76 +558,75 @@ def solvers(config_file, profile, endpoint, region, client_type, solver_def,
508558
help='List/dict of couplings for Ising model problem formulation')
509559
@click.option('--random-problem', '-r', default=False, is_flag=True,
510560
help='Submit a valid random problem using all qubits')
511-
@click.option('--num-reads', '-n', default=1, type=int,
561+
@click.option('--num-reads', '-n', default=None, type=int,
512562
help='Number of reads/samples')
563+
@click.option('--label', default='dwave sample', type=str, help='Problem label')
564+
@click.option('--sampling-params', '-m', default=None,
565+
help='Sampling parameters, JSON encoded')
513566
@click.option('--verbose', '-v', default=False, is_flag=True,
514567
help='Increase output verbosity')
515-
def sample(config_file, profile, endpoint, region, client_type, solver_def,
516-
biases, couplings, random_problem, num_reads, verbose):
568+
@click.option('--json', 'json_output', default=False, is_flag=True,
569+
help='JSON output')
570+
@standardized_output
571+
def sample(*, config_file, profile, endpoint, region, client_type, solver_def,
572+
biases, couplings, random_problem, num_reads, label, sampling_params,
573+
verbose, json_output, output):
517574
"""Submit Ising-formulated problem and return samples."""
518575

519-
# TODO: de-dup wrt ping
576+
# we'll limit max line len in non-verbose mode
577+
maxlen = None if verbose else 120
520578

521-
def echo(s, maxlen=100):
522-
click.echo(s if verbose else strtrunc(s, maxlen))
579+
# parse params (TODO: move to click validator)
580+
params = {}
581+
if sampling_params is not None:
582+
try:
583+
params = json.loads(sampling_params)
584+
assert isinstance(params, dict)
585+
except:
586+
raise CLIError("sampling parameters required as JSON-encoded "
587+
"map of param names to values", code=99)
523588

524-
try:
525-
client = Client.from_config(
526-
config_file=config_file, profile=profile,
527-
endpoint=endpoint, region=region,
528-
client=client_type, solver=solver_def)
529-
except Exception as e:
530-
click.echo("Invalid configuration: {}".format(e))
531-
return 1
532-
if config_file:
533-
echo("Using configuration file: {}".format(config_file))
534-
if profile:
535-
echo("Using profile: {}".format(profile))
536-
echo("Using endpoint: {}".format(client.endpoint))
589+
if num_reads is not None:
590+
params.update(num_reads=num_reads)
537591

538-
try:
539-
solver = client.get_solver()
540-
except SolverAuthenticationError:
541-
click.echo("Authentication error. Check credentials in your configuration file.")
542-
return 1
543-
except (InvalidAPIResponseError, UnsupportedSolverError):
544-
click.echo("Invalid or unexpected API response.")
545-
return 2
546-
except SolverNotFoundError:
547-
click.echo("Solver with the specified features does not exist.")
548-
return 3
592+
if label:
593+
params.update(label=label)
594+
595+
# TODO: add other params, like timeout?
596+
config = dict(
597+
config_file=config_file, profile=profile,
598+
endpoint=endpoint, region=region,
599+
client=client_type, solver=solver_def)
549600

550-
echo("Using solver: {}".format(solver.id))
601+
client, solver = _get_client_solver(config, output)
551602

552603
if random_problem:
553604
linear, quadratic = generate_random_ising_problem(solver)
554605
else:
555606
try:
556-
linear = ast.literal_eval(biases) if biases else []
607+
linear = ast.literal_eval(biases) if biases else {}
608+
if isinstance(linear, Sequence):
609+
linear = dict(enumerate(linear))
557610
except Exception as e:
558-
click.echo("Invalid biases: {}".format(e))
611+
raise CLIError(f"Invalid biases: {e}", code=99)
559612
try:
560613
quadratic = ast.literal_eval(couplings) if couplings else {}
561614
except Exception as e:
562-
click.echo("Invalid couplings: {}".format(e))
615+
raise CLIError(f"Invalid couplings: {e}", code=99)
563616

564-
echo("Using qubit biases: {!r}".format(linear))
565-
echo("Using qubit couplings: {!r}".format(quadratic))
566-
echo("Number of samples: {}".format(num_reads))
617+
output("Using qubit biases: {linear}", linear=list(linear.items()), maxlen=maxlen)
618+
output("Using qubit couplings: {quadratic}", quadratic=list(quadratic.items()), maxlen=maxlen)
619+
output("Sampling parameters: {sampling_params}", sampling_params=params)
567620

568-
try:
569-
result = solver.sample_ising(linear, quadratic, num_reads=num_reads)
570-
result.result()
571-
except Exception as e:
572-
click.echo(e)
573-
return 4
621+
response = _sample(
622+
solver, problem=(linear, quadratic), params=params, output=output)
574623

575624
if verbose:
576-
click.echo("Result: {!r}".format(result))
625+
output("Result: {response!r}", response=response.result())
577626

578-
echo("Samples: {!r}".format(result.samples))
579-
echo("Occurrences: {!r}".format(result.occurrences))
580-
echo("Energies: {!r}".format(result.energies))
627+
output("Samples: {samples!r}", samples=response.samples, maxlen=maxlen)
628+
output("Occurrences: {num_occurrences!r}", num_occurrences=response.num_occurrences, maxlen=maxlen)
629+
output("Energies: {energies!r}", energies=response.energies, maxlen=maxlen)
581630

582631

583632
@cli.command()
@@ -796,10 +845,16 @@ def _install_contrib_package(name, verbose=0, prompt=True):
796845
@click.option('--install-all', '--all', '-a', default=False, is_flag=True,
797846
help='Install all non-open-source packages '\
798847
'available and accept licenses without prompting')
848+
@click.option('--full', 'ask_full', default=False, is_flag=True,
849+
help='Configure non-essential options (such as endpoint and solver).')
799850
@click.option('--verbose', '-v', count=True,
800851
help='Increase output verbosity (additive, up to 4 times)')
801-
def setup(install_all, verbose):
802-
"""Setup optional Ocean packages and configuration file(s)."""
852+
def setup(install_all, ask_full, verbose):
853+
"""Setup optional Ocean packages and configuration file(s).
854+
855+
Equivalent to running `dwave install [--all]`, followed by
856+
`dwave config create [--full]`.
857+
"""
803858

804859
contrib = get_contrib_packages()
805860
packages = list(contrib)
@@ -824,4 +879,4 @@ def setup(install_all, verbose):
824879
_install_contrib_package(pkg, verbose=verbose, prompt=not install_all)
825880

826881
click.echo("Creating the D-Wave configuration file.")
827-
return _config_create(config_file=None, profile=None, ask_full=False)
882+
return _config_create(config_file=None, profile=None, ask_full=ask_full)

0 commit comments

Comments
 (0)