Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import jakarta.ws.rs.{Consumes, DELETE, GET, POST, PUT, Path, Produces}
import org.apache.texera.auth.JwtParser.parseToken
import org.apache.texera.auth.SessionUser
import org.apache.texera.auth.util.{ComputingUnitAccess, HeaderField}
import org.apache.texera.common.config.{GuiConfig, KubernetesConfig, LLMConfig}
import org.apache.texera.common.config.{GuiConfig, LLMConfig}
import org.apache.texera.dao.SqlServer
import org.apache.texera.dao.jooq.generated.enums.PrivilegeEnum
import org.apache.texera.dao.jooq.generated.tables.daos.WorkflowComputingUnitDao

import java.net.URLDecoder
import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -136,12 +138,25 @@ object AccessControlResource extends LazyLogging {
}

// Dynamic Routing Logic
val workflowComputingUnitPoolName = KubernetesConfig.computeUnitPoolName
val workflowComputingUnitPoolNamespace = KubernetesConfig.computeUnitPoolNamespace
val workflowComputingUnitPoolPort = KubernetesConfig.computeUnitPortNumber

val targetHost =
s"computing-unit-$cuidInt.$workflowComputingUnitPoolName-svc.$workflowComputingUnitPoolNamespace.svc.cluster.local:$workflowComputingUnitPoolPort"
// Route to the URI recorded for the computing unit (written by the managing
// service when the pod is created). This recorded URI is the single source
// of truth for where the unit is reachable, allowing units to live anywhere
// the gateway can route to. If no URI has been recorded, the unit is not
// routable and the connection is refused.
val cuDao = new WorkflowComputingUnitDao(
SqlServer.getInstance().createDSLContext().configuration()
)
val unit = cuDao.fetchOneByCuid(cuidInt)
val recordedUri = Option(unit).flatMap(u => Option(u.getUri)).map(_.trim).filter(_.nonEmpty)

val targetHost = recordedUri match {
case Some(uri) =>
logger.info(s"Routing CU $cuidInt to recorded host: $uri")
uri
case None =>
logger.warn(s"Refusing CU $cuidInt: no URI recorded for the computing unit")
return Response.status(Response.Status.FORBIDDEN).build()
}

Response
.ok()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class AccessControlResourceSpec
private val testURI: String = "http://localhost:8080/"
private val testPath: String = "/api/executions/1/stats/1"

// The host:port the managing service records for a computing unit when it
// creates the pod. The access-control-service routes to this recorded URI.
private val testRecordedUri: String =
"computing-unit-2.compute-unit-svc.default.svc.cluster.local:8888"

private val testUser1: User = {
val user = new User()
user.setUid(1)
Expand All @@ -81,6 +86,31 @@ class AccessControlResourceSpec
cu.setType(WorkflowComputingUnitTypeEnum.kubernetes)
cu.setCuid(2)
cu.setName("test-cu")
cu.setUri(testRecordedUri)
cu
}

// A computing unit the user can access but for which no URI was ever recorded
// (e.g. the pod was never created). Such a unit is not routable and must be
// refused.
private val testCUNoUri: WorkflowComputingUnit = {
val cu = new WorkflowComputingUnit()
cu.setUid(2)
cu.setType(WorkflowComputingUnitTypeEnum.kubernetes)
cu.setCuid(3)
cu.setName("test-cu-no-uri")
cu
}

// A computing unit whose recorded URI is blank/whitespace-only — also treated
// as "no URI recorded" and refused.
private val testCUBlankUri: WorkflowComputingUnit = {
val cu = new WorkflowComputingUnit()
cu.setUid(2)
cu.setType(WorkflowComputingUnitTypeEnum.kubernetes)
cu.setCuid(4)
cu.setName("test-cu-blank-uri")
cu.setUri(" ")
cu
}

Expand All @@ -96,12 +126,18 @@ class AccessControlResourceSpec
userDao.insert(testUser1)
userDao.insert(testUser2)
computingUnitDao.insert(testCU)

val cuAccess = new ComputingUnitUserAccess()
cuAccess.setUid(testUser1.getUid)
cuAccess.setCuid(testCU.getCuid)
cuAccess.setPrivilege(PrivilegeEnum.WRITE)
computingUnitOfUserDao.insert(cuAccess)
computingUnitDao.insert(testCUNoUri)
computingUnitDao.insert(testCUBlankUri)

// Grant testUser1 WRITE access to every test computing unit so the routing
// logic (not the access check) is what each routing test exercises.
Seq(testCU, testCUNoUri, testCUBlankUri).foreach { cu =>
val cuAccess = new ComputingUnitUserAccess()
cuAccess.setUid(testUser1.getUid)
cuAccess.setCuid(cu.getCuid)
cuAccess.setPrivilege(PrivilegeEnum.WRITE)
computingUnitOfUserDao.insert(cuAccess)
}

val claims = JwtAuth.jwtClaims(testUser1, 1)
token = JwtAuth.jwtToken(claims)
Expand Down Expand Up @@ -232,6 +268,23 @@ class AccessControlResourceSpec
response.getHeaderString(HeaderField.UserId) shouldBe testUser1.getUid.toString
response.getHeaderString(HeaderField.UserName) shouldBe testUser1.getName
response.getHeaderString(HeaderField.UserEmail) shouldBe testUser1.getEmail
// Envoy routes by the rewritten Host header, which must be the URI recorded
// for the computing unit.
response.getHeaderString("Host") shouldBe testRecordedUri
}

it should "refuse the connection when no URI is recorded for the computing unit" in {
val (uri, headers) = mockRequest(testPath, Some(testCUNoUri.getCuid.toString))
val response = new AccessControlResource().authorizeGet(uri, headers)

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}

it should "refuse the connection when the recorded URI is blank" in {
val (uri, headers) = mockRequest(testPath, Some(testCUBlankUri.getCuid.toString))
val response = new AccessControlResource().authorizeGet(uri, headers)

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}

private def mockRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object KubernetesClient {
private val podNamePrefix = "computing-unit"

def generatePodURI(cuid: Int): String = {
s"${generatePodName(cuid)}.${KubernetesConfig.computeUnitServiceName}.$namespace.svc.cluster.local"
s"${generatePodName(cuid)}.${KubernetesConfig.computeUnitServiceName}.$namespace.svc.cluster.local:${KubernetesConfig.computeUnitPortNumber}"
}

def generatePodName(cuid: Int): String = s"$podNamePrefix-$cuid"
Expand Down
Loading