first commit
This commit is contained in:
@@ -0,0 +1,265 @@
|
||||
const { ModelVersion, TrainingData } = require('../models');
|
||||
const logger = require('../utils/logger');
|
||||
|
||||
const getModelStatus = async (req, res) => {
|
||||
try {
|
||||
const activeModels = await ModelVersion.findAll({
|
||||
where: { is_active: true },
|
||||
order: [['created_at', 'DESC']]
|
||||
});
|
||||
|
||||
const modelStatus = {
|
||||
MODEL1: null,
|
||||
QUERYMODEL: null
|
||||
};
|
||||
|
||||
activeModels.forEach(model => {
|
||||
if (model.model_type === 'MODEL1') {
|
||||
modelStatus.MODEL1 = {
|
||||
id: model.id,
|
||||
version: model.version,
|
||||
deployment_status: model.deployment_status,
|
||||
performance_metrics: model.performance_metrics,
|
||||
last_updated: model.updated_at
|
||||
};
|
||||
} else if (model.model_type === 'QUERYMODEL') {
|
||||
modelStatus.QUERYMODEL = {
|
||||
id: model.id,
|
||||
version: model.version,
|
||||
deployment_status: model.deployment_status,
|
||||
performance_metrics: model.performance_metrics,
|
||||
last_updated: model.updated_at
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
data: { modelStatus }
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Get model status error:', error);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const getModelVersions = async (req, res) => {
|
||||
try {
|
||||
const { modelType, page = 1, limit = 10 } = req.query;
|
||||
|
||||
const whereClause = {};
|
||||
if (modelType) whereClause.model_type = modelType;
|
||||
|
||||
const models = await ModelVersion.findAndCountAll({
|
||||
where: whereClause,
|
||||
order: [['created_at', 'DESC']],
|
||||
limit: parseInt(limit),
|
||||
offset: (parseInt(page) - 1) * parseInt(limit)
|
||||
});
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
data: {
|
||||
models: models.rows,
|
||||
pagination: {
|
||||
page: parseInt(page),
|
||||
limit: parseInt(limit),
|
||||
total: models.count,
|
||||
pages: Math.ceil(models.count / parseInt(limit))
|
||||
}
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Get model versions error:', error);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const createModelVersion = async (req, res) => {
|
||||
try {
|
||||
const {
|
||||
modelName,
|
||||
modelType,
|
||||
baseModel,
|
||||
fineTuningMethod,
|
||||
hyperparameters = {}
|
||||
} = req.body;
|
||||
|
||||
if (!modelName || !modelType || !baseModel) {
|
||||
return res.status(400).json({
|
||||
success: false,
|
||||
error: 'Model name, type, and base model are required'
|
||||
});
|
||||
}
|
||||
|
||||
const validModelTypes = ['MODEL1', 'QUERYMODEL'];
|
||||
if (!validModelTypes.includes(modelType)) {
|
||||
return res.status(400).json({
|
||||
success: false,
|
||||
error: 'Invalid model type'
|
||||
});
|
||||
}
|
||||
|
||||
// Deactivate current active model of same type
|
||||
await ModelVersion.update(
|
||||
{ is_active: false },
|
||||
{ where: { model_type: modelType, is_active: true } }
|
||||
);
|
||||
|
||||
const model = await ModelVersion.create({
|
||||
model_name: modelName,
|
||||
version: `v${Date.now()}`,
|
||||
model_type: modelType,
|
||||
base_model: baseModel,
|
||||
fine_tuning_method: fineTuningMethod,
|
||||
hyperparameters,
|
||||
deployment_status: 'training'
|
||||
});
|
||||
|
||||
logger.info(`New model version created: ${model.id}`);
|
||||
|
||||
res.status(201).json({
|
||||
success: true,
|
||||
data: { model }
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Create model version error:', error);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const updateModelVersion = async (req, res) => {
|
||||
try {
|
||||
const { modelId } = req.params;
|
||||
const {
|
||||
deploymentStatus,
|
||||
performanceMetrics,
|
||||
modelPath,
|
||||
trainingLog
|
||||
} = req.body;
|
||||
|
||||
const model = await ModelVersion.findByPk(modelId);
|
||||
|
||||
if (!model) {
|
||||
return res.status(404).json({
|
||||
success: false,
|
||||
error: 'Model version not found'
|
||||
});
|
||||
}
|
||||
|
||||
const updateData = {};
|
||||
if (deploymentStatus) updateData.deployment_status = deploymentStatus;
|
||||
if (performanceMetrics) updateData.performance_metrics = performanceMetrics;
|
||||
if (modelPath) updateData.model_path = modelPath;
|
||||
if (trainingLog) updateData.training_log = trainingLog;
|
||||
|
||||
await model.update(updateData);
|
||||
|
||||
logger.info(`Model version updated: ${modelId}`);
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
data: { model }
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Update model version error:', error);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const activateModel = async (req, res) => {
|
||||
try {
|
||||
const { modelId } = req.params;
|
||||
|
||||
const model = await ModelVersion.findByPk(modelId);
|
||||
|
||||
if (!model) {
|
||||
return res.status(404).json({
|
||||
success: false,
|
||||
error: 'Model version not found'
|
||||
});
|
||||
}
|
||||
|
||||
// Deactivate other models of same type
|
||||
await ModelVersion.update(
|
||||
{ is_active: false },
|
||||
{ where: { model_type: model.model_type, is_active: true } }
|
||||
);
|
||||
|
||||
// Activate this model
|
||||
await model.update({
|
||||
is_active: true,
|
||||
deployment_status: 'deployed'
|
||||
});
|
||||
|
||||
logger.info(`Model activated: ${modelId}`);
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
data: { model }
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Activate model error:', error);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const getTrainingData = async (req, res) => {
|
||||
try {
|
||||
const { modelVersionId, dataType, page = 1, limit = 10 } = req.query;
|
||||
|
||||
const whereClause = {};
|
||||
if (modelVersionId) whereClause.model_version_id = modelVersionId;
|
||||
if (dataType) whereClause.data_type = dataType;
|
||||
|
||||
const trainingData = await TrainingData.findAndCountAll({
|
||||
where: whereClause,
|
||||
order: [['created_at', 'DESC']],
|
||||
limit: parseInt(limit),
|
||||
offset: (parseInt(page) - 1) * parseInt(limit)
|
||||
});
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
data: {
|
||||
trainingData: trainingData.rows,
|
||||
pagination: {
|
||||
page: parseInt(page),
|
||||
limit: parseInt(limit),
|
||||
total: trainingData.count,
|
||||
pages: Math.ceil(trainingData.count / parseInt(limit))
|
||||
}
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Get training data error:', error);
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: 'Internal server error'
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getModelStatus,
|
||||
getModelVersions,
|
||||
createModelVersion,
|
||||
updateModelVersion,
|
||||
activateModel,
|
||||
getTrainingData
|
||||
};
|
||||
Reference in New Issue
Block a user