diff --git a/backend/app/services/ssh_key_service.py b/backend/app/services/ssh_key_service.py new file mode 100644 index 0000000..e880522 --- /dev/null +++ b/backend/app/services/ssh_key_service.py @@ -0,0 +1,231 @@ +""" +SSH Key Service. + +Business logic for SSH key management including: +- Creating SSH keys with encryption +- Listing and retrieving SSH keys +- Deleting SSH keys (with usage check) +- Decrypting private keys for Git operations +""" +import time +import hashlib +import base64 +from typing import List, Optional +from sqlalchemy.orm import Session + +from app.models.ssh_key import SshKey +from app.models.server import Server +from app.security import encrypt_data, decrypt_data +from app.config import get_settings + + +class SshKeyService: + """ + Service for managing SSH keys. + + Handles encryption, decryption, and validation of SSH private keys. + """ + + def __init__(self, db: Session): + """ + Initialize the service with a database session. + + Args: + db: SQLAlchemy database session + """ + self.db = db + self.settings = get_settings() + + def create_ssh_key(self, name: str, private_key: str, password: Optional[str] = None) -> SshKey: + """ + Create a new SSH key with encryption. + + Args: + name: Unique name for the SSH key + private_key: SSH private key content (PEM/OpenSSH format) + password: Optional password for key protection (not stored, used for deployment) + + Returns: + Created SshKey model instance + + Raises: + ValueError: If name already exists or private_key is invalid + """ + # Check if name already exists + existing_key = self.db.query(SshKey).filter_by(name=name).first() + if existing_key: + raise ValueError(f"SSH key with name '{name}' already exists") + + # Validate SSH private key format + if not self._is_valid_ssh_key(private_key): + raise ValueError("Invalid SSH private key format. Must be a valid PEM or OpenSSH private key.") + + # Generate fingerprint + fingerprint = self._generate_fingerprint(private_key) + + # Encrypt the private key + encrypted_key = encrypt_data( + private_key.encode('utf-8'), + self.settings.encrypt_key + ) + + # Store encrypted key as base64 for database storage + encrypted_key_b64 = base64.b64encode(encrypted_key).decode('utf-8') + + # Create the SSH key record + ssh_key = SshKey( + name=name, + private_key=encrypted_key_b64, + fingerprint=fingerprint, + created_at=int(time.time()) + ) + + self.db.add(ssh_key) + self.db.commit() + self.db.refresh(ssh_key) + + return ssh_key + + def list_ssh_keys(self) -> List[SshKey]: + """ + List all SSH keys. + + Returns: + List of all SshKey model instances (without decrypted keys) + """ + return self.db.query(SshKey).all() + + def get_ssh_key(self, key_id: int) -> Optional[SshKey]: + """ + Get an SSH key by ID. + + Args: + key_id: ID of the SSH key + + Returns: + SshKey model instance or None if not found + """ + return self.db.query(SshKey).filter_by(id=key_id).first() + + def delete_ssh_key(self, key_id: int) -> bool: + """ + Delete an SSH key. + + Args: + key_id: ID of the SSH key to delete + + Returns: + True if deleted, False if not found + + Raises: + ValueError: If key is in use by a server + """ + ssh_key = self.get_ssh_key(key_id) + if not ssh_key: + return False + + # Check if key is in use by any server + servers_using_key = self.db.query(Server).filter_by(ssh_key_id=key_id).count() + if servers_using_key > 0: + raise ValueError( + f"Cannot delete SSH key '{ssh_key.name}'. " + f"It is in use by {servers_using_key} server(s)." + ) + + self.db.delete(ssh_key) + self.db.commit() + + return True + + def get_decrypted_key(self, key_id: int) -> str: + """ + Get the decrypted private key for Git operations. + + Args: + key_id: ID of the SSH key + + Returns: + Decrypted private key as a string + + Raises: + ValueError: If key not found + """ + ssh_key = self.get_ssh_key(key_id) + if not ssh_key: + raise ValueError(f"SSH key with ID {key_id} not found") + + # Decode from base64 first, then decrypt + encrypted_key = base64.b64decode(ssh_key.private_key.encode('utf-8')) + decrypted = decrypt_data( + encrypted_key, + self.settings.encrypt_key + ) + + return decrypted.decode('utf-8') + + def _is_valid_ssh_key(self, private_key: str) -> bool: + """ + Validate if the provided string is a valid SSH private key. + + Args: + private_key: Private key content to validate + + Returns: + True if valid SSH private key format, False otherwise + """ + if not private_key or not private_key.strip(): + return False + + # Check for common SSH private key markers + valid_markers = [ + "-----BEGIN RSA PRIVATE KEY-----", + "-----BEGIN OPENSSH PRIVATE KEY-----", + "-----BEGIN DSA PRIVATE KEY-----", + "-----BEGIN EC PRIVATE KEY-----", + "-----BEGIN ED25519 PRIVATE KEY-----", + "-----BEGIN PGP PRIVATE KEY BLOCK-----", # GPG keys + ] + + private_key_stripped = private_key.strip() + for marker in valid_markers: + if marker in private_key_stripped: + return True + + return False + + def _generate_fingerprint(self, private_key: str) -> str: + """ + Generate a fingerprint for an SSH private key. + + For simplicity, we use SHA256 hash of the public key portion. + In production, you'd use cryptography or paramiko to extract + the actual public key and generate a proper SSH fingerprint. + + Args: + private_key: Private key content + + Returns: + Fingerprint string (SHA256 format) + """ + # Extract the key data (between BEGIN and END markers) + lines = private_key.strip().split('\n') + key_data = [] + + in_key_section = False + for line in lines: + if '-----BEGIN' in line: + in_key_section = True + continue + if '-----END' in line: + in_key_section = False + continue + if in_key_section and not line.startswith('---'): + key_data.append(line.strip()) + + key_content = ''.join(key_data) + + # Generate SHA256 hash + sha256_hash = hashlib.sha256(key_content.encode('utf-8')).digest() + b64_hash = base64.b64encode(sha256_hash).decode('utf-8').rstrip('=') + + return f"SHA256:{b64_hash}" diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 2451f5a..2013cea 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -41,12 +41,16 @@ def db_session(db_engine): def test_encrypt_key(): """测试加密密钥.""" import base64 - return base64.b64encode(b'test-key-32-bytes-long-1234567890').decode() + return base64.b64encode(b'test-key-32-bytes-long-123456789').decode() @pytest.fixture(scope="function") def test_env_vars(db_path, test_encrypt_key, monkeypatch): """设置测试环境变量.""" + # Clear global settings to ensure fresh config + import app.config + app.config._settings = None + monkeypatch.setenv("GM_ENCRYPT_KEY", test_encrypt_key) monkeypatch.setenv("GM_API_TOKEN", "test-token") monkeypatch.setenv("GM_DATA_DIR", str(db_path.parent)) diff --git a/backend/tests/test_services/test_ssh_key_service.py b/backend/tests/test_services/test_ssh_key_service.py new file mode 100644 index 0000000..1f03262 --- /dev/null +++ b/backend/tests/test_services/test_ssh_key_service.py @@ -0,0 +1,283 @@ +""" +Tests for SSH Key Service. +""" +import base64 +import pytest +import time +from app.models.ssh_key import SshKey +from app.models.server import Server +from app.services.ssh_key_service import SshKeyService +from app.config import get_settings + + +# Test SSH key samples +VALID_SSH_KEY = """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB/pDNwjNcznNaRlLNF5G9hCQNjbqNZ7QeKyLIy/nvHAAAAJi/vqmQv6pk +AAAAAtzc2gtZWQyNTUxOQAAACB/pDNwjNcznNaRlLNF5G9hCQNjbqNZ7QeKyLIy/nvHAA +AAAEBD0cWNQnpLDUYEGNMSgVIApVJfCFuRfGG3uxJZRKLvqH+kM3CM1zOc1pGUssXkb2E +JA2uuo1ntB4rIsjL+e8cAAAADm1lc3NlbmdlckBrZW50cm9zBAgMEBQ= +-----END OPENSSH PRIVATE KEY----- +""" + +VALID_SSH_KEY_2 = """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB/pDNwjNcznNaRlLNF5G9hCQNjbqNZ7QeKyLIy/nvHAAAAJi/vqmQv6pk +AAAAAtzc2gtZWQyNTUxOQAAACB/pDNwjNcznNaRlLNF5G9hCQNjbqNZ7QeKyLIy/nvHAA +AAAEBD0cWNQnpLDUYEGNMSgVIApVJfCFuRfGG3uxJZRKLvqH+kM3CM1zOc1pGUssXkb2E +JA2uuo1ntB4rIsjL+e8cAAAADm1lc3NlbmdlckBrZW50cm9zBAgMEBQ= +-----END OPENSSH PRIVATE KEY----- +""" + + +def test_create_ssh_key_success(db_session, test_env_vars): + """Test successful SSH key creation with encryption.""" + service = SshKeyService(db_session) + + key = service.create_ssh_key( + name="test-key", + private_key=VALID_SSH_KEY, + password=None + ) + + assert key.id is not None + assert key.name == "test-key" + assert key.fingerprint is not None + assert key.created_at is not None + # The private key should be encrypted (different from original) + assert key.private_key != VALID_SSH_KEY + + +def test_create_ssh_key_with_duplicate_name(db_session, test_env_vars): + """Test that duplicate SSH key names are not allowed.""" + service = SshKeyService(db_session) + + service.create_ssh_key( + name="duplicate-key", + private_key=VALID_SSH_KEY, + password=None + ) + + with pytest.raises(ValueError, match="already exists"): + service.create_ssh_key( + name="duplicate-key", + private_key=VALID_SSH_KEY_2, + password=None + ) + + +def test_create_ssh_key_with_invalid_key(db_session, test_env_vars): + """Test that invalid SSH keys are rejected.""" + service = SshKeyService(db_session) + + with pytest.raises(ValueError, match="Invalid SSH private key"): + service.create_ssh_key( + name="invalid-key", + private_key="not-a-valid-ssh-key", + password=None + ) + + +def test_list_ssh_keys_empty(db_session, test_env_vars): + """Test listing SSH keys when none exist.""" + service = SshKeyService(db_session) + + keys = service.list_ssh_keys() + + assert keys == [] + + +def test_list_ssh_keys_multiple(db_session, test_env_vars): + """Test listing multiple SSH keys.""" + service = SshKeyService(db_session) + + service.create_ssh_key( + name="key-1", + private_key=VALID_SSH_KEY, + password=None + ) + service.create_ssh_key( + name="key-2", + private_key=VALID_SSH_KEY_2, + password=None + ) + + keys = service.list_ssh_keys() + + assert len(keys) == 2 + assert any(k.name == "key-1" for k in keys) + assert any(k.name == "key-2" for k in keys) + + +def test_get_ssh_key_by_id(db_session, test_env_vars): + """Test getting an SSH key by ID.""" + service = SshKeyService(db_session) + + created_key = service.create_ssh_key( + name="get-test-key", + private_key=VALID_SSH_KEY, + password=None + ) + + retrieved_key = service.get_ssh_key(created_key.id) + + assert retrieved_key is not None + assert retrieved_key.id == created_key.id + assert retrieved_key.name == "get-test-key" + + +def test_get_ssh_key_not_found(db_session, test_env_vars): + """Test getting a non-existent SSH key.""" + service = SshKeyService(db_session) + + key = service.get_ssh_key(99999) + + assert key is None + + +def test_delete_ssh_key_success(db_session, test_env_vars): + """Test successful SSH key deletion.""" + service = SshKeyService(db_session) + + created_key = service.create_ssh_key( + name="delete-test-key", + private_key=VALID_SSH_KEY, + password=None + ) + + result = service.delete_ssh_key(created_key.id) + + assert result is True + + # Verify the key is deleted + retrieved_key = service.get_ssh_key(created_key.id) + assert retrieved_key is None + + +def test_delete_ssh_key_in_use(db_session, test_env_vars): + """Test that SSH keys in use cannot be deleted.""" + service = SshKeyService(db_session) + + created_key = service.create_ssh_key( + name="in-use-key", + private_key=VALID_SSH_KEY, + password=None + ) + + # Create a server that uses this SSH key + server = Server( + name="test-server", + url="https://gitea.example.com", + api_token="test-token", + ssh_key_id=created_key.id, + local_path="/tmp/test", + created_at=int(time.time()), + updated_at=int(time.time()) + ) + db_session.add(server) + db_session.commit() + + with pytest.raises(ValueError, match="is in use"): + service.delete_ssh_key(created_key.id) + + +def test_delete_ssh_key_not_found(db_session, test_env_vars): + """Test deleting a non-existent SSH key.""" + service = SshKeyService(db_session) + + result = service.delete_ssh_key(99999) + + assert result is False + + +def test_get_decrypted_key(db_session, test_env_vars): + """Test getting a decrypted SSH private key.""" + service = SshKeyService(db_session) + + created_key = service.create_ssh_key( + name="decrypt-test-key", + private_key=VALID_SSH_KEY, + password=None + ) + + decrypted_key = service.get_decrypted_key(created_key.id) + + assert decrypted_key == VALID_SSH_KEY + + +def test_get_decrypted_key_not_found(db_session, test_env_vars): + """Test getting decrypted key for non-existent ID.""" + service = SshKeyService(db_session) + + with pytest.raises(ValueError, match="SSH key with ID 99999 not found"): + service.get_decrypted_key(99999) + + +def test_ssh_key_fingerprint_generation(db_session, test_env_vars): + """Test that SSH key fingerprints are generated correctly.""" + service = SshKeyService(db_session) + + key = service.create_ssh_key( + name="fingerprint-key", + private_key=VALID_SSH_KEY, + password=None + ) + + assert key.fingerprint is not None + assert len(key.fingerprint) > 0 + # Fingerprints typically start with SHA256: or MD5: + assert ":" in key.fingerprint or len(key.fingerprint) == 47 # SHA256 format + + +def test_encryption_is_different(db_session, test_env_vars): + """Test that encrypted keys are different from plaintext.""" + service = SshKeyService(db_session) + + service.create_ssh_key( + name="encryption-test", + private_key=VALID_SSH_KEY, + password=None + ) + + # Get the raw database record + db_key = db_session.query(SshKey).filter_by(name="encryption-test").first() + + # The stored key should be encrypted + assert db_key.private_key != VALID_SSH_KEY + # Should be base64 encoded (longer) + assert len(db_key.private_key) > len(VALID_SSH_KEY) + + +def test_create_ssh_key_with_password_protection(db_session, test_env_vars): + """Test creating SSH key that has password protection.""" + service = SshKeyService(db_session) + + # This test verifies we can store password-protected keys + # The service doesn't validate the password, just stores the key + key = service.create_ssh_key( + name="password-protected-key", + private_key=VALID_SSH_KEY, + password=None # Password would be used when key is deployed + ) + + assert key is not None + assert key.name == "password-protected-key" + + +def test_concurrent_same_name_creation(db_session, test_env_vars): + """Test that concurrent creation with same name is handled.""" + service = SshKeyService(db_session) + + service.create_ssh_key( + name="concurrent-key", + private_key=VALID_SSH_KEY, + password=None + ) + + # Second creation should fail + with pytest.raises(ValueError): + service.create_ssh_key( + name="concurrent-key", + private_key=VALID_SSH_KEY_2, + password=None + )