Files

208 lines
6.4 KiB
Python

# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to SSH into a Cloud Run Deployment."""
import argparse
import enum
import json
import subprocess
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.compute import iap_tunnel
from googlecloudsdk.command_lib.compute import ssh_utils
from googlecloudsdk.command_lib.util.ssh import ssh
def ProjectIdToProjectNumber(project_id):
"""Returns the Cloud project number associated with the `project_id`."""
crm_message_module = apis.GetMessagesModule("cloudresourcemanager", "v1")
resource_manager = apis.GetClientInstance("cloudresourcemanager", "v1")
req = crm_message_module.CloudresourcemanagerProjectsGetRequest(
projectId=project_id
)
project = resource_manager.projects.Get(req)
return project.projectNumber
def CreateSshTunnelArgs(
track,
project_number,
project,
deployment_name,
workload_type,
region,
instance_id=None,
container_id=None,
iap_tunnel_url_override=None,
):
"""Construct an SshTunnelArgs from command line args and values.
Args:
track: ReleaseTrack, The currently running release track.
project_number: str, the project number (string with digits).
project: str, the project id.
deployment_name: str, the name of the deployment.
workload_type: Ssh.WorkloadType, the type of the workload.
region: str, the region of the deployment.
instance_id: str, the instance id (optional).
container_id: str, the container id (optional).
iap_tunnel_url_override: str, the IAP tunnel URL override (optional).
Returns:
SshTunnelArgs.
"""
cloud_run_args = {}
cloud_run_args["deployment_name"] = deployment_name
cloud_run_args["workload_type"] = workload_type
cloud_run_args["project_number"] = project_number
cloud_run_args["instance_id"] = instance_id
cloud_run_args["container_id"] = container_id
res = iap_tunnel.SshTunnelArgs()
res.track = track.prefix
res.cloud_run_args = cloud_run_args
res.region = region
res.project = project
if iap_tunnel_url_override:
res.pass_through_args.append(
"--iap-tunnel-url-override=" + iap_tunnel_url_override)
return res
class Ssh:
"""SSH into a Cloud Run Deployment."""
class WorkloadType(enum.Enum):
"""The type of the deployment."""
WORKER_POOL = "worker_pool"
JOB = "job"
SERVICE = "service"
def __init__(self, args: argparse.Namespace, workload_type: WorkloadType):
"""Initialize the SSH library."""
self.deployment_name = args.deployment_name
self.workload_type = workload_type
self.project = args.project
self.project_number = ProjectIdToProjectNumber(args.project)
self.instance = getattr(args, "instance", None)
self.container = getattr(args, "container", None)
self.region = args.region
self.release_track = args.release_track
self.iap_tunnel_url_override = getattr(
args, "iap_tunnel_url_override", None
)
self.service_account = self._GetServiceAccountFromWorkload()
def _GetServiceAccountFromWorkload(self):
"""Retrieves the service account from the Cloud Run workload."""
command = ["gcloud"]
if self.workload_type == self.WorkloadType.SERVICE:
command.extend(["run", "services", "describe"])
elif self.workload_type == self.WorkloadType.WORKER_POOL:
command.extend(["beta", "run", "worker-pools", "describe"])
elif self.workload_type == self.WorkloadType.JOB:
command.extend(["run", "jobs", "describe"])
else:
raise ValueError(f"Unsupported workload type: {self.workload_type}")
command.extend([
self.deployment_name,
"--region",
self.region,
"--project",
self.project,
"--format",
"json",
])
try:
output = subprocess.check_output(command, stderr=subprocess.PIPE)
service_data = json.loads(output)
return (
service_data.get("spec", {})
.get("template", {})
.get("spec", {})
.get("serviceAccountName")
)
except subprocess.CalledProcessError as e:
raise ValueError(f"Failed to get service account: {e}")
def HostKeyAlias(self):
"""Returns the host key alias for the SSH connection."""
if self.instance and self.container:
return "cloud-run-{}-{}".format(self.instance, self.container)
else:
return "cloud-run-default"
def Run(self):
"""Run the SSH command."""
env = ssh.Environment.Current()
env.RequireSSH()
keys = ssh.Keys.FromFilename()
keys.EnsureKeysExist(overwrite=False)
user = "root"
# Note: this actually creates the certificate.
ssh.GetOsloginState(
None,
None,
user,
keys.GetPublicKey().ToEntry(),
None,
self.release_track,
cloud_run_params={
"deployment_name": self.deployment_name,
"project_id": self.project,
"region": self.region,
"service_account": self.service_account,
},
)
cert_file = ssh.CertFileFromCloudRunDeployment(
project=self.project,
region=self.region,
deployment=self.deployment_name,
)
dest_addr = self.HostKeyAlias()
remote = ssh.Remote(dest_addr, user)
iap_tunnel_args = CreateSshTunnelArgs(
self.release_track,
self.project_number,
self.project,
self.deployment_name,
self.workload_type,
self.region,
self.instance,
self.container,
self.iap_tunnel_url_override,
)
ssh_helper = ssh_utils.BaseSSHCLIHelper()
ssh_options = ssh_helper.GetConfig(
host_key_alias=dest_addr,
strict_host_key_checking="no",
)
return ssh.SSHCommand(
remote=remote,
cert_file=cert_file,
iap_tunnel_args=iap_tunnel_args,
options=ssh_options,
identity_file=keys.key_file,
).Run(env)