diff --git a/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala b/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala index 4f1691287fd..792a0dfd8a1 100644 --- a/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala +++ b/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala @@ -27,8 +27,10 @@ import jakarta.ws.rs.{Consumes, DELETE, GET, POST, PUT, Path, Produces} import org.apache.texera.auth.JwtParser.parseToken import org.apache.texera.auth.SessionUser import org.apache.texera.auth.util.{ComputingUnitAccess, HeaderField} -import org.apache.texera.common.config.{GuiConfig, KubernetesConfig, LLMConfig} +import org.apache.texera.common.config.{GuiConfig, LLMConfig} +import org.apache.texera.dao.SqlServer import org.apache.texera.dao.jooq.generated.enums.PrivilegeEnum +import org.apache.texera.dao.jooq.generated.tables.daos.WorkflowComputingUnitDao import java.net.URLDecoder import java.nio.charset.StandardCharsets @@ -136,12 +138,25 @@ object AccessControlResource extends LazyLogging { } // Dynamic Routing Logic - val workflowComputingUnitPoolName = KubernetesConfig.computeUnitPoolName - val workflowComputingUnitPoolNamespace = KubernetesConfig.computeUnitPoolNamespace - val workflowComputingUnitPoolPort = KubernetesConfig.computeUnitPortNumber - - val targetHost = - s"computing-unit-$cuidInt.$workflowComputingUnitPoolName-svc.$workflowComputingUnitPoolNamespace.svc.cluster.local:$workflowComputingUnitPoolPort" + // Route to the URI recorded for the computing unit (written by the managing + // service when the pod is created). This recorded URI is the single source + // of truth for where the unit is reachable, allowing units to live anywhere + // the gateway can route to. If no URI has been recorded, the unit is not + // routable and the connection is refused. + val cuDao = new WorkflowComputingUnitDao( + SqlServer.getInstance().createDSLContext().configuration() + ) + val unit = cuDao.fetchOneByCuid(cuidInt) + val recordedUri = Option(unit).flatMap(u => Option(u.getUri)).map(_.trim).filter(_.nonEmpty) + + val targetHost = recordedUri match { + case Some(uri) => + logger.info(s"Routing CU $cuidInt to recorded host: $uri") + uri + case None => + logger.warn(s"Refusing CU $cuidInt: no URI recorded for the computing unit") + return Response.status(Response.Status.FORBIDDEN).build() + } Response .ok() diff --git a/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala b/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala index 3dfe81d89d5..365f5f885f0 100644 --- a/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala +++ b/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala @@ -55,6 +55,11 @@ class AccessControlResourceSpec private val testURI: String = "http://localhost:8080/" private val testPath: String = "/api/executions/1/stats/1" + // The host:port the managing service records for a computing unit when it + // creates the pod. The access-control-service routes to this recorded URI. + private val testRecordedUri: String = + "computing-unit-2.compute-unit-svc.default.svc.cluster.local:8888" + private val testUser1: User = { val user = new User() user.setUid(1) @@ -81,6 +86,31 @@ class AccessControlResourceSpec cu.setType(WorkflowComputingUnitTypeEnum.kubernetes) cu.setCuid(2) cu.setName("test-cu") + cu.setUri(testRecordedUri) + cu + } + + // A computing unit the user can access but for which no URI was ever recorded + // (e.g. the pod was never created). Such a unit is not routable and must be + // refused. + private val testCUNoUri: WorkflowComputingUnit = { + val cu = new WorkflowComputingUnit() + cu.setUid(2) + cu.setType(WorkflowComputingUnitTypeEnum.kubernetes) + cu.setCuid(3) + cu.setName("test-cu-no-uri") + cu + } + + // A computing unit whose recorded URI is blank/whitespace-only — also treated + // as "no URI recorded" and refused. + private val testCUBlankUri: WorkflowComputingUnit = { + val cu = new WorkflowComputingUnit() + cu.setUid(2) + cu.setType(WorkflowComputingUnitTypeEnum.kubernetes) + cu.setCuid(4) + cu.setName("test-cu-blank-uri") + cu.setUri(" ") cu } @@ -96,12 +126,18 @@ class AccessControlResourceSpec userDao.insert(testUser1) userDao.insert(testUser2) computingUnitDao.insert(testCU) - - val cuAccess = new ComputingUnitUserAccess() - cuAccess.setUid(testUser1.getUid) - cuAccess.setCuid(testCU.getCuid) - cuAccess.setPrivilege(PrivilegeEnum.WRITE) - computingUnitOfUserDao.insert(cuAccess) + computingUnitDao.insert(testCUNoUri) + computingUnitDao.insert(testCUBlankUri) + + // Grant testUser1 WRITE access to every test computing unit so the routing + // logic (not the access check) is what each routing test exercises. + Seq(testCU, testCUNoUri, testCUBlankUri).foreach { cu => + val cuAccess = new ComputingUnitUserAccess() + cuAccess.setUid(testUser1.getUid) + cuAccess.setCuid(cu.getCuid) + cuAccess.setPrivilege(PrivilegeEnum.WRITE) + computingUnitOfUserDao.insert(cuAccess) + } val claims = JwtAuth.jwtClaims(testUser1, 1) token = JwtAuth.jwtToken(claims) @@ -232,6 +268,23 @@ class AccessControlResourceSpec response.getHeaderString(HeaderField.UserId) shouldBe testUser1.getUid.toString response.getHeaderString(HeaderField.UserName) shouldBe testUser1.getName response.getHeaderString(HeaderField.UserEmail) shouldBe testUser1.getEmail + // Envoy routes by the rewritten Host header, which must be the URI recorded + // for the computing unit. + response.getHeaderString("Host") shouldBe testRecordedUri + } + + it should "refuse the connection when no URI is recorded for the computing unit" in { + val (uri, headers) = mockRequest(testPath, Some(testCUNoUri.getCuid.toString)) + val response = new AccessControlResource().authorizeGet(uri, headers) + + response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode + } + + it should "refuse the connection when the recorded URI is blank" in { + val (uri, headers) = mockRequest(testPath, Some(testCUBlankUri.getCuid.toString)) + val response = new AccessControlResource().authorizeGet(uri, headers) + + response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode } private def mockRequest( diff --git a/computing-unit-managing-service/src/main/scala/org/apache/texera/service/util/KubernetesClient.scala b/computing-unit-managing-service/src/main/scala/org/apache/texera/service/util/KubernetesClient.scala index 5177ebaf471..4f1d391cb30 100644 --- a/computing-unit-managing-service/src/main/scala/org/apache/texera/service/util/KubernetesClient.scala +++ b/computing-unit-managing-service/src/main/scala/org/apache/texera/service/util/KubernetesClient.scala @@ -35,7 +35,7 @@ object KubernetesClient { private val podNamePrefix = "computing-unit" def generatePodURI(cuid: Int): String = { - s"${generatePodName(cuid)}.${KubernetesConfig.computeUnitServiceName}.$namespace.svc.cluster.local" + s"${generatePodName(cuid)}.${KubernetesConfig.computeUnitServiceName}.$namespace.svc.cluster.local:${KubernetesConfig.computeUnitPortNumber}" } def generatePodName(cuid: Int): String = s"$podNamePrefix-$cuid"