359 lines
13 KiB
Python
359 lines
13 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2025 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Generates optimized Kubernetes manifests for GKE Inference Quickstart."""
|
|
|
|
from apitools.base.py import exceptions as apitools_exceptions
|
|
from googlecloudsdk.api_lib.ai.recommender import util
|
|
from googlecloudsdk.api_lib.util import exceptions as api_lib_exceptions
|
|
from googlecloudsdk.calliope import base
|
|
from googlecloudsdk.core import exceptions
|
|
from googlecloudsdk.core import log
|
|
from googlecloudsdk.core.util import files
|
|
|
|
|
|
@base.DefaultUniverseOnly
|
|
@base.ReleaseTracks(base.ReleaseTrack.GA)
|
|
class Create(base.CreateCommand):
|
|
"""Generate ready-to-deploy Kubernetes manifests with compute, load balancing, and autoscaling capabilities.
|
|
|
|
To get supported model, model servers, and model server versions, run `gcloud
|
|
alpha container ai profiles model-and-server-combinations list`. To get
|
|
supported accelerators with their performance metrics, run `gcloud alpha
|
|
container ai profiles accelerators list`.
|
|
"""
|
|
|
|
@staticmethod
|
|
def Args(parser):
|
|
parser.add_argument(
|
|
"--model",
|
|
required=True,
|
|
help="The model.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-server",
|
|
required=True,
|
|
help="The model server.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-server-version",
|
|
help=(
|
|
"The model server version. If not specified, this defaults to the"
|
|
" latest version."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--target-ntpot-milliseconds",
|
|
type=int,
|
|
help=(
|
|
"The maximum normalized time per output token (NTPOT) in"
|
|
" milliseconds. NTPOT is measured as the request_latency /"
|
|
" output_tokens. If this is set, the manifests will include"
|
|
" Horizontal Pod Autoscaler (HPA) resources which automatically"
|
|
" adjust the model server replica count in response to changes in"
|
|
" model server load to keep p50 NTPOT below the specified"
|
|
" threshold. If the provided target-ntpot-milliseconds is too low"
|
|
" to achieve, the HPA manifest will not be generated. "
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--target-ttft-milliseconds",
|
|
type=int,
|
|
help=(
|
|
"If specified, results will only show accelerators that can meet"
|
|
" the latency target and will show their throughput performances at"
|
|
" the target ttft target to achieve, the HPA manifest will not be"
|
|
" generated. "
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--accelerator-type",
|
|
required=True,
|
|
help="The accelerator type.",
|
|
)
|
|
parser.add_argument(
|
|
"--namespace",
|
|
help=(
|
|
"The namespace to deploy the manifests in. Default namespace is"
|
|
" 'default'."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
choices=["manifest", "comments", "all"],
|
|
default="all",
|
|
help="The output to display. Default is all.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-path",
|
|
help=(
|
|
"The path to save the output to. If not specified, output to the"
|
|
" terminal."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model-bucket-uri",
|
|
help=(
|
|
"The Google Cloud Storage bucket URI to load the model from. This"
|
|
" URI must point to the directory containing the model's config"
|
|
" file (config.json) and model weights. If unspecified, defaults to"
|
|
" loading the model from Hugging Face."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--target-itl-milliseconds",
|
|
type=int,
|
|
help=(
|
|
"The target inter-token latency (ITL) in milliseconds. If this is"
|
|
" set, the manifest will include Horizontal Pod Autoscaler (HPA)"
|
|
" resources which automatically adjust the model server replica"
|
|
" count in response to changes in model server load to keep p50 ITL"
|
|
" below the specified threshold. If the provided"
|
|
" target-itl-milliseconds is too low to achieve, the HPA manifest"
|
|
" will not be generated."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--use-case",
|
|
help=(
|
|
"The manifest will be optimized for this use case. Options are:"
|
|
" Advanced Customer Support, Code Completion, Text Summarization,"
|
|
" Chatbot (ShareGPT), Code Generation, Deep Research. Will default"
|
|
" to Chatbot if not specified."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--serving-stack",
|
|
required=False,
|
|
help=(
|
|
"The serving stack to filter manifests by. If not"
|
|
" provided, manifests for all serving stacks that support"
|
|
" the given model and model server will be considered."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--serving-stack-version",
|
|
required=False,
|
|
help=(
|
|
"The serving stack version to filter manifests by. If not"
|
|
" provided, manifests for all versions that support"
|
|
" the given model and model server will be considered."
|
|
),
|
|
)
|
|
|
|
def Run(self, args):
|
|
client = util.GetClientInstance(base.ReleaseTrack.GA)
|
|
messages = util.GetMessagesModule(base.ReleaseTrack.GA)
|
|
|
|
try:
|
|
model_server_info = messages.ModelServerInfo(
|
|
model=args.model,
|
|
modelServer=args.model_server,
|
|
modelServerVersion=args.model_server_version,
|
|
)
|
|
performance_requirements = messages.PerformanceRequirements()
|
|
if args.target_ntpot_milliseconds:
|
|
performance_requirements.targetNtpotMilliseconds = (
|
|
args.target_ntpot_milliseconds
|
|
)
|
|
if args.target_ttft_milliseconds:
|
|
performance_requirements.targetTtftMilliseconds = (
|
|
args.target_ttft_milliseconds
|
|
)
|
|
if args.target_itl_milliseconds:
|
|
performance_requirements.targetItlMilliseconds = (
|
|
args.target_itl_milliseconds
|
|
)
|
|
storage_config = messages.StorageConfig()
|
|
if args.model_bucket_uri:
|
|
storage_config.modelBucketUri = args.model_bucket_uri
|
|
|
|
serving_stack = None
|
|
if args.serving_stack:
|
|
serving_stack = messages.ServingStack(
|
|
name=args.serving_stack,
|
|
)
|
|
if args.serving_stack_version:
|
|
serving_stack.version = args.serving_stack_version
|
|
|
|
request = messages.GenerateOptimizedManifestRequest(
|
|
modelServerInfo=model_server_info,
|
|
acceleratorType=args.accelerator_type,
|
|
kubernetesNamespace=args.namespace,
|
|
servingStack=serving_stack,
|
|
)
|
|
if (
|
|
performance_requirements.targetNtpotMilliseconds is not None
|
|
or performance_requirements.targetTtftMilliseconds is not None
|
|
or performance_requirements.targetItlMilliseconds is not None
|
|
):
|
|
request.performanceRequirements = performance_requirements
|
|
if storage_config.modelBucketUri is not None:
|
|
request.storageConfig = storage_config
|
|
if args.use_case:
|
|
request.useCase = args.use_case
|
|
response = client.optimizedManifest.Generate(request)
|
|
return response
|
|
except apitools_exceptions.HttpError as error:
|
|
raise api_lib_exceptions.HttpException(error, util.HTTP_ERROR_FORMAT)
|
|
|
|
def Display(self, args, resources):
|
|
if not resources:
|
|
log.out.Print("No manifests generated.")
|
|
return
|
|
|
|
output_content = ""
|
|
if args.output != "comments":
|
|
for manifest in resources.kubernetesManifests:
|
|
output_content += manifest.content + "\n---\n"
|
|
|
|
if resources.comments:
|
|
comment_string = "\n".join([f"# {line}" for line in resources.comments])
|
|
output_content += comment_string
|
|
|
|
if args.output_path:
|
|
try:
|
|
with files.FileWriter(args.output_path, output_content) as f:
|
|
f.write(output_content)
|
|
log.out.Print(f"Output saved to {args.output_path}")
|
|
except exceptions.Error as e:
|
|
log.error(f"An error occurred while saving output to file: {e}")
|
|
else:
|
|
log.out.Print(output_content)
|
|
|
|
|
|
@base.DefaultUniverseOnly
|
|
@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
|
|
class CreateAlpha(base.CreateCommand):
|
|
"""Generate ready-to-deploy Kubernetes manifests with compute, load balancing, and autoscaling capabilities.
|
|
|
|
To get supported model, model servers, and model server versions, run `gcloud
|
|
alpha container ai profiles model-and-server-combinations list`. To get
|
|
supported accelerators with their performance metrics, run `gcloud alpha
|
|
container ai profiles accelerators list`.
|
|
"""
|
|
|
|
@staticmethod
|
|
def Args(parser):
|
|
parser.add_argument(
|
|
"--model",
|
|
required=True,
|
|
help="The model.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-server",
|
|
required=True,
|
|
help="The model server.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-server-version",
|
|
help=(
|
|
"The model server version. If not specified, this defaults to the"
|
|
" latest version."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--target-ntpot-milliseconds",
|
|
type=int,
|
|
help=(
|
|
"The maximum normalized time per output token (NTPOT) in"
|
|
" milliseconds. NTPOT is measured as the request_latency /"
|
|
" output_tokens. If this is set, the manifests will include"
|
|
" Horizontal Pod Autoscaler (HPA) resources which automatically"
|
|
" adjust the model server replica count in response to changes in"
|
|
" model server load to keep p50 NTPOT below the specified"
|
|
" threshold. If the provided target-ntpot-milliseconds is too low"
|
|
" to achieve, the HPA manifest will not be generated. "
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--accelerator-type",
|
|
required=True,
|
|
help="The accelerator type.",
|
|
)
|
|
parser.add_argument(
|
|
"--namespace",
|
|
help=(
|
|
"The namespace to deploy the manifests in. Default namespace is"
|
|
" 'default'."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
choices=["manifest", "comments", "all"],
|
|
default="all",
|
|
help="The output to display. Default is all.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-path",
|
|
help=(
|
|
"The path to save the output to. If not specified, output to the"
|
|
" terminal."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model-bucket-uri",
|
|
hidden=True,
|
|
help=(
|
|
"GCS bucket URI to pull model from. If not specified, default"
|
|
" to the model hoster."
|
|
),
|
|
)
|
|
|
|
def Run(self, args):
|
|
client = util.GetClientInstance(base.ReleaseTrack.ALPHA)
|
|
messages = util.GetMessagesModule(base.ReleaseTrack.ALPHA)
|
|
|
|
try:
|
|
request = messages.GkerecommenderOptimizedManifestRequest(
|
|
modelAndModelServerInfo_modelName=args.model,
|
|
modelAndModelServerInfo_modelServerName=args.model_server,
|
|
modelAndModelServerInfo_modelServerVersion=args.model_server_version,
|
|
targetNtpotMilliseconds=args.target_ntpot_milliseconds,
|
|
acceleratorType=args.accelerator_type,
|
|
kubernetesNamespace=args.namespace,
|
|
storageConfig_modelBucketUri=args.model_bucket_uri,
|
|
)
|
|
response = client.v1alpha1.OptimizedManifest(request)
|
|
return response
|
|
except exceptions.Error as e:
|
|
log.error(f"An error has occurred: {e}")
|
|
log.status.Print(f"An error has occurred: {e}")
|
|
return []
|
|
|
|
def Display(self, args, resources):
|
|
if not resources:
|
|
log.out.Print("No manifests generated.")
|
|
return
|
|
|
|
output_content = ""
|
|
if args.output == "manifest" or args.output == "all":
|
|
for manifest in resources.k8sManifests:
|
|
output_content += manifest.content + "\n---\n"
|
|
|
|
if args.output == "comments" or args.output == "all":
|
|
if resources.comments:
|
|
comment_string = "\n".join([f"# {line}" for line in resources.comments])
|
|
output_content += comment_string
|
|
|
|
if args.output_path:
|
|
try:
|
|
with files.FileWriter(args.output_path, output_content) as f:
|
|
f.write(output_content)
|
|
log.out.Print(f"Output saved to {args.output_path}")
|
|
except exceptions.Error as e:
|
|
log.error(f"An error occurred while saving output to file: {e}")
|
|
else:
|
|
log.out.Print(output_content)
|