updated style mimciking fine tuning
This commit is contained in:
@@ -910,28 +910,49 @@ class StylingDataPipeline:
|
||||
def load_and_preprocess(self, config: StylingConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
|
||||
"""Load and preprocess data"""
|
||||
|
||||
# Load data
|
||||
if config.data_source == "huggingface":
|
||||
raw_splits = self.hf_loader.load(config)
|
||||
processed_splits = self.hf_loader.preprocess(raw_splits, config)
|
||||
elif config.data_source == "custom":
|
||||
raw_splits = self.custom_loader.load(config)
|
||||
processed_splits = self.custom_loader.preprocess(raw_splits, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source: {config.data_source}")
|
||||
logger.info(f"Starting data loading and preprocessing...")
|
||||
logger.info(f"Data source: {config.data_source}")
|
||||
|
||||
# Validate processed data
|
||||
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
|
||||
if not is_valid:
|
||||
logger.error("Data validation failed:")
|
||||
for error in errors:
|
||||
logger.error(f" - {error}")
|
||||
raise ValueError("Data validation failed")
|
||||
|
||||
# Analyze dataset
|
||||
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
|
||||
|
||||
return processed_splits, analysis
|
||||
try:
|
||||
# Load data
|
||||
if config.data_source == "huggingface":
|
||||
logger.info("Loading HuggingFace dataset...")
|
||||
raw_splits = self.hf_loader.load(config)
|
||||
logger.info("Preprocessing HuggingFace dataset...")
|
||||
processed_splits = self.hf_loader.preprocess(raw_splits, config)
|
||||
elif config.data_source == "custom":
|
||||
logger.info("Loading custom dataset...")
|
||||
raw_splits = self.custom_loader.load(config)
|
||||
logger.info("Preprocessing custom dataset...")
|
||||
processed_splits = self.custom_loader.preprocess(raw_splits, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source: {config.data_source}")
|
||||
|
||||
logger.info(f"Data loading and preprocessing completed successfully")
|
||||
logger.info(f"Raw splits: {list(raw_splits.keys())}")
|
||||
logger.info(f"Processed splits: {list(processed_splits.keys())}")
|
||||
|
||||
# Validate processed data
|
||||
logger.info("Validating processed data...")
|
||||
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
|
||||
if not is_valid:
|
||||
logger.error("Data validation failed:")
|
||||
for error in errors:
|
||||
logger.error(f" - {error}")
|
||||
raise ValueError("Data validation failed")
|
||||
|
||||
logger.info("Data validation passed")
|
||||
|
||||
# Analyze dataset
|
||||
logger.info("Analyzing dataset...")
|
||||
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
|
||||
logger.info("Dataset analysis completed")
|
||||
|
||||
return processed_splits, analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in load_and_preprocess: {e}")
|
||||
raise
|
||||
|
||||
def convert_to_alpaca_format(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
|
||||
"""Convert styling data to Alpaca format with instruction"""
|
||||
@@ -1481,6 +1502,9 @@ def main():
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error running pipeline: {e}")
|
||||
import traceback
|
||||
print("Full error traceback:")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user