Tomato's BLOG

[딥러닝] 텐서플로우 데이터셋(tensorflow_datasets) 활용 본문

개발/머신러닝, 딥러닝

[딥러닝] 텐서플로우 데이터셋(tensorflow_datasets) 활용

토마토Tomato 2021. 2. 4. 19:53
import tensorflow_datasets as tfds

train_dataset = tfds.load('iris', split='train[:80%]')
valid_dataset = tfds.load('iris', split='train[80%:]')

 

  • 'iris': 불러오고자 하는 데이터 셋 이름
  • split='train': 'train' set을 가져온다
    • 만약 해당 데이터셋에 'test'와 'validation'이 있을 경우 불어올 수 있다.
    • 'train[:80%]'는 80%까지 가져오겠다는 말
    • 파이썬 문법처럼 음수도 사용가능. (e.g. 'train[-20%:]')

데이터 확인

train_dataset.take(5)

처음 데이터 5개를 리스트(?) 형식으로 반환해준다. 

 

X, y 분리 및 one hot encoding으로 변환 (preprocessing)

dataset 정보는 텐서플로우 데이터셋 공식문서에서 확인(label class 종류, dictionary label 값 등 확인할 수 있음)

def proprocess(data):
	x = data['features']
    y = data['label']
    y = tf.one_hot(y, 3)
    return x, y

tf.one_hot: label을 one-hot encoding으로 변환해줌. 3은 label의 개수(공식문서 num_classes에서 확인하면 됨)

Comments