Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions pkg/cmd/gpucreate/gpucreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ type CreateResult struct {
// searchFilterFlags holds the search filter flag values for create
type searchFilterFlags struct {
gpuName string
region string
provider string
minVRAM float64
minTotalVRAM float64
Expand All @@ -141,7 +142,7 @@ type searchFilterFlags struct {

// hasUserFilters returns true if the user specified any search filter flags
func (f *searchFilterFlags) hasUserFilters() bool {
return f.gpuName != "" || f.provider != "" || f.minVRAM > 0 || f.minTotalVRAM > 0 ||
return f.gpuName != "" || f.region != "" || f.provider != "" || f.minVRAM > 0 || f.minTotalVRAM > 0 ||
f.minCapability > 0 || f.minDisk > 0 || f.maxBootTime > 0 ||
f.stoppable || f.rebootable || f.flexPorts
}
Expand Down Expand Up @@ -230,13 +231,26 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra
ComposeFile: composeFile,
LaunchableID: launchableID,
LaunchableInfo: launchableInfo,
Region: filters.region,
}

if filters.region != "" {
if err := validateRegionExists(filters.region, gpuCreateStore); err != nil {
return err
}
}

opts.InstanceTypes, err = resolveInstanceTypes(cmd, gpuCreateStore, opts, types, &filters)
if err != nil {
return err
}

if filters.region != "" {
if err := validateRegion(filters.region, opts.InstanceTypes, gpuCreateStore); err != nil {
return err
}
}

if dryRun {
return runDryRun(t, gpuCreateStore, opts.InstanceTypes, &filters)
}
Expand Down Expand Up @@ -278,6 +292,7 @@ func registerCreateFlags(cmd *cobra.Command, name, instanceTypes *string, count,
cmd.Flags().StringVar(containerImage, "container-image", "", "Container image URL (required for container mode)")
cmd.Flags().StringVar(composeFile, "compose-file", "", "Docker compose file path or URL (required for compose mode)")
cmd.Flags().StringVarP(launchable, "launchable", "l", "", "Launchable ID or URL to deploy (e.g., env-XXX or console URL)")
cmd.Flags().StringVarP(&filters.region, "region", "r", "", "Region/location to deploy the instance (e.g., us-east-1, us-central1)")

cmd.Flags().StringVarP(&filters.gpuName, "gpu-name", "g", "", "Filter by GPU name (e.g., A100, H100)")
cmd.Flags().StringVar(&filters.provider, "provider", "", "Filter by provider/cloud (e.g., aws, gcp)")
Expand Down Expand Up @@ -315,6 +330,7 @@ type GPUCreateOptions struct {
ComposeFile string
LaunchableID string
LaunchableInfo *store.LaunchableResponse // populated when LaunchableID is set
Region string
}

// parseLaunchableID extracts a launchable ID from either a raw ID (env-XXX) or
Expand Down Expand Up @@ -493,7 +509,7 @@ func searchInstances(s GPUCreateStore, filters *searchFilterFlags) ([]gpusearch.
}

instances := gpusearch.ProcessInstances(response.Items)
filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.provider, "", filters.minVRAM,
filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.region, filters.provider, "", filters.minVRAM,
minTotalVRAM, minCapability, 0, minDisk, 0, maxBootTime, filters.stoppable, filters.rebootable, filters.flexPorts, true)
gpusearch.SortInstances(filtered, sortBy, filters.descending)

Expand Down Expand Up @@ -539,6 +555,91 @@ func runDryRun(t *terminal.Terminal, s GPUCreateStore, specs []InstanceSpec, fil
return nil
}

// validateRegionExists checks that the given region appears in at least one instance type in the catalog.
func validateRegionExists(region string, store GPUCreateStore) error {
response, err := store.GetInstanceTypes(false)
if err != nil {
return breverrors.WrapAndTrace(err)
}
if response == nil || len(response.Items) == 0 {
return nil
}

regionLower := strings.ToLower(region)
for _, item := range response.Items {
if typeSupportsRegion(item, regionLower) {
return nil
}
}

return breverrors.NewValidationError(
fmt.Sprintf("region %q is not offered by any instance type -- use 'brev search --json' to list valid regions",
region),
)
}

// validateRegion checks that every requested instance type is available in the given region.
func validateRegion(region string, types []InstanceSpec, store GPUCreateStore) error {
if len(types) == 0 {
return nil
}

response, err := store.GetInstanceTypes(false)
if err != nil {
return breverrors.WrapAndTrace(err)
}
if response == nil || len(response.Items) == 0 {
return nil
}

catalog := make(map[string]gpusearch.InstanceType, len(response.Items))
for _, item := range response.Items {
catalog[item.Type] = item
}

regionLower := strings.ToLower(region)
var unsupported []string
var unknown []string

for _, spec := range types {
item, ok := catalog[spec.Type]
if !ok {
unknown = append(unknown, spec.Type)
continue
}
if !typeSupportsRegion(item, regionLower) {
unsupported = append(unsupported, spec.Type)
}
}

if len(unknown) > 0 {
return breverrors.NewValidationError(
fmt.Sprintf("unknown instance type(s) %s -- use 'brev search' to list available types", strings.Join(unknown, ", ")),
)
}
if len(unsupported) > 0 {
return breverrors.NewValidationError(
fmt.Sprintf("region %q is not available for instance type(s) %s -- use 'brev search --region %s' to find compatible types",
region, strings.Join(unsupported, ", "), region),
)
}
return nil
}

// typeSupportsRegion reports whether an instance type lists the given region (already lowercased)
// in either its primary Location or AvailableLocations, using substring matching.
func typeSupportsRegion(item gpusearch.InstanceType, regionLower string) bool {
if strings.Contains(strings.ToLower(item.Location), regionLower) {
return true
}
for _, loc := range item.AvailableLocations {
if strings.Contains(strings.ToLower(loc), regionLower) {
return true
}
}
return false
}

// orDefault returns val if it's non-zero, otherwise returns def
func orDefault(val, def float64) float64 {
if val > 0 {
Expand Down Expand Up @@ -993,6 +1094,10 @@ func (c *createContext) createWorkspace(name string, spec InstanceSpec) (*entity
}
}

if c.opts.Region != "" {
cwOptions.Location = c.opts.Region
}

// Apply launchable config or build mode
if c.opts.LaunchableID != "" {
applyLaunchableConfig(cwOptions, c.opts.LaunchableID, c.opts.LaunchableInfo)
Expand Down
Loading
Loading