split_dataset_test.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. def split_data_into_parts(total_data_count, num_parts=4, percentage=0.05):
  2. """
  3. Splits the total data into four parts, each containing a specified percentage of the total data.
  4. Each part will contain unique, non-overlapping elements.
  5. Args:
  6. total_data_count (int): The total number of data points.
  7. num_parts (int): The number of parts to divide the data into (default is 4).
  8. percentage (float): The percentage of data points each part should contain (default is 0.05).
  9. Returns:
  10. List[List[int]]: A list of lists, where each inner list contains the indices for one part.
  11. """
  12. # Calculate the number of elements in each part
  13. num_elements_per_part = int(total_data_count * percentage)
  14. # Ensure that we have enough data to split into the desired number of parts
  15. if num_elements_per_part * num_parts > total_data_count:
  16. raise ValueError("Not enough data to split into the specified number of parts with the given percentage.")
  17. # Generate a list of all indices
  18. all_indices = list(range(total_data_count))
  19. # Split the indices into non-overlapping parts
  20. parts = []
  21. for i in range(num_parts):
  22. start_idx = i * num_elements_per_part
  23. end_idx = start_idx + num_elements_per_part
  24. part_indices = all_indices[start_idx:end_idx]
  25. parts.append(part_indices)
  26. return parts
  27. def get_percentage_segment(index, total):
  28. # 计算每段的长度(5% 的数据)
  29. segment_size = max(1, int(total * 0.05))
  30. # 计算开始索引和结束索引
  31. start = index * segment_size
  32. end = start + segment_size
  33. # 确保结束索引不超过总数
  34. if end > total:
  35. end = total
  36. # 返回指定段的索引列表
  37. return list(range(start, end))
  38. def find_index_in_parts(parts, index):
  39. """
  40. Finds the part containing the given index.
  41. Args:
  42. parts (List[List[int]]): A list of parts, where each part is a list of indices.
  43. index (int): The index to search for.
  44. Returns:
  45. Tuple[bool, int]: A tuple containing a boolean indicating if the index is found,
  46. and the index of the part if found, otherwise -1.
  47. """
  48. for i, part in enumerate(parts):
  49. if index in part:
  50. return True, i
  51. return False, -1
  52. # Example usage
  53. total_data_count = 1000 # Example total number of data points
  54. parts = split_data_into_parts(total_data_count)
  55. # Check if index 123 is in any of the parts
  56. index_to_find = 123
  57. found, part_index = find_index_in_parts(parts, index_to_find)
  58. for part in parts:
  59. print(part)
  60. if found:
  61. print(f"Index {index_to_find} is in part {part_index + 1}")
  62. else:
  63. print(f"Index {index_to_find} is not in any of the parts")
  64. print(get_percentage_segment(1, 200))