import org.springframework.web.bind.annotation.*;
import org.springframework.web.client.RestTemplate;
import org.springframework.http.ResponseEntity;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Set;
import java.util.Map;
import java.util.Optional;
@RestController
public class SecureProxyController {
private final RestTemplate secureRestTemplate;
private final SSRFProtectionService ssrfProtection;
private static final Set<String> ALLOWED_PROXY_HOSTS = Set.of(
"api.trusted-partner.com",
"public-api.example.com",
"webhook.partner.org"
);
private static final Map<String, String> ALLOWED_SERVICES = Map.of(
"user-service", "user-api.internal.company.com",
"order-service", "order-api.internal.company.com",
"payment-service", "payment-api.internal.company.com"
);
public SecureProxyController(RestTemplate secureRestTemplate, SSRFProtectionService ssrfProtection) {
this.secureRestTemplate = secureRestTemplate;
this.ssrfProtection = ssrfProtection;
}
// Secure: Strict allowlist validation
@GetMapping("/proxy")
public ResponseEntity<String> proxyRequest(@RequestParam String url) {
if (!ssrfProtection.isUrlAllowed(url, ALLOWED_PROXY_HOSTS)) {
return ResponseEntity.badRequest().body("URL not allowed");
}
try {
String response = secureRestTemplate.getForObject(url, String.class);
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("Proxy request failed: {}", url, e);
return ResponseEntity.internalServerError().body("Request failed");
}
}
// Secure: Allowlist-based hostname validation
@PostMapping("/webhook/{hostname}")
public ResponseEntity<String> callWebhook(
@PathVariable String hostname,
@RequestParam String path,
@RequestBody String payload) {
// Validate hostname against allowlist
if (!ALLOWED_PROXY_HOSTS.contains(hostname.toLowerCase())) {
return ResponseEntity.badRequest().body("Hostname not allowed");
}
// Validate and sanitize path
String sanitizedPath = ssrfProtection.sanitizePath(path);
if (sanitizedPath == null) {
return ResponseEntity.badRequest().body("Invalid path");
}
Optional<String> urlOpt = ssrfProtection.buildSecureUrl(
"https", hostname, sanitizedPath, null
);
if (urlOpt.isEmpty()) {
return ResponseEntity.badRequest().body("Failed to build secure URL");
}
try {
ResponseEntity<String> response = secureRestTemplate.postForEntity(
urlOpt.get(), payload, String.class);
return ResponseEntity.ok("Webhook called successfully");
} catch (Exception e) {
logger.error("Webhook call failed: {}", hostname, e);
return ResponseEntity.internalServerError().body("Webhook failed");
}
}
// Secure: Comprehensive URL validation
@GetMapping("/fetch-content")
public ResponseEntity<String> fetchContent(@RequestParam String targetUrl) {
if (!ssrfProtection.isUrlAllowed(targetUrl, ALLOWED_PROXY_HOSTS)) {
return ResponseEntity.badRequest().body("URL not allowed");
}
try {
String content = secureRestTemplate.getForObject(targetUrl, String.class);
return ResponseEntity.ok(content);
} catch (Exception e) {
logger.error("Content fetch failed: {}", targetUrl, e);
return ResponseEntity.internalServerError().body("Fetch failed");
}
}
// Secure: Service-based routing with allowlist
@PostMapping("/api-gateway")
public ResponseEntity<Object> callApi(
@RequestParam String service,
@RequestParam String endpoint,
@RequestBody Object data) {
// Validate service against allowlist
String hostname = ALLOWED_SERVICES.get(service);
if (hostname == null) {
return ResponseEntity.badRequest().body("Unknown service: " + service);
}
// Validate and sanitize endpoint
String sanitizedEndpoint = ssrfProtection.sanitizePath(endpoint);
if (sanitizedEndpoint == null) {
return ResponseEntity.badRequest().body("Invalid endpoint");
}
Optional<String> urlOpt = ssrfProtection.buildSecureUrl(
"https", hostname, sanitizedEndpoint, null
);
if (urlOpt.isEmpty()) {
return ResponseEntity.badRequest().body("Failed to build API URL");
}
try {
Object response = secureRestTemplate.postForObject(urlOpt.get(), data, Object.class);
return ResponseEntity.ok(response);
} catch (Exception e) {
logger.error("API call failed: {} {}", service, endpoint, e);
return ResponseEntity.internalServerError().body("API call failed");
}
}
// Additional endpoint for listing allowed services
@GetMapping("/api-gateway/services")
public ResponseEntity<Set<String>> getAllowedServices() {
return ResponseEntity.ok(ALLOWED_SERVICES.keySet());
}
}
@Service
public class SSRFProtectionService {
private static final Set<String> BLOCKED_SCHEMES = Set.of("file", "ftp", "gopher", "dict");
private static final Set<String> DANGEROUS_HOSTS = Set.of(
"localhost", "127.0.0.1", "0.0.0.0", "::1",
"169.254.169.254", "metadata.google.internal", "metadata.azure.com"
);
public boolean isUrlAllowed(String url, Set<String> allowedHosts) {
try {
URI uri = new URI(url);
// Check scheme
String scheme = uri.getScheme();
if (scheme == null || BLOCKED_SCHEMES.contains(scheme.toLowerCase())) {
return false;
}
// Check host
String host = uri.getHost();
if (host == null) {
return false;
}
host = host.toLowerCase();
// Block dangerous hosts
if (DANGEROUS_HOSTS.contains(host)) {
return false;
}
// Block private IP addresses
if (isPrivateIpAddress(host)) {
return false;
}
// Check against allowlist
return allowedHosts.contains(host);
} catch (URISyntaxException e) {
return false;
}
}
public String sanitizePath(String path) {
if (path == null || path.trim().isEmpty()) {
return "/";
}
// Remove dangerous patterns
if (path.contains("..") || path.contains("//") || path.contains("\0")) {
return null;
}
// Ensure path starts with /
if (!path.startsWith("/")) {
path = "/" + path;
}
// Validate path format
if (!path.matches("[a-zA-Z0-9/_.-]+")) {
return null;
}
return path;
}
public Optional<String> buildSecureUrl(String scheme, String host, String path, Map<String, String> params) {
try {
UriComponentsBuilder builder = UriComponentsBuilder.newInstance()
.scheme(scheme)
.host(host)
.path(path);
if (params != null) {
for (Map.Entry<String, String> entry : params.entrySet()) {
builder.queryParam(entry.getKey(), entry.getValue());
}
}
return Optional.of(builder.build().toUriString());
} catch (Exception e) {
return Optional.empty();
}
}
private boolean isPrivateIpAddress(String host) {
try {
InetAddress address = InetAddress.getByName(host);
return address.isSiteLocalAddress() ||
address.isLoopbackAddress() ||
address.isLinkLocalAddress();
} catch (Exception e) {
return false;
}
}
}