개요
Hyperparameter 최적화는 AI기반 모델 성능을 최대로 이끌어 내는데 꼭 필요한 부분이다.
딥러닝 모델을 pytorch lightning 기반으로 만들면 분산 병렬 처리 (e.g., dpp strategy)등을 쉽게 처리할 수 있어서, 많이 쓰게된다.
따라서 pytorch lightning기반의 모델에 대해 hyperparamter를 찾고자할 때 쉽게 쓰기 위한 로직을 소개하고자 함.
설명
hyperopt를 사용하면 hyper-parameter 를 찾을때 TPE([NIPS'11] Algorithms for Hyper-Parameter Optimization, PDF) 와 같은 방법론을 손쉽게 쓸 수있다.
pytorch lightning 학습과 hyperopt를 결합하면 하이퍼 파라미터를 최적화하는데 유용하게 쓸수 있다.
TPE
정리 예정 🚀
예제 코드
hyper-parameter에 대한 space
와 목적함수 objective 를 정의해 찾으면 된다. 다음 예제코드를 돌리면 찾아짐.
python hyper_search.py --num_experiments=10
주의사항: pytorch lightning에서 ddp strategy로 학습시 objective 함수안에서 여러 프로세스로 fork되어 분산 병렬처리 되므로 (1)params
를 동일하게 만드는 작업 (2)semaphore 를 사용해 gpu 수 만큼의 프로세스들만 lock을 획득하여 처리하도록 하는부분이 필요. 예제에 약간 깔끔하게 구현한건지는 모르겠지만 동작하도록 구현해놓음.
import os
import time
import logging
import itertools
from functools import partial
import fire
import torch
import joblib
import torch.nn as nn
import lightning as pl
import torch.optim as optim
from multiprocessing import Semaphore
from lightning.pytorch.utilities import rank_zero_only
from torch.utils.data import DataLoader, TensorDataset
from hyperopt import hp, fmin, tpe, space_eval, Trials, STATUS_OK
logging.basicConfig(level=logging.INFO)
torch.set_float32_matmul_precision('high')
logger = logging.getLogger(__name__)
# Define your model, dataset, and training/validation functions here
class SimpleModel(pl.LightningModule):
def __init__(self, hidden_dim):
super(SimpleModel, self).__init__()
self.layer = nn.Linear(28 * 28, hidden_dim)
def forward(self, x):
return torch.relu(self.layer(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
@rank_zero_only
def logger_info(content):
logger.info(f"[RankZeroOnly] {content}")
@rank_zero_only
def remove_params_file_if_exists(params_file):
if os.path.exists(params_file):
os.remove(params_file)
@rank_zero_only
def create_params_file_if_not_exists(params, params_file):
if not os.path.exists(params_file):
joblib.dump(params, params_file)
def objective(params, semaphore, exp_counter):
exp_id = next(exp_counter)
time.sleep(exp_id + 1)
# Synchronize file-based param for all process(local_rank)
params_file = f'params_{exp_id}.lock'
create_params_file_if_not_exists(params, params_file)
# Crtical Section: Acquire/Release semaphore lock based on exp_id
with semaphore:
logger_info(f"Acquire semaphore[{exp_id}]")
# Load params from the file-based synchronized param
while(not os.path.exists(params_file)):
time.sleep(1)
params = joblib.load(params_file)
logger_info(f"Try {params} on exp_id={exp_id})")
# logger.info(f"Try {params} on exp_id={exp_id}, pid={os.getpid()}, ppid={os.getppid()}, params={params})")
# Training code here
model = SimpleModel(params['hidden_dim'])
trainer = pl.Trainer(
max_epochs=5,
strategy='ddp',
precision='16-mixed',
devices=torch.cuda.device_count()
)
# Assuming you have your dataset and dataloaders ready
# Replace with your actual data loading logic
dataset = TensorDataset(torch.randn(1000, 28 * 28), torch.randint(0, 10, (1000,)))
train_loader = DataLoader(dataset, batch_size=params['batch_size'])
trainer.fit(model, train_loader)
# Validation code here
val_loss = 0.0 # Replace with actual validation loss computation
logger_info(f"Release semaphore[{exp_id}]")
remove_params_file_if_exists(params_file)
return {'loss': val_loss, 'status': STATUS_OK}
def run(num_experiments=3):
# Define counter for experiments
exp_counter = itertools.count(start=0, step=1)
# Define number of devices (processes per exp_id)
num_devices_per_experiment = torch.cuda.device_count() # Number of GPUs
logger.info(f"num_experiments: {num_experiments}, num_devices_per_experiment: {num_devices_per_experiment}")
# Define your parameters space for hyperparameter optimization
space = {
'batch_size': hp.choice('batch_size', list(range(12, 64, 4))),
'hidden_dim': hp.choice('hidden_dim', list(range(10, 32, 2))),
# Add other hyperparameters here
}
# Create semaphore for accept num_devices_per_experiment processes
semaphore = Semaphore(num_devices_per_experiment)
# Start hyperparameter optimization
trials = Trials()
best_params = fmin(
fn=partial(objective, semaphore=semaphore, exp_counter=exp_counter),
space=space,
algo=tpe.suggest,
max_evals=num_experiments,
trials=trials,
)
logger_info(f"Best hyperparameter: {space_eval(space, best_params)}")
logger_info(f"Trials: {trials.results}")
if __name__ == "__main__":
fire.Fire(run)